diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java index 8e0b89ceba3cbd..5539e894cc70a1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifier.java @@ -19,6 +19,8 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter; @@ -81,7 +83,7 @@ public class GraphSimplifier { */ public GraphSimplifier(HyperGraph graph) { this.graph = graph; - edgeSize = graph.getEdges().size(); + edgeSize = graph.getJoinEdges().size(); for (int i = 0; i < edgeSize; i++) { BestSimplification bestSimplification = new BestSimplification(); simplifications.add(bestSimplification); @@ -91,7 +93,7 @@ public GraphSimplifier(HyperGraph graph) { cacheStats.put(node.getNodeMap(), dPhyperNode.getGroup().getStatistics()); cacheCost.put(node.getNodeMap(), dPhyperNode.getRowCount()); } - validEdges = graph.getEdges().stream() + validEdges = graph.getJoinEdges().stream() .filter(e -> { for (Slot slot : e.getJoin().getConditionSlot()) { boolean contains = false; @@ -136,8 +138,8 @@ private void initFirstStep() { public boolean isTotalOrder() { for (int i = 0; i < edgeSize; i++) { for (int j = i + 1; j < edgeSize; j++) { - Edge edge1 = graph.getEdge(i); - Edge edge2 = graph.getEdge(j); + Edge edge1 = graph.getJoinEdge(i); + Edge edge2 = graph.getJoinEdge(j); List superset = new ArrayList<>(); tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getLeftExtendedNodes(), superset); tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes(), superset); @@ -342,8 +344,8 @@ private void updatePriorityQueue(int index) { } private Optional makeSimplificationStep(int edgeIndex1, int edgeIndex2) { - Edge edge1 = graph.getEdge(edgeIndex1); - Edge edge2 = graph.getEdge(edgeIndex2); + JoinEdge edge1 = graph.getJoinEdge(edgeIndex1); + JoinEdge edge2 = graph.getJoinEdge(edgeIndex2); if (edge1.isSub(edge2) || edge2.isSub(edge1) || circleDetector.checkCircleWithEdge(edgeIndex1, edgeIndex2) || circleDetector.checkCircleWithEdge(edgeIndex2, edgeIndex1) @@ -358,8 +360,8 @@ private Optional makeSimplificationStep(int edgeIndex1, int || !cacheStats.containsKey(left2) || !cacheStats.containsKey(right2)) { return Optional.empty(); } - Edge edge1Before2; - Edge edge2Before1; + JoinEdge edge1Before2; + JoinEdge edge2Before1; List superBitset = new ArrayList<>(); if (tryGetSuperset(left1, left2, superBitset)) { // (common Join1 right1) Join2 right2 @@ -394,36 +396,34 @@ private Optional makeSimplificationStep(int edgeIndex1, int return Optional.of(simplificationStep); } - private Edge constructEdge(long leftNodes, Edge edge, long rightNodes) { + private JoinEdge constructEdge(long leftNodes, JoinEdge edge, long rightNodes) { LogicalJoin join; - if (graph.getEdges().size() > 64 * 63 / 8) { + if (graph.getJoinEdges().size() > 64 * 63 / 8) { // If there are too many edges, it is advisable to return the "edge" directly // to avoid lengthy enumeration time. join = edge.getJoin(); } else { BitSet validEdgesMap = graph.getEdgesInOperator(leftNodes, rightNodes); List hashConditions = validEdgesMap.stream() - .mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts()) + .mapToObj(i -> graph.getJoinEdge(i).getJoin().getHashJoinConjuncts()) .flatMap(Collection::stream) .collect(Collectors.toList()); List otherConditions = validEdgesMap.stream() - .mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts()) + .mapToObj(i -> graph.getJoinEdge(i).getJoin().getHashJoinConjuncts()) .flatMap(Collection::stream) .collect(Collectors.toList()); join = edge.getJoin().withJoinConjuncts(hashConditions, otherConditions); } - Edge newEdge = new Edge( - join, - edge.getIndex(), edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes()); - newEdge.setLeftRequiredNodes(edge.getLeftRequiredNodes()); - newEdge.setRightRequiredNodes(edge.getRightRequiredNodes()); - newEdge.addLeftNode(leftNodes); - newEdge.addRightNode(rightNodes); + JoinEdge newEdge = new JoinEdge(join, edge.getIndex(), + edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes(), + edge.getLeftRequiredNodes(), edge.getRightRequiredNodes()); + newEdge.addLeftExtendNode(leftNodes); + newEdge.addRightExtendNode(rightNodes); return newEdge; } - private void deriveStats(Edge edge, long leftBitmap, long rightBitmap) { + private void deriveStats(JoinEdge edge, long leftBitmap, long rightBitmap) { // The bitmap may differ from the edge's reference slots. // Taking into account the order: edge1<{1} - {2}> edge2<{1,3} - {4}>. // Actually, we are considering the sequence {1,3} - {2} - {4} @@ -438,7 +438,7 @@ private void deriveStats(Edge edge, long leftBitmap, long rightBitmap) { cacheStats.put(bitmap, joinStats); } - private double calCost(Edge edge, long leftBitmap, long rightBitmap) { + private double calCost(JoinEdge edge, long leftBitmap, long rightBitmap) { long bitmap = LongBitmap.newBitmapUnion(leftBitmap, rightBitmap); Preconditions.checkArgument(cacheStats.containsKey(leftBitmap) && cacheStats.containsKey(rightBitmap) && cacheStats.containsKey(bitmap), @@ -461,7 +461,7 @@ private double calCost(Edge edge, long leftBitmap, long rightBitmap) { return cost; } - private @Nullable Edge threeLeftJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) { + private @Nullable JoinEdge threeLeftJoin(long bitmap1, JoinEdge edge1, long bitmap2, JoinEdge edge2, long bitmap3) { // (plan1 edge1 plan2) edge2 plan3 // if the left and right is overlapping, just return null. Preconditions.checkArgument( @@ -471,7 +471,7 @@ private double calCost(Edge edge, long leftBitmap, long rightBitmap) { if (LongBitmap.isOverlap(newLeft, bitmap3)) { return null; } - Edge newEdge = constructEdge(newLeft, edge2, bitmap3); + JoinEdge newEdge = constructEdge(newLeft, edge2, bitmap3); deriveStats(edge1, bitmap1, bitmap2); deriveStats(newEdge, newLeft, bitmap3); @@ -481,15 +481,16 @@ private double calCost(Edge edge, long leftBitmap, long rightBitmap) { return newEdge; } - private @Nullable Edge threeRightJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) { - Preconditions.checkArgument( - cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3)); + private @Nullable JoinEdge threeRightJoin(long bitmap1, JoinEdge edge1, long bitmap2, + JoinEdge edge2, long bitmap3) { + Preconditions.checkArgument(cacheStats.containsKey(bitmap1) + && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3)); // plan1 edge1 (plan2 edge2 plan3) long newRight = LongBitmap.newBitmapUnion(bitmap2, bitmap3); if (LongBitmap.isOverlap(bitmap1, newRight)) { return null; } - Edge newEdge = constructEdge(bitmap1, edge1, newRight); + JoinEdge newEdge = constructEdge(bitmap1, edge1, newRight); deriveStats(edge2, bitmap2, bitmap3); deriveStats(newEdge, bitmap1, newRight); @@ -498,8 +499,8 @@ private double calCost(Edge edge, long leftBitmap, long rightBitmap) { return newEdge; } - private SimplificationStep orderJoin(Edge edge1Before2, - Edge edge2Before1, int edgeIndex1, int edgeIndex2) { + private SimplificationStep orderJoin(JoinEdge edge1Before2, + JoinEdge edge2Before1, int edgeIndex1, int edgeIndex2) { double cost1Before2 = calCost(edge1Before2, edge1Before2.getLeftExtendedNodes(), edge1Before2.getRightExtendedNodes()); double cost2Before1 = calCost(edge2Before1, @@ -515,16 +516,16 @@ private SimplificationStep orderJoin(Edge edge1Before2, step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2, edge1Before2.getLeftExtendedNodes(), edge1Before2.getRightExtendedNodes(), - graph.getEdge(edgeIndex2).getLeftExtendedNodes(), - graph.getEdge(edgeIndex2).getRightExtendedNodes()); + graph.getJoinEdge(edgeIndex2).getLeftExtendedNodes(), + graph.getJoinEdge(edgeIndex2).getRightExtendedNodes()); } else { if (cost2Before1 != 0) { benefit = cost1Before2 / cost2Before1; } // choose edge2Before1 step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.getLeftExtendedNodes(), - edge2Before1.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(), - graph.getEdge(edgeIndex1).getRightExtendedNodes()); + edge2Before1.getRightExtendedNodes(), graph.getJoinEdge(edgeIndex1).getLeftExtendedNodes(), + graph.getJoinEdge(edgeIndex1).getRightExtendedNodes()); } return step; } @@ -545,9 +546,9 @@ private boolean tryGetSuperset(long bitmap1, long bitmap2, List superset) */ private void extractJoinDependencies() { for (int i = 0; i < edgeSize; i++) { - Edge edge1 = graph.getEdge(i); + Edge edge1 = graph.getJoinEdge(i); for (int j = i + 1; j < edgeSize; j++) { - Edge edge2 = graph.getEdge(j); + Edge edge2 = graph.getJoinEdge(j); if (edge1.isSub(edge2)) { Preconditions.checkArgument(circleDetector.tryAddDirectedEdge(i, j), "Edge %s violates Edge %s", edge1, edge2); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index e8e5acd2cddeca..17c77c855ef383 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -19,11 +19,15 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.FilterEdge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -32,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.PlanUtils; @@ -53,7 +58,8 @@ * It's used for join ordering */ public class HyperGraph { - private final List edges = new ArrayList<>(); + private final List joinEdges = new ArrayList<>(); + private final List filterEdges = new ArrayList<>(); private final List nodes = new ArrayList<>(); private final HashMap slotToNodeMap = new HashMap<>(); // record all edges that can be placed on the subgraph @@ -69,8 +75,8 @@ public class HyperGraph { this.finalOutputs = ImmutableSet.copyOf(finalOutputs); } - public List getEdges() { - return edges; + public List getJoinEdges() { + return joinEdges; } public List getNodes() { @@ -81,8 +87,12 @@ public long getNodesMap() { return LongBitmap.newBitmapBetween(0, nodes.size()); } - public Edge getEdge(int index) { - return edges.get(index); + public JoinEdge getJoinEdge(int index) { + return joinEdges.get(index); + } + + public FilterEdge getFilterEdge(int index) { + return filterEdges.get(index); } public AbstractNode getNode(int index) { @@ -176,17 +186,14 @@ public HashMap> getComplexProject() { return complexProject; } - private void addEdgeOfInfo(Edge edge) { + private void addEdgeOfInfo(JoinEdge edge) { long nodeMap = calNodeMap(edge.getInputSlots()); Preconditions.checkArgument(LongBitmap.getCardinality(nodeMap) > 1, "edge must have more than one ends"); - this.edges.add(new Edge(edge.getJoin(), edges.size(), null, null, null)); long left = LongBitmap.newBitmap(LongBitmap.nextSetBit(nodeMap, 0)); long right = LongBitmap.newBitmapDiff(nodeMap, left); - edge.setLeftRequiredNodes(left); - edge.setLeftExtendedNodes(left); - edge.setRightRequiredNodes(right); - edge.setRightExtendedNodes(right); + this.joinEdges.add(new JoinEdge(edge.getJoin(), joinEdges.size(), + null, null, 0, left, right)); } /** @@ -194,7 +201,8 @@ private void addEdgeOfInfo(Edge edge) { * * @param join The join plan */ - public BitSet addEdge(LogicalJoin join, Pair leftEdgeNodes, Pair rightEdgeNodes) { + private BitSet addJoin(LogicalJoin join, + Pair leftEdgeNodes, Pair rightEdgeNodes) { HashMap, Pair, List>> conjuncts = new HashMap<>(); for (Expression expression : join.getHashJoinConjuncts()) { // TODO: avoid calling calculateEnds if calNodeMap's results are same @@ -217,59 +225,77 @@ public BitSet addEdge(LogicalJoin join, Pair leftEdgeNodes, BitSet curJoinEdges = new BitSet(); for (Map.Entry, Pair, List>> entry : conjuncts .entrySet()) { - LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first, + LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first, entry.getValue().second, JoinHint.NONE, join.getMarkJoinSlotReference(), Lists.newArrayList(join.left(), join.right())); - Edge edge = new Edge(singleJoin, edges.size(), leftEdgeNodes.first, rightEdgeNodes.first, - LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second)); Pair ends = entry.getKey(); - edge.setLeftRequiredNodes(ends.first); - edge.setLeftExtendedNodes(ends.first); - edge.setRightRequiredNodes(ends.second); - edge.setRightExtendedNodes(ends.second); + JoinEdge edge = new JoinEdge(singleJoin, joinEdges.size(), leftEdgeNodes.first, rightEdgeNodes.first, + LongBitmap.newBitmapUnion(leftEdgeNodes.second, rightEdgeNodes.second), ends.first, ends.second); for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) { nodes.get(nodeIndex).attachEdge(edge); } curJoinEdges.set(edge.getIndex()); - edges.add(edge); + joinEdges.add(edge); } - curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges)); - curJoinEdges.stream().forEach(i -> edges.get(i).addCurJoinEdges(curJoinEdges)); - curJoinEdges.stream().forEach(i -> makeConflictRules(edges.get(i))); + curJoinEdges.stream().forEach(i -> joinEdges.get(i).addCurJoinEdges(curJoinEdges)); + curJoinEdges.stream().forEach(i -> makeJoinConflictRules(joinEdges.get(i))); + curJoinEdges.stream().forEach(i -> makeFilterConflictRules(joinEdges.get(i))); return curJoinEdges; // In MySQL, each edge is reversed and store in edges again for reducing the branch miss // We don't implement this trick now. } + private BitSet addFilter(LogicalFilter filter, Pair childEdgeNodes) { + FilterEdge edge = new FilterEdge(filter, filterEdges.size(), childEdgeNodes.first, childEdgeNodes.second, + childEdgeNodes.second); + filterEdges.add(edge); + BitSet bitSet = new BitSet(); + bitSet.set(edge.getIndex()); + return bitSet; + } + + private void makeFilterConflictRules(JoinEdge joinEdge) { + long leftSubNodes = joinEdge.getLeftSubNodes(joinEdges); + long rightSubNodes = joinEdge.getRightSubNodes(joinEdges); + filterEdges.forEach(e -> { + if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes) + && !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType())) { + e.addRejectJoin(joinEdge); + } + if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes) + && !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType())) { + e.addRejectJoin(joinEdge); + } + }); + } + // Make edge with CD-C algorithm in // On the correct and complete enumeration of the core search - private void makeConflictRules(Edge edgeB) { + private void makeJoinConflictRules(JoinEdge edgeB) { BitSet leftSubTreeEdges = subTreeEdges(edgeB.getLeftChildEdges()); BitSet rightSubTreeEdges = subTreeEdges(edgeB.getRightChildEdges()); long leftRequired = edgeB.getLeftRequiredNodes(); long rightRequired = edgeB.getRightRequiredNodes(); for (int i = leftSubTreeEdges.nextSetBit(0); i >= 0; i = leftSubTreeEdges.nextSetBit(i + 1)) { - Edge childA = edges.get(i); + JoinEdge childA = joinEdges.get(i); if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) { - leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(edges)); + leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(joinEdges)); } if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) { - leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(edges)); + leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(joinEdges)); } } for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) { - Edge childA = edges.get(i); + JoinEdge childA = joinEdges.get(i); if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) { - rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(edges)); + rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(joinEdges)); } if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) { - rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(edges)); + rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(joinEdges)); } } - edgeB.setLeftRequiredNodes(leftRequired); - edgeB.setRightRequiredNodes(rightRequired); edgeB.setLeftExtendedNodes(leftRequired); edgeB.setRightExtendedNodes(rightRequired); } @@ -277,7 +303,7 @@ private void makeConflictRules(Edge edgeB) { private BitSet subTreeEdge(Edge edge) { long subTreeNodes = edge.getSubTreeNodes(); BitSet subEdges = new BitSet(); - edges.stream() + joinEdges.stream() .filter(e -> LongBitmap.isSubset(subTreeNodes, e.getReferenceNodes())) .forEach(e -> subEdges.set(e.getIndex())); return subEdges; @@ -286,7 +312,7 @@ private BitSet subTreeEdge(Edge edge) { private BitSet subTreeEdges(BitSet edgeSet) { BitSet bitSet = new BitSet(); edgeSet.stream() - .mapToObj(i -> subTreeEdge(edges.get(i))) + .mapToObj(i -> subTreeEdge(joinEdges.get(i))) .forEach(bitSet::or); return bitSet; } @@ -301,15 +327,19 @@ private Pair calculateEnds(long allNodes, Pair leftEdg if (left == 0) { Preconditions.checkArgument(leftEdgeNodes.first.cardinality() > 0, "the number of the table which expression reference is less 2"); - Pair llEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges); - Pair lrEdgesNodes = edges.get(leftEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges); + Pair llEdgesNodes = joinEdges.get(leftEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes( + joinEdges); + Pair lrEdgesNodes = joinEdges.get(leftEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes( + joinEdges); return calculateEnds(allNodes, llEdgesNodes, lrEdgesNodes); } if (right == 0) { Preconditions.checkArgument(rightEdgeNodes.first.cardinality() > 0, "the number of the table which expression reference is less 2"); - Pair rlEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes(edges); - Pair rrEdgesNodes = edges.get(rightEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes(edges); + Pair rlEdgesNodes = joinEdges.get(rightEdgeNodes.first.nextSetBit(0)).getLeftEdgeNodes( + joinEdges); + Pair rrEdgesNodes = joinEdges.get(rightEdgeNodes.first.nextSetBit(0)).getRightEdgeNodes( + joinEdges); return calculateEnds(allNodes, rlEdgesNodes, rrEdgesNodes); } return Pair.of(left, right); @@ -329,7 +359,7 @@ public BitSet getEdgesInOperator(long left, long right) { public BitSet getEdgesInTree(long treeNodesMap) { if (!treeEdgesCache.containsKey(treeNodesMap)) { BitSet edgesMap = new BitSet(); - for (Edge edge : edges) { + for (Edge edge : joinEdges) { if (LongBitmap.isSubset(edge.getReferenceNodes(), treeNodesMap)) { edgesMap.set(edge.getIndex()); } @@ -364,7 +394,7 @@ private List flatChildren() { for (AbstractNode node : nodes) { res = flatChild((StructInfoNode) node, res); } - for (Edge edge : edges) { + for (JoinEdge edge : joinEdges) { res.forEach(g -> g.addEdgeOfInfo(edge)); } return res; @@ -376,12 +406,12 @@ private List flatChild(StructInfoNode infoNode, List hyp return hyperGraphs; } return hyperGraphs.stream().flatMap(g -> - infoNode.getGraphs().stream().map(subGraph -> { - HyperGraph hyperGraph = new HyperGraph(g.finalOutputs); - hyperGraph.addStructInfo(g); - hyperGraph.addStructInfo(subGraph); - return hyperGraph; - }) + infoNode.getGraphs().stream().map(subGraph -> { + HyperGraph hyperGraph = new HyperGraph(g.finalOutputs); + hyperGraph.addStructInfo(g); + hyperGraph.addStructInfo(subGraph); + return hyperGraph; + }) ).collect(Collectors.toList()); } @@ -410,7 +440,7 @@ private Pair buildDPhyperGraph(GroupExpression groupExpression) { LogicalJoin join = (LogicalJoin) groupExpression.getPlan(); Pair left = this.buildDPhyperGraph(groupExpression.child(0).getLogicalExpressions().get(0)); Pair right = this.buildDPhyperGraph(groupExpression.child(1).getLogicalExpressions().get(0)); - return Pair.of(this.addEdge(join, left, right), + return Pair.of(this.addJoin(join, left, right), LongBitmap.or(left.second, right.second)); } @@ -424,10 +454,10 @@ private void addStructInfo(HyperGraph other) { other.getNodes().forEach(n -> this.addStructInfoNode(n.getPlan())); other.getComplexProject().forEach((t, projectList) -> projectList.forEach(e -> this.addAlias((Alias) e, t << offset))); - other.getEdges().forEach(this::addEdgeOfInfo); + other.getJoinEdges().forEach(this::addEdgeOfInfo); } - // Build Graph for matching mv + // Build Graph for matching mv, return join edge set and nodes in this plan private Pair buildStructInfo(Plan plan) { if (plan instanceof GroupPlan) { Group group = ((GroupPlan) plan).getGroup(); @@ -454,14 +484,21 @@ private Pair buildStructInfo(Plan plan) { } // process Join - if (isValidJoin(plan)) { + if (isValidJoinForStructInfo(plan)) { LogicalJoin join = (LogicalJoin) plan; Pair left = this.buildStructInfo(plan.child(0)); Pair right = this.buildStructInfo(plan.child(1)); - return Pair.of(this.addEdge(join, left, right), + return Pair.of(this.addJoin(join, left, right), LongBitmap.or(left.second, right.second)); } + if (isValidFilter(plan)) { + LogicalFilter filter = (LogicalFilter) plan; + Pair child = this.buildStructInfo(filter.child()); + this.addFilter(filter, child); + return Pair.of(new BitSet(), child.second); + } + // process Other Node int idx = this.addStructInfoNode(plan); return Pair.of(new BitSet(), LongBitmap.newBitmap(idx)); @@ -480,6 +517,23 @@ public static boolean isValidJoin(Plan plan) { && !join.getExpressions().isEmpty(); } + /** + * inner join group without mark slot + */ + public static boolean isValidJoinForStructInfo(Plan plan) { + if (!(plan instanceof LogicalJoin)) { + return false; + } + + LogicalJoin join = (LogicalJoin) plan; + return !join.isMarkJoin() + && !join.getExpressions().isEmpty(); + } + + public static boolean isValidFilter(Plan plan) { + return plan instanceof LogicalFilter; + } + /** * the project with alias and slot */ @@ -502,14 +556,14 @@ public void modifyEdge(int edgeIndex, long newLeft, long newRight) { // When modify an edge in hyper graph, we need to update the left and right nodes // For these nodes that are only in the old edge, we need remove the edge from them // For these nodes that are only in the new edge, we need to add the edge to them - Edge edge = edges.get(edgeIndex); + Edge edge = joinEdges.get(edgeIndex); if (treeEdgesCache.containsKey(edge.getReferenceNodes())) { treeEdgesCache.get(edge.getReferenceNodes()).set(edgeIndex, false); } updateEdges(edge, edge.getLeftExtendedNodes(), newLeft); updateEdges(edge, edge.getRightExtendedNodes(), newRight); - edges.get(edgeIndex).setLeftExtendedNodes(newLeft); - edges.get(edgeIndex).setRightExtendedNodes(newRight); + joinEdges.get(edgeIndex).setLeftExtendedNodes(newLeft); + joinEdges.get(edgeIndex).setRightExtendedNodes(newRight); if (treeEdgesCache.containsKey(edge.getReferenceNodes())) { treeEdgesCache.get(edge.getReferenceNodes()).set(edgeIndex, true); } @@ -534,7 +588,7 @@ private void updateEdges(Edge edge, long oldNodes, long newNodes) { */ public String toDottyHyperGraph() { StringBuilder builder = new StringBuilder(); - builder.append(String.format("digraph G { # %d edges\n", edges.size())); + builder.append(String.format("digraph G { # %d edges\n", joinEdges.size())); List graphvisNodes = new ArrayList<>(); for (AbstractNode node : nodes) { String nodeName = node.getName(); @@ -550,11 +604,11 @@ public String toDottyHyperGraph() { nodeID, nodeName, rowCount)); graphvisNodes.add(nodeName); } - for (int i = 0; i < edges.size(); i += 1) { - Edge edge = edges.get(i); + for (int i = 0; i < joinEdges.size(); i += 1) { + JoinEdge edge = joinEdges.get(i); // TODO: add cardinality to label String label = String.format("%.2f", edge.getSelectivity()); - if (edges.get(i).isSimple()) { + if (joinEdges.get(i).isSimple()) { String arrowHead = ""; if (edge.getJoin().getJoinType() == JoinType.INNER_JOIN) { arrowHead = ",arrowhead=none"; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java index 55c346ce1f9b53..98c5542d10e3fe 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumerator.java @@ -19,6 +19,8 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmapSubsetIterator; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver; @@ -79,7 +81,7 @@ public boolean enumerate() { int size = nodes.size(); // Init edgeCalculator - edgeCalculator = new EdgeCalculator(hyperGraph.getEdges()); + edgeCalculator = new EdgeCalculator(hyperGraph.getJoinEdges()); for (AbstractNode node : nodes) { edgeCalculator.initSubgraph(node.getNodeMap()); } @@ -149,7 +151,7 @@ private boolean enumerateCmpRec(long csg, long cmp, long forbiddenNodes) { edgeCalculator.unionEdges(cmp, subset); if (receiver.contain(newCmp)) { // We check all edges for finding an edge. - List edges = edgeCalculator.connectCsgCmp(csg, newCmp); + List edges = edgeCalculator.connectCsgCmp(csg, newCmp); if (edges.isEmpty()) { continue; } @@ -185,7 +187,7 @@ private boolean emitCsg(long csg) { for (int nodeIndex : LongBitmap.getReverseIterator(neighborhoods)) { long cmp = LongBitmap.newBitmap(nodeIndex); // whether there is an edge between csg and cmp - List edges = edgeCalculator.connectCsgCmp(csg, cmp); + List edges = edgeCalculator.connectCsgCmp(csg, cmp); if (!edges.isEmpty()) { if (!receiver.emitCsgCmp(csg, cmp, edges)) { return false; @@ -241,7 +243,7 @@ public long calcNeighborhood(long subgraph, long forbiddenNodes, EdgeCalculator } static class EdgeCalculator { - final List edges; + final List edges; // It cached all edges that contained by this subgraph, Note we always // use bitset store edge map because the number of edges can be very large // We split these into simple edges (only one node on each side) and complex edges (others) @@ -254,7 +256,7 @@ static class EdgeCalculator { // complex edges HashMap overlapEdges = new HashMap<>(); - EdgeCalculator(List edges) { + EdgeCalculator(List edges) { this.edges = edges; } @@ -326,10 +328,10 @@ public void unionEdges(long bitmap1, long bitmap2) { overlapEdges.put(subgraph, overlaps); } - public List connectCsgCmp(long csg, long cmp) { + public List connectCsgCmp(long csg, long cmp) { Preconditions.checkArgument( containSimpleEdges.containsKey(csg) && containSimpleEdges.containsKey(cmp)); - List foundEdges = new ArrayList<>(); + List foundEdges = new ArrayList<>(); BitSet edgeMap = new BitSet(); edgeMap.or(containSimpleEdges.get(csg)); edgeMap.and(containSimpleEdges.get(cmp)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java similarity index 52% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java index 4c7ce312ba1414..6f0920e1f5420a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/Edge.java @@ -15,98 +15,68 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.jobs.joinorder.hypergraph; +package org.apache.doris.nereids.jobs.joinorder.hypergraph.edge; import org.apache.doris.common.Pair; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; - -import com.google.common.base.Preconditions; import java.util.BitSet; -import java.util.HashSet; import java.util.List; import java.util.Set; -import javax.annotation.Nullable; /** * Edge in HyperGraph */ -public class Edge { - final int index; - final LogicalJoin join; - final double selectivity; +public abstract class Edge { + private final int index; + private final double selectivity; // "RequiredNodes" refers to the nodes that can activate this edge based on // specific requirements. These requirements are established during the building process. // "ExtendNodes" encompasses both the "RequiredNodes" and any additional nodes // added by the graph simplifier. - private long leftRequiredNodes = LongBitmap.newBitmap(); - private long rightRequiredNodes = LongBitmap.newBitmap(); - private long leftExtendedNodes = LongBitmap.newBitmap(); - private long rightExtendedNodes = LongBitmap.newBitmap(); - - private long referenceNodes = LongBitmap.newBitmap(); + private final long leftRequiredNodes; + private final long rightRequiredNodes; + private long leftExtendedNodes; + private long rightExtendedNodes; // record the left child edges and right child edges in origin plan tree private final BitSet leftChildEdges; private final BitSet rightChildEdges; // record the edges in the same operator - private final BitSet curJoinEdges = new BitSet(); + private final BitSet curOperatorEdges = new BitSet(); // record all sub nodes behind in this operator. It's T function in paper private final long subTreeNodes; /** * Create simple edge. */ - public Edge(LogicalJoin join, int index, BitSet leftChildEdges, BitSet rightChildEdges, Long subTreeNodes) { + Edge(int index, BitSet leftChildEdges, BitSet rightChildEdges, + long subTreeNodes, long leftRequiredNodes, Long rightRequiredNodes) { this.index = index; - this.join = join; this.selectivity = 1.0; this.leftChildEdges = leftChildEdges; this.rightChildEdges = rightChildEdges; + this.leftRequiredNodes = leftRequiredNodes; + this.rightRequiredNodes = rightRequiredNodes; + this.leftExtendedNodes = leftRequiredNodes; + this.rightExtendedNodes = rightRequiredNodes; this.subTreeNodes = subTreeNodes; } - public LogicalJoin getJoin() { - return join; - } - - public JoinType getJoinType() { - return join.getJoinType(); - } - public boolean isSimple() { return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1; } - public void addLeftNode(long left) { + public void addLeftExtendNode(long left) { this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, left); - referenceNodes = LongBitmap.or(referenceNodes, left); } - public void addLeftNodes(long... bitmaps) { - for (long bitmap : bitmaps) { - this.leftExtendedNodes = LongBitmap.or(this.leftExtendedNodes, bitmap); - referenceNodes = LongBitmap.or(referenceNodes, bitmap); - } - } - - public void addRightNode(long right) { + public void addRightExtendNode(long right) { this.rightExtendedNodes = LongBitmap.or(this.rightExtendedNodes, right); - referenceNodes = LongBitmap.or(referenceNodes, right); - } - - public void addRightNodes(long... bitmaps) { - for (long bitmap : bitmaps) { - LongBitmap.or(this.rightExtendedNodes, bitmap); - LongBitmap.or(referenceNodes, bitmap); - } } public long getSubTreeNodes() { @@ -121,22 +91,22 @@ public BitSet getLeftChildEdges() { return leftChildEdges; } - public Pair getLeftEdgeNodes(List edges) { + public Pair getLeftEdgeNodes(List edges) { return Pair.of(leftChildEdges, getLeftSubNodes(edges)); } - public Pair getRightEdgeNodes(List edges) { + public Pair getRightEdgeNodes(List edges) { return Pair.of(rightChildEdges, getRightSubNodes(edges)); } - public long getLeftSubNodes(List edges) { + public long getLeftSubNodes(List edges) { if (leftChildEdges.isEmpty()) { return leftRequiredNodes; } return edges.get(leftChildEdges.nextSetBit(0)).getSubTreeNodes(); } - public long getRightSubNodes(List edges) { + public long getRightSubNodes(List edges) { if (rightChildEdges.isEmpty()) { return rightRequiredNodes; } @@ -144,7 +114,6 @@ public long getRightSubNodes(List edges) { } public void setLeftExtendedNodes(long leftExtendedNodes) { - referenceNodes = LongBitmap.clear(referenceNodes); this.leftExtendedNodes = leftExtendedNodes; } @@ -157,7 +126,6 @@ public BitSet getRightChildEdges() { } public void setRightExtendedNodes(long rightExtendedNodes) { - referenceNodes = LongBitmap.clear(referenceNodes); this.rightExtendedNodes = rightExtendedNodes; } @@ -165,24 +133,16 @@ public long getLeftRequiredNodes() { return leftRequiredNodes; } - public void setLeftRequiredNodes(long left) { - this.leftRequiredNodes = left; - } - public long getRightRequiredNodes() { return rightRequiredNodes; } - public void setRightRequiredNodes(long right) { - this.rightRequiredNodes = right; - } - public void addCurJoinEdges(BitSet edges) { - curJoinEdges.or(edges); + curOperatorEdges.or(edges); } - public BitSet getCurJoinEdges() { - return curJoinEdges; + public BitSet getCurOperatorEdges() { + return curOperatorEdges; } public boolean isSub(Edge edge) { @@ -192,10 +152,7 @@ public boolean isSub(Edge edge) { } public long getReferenceNodes() { - if (LongBitmap.getCardinality(referenceNodes) == 0) { - referenceNodes = LongBitmap.newBitmapUnion(leftExtendedNodes, rightExtendedNodes); - } - return referenceNodes; + return LongBitmap.newBitmapUnion(leftExtendedNodes, rightExtendedNodes); } public long getRequireNodes() { @@ -210,51 +167,14 @@ public double getSelectivity() { return selectivity; } - public Expression getExpression() { - Preconditions.checkArgument(join.getExpressions().size() == 1); - return join.getExpressions().get(0); - } - - public List getHashJoinConjuncts() { - return join.getHashJoinConjuncts(); - } - - public List getOtherJoinConjuncts() { - return join.getOtherJoinConjuncts(); - } + public abstract Set getInputSlots(); - public final Set getInputSlots() { - Set slots = new HashSet<>(); - join.getExpressions().stream().forEach(expression -> slots.addAll(expression.getInputSlots())); - return slots; - } + public abstract List getExpressions(); @Override public String toString() { return String.format("<%s - %s>", LongBitmap.toString(leftExtendedNodes), LongBitmap.toString( rightExtendedNodes)); } - - /** - * extract join type and conjuncts from edges - */ - public static @Nullable JoinType extractJoinTypeAndConjuncts(List edges, - List hashConjuncts, List otherConjuncts) { - JoinType joinType = null; - for (Edge edge : edges) { - if (edge.getJoinType() != joinType && joinType != null) { - return null; - } - Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType()); - joinType = edge.getJoinType(); - hashConjuncts.addAll(edge.getHashJoinConjuncts()); - otherConjuncts.addAll(edge.getOtherJoinConjuncts()); - } - return joinType; - } - - public static Edge createTempEdge(LogicalJoin join) { - return new Edge(join, -1, null, null, 0L); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java new file mode 100644 index 00000000000000..d04031067d8682 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/FilterEdge.java @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.jobs.joinorder.hypergraph.edge; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.List; +import java.util.Set; + +/** + * Edge represents a filter + */ +public class FilterEdge extends Edge { + private final LogicalFilter filter; + private final List rejectEdges; + + public FilterEdge(LogicalFilter filter, int index, + BitSet childEdges, long subTreeNodes, long childRequireNodes) { + super(index, childEdges, new BitSet(), subTreeNodes, childRequireNodes, 0L); + this.filter = filter; + rejectEdges = new ArrayList<>(); + } + + public void addRejectJoin(JoinEdge joinEdge) { + rejectEdges.add(joinEdge.getIndex()); + } + + public List getRejectEdges() { + return rejectEdges; + } + + @Override + public Set getInputSlots() { + return filter.getInputSlots(); + } + + @Override + public List getExpressions() { + return filter.getExpressions(); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java new file mode 100644 index 00000000000000..81e80bee85fe37 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/edge/JoinEdge.java @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.jobs.joinorder.hypergraph.edge; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; + +import com.google.common.base.Preconditions; + +import java.util.BitSet; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import javax.annotation.Nullable; + +/** + * Edge represents a join + */ +public class JoinEdge extends Edge { + + private final LogicalJoin join; + + public JoinEdge(LogicalJoin join, int index, + BitSet leftChildEdges, BitSet rightChildEdges, long subTreeNodes, + long leftRequireNodes, long rightRequireNodes) { + super(index, leftChildEdges, rightChildEdges, subTreeNodes, leftRequireNodes, rightRequireNodes); + this.join = join; + } + + public JoinType getJoinType() { + return join.getJoinType(); + } + + public LogicalJoin getJoin() { + return join; + } + + /** + * extract join type for edges and push them in hash conjuncts and other conjuncts + */ + public static @Nullable JoinType extractJoinTypeAndConjuncts(List edges, + List hashConjuncts, List otherConjuncts) { + JoinType joinType = null; + for (JoinEdge edge : edges) { + if (edge.getJoinType() != joinType && joinType != null) { + return null; + } + Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType()); + joinType = edge.getJoinType(); + hashConjuncts.addAll(edge.getHashJoinConjuncts()); + otherConjuncts.addAll(edge.getOtherJoinConjuncts()); + } + return joinType; + } + + public Expression getExpression() { + Preconditions.checkArgument(join.getExpressions().size() == 1); + return join.getExpressions().get(0); + } + + @Override + public List getExpressions() { + return join.getExpressions(); + } + + public List getHashJoinConjuncts() { + return join.getHashJoinConjuncts(); + } + + public List getOtherJoinConjuncts() { + return join.getOtherJoinConjuncts(); + } + + @Override + public Set getInputSlots() { + Set slots = new HashSet<>(); + join.getExpressions().forEach(expression -> slots.addAll(expression.getInputSlots())); + return slots; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/AbstractNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/AbstractNode.java index 803d637a0da388..16ee7340876ef4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/AbstractNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/AbstractNode.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/DPhyperNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/DPhyperNode.java index b905646820a2a0..7601d2dff4ccfd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/DPhyperNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/DPhyperNode.java @@ -17,7 +17,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; import org.apache.doris.nereids.memo.Group; import com.google.common.base.Preconditions; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java index d71a08ea13af70..042e22fcf88a80 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.node; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/AbstractReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/AbstractReceiver.java index ddd1394f9df96a..89bd1ffa3ee0c8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/AbstractReceiver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/AbstractReceiver.java @@ -17,7 +17,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; import org.apache.doris.nereids.memo.Group; import java.util.List; @@ -26,7 +26,7 @@ * A interface of receiver */ public interface AbstractReceiver { - boolean emitCsgCmp(long csg, long cmp, List edges); + boolean emitCsgCmp(long csg, long cmp, List edges); void addGroup(long bitSet, Group group); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/Counter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/Counter.java index d14ce91469b90f..ede98f6befd524 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/Counter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/Counter.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; import org.apache.doris.nereids.memo.Group; import com.google.common.base.Preconditions; @@ -51,7 +51,7 @@ public Counter(int limit) { * @param edges the join operator * @return the left and the right can be connected by the edge */ - public boolean emitCsgCmp(long left, long right, List edges) { + public boolean emitCsgCmp(long left, long right, List edges) { Preconditions.checkArgument(counter.containsKey(left)); Preconditions.checkArgument(counter.containsKey(right)); emitCount += 1; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java index c1a830faa9efb0..cc0064c6e482fa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java @@ -20,9 +20,10 @@ import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.jobs.cascades.CostAndEnforcerJob; import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; import org.apache.doris.nereids.memo.CopyInResult; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; @@ -103,7 +104,7 @@ public PlanReceiver(JobContext jobContext, int limit, HyperGraph hyperGraph, Set * @return the left and the right can be connected by the edge */ @Override - public boolean emitCsgCmp(long left, long right, List edges) { + public boolean emitCsgCmp(long left, long right, List edges) { Preconditions.checkArgument(planTable.containsKey(left)); Preconditions.checkArgument(planTable.containsKey(right)); processMissedEdges(left, right, edges); @@ -122,7 +123,7 @@ public boolean emitCsgCmp(long left, long right, List edges) { List hashConjuncts = new ArrayList<>(); List otherConjuncts = new ArrayList<>(); - JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts); + JoinType joinType = JoinEdge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts); if (joinType == null) { return true; } @@ -149,7 +150,7 @@ public boolean emitCsgCmp(long left, long right, List edges) { // be aware that the requiredOutputSlots is a superset of the actual output of current node // check proposeProject method to get how to create a project node for the outputs of current node. - private Set calculateRequiredSlots(long left, long right, List edges) { + private Set calculateRequiredSlots(long left, long right, List edges) { // required output slots = final outputs + slot of unused edges + complex project exprs(if there is any) // 1. add finalOutputs to requiredOutputSlots Set requiredOutputSlots = new HashSet<>(this.finalOutputs); @@ -162,7 +163,7 @@ private Set calculateRequiredSlots(long left, long right, List edges // 2. add unused edges' input slots to requiredOutputSlots usdEdges.put(LongBitmap.newBitmapUnion(left, right), usedEdgesBitmap); - for (Edge edge : hyperGraph.getEdges()) { + for (Edge edge : hyperGraph.getJoinEdges()) { if (!usedEdgesBitmap.get(edge.getIndex())) { requiredOutputSlots.addAll(edge.getInputSlots()); } @@ -180,7 +181,7 @@ private Set calculateRequiredSlots(long left, long right, List edges } // add any missed edge into edges to connect left and right - private void processMissedEdges(long left, long right, List edges) { + private void processMissedEdges(long left, long right, List edges) { // find all used edges BitSet usedEdgesBitmap = new BitSet(); usedEdgesBitmap.or(usdEdges.get(left)); @@ -191,9 +192,8 @@ private void processMissedEdges(long left, long right, List edges) { long allReferenceNodes = LongBitmap.or(left, right); // find the edge which is not in usedEdgesBitmap and its referenced nodes is subset of allReferenceNodes - for (Edge edge : hyperGraph.getEdges()) { - long referenceNodes = - LongBitmap.newBitmapUnion(edge.getLeftRequiredNodes(), edge.getRightRequiredNodes()); + for (JoinEdge edge : hyperGraph.getJoinEdges()) { + long referenceNodes = LongBitmap.newBitmapUnion(edge.getLeftRequiredNodes(), edge.getRightRequiredNodes()); if (LongBitmap.isSubset(referenceNodes, allReferenceNodes) && !usedEdgesBitmap.get(edge.getIndex())) { // add the missed edge to edges @@ -220,8 +220,8 @@ private List proposeAllPhysicalJoins(JoinType joinType, Plan left, Plan ri List plans = Lists.newArrayList(); if (JoinUtils.shouldNestedLoopJoin(joinType, hashConjuncts)) { plans.add(new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts, - Optional.empty(), joinProperties, - left, right)); + Optional.empty(), joinProperties, + left, right)); if (joinType.isSwapJoinType()) { plans.add(new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(), joinProperties, @@ -241,17 +241,6 @@ private List proposeAllPhysicalJoins(JoinType joinType, Plan left, Plan ri return plans; } - private boolean extractIsMarkJoin(List edges) { - boolean isMarkJoin = false; - JoinType joinType = null; - for (Edge edge : edges) { - Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType()); - isMarkJoin = edge.getJoin().isMarkJoin() || isMarkJoin; - joinType = edge.getJoinType(); - } - return isMarkJoin; - } - @Override public void addGroup(long bitmap, Group group) { Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 1); @@ -330,7 +319,7 @@ private void makeLogicalExpression(Supplier root) { } } - private List proposeProject(List allChild, List edges, long left, long right) { + private List proposeProject(List allChild, List edges, long left, long right) { long fullKey = LongBitmap.newBitmapUnion(left, right); List outputs = allChild.get(0).getOutput(); Set outputSet = allChild.get(0).getOutputSet(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java index 384522c0f679c9..c1873d09e481ee 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewJoinRule.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.exploration.mv; -import org.apache.doris.nereids.jobs.joinorder.hypergraph.Edge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode; import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping; @@ -91,7 +91,7 @@ protected boolean checkPattern(StructInfo structInfo) { return false; } } - for (Edge edge : hyperGraph.getEdges()) { + for (JoinEdge edge : hyperGraph.getJoinEdges()) { if (!edge.getJoin().accept(StructInfo.JOIN_PATTERN_CHECKER, SUPPORTED_JOIN_TYPE_SET)) { return false; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java index 5be9eee897e3a5..3e23ae2e49c487 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/StructInfo.java @@ -96,7 +96,7 @@ private void init() { this.predicates = Predicates.of(); // Collect predicate from join condition in hyper graph - this.hyperGraph.getEdges().forEach(edge -> { + this.hyperGraph.getJoinEdges().forEach(edge -> { List hashJoinConjuncts = edge.getHashJoinConjuncts(); hashJoinConjuncts.forEach(conjunctExpr -> { predicates.addPredicate(conjunctExpr); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughJoin.java index 9165a95b6c67d6..ebee8b9c1a2e96 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughJoin.java @@ -41,7 +41,7 @@ public class PushDownFilterThroughJoin extends OneRewriteRuleFactory { public static final PushDownFilterThroughJoin INSTANCE = new PushDownFilterThroughJoin(); - private static final ImmutableList COULD_PUSH_THROUGH_LEFT = ImmutableList.of( + public static final ImmutableList COULD_PUSH_THROUGH_LEFT = ImmutableList.of( JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_SEMI_JOIN, @@ -50,7 +50,7 @@ public class PushDownFilterThroughJoin extends OneRewriteRuleFactory { JoinType.CROSS_JOIN ); - private static final ImmutableList COULD_PUSH_THROUGH_RIGHT = ImmutableList.of( + public static final ImmutableList COULD_PUSH_THROUGH_RIGHT = ImmutableList.of( JoinType.INNER_JOIN, JoinType.RIGHT_OUTER_JOIN, JoinType.RIGHT_SEMI_JOIN, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java index b65192ad008960..c26e1650781799 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/GraphSimplifierTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter; import org.apache.doris.nereids.trees.expressions.Alias; @@ -261,7 +262,7 @@ void test64Clique() { GraphSimplifier graphSimplifier = new GraphSimplifier(hyperGraph); graphSimplifier.simplifyGraph(1); - for (Edge edge : hyperGraph.getEdges()) { + for (Edge edge : hyperGraph.getJoinEdges()) { System.out.println(edge); } Assertions.assertTrue(subgraphEnumerator.enumerate()); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java index fc079c28874e10..af9c9d7e3c1754 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java @@ -102,7 +102,7 @@ void testRandomQuery() { HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder(); HyperGraph hyperGraph = hyperGraphBuilder.randomBuildWith(tableNum, edgeNum); Assertions.assertEquals(hyperGraph.getNodes().size(), tableNum); - Assertions.assertEquals(hyperGraph.getEdges().size(), edgeNum); + Assertions.assertEquals(hyperGraph.getJoinEdges().size(), edgeNum); } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumeratorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumeratorTest.java index 47dc68c9e04bd7..c1e68cd55a9ec4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumeratorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/SubgraphEnumeratorTest.java @@ -19,6 +19,7 @@ import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmapSubsetIterator; +import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge; import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.util.HyperGraphBuilder; @@ -129,7 +130,7 @@ private int countAndCheck(long bitmap, HyperGraph hyperGraph, HashMap { + HyperGraph structInfo = HyperGraph.toStructInfo(j).get(0); + Assertions.assertTrue(structInfo.getJoinEdge(0).getJoinType().isLeftOuterJoin()); + Assertions.assertEquals(0, (int) structInfo.getFilterEdge(0).getRejectEdges().get(0)); + return true; + })); + + sql = "select * from (select id from T1 where id = 0) T1 left outer join T2 " + + "on T1.id = T2.id "; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin() + .when(j -> { + HyperGraph structInfo = HyperGraph.toStructInfo(j).get(0); + Assertions.assertTrue(structInfo.getJoinEdge(0).getJoinType().isLeftOuterJoin()); + Assertions.assertTrue(structInfo.getFilterEdge(0).getRejectEdges().isEmpty()); + return true; + })); + } }