diff --git a/core/src/main/java/org/apache/calcite/plan/hep/HepPlanner.java b/core/src/main/java/org/apache/calcite/plan/hep/HepPlanner.java index e9cb6da02adc..31d1ca0f86ad 100644 --- a/core/src/main/java/org/apache/calcite/plan/hep/HepPlanner.java +++ b/core/src/main/java/org/apache/calcite/plan/hep/HepPlanner.java @@ -51,7 +51,9 @@ import org.apache.calcite.util.graph.Graphs; import org.apache.calcite.util.graph.TopologicalOrderIterator; +import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Multimap; import org.checkerframework.checker.nullness.qual.Nullable; @@ -66,6 +68,7 @@ import java.util.Map; import java.util.Queue; import java.util.Set; +import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; @@ -114,6 +117,18 @@ public class HepPlanner extends AbstractRelOptPlanner { private final List materializations = new ArrayList<>(); + /** + * Cache for fired rules for by fixed operands and rule. + */ + private final Multimap, RelOptRule> fireRuleCache = HashMultimap.create(); + + /** + * Track the RelNode ID to multiple RelNode IDs in fireRuleCache for GC. + */ + private final Multimap> fireRuleCacheIndex = HashMultimap.create(); + + private boolean enableFireRuleCache = false; + //~ Constructors ----------------------------------------------------------- /** @@ -173,6 +188,8 @@ public HepPlanner( removeRule(rule); } this.materializations.clear(); + this.fireRuleCache.clear(); + this.fireRuleCacheIndex.clear(); } @Override public RelNode changeTraits(RelNode rel, RelTraitSet toTraits) { @@ -195,6 +212,17 @@ public HepPlanner( return buildFinalPlan(requireNonNull(root, "'root' must not be null")); } + /** + * Enables or disables the fire-rule cache. + * + *

If enabled, a rule will not fire twice on the same {@code RelNode::getId()}. + * + * @param enable true to enable; false is default value. + */ + public void setEnableFireRuleCache(boolean enable) { + enableFireRuleCache = enable; + } + /** Top-level entry point for a program. Initializes state and then invokes * the program. */ private void executeProgram(HepProgram program) { @@ -519,6 +547,14 @@ private Iterator getGraphIterator( nodeChildren, parents); + List relIds = null; + if (enableFireRuleCache) { + relIds = call.getRelList().stream().map(RelNode::getId).collect(Collectors.toList()); + if (fireRuleCache.get(relIds).contains(rule)) { + return null; + } + } + // Allow the rule to apply its own side-conditions. if (!rule.matches(call)) { return null; @@ -526,6 +562,13 @@ private Iterator getGraphIterator( fireRule(call); + if (relIds != null) { + fireRuleCache.put(relIds, rule); + for (Integer relId : relIds) { + fireRuleCacheIndex.put(relId, relIds); + } + } + if (!call.getResults().isEmpty()) { return applyTransformationResults( vertex, @@ -982,6 +1025,15 @@ private void collectGarbage() { // Clean up metadata cache too. sweepSet.forEach(this::clearCache); + + if (enableFireRuleCache) { + sweepSet.forEach(rel -> { + for (List relIds : fireRuleCacheIndex.get(rel.getCurrentRel().getId())) { + fireRuleCache.removeAll(relIds); + } + fireRuleCacheIndex.removeAll(rel.getCurrentRel().getId()); + }); + } } private void assertNoCycles() { diff --git a/core/src/test/java/org/apache/calcite/test/HepPlannerTest.java b/core/src/test/java/org/apache/calcite/test/HepPlannerTest.java index 262bba031392..50d412db3d0e 100644 --- a/core/src/test/java/org/apache/calcite/test/HepPlannerTest.java +++ b/core/src/test/java/org/apache/calcite/test/HepPlannerTest.java @@ -366,11 +366,29 @@ private void assertIncludesExactlyOnce(String message, String digest, } @Test void testRuleApplyCount() { - final long applyTimes1 = checkRuleApplyCount(HepMatchOrder.ARBITRARY); - assertThat(applyTimes1, is(316L)); + long applyTimes = checkRuleApplyCount(HepMatchOrder.ARBITRARY, false); + assertThat(applyTimes, is(316L)); - final long applyTimes2 = checkRuleApplyCount(HepMatchOrder.DEPTH_FIRST); - assertThat(applyTimes2, is(87L)); + applyTimes = checkRuleApplyCount(HepMatchOrder.DEPTH_FIRST, false); + assertThat(applyTimes, is(87L)); + + applyTimes = checkRuleApplyCount(HepMatchOrder.TOP_DOWN, false); + assertThat(applyTimes, is(295L)); + + applyTimes = checkRuleApplyCount(HepMatchOrder.BOTTOM_UP, false); + assertThat(applyTimes, is(296L)); + + applyTimes = checkRuleApplyCount(HepMatchOrder.ARBITRARY, true); + assertThat(applyTimes, is(65L)); + + applyTimes = checkRuleApplyCount(HepMatchOrder.DEPTH_FIRST, true); + assertThat(applyTimes, is(65L)); + + applyTimes = checkRuleApplyCount(HepMatchOrder.TOP_DOWN, true); + assertThat(applyTimes, is(65L)); + + applyTimes = checkRuleApplyCount(HepMatchOrder.BOTTOM_UP, true); + assertThat(applyTimes, is(65L)); } @Test void testMaterialization() { @@ -387,7 +405,7 @@ private void assertIncludesExactlyOnce(String message, String digest, assertThat(planner.getMaterializations(), empty()); } - private long checkRuleApplyCount(HepMatchOrder matchOrder) { + private long checkRuleApplyCount(HepMatchOrder matchOrder, boolean enableFireRuleCache) { final HepProgramBuilder programBuilder = HepProgram.builder(); programBuilder.addMatchOrder(matchOrder); programBuilder.addRuleInstance(CoreRules.FILTER_REDUCE_EXPRESSIONS); @@ -397,6 +415,7 @@ private long checkRuleApplyCount(HepMatchOrder matchOrder) { HepPlanner planner = new HepPlanner(programBuilder.build()); planner.addListener(listener); planner.setRoot(sql(COMPLEX_UNION_TREE).toRel()); + planner.setEnableFireRuleCache(enableFireRuleCache); planner.findBestExp(); return listener.getApplyTimes(); }