Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.Counter;
Expand Down Expand Up @@ -81,7 +83,7 @@ public class GraphSimplifier {
*/
public GraphSimplifier(HyperGraph graph) {
this.graph = graph;
edgeSize = graph.getEdges().size();
edgeSize = graph.getJoinEdges().size();
for (int i = 0; i < edgeSize; i++) {
BestSimplification bestSimplification = new BestSimplification();
simplifications.add(bestSimplification);
Expand All @@ -91,7 +93,7 @@ public GraphSimplifier(HyperGraph graph) {
cacheStats.put(node.getNodeMap(), dPhyperNode.getGroup().getStatistics());
cacheCost.put(node.getNodeMap(), dPhyperNode.getRowCount());
}
validEdges = graph.getEdges().stream()
validEdges = graph.getJoinEdges().stream()
.filter(e -> {
for (Slot slot : e.getJoin().getConditionSlot()) {
boolean contains = false;
Expand Down Expand Up @@ -136,8 +138,8 @@ private void initFirstStep() {
public boolean isTotalOrder() {
for (int i = 0; i < edgeSize; i++) {
for (int j = i + 1; j < edgeSize; j++) {
Edge edge1 = graph.getEdge(i);
Edge edge2 = graph.getEdge(j);
Edge edge1 = graph.getJoinEdge(i);
Edge edge2 = graph.getJoinEdge(j);
List<Long> superset = new ArrayList<>();
tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getLeftExtendedNodes(), superset);
tryGetSuperset(edge1.getLeftExtendedNodes(), edge2.getRightExtendedNodes(), superset);
Expand Down Expand Up @@ -342,8 +344,8 @@ private void updatePriorityQueue(int index) {
}

private Optional<SimplificationStep> makeSimplificationStep(int edgeIndex1, int edgeIndex2) {
Edge edge1 = graph.getEdge(edgeIndex1);
Edge edge2 = graph.getEdge(edgeIndex2);
JoinEdge edge1 = graph.getJoinEdge(edgeIndex1);
JoinEdge edge2 = graph.getJoinEdge(edgeIndex2);
if (edge1.isSub(edge2) || edge2.isSub(edge1)
|| circleDetector.checkCircleWithEdge(edgeIndex1, edgeIndex2)
|| circleDetector.checkCircleWithEdge(edgeIndex2, edgeIndex1)
Expand All @@ -358,8 +360,8 @@ private Optional<SimplificationStep> makeSimplificationStep(int edgeIndex1, int
|| !cacheStats.containsKey(left2) || !cacheStats.containsKey(right2)) {
return Optional.empty();
}
Edge edge1Before2;
Edge edge2Before1;
JoinEdge edge1Before2;
JoinEdge edge2Before1;
List<Long> superBitset = new ArrayList<>();
if (tryGetSuperset(left1, left2, superBitset)) {
// (common Join1 right1) Join2 right2
Expand Down Expand Up @@ -394,36 +396,34 @@ private Optional<SimplificationStep> makeSimplificationStep(int edgeIndex1, int
return Optional.of(simplificationStep);
}

private Edge constructEdge(long leftNodes, Edge edge, long rightNodes) {
private JoinEdge constructEdge(long leftNodes, JoinEdge edge, long rightNodes) {
LogicalJoin<? extends Plan, ? extends Plan> join;
if (graph.getEdges().size() > 64 * 63 / 8) {
if (graph.getJoinEdges().size() > 64 * 63 / 8) {
// If there are too many edges, it is advisable to return the "edge" directly
// to avoid lengthy enumeration time.
join = edge.getJoin();
} else {
BitSet validEdgesMap = graph.getEdgesInOperator(leftNodes, rightNodes);
List<Expression> hashConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.mapToObj(i -> graph.getJoinEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
List<Expression> otherConditions = validEdgesMap.stream()
.mapToObj(i -> graph.getEdge(i).getJoin().getHashJoinConjuncts())
.mapToObj(i -> graph.getJoinEdge(i).getJoin().getHashJoinConjuncts())
.flatMap(Collection::stream)
.collect(Collectors.toList());
join = edge.getJoin().withJoinConjuncts(hashConditions, otherConditions);
}

Edge newEdge = new Edge(
join,
edge.getIndex(), edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes());
newEdge.setLeftRequiredNodes(edge.getLeftRequiredNodes());
newEdge.setRightRequiredNodes(edge.getRightRequiredNodes());
newEdge.addLeftNode(leftNodes);
newEdge.addRightNode(rightNodes);
JoinEdge newEdge = new JoinEdge(join, edge.getIndex(),
edge.getLeftChildEdges(), edge.getRightChildEdges(), edge.getSubTreeNodes(),
edge.getLeftRequiredNodes(), edge.getRightRequiredNodes());
newEdge.addLeftExtendNode(leftNodes);
newEdge.addRightExtendNode(rightNodes);
return newEdge;
}

private void deriveStats(Edge edge, long leftBitmap, long rightBitmap) {
private void deriveStats(JoinEdge edge, long leftBitmap, long rightBitmap) {
// The bitmap may differ from the edge's reference slots.
// Taking into account the order: edge1<{1} - {2}> edge2<{1,3} - {4}>.
// Actually, we are considering the sequence {1,3} - {2} - {4}
Expand All @@ -438,7 +438,7 @@ private void deriveStats(Edge edge, long leftBitmap, long rightBitmap) {
cacheStats.put(bitmap, joinStats);
}

private double calCost(Edge edge, long leftBitmap, long rightBitmap) {
private double calCost(JoinEdge edge, long leftBitmap, long rightBitmap) {
long bitmap = LongBitmap.newBitmapUnion(leftBitmap, rightBitmap);
Preconditions.checkArgument(cacheStats.containsKey(leftBitmap) && cacheStats.containsKey(rightBitmap)
&& cacheStats.containsKey(bitmap),
Expand All @@ -461,7 +461,7 @@ private double calCost(Edge edge, long leftBitmap, long rightBitmap) {
return cost;
}

private @Nullable Edge threeLeftJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
private @Nullable JoinEdge threeLeftJoin(long bitmap1, JoinEdge edge1, long bitmap2, JoinEdge edge2, long bitmap3) {
// (plan1 edge1 plan2) edge2 plan3
// if the left and right is overlapping, just return null.
Preconditions.checkArgument(
Expand All @@ -471,7 +471,7 @@ private double calCost(Edge edge, long leftBitmap, long rightBitmap) {
if (LongBitmap.isOverlap(newLeft, bitmap3)) {
return null;
}
Edge newEdge = constructEdge(newLeft, edge2, bitmap3);
JoinEdge newEdge = constructEdge(newLeft, edge2, bitmap3);

deriveStats(edge1, bitmap1, bitmap2);
deriveStats(newEdge, newLeft, bitmap3);
Expand All @@ -481,15 +481,16 @@ private double calCost(Edge edge, long leftBitmap, long rightBitmap) {
return newEdge;
}

private @Nullable Edge threeRightJoin(long bitmap1, Edge edge1, long bitmap2, Edge edge2, long bitmap3) {
Preconditions.checkArgument(
cacheStats.containsKey(bitmap1) && cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3));
private @Nullable JoinEdge threeRightJoin(long bitmap1, JoinEdge edge1, long bitmap2,
JoinEdge edge2, long bitmap3) {
Preconditions.checkArgument(cacheStats.containsKey(bitmap1)
&& cacheStats.containsKey(bitmap2) && cacheStats.containsKey(bitmap3));
// plan1 edge1 (plan2 edge2 plan3)
long newRight = LongBitmap.newBitmapUnion(bitmap2, bitmap3);
if (LongBitmap.isOverlap(bitmap1, newRight)) {
return null;
}
Edge newEdge = constructEdge(bitmap1, edge1, newRight);
JoinEdge newEdge = constructEdge(bitmap1, edge1, newRight);

deriveStats(edge2, bitmap2, bitmap3);
deriveStats(newEdge, bitmap1, newRight);
Expand All @@ -498,8 +499,8 @@ private double calCost(Edge edge, long leftBitmap, long rightBitmap) {
return newEdge;
}

private SimplificationStep orderJoin(Edge edge1Before2,
Edge edge2Before1, int edgeIndex1, int edgeIndex2) {
private SimplificationStep orderJoin(JoinEdge edge1Before2,
JoinEdge edge2Before1, int edgeIndex1, int edgeIndex2) {
double cost1Before2 = calCost(edge1Before2,
edge1Before2.getLeftExtendedNodes(), edge1Before2.getRightExtendedNodes());
double cost2Before1 = calCost(edge2Before1,
Expand All @@ -515,16 +516,16 @@ private SimplificationStep orderJoin(Edge edge1Before2,
step = new SimplificationStep(benefit, edgeIndex1, edgeIndex2,
edge1Before2.getLeftExtendedNodes(),
edge1Before2.getRightExtendedNodes(),
graph.getEdge(edgeIndex2).getLeftExtendedNodes(),
graph.getEdge(edgeIndex2).getRightExtendedNodes());
graph.getJoinEdge(edgeIndex2).getLeftExtendedNodes(),
graph.getJoinEdge(edgeIndex2).getRightExtendedNodes());
} else {
if (cost2Before1 != 0) {
benefit = cost1Before2 / cost2Before1;
}
// choose edge2Before1
step = new SimplificationStep(benefit, edgeIndex2, edgeIndex1, edge2Before1.getLeftExtendedNodes(),
edge2Before1.getRightExtendedNodes(), graph.getEdge(edgeIndex1).getLeftExtendedNodes(),
graph.getEdge(edgeIndex1).getRightExtendedNodes());
edge2Before1.getRightExtendedNodes(), graph.getJoinEdge(edgeIndex1).getLeftExtendedNodes(),
graph.getJoinEdge(edgeIndex1).getRightExtendedNodes());
}
return step;
}
Expand All @@ -545,9 +546,9 @@ private boolean tryGetSuperset(long bitmap1, long bitmap2, List<Long> superset)
*/
private void extractJoinDependencies() {
for (int i = 0; i < edgeSize; i++) {
Edge edge1 = graph.getEdge(i);
Edge edge1 = graph.getJoinEdge(i);
for (int j = i + 1; j < edgeSize; j++) {
Edge edge2 = graph.getEdge(j);
Edge edge2 = graph.getJoinEdge(j);
if (edge1.isSub(edge2)) {
Preconditions.checkArgument(circleDetector.tryAddDirectedEdge(i, j),
"Edge %s violates Edge %s", edge1, edge2);
Expand Down
Loading