From 1ef49c5ffac429c7f1904a7e65efeda0603b7f53 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Tue, 20 Jun 2023 11:24:43 +0530 Subject: [PATCH 01/12] initial commit --- .../SortMergeJoinFrameProcessorFactory.java | 14 +++++---- .../druid/sql/calcite/rule/DruidJoinRule.java | 30 ++++++++++++++++--- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java index 9aa50630929d..33734143ccc8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java @@ -204,12 +204,14 @@ public static List> toKeyColumns(final JoinConditionAnalysis con retVal.add(new ArrayList<>()); // Right-side key columns for (final Equality equiCondition : condition.getEquiConditions()) { - final String leftColumn = Preconditions.checkNotNull( - equiCondition.getLeftExpr().getBindingIfIdentifier(), - "leftExpr#getBindingIfIdentifier" - ); - retVal.get(0).add(new KeyColumn(leftColumn, KeyOrder.ASCENDING)); + if (!equiCondition.getLeftExpr().isLiteral()) { + final String leftColumn = Preconditions.checkNotNull( + equiCondition.getLeftExpr().getBindingIfIdentifier(), + "leftExpr#getBindingIfIdentifier" + ); + retVal.get(0).add(new KeyColumn(leftColumn, KeyOrder.ASCENDING)); + } retVal.get(1).add(new KeyColumn(equiCondition.getRightColumn(), KeyOrder.ASCENDING)); } @@ -234,7 +236,7 @@ public static JoinConditionAnalysis validateCondition(final JoinConditionAnalysi throw new IAE("Cannot handle non-equijoin condition: %s", condition.getOriginalExpression()); } - if (condition.getEquiConditions().stream().anyMatch(c -> !c.getLeftExpr().isIdentifier())) { + if (condition.getEquiConditions().stream().anyMatch(c -> !c.getLeftExpr().isIdentifier() && !c.getLeftExpr().isLiteral())) { throw new IAE( "Cannot handle equality condition involving left-hand expression: %s", condition.getOriginalExpression() diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java index 2d6e5e56a43e..7cf4fd0a1742 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java @@ -241,6 +241,7 @@ private Optional analyzeCondition( final List subConditions = decomposeAnd(condition); final List> equalitySubConditions = new ArrayList<>(); final List literalSubConditions = new ArrayList<>(); + final List inputRefSubConditions = new ArrayList<>(); final int numLeftFields = leftRowType.getFieldCount(); final Set rightColumns = new HashSet<>(); @@ -266,6 +267,11 @@ private Optional analyzeCondition( continue; } + if (subCondition.isA(SqlKind.INPUT_REF)) { + inputRefSubConditions.add((RexInputRef) subCondition); + continue; + } + if (!subCondition.isA(SqlKind.EQUALS)) { // If it's not EQUALS, it's not supported. plannerContext.setPlanningError( @@ -310,7 +316,13 @@ && isLeftExpression(operands.get(1), numLeftFields)) { } } - return Optional.of(new ConditionAnalysis(numLeftFields, equalitySubConditions, literalSubConditions)); + return Optional.of( + new ConditionAnalysis( + numLeftFields, + equalitySubConditions, + literalSubConditions, + inputRefSubConditions + )); } @VisibleForTesting @@ -375,15 +387,19 @@ static class ConditionAnalysis */ private final List literalSubConditions; + private final List inputRefs; + ConditionAnalysis( int numLeftFields, List> equalitySubConditions, - List literalSubConditions + List literalSubConditions, + List inputRefs ) { this.numLeftFields = numLeftFields; this.equalitySubConditions = equalitySubConditions; this.literalSubConditions = literalSubConditions; + this.inputRefs = inputRefs; } public ConditionAnalysis pushThroughLeftProject(final Project leftProject) @@ -403,7 +419,8 @@ public ConditionAnalysis pushThroughLeftProject(final Project leftProject) ) ) .collect(Collectors.toList()), - literalSubConditions + literalSubConditions, + inputRefs ); } @@ -428,7 +445,8 @@ public ConditionAnalysis pushThroughRightProject(final Project rightProject) ) ) .collect(Collectors.toList()), - literalSubConditions + literalSubConditions, + inputRefs ); } @@ -454,6 +472,10 @@ public RexNode getCondition(final RexBuilder rexBuilder) equalitySubConditions .stream() .map(equality -> rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, equality.lhs, equality.rhs)) + .collect(Collectors.toList()), + inputRefs + .stream() + .map(inputRef -> rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, inputRef, rexBuilder.makeLiteral(true))) .collect(Collectors.toList()) ), false From 37771fccd3601cc2b8fe55933fbc570f8c6d8862 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Fri, 30 Jun 2023 12:01:42 +0530 Subject: [PATCH 02/12] use algorithm as a hint --- .../druid/msq/querykit/DataSourcePlan.java | 30 +++++++- .../druid/sql/calcite/rule/DruidJoinRule.java | 72 +++++++++---------- .../sql/calcite/rule/DruidJoinRuleTest.java | 21 ++++-- 3 files changed, 73 insertions(+), 50 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index d6a21fc13382..b339f33bc59d 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -56,6 +56,7 @@ import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.sql.calcite.external.ExternalDataSource; import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.planner.JoinAlgorithm; @@ -144,9 +145,13 @@ public static DataSourcePlan forDataSource( broadcast ); } else if (dataSource instanceof JoinDataSource) { - final JoinAlgorithm joinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext); + final JoinAlgorithm preferredJoinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext); + final JoinAlgorithm deducedJoinAlgorithm = deduceJoinAlgorithm( + preferredJoinAlgorithm, + ((JoinDataSource) dataSource) + ); - switch (joinAlgorithm) { + switch (deducedJoinAlgorithm) { case BROADCAST: return forBroadcastHashJoin( queryKit, @@ -171,7 +176,7 @@ public static DataSourcePlan forDataSource( ); default: - throw new UOE("Cannot handle join algorithm [%s]", joinAlgorithm); + throw new UOE("Cannot handle join algorithm [%s]", deducedJoinAlgorithm); } } else { throw new UOE("Cannot handle dataSource [%s]", dataSource); @@ -198,6 +203,25 @@ public Optional getSubQueryDefBuilder() return Optional.ofNullable(subQueryDefBuilder); } + private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgorithm, JoinDataSource joinDataSource) + { + if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) { + return JoinAlgorithm.BROADCAST; + } else { + if (isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis())) { + return JoinAlgorithm.SORT_MERGE; + } + } + return JoinAlgorithm.BROADCAST; + } + + private static boolean isConditionEqualityOnLeftAndRightColumns(JoinConditionAnalysis joinConditionAnalysis) + { + return joinConditionAnalysis.getEquiConditions() + .stream() + .allMatch(equality -> equality.getLeftExpr().isIdentifier()); + } + /** * Whether this datasource must be processed by a single worker. True if, and only if, all inputs are broadcast. */ diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java index 7cf4fd0a1742..add568608360 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java @@ -92,7 +92,7 @@ public boolean matches(RelOptRuleCall call) // 1) Can handle the join condition as a native join. // 2) Left has a PartialDruidQuery (i.e., is a real query, not top-level UNION ALL). // 3) Right has a PartialDruidQuery (i.e., is a real query, not top-level UNION ALL). - return canHandleCondition(join.getCondition(), join.getLeft().getRowType(), right) + return canHandleCondition(join.getCondition(), join.getLeft().getRowType(), right, join.getCluster().getRexBuilder()) && left.getPartialDruidQuery() != null && right.getPartialDruidQuery() != null; } @@ -116,7 +116,8 @@ public void onMatch(RelOptRuleCall call) ConditionAnalysis conditionAnalysis = analyzeCondition( join.getCondition(), join.getLeft().getRowType(), - right + right, + rexBuilder ).get(); final boolean isLeftDirectAccessPossible = enableLeftScanDirect && (left instanceof DruidQueryRel); @@ -223,9 +224,9 @@ private static RexNode makeNullableIfLiteral(final RexNode rexNode, final RexBui * Returns whether {@link #analyzeCondition} would return something. */ @VisibleForTesting - boolean canHandleCondition(final RexNode condition, final RelDataType leftRowType, DruidRel right) + boolean canHandleCondition(final RexNode condition, final RelDataType leftRowType, DruidRel right, final RexBuilder rexBuilder) { - return analyzeCondition(condition, leftRowType, right).isPresent(); + return analyzeCondition(condition, leftRowType, right, rexBuilder).isPresent(); } /** @@ -235,13 +236,13 @@ boolean canHandleCondition(final RexNode condition, final RelDataType leftRowTyp private Optional analyzeCondition( final RexNode condition, final RelDataType leftRowType, - final DruidRel right + final DruidRel right, + final RexBuilder rexBuilder ) { final List subConditions = decomposeAnd(condition); final List> equalitySubConditions = new ArrayList<>(); final List literalSubConditions = new ArrayList<>(); - final List inputRefSubConditions = new ArrayList<>(); final int numLeftFields = leftRowType.getFieldCount(); final Set rightColumns = new HashSet<>(); @@ -267,10 +268,6 @@ private Optional analyzeCondition( continue; } - if (subCondition.isA(SqlKind.INPUT_REF)) { - inputRefSubConditions.add((RexInputRef) subCondition); - continue; - } if (!subCondition.isA(SqlKind.EQUALS)) { // If it's not EQUALS, it's not supported. @@ -281,16 +278,28 @@ private Optional analyzeCondition( return Optional.empty(); } - final List operands = ((RexCall) subCondition).getOperands(); - Preconditions.checkState(operands.size() == 2, "Expected 2 operands, got[%,d]", operands.size()); + RexNode firstOperand; + RexNode secondOperand; - if (isLeftExpression(operands.get(0), numLeftFields) && isRightInputRef(operands.get(1), numLeftFields)) { - equalitySubConditions.add(Pair.of(operands.get(0), (RexInputRef) operands.get(1))); - rightColumns.add((RexInputRef) operands.get(1)); - } else if (isRightInputRef(operands.get(0), numLeftFields) - && isLeftExpression(operands.get(1), numLeftFields)) { - equalitySubConditions.add(Pair.of(operands.get(1), (RexInputRef) operands.get(0))); - rightColumns.add((RexInputRef) operands.get(0)); + if (subCondition.isA(SqlKind.INPUT_REF)) { + firstOperand = rexBuilder.makeLiteral(true); + secondOperand = (RexInputRef) subCondition; + } else if (subCondition.isA(SqlKind.EQUALS)) { + final List operands = ((RexCall) subCondition).getOperands(); + Preconditions.checkState(operands.size() == 2, "Expected 2 operands, got[%,d]", operands.size()); + firstOperand = operands.get(0); + secondOperand = operands.get(1); + } else { + return Optional.empty(); + } + + if (isLeftExpression(firstOperand, numLeftFields) && isRightInputRef(secondOperand, numLeftFields)) { + equalitySubConditions.add(Pair.of(firstOperand, (RexInputRef) secondOperand)); + rightColumns.add((RexInputRef) secondOperand); + } else if (isRightInputRef(firstOperand, numLeftFields) + && isLeftExpression(secondOperand, numLeftFields)) { + equalitySubConditions.add(Pair.of(secondOperand, (RexInputRef) firstOperand)); + rightColumns.add((RexInputRef) firstOperand); } else { // Cannot handle this condition. plannerContext.setPlanningError("SQL is resulting in a join that has unsupported operand types."); @@ -320,8 +329,7 @@ && isLeftExpression(operands.get(1), numLeftFields)) { new ConditionAnalysis( numLeftFields, equalitySubConditions, - literalSubConditions, - inputRefSubConditions + literalSubConditions )); } @@ -353,13 +361,6 @@ static List decomposeAnd(final RexNode condition) private boolean isLeftExpression(final RexNode rexNode, final int numLeftFields) { - if (!plannerContext.getJoinAlgorithm().canHandleLeftExpressions()) { - // Must be INPUT_REF. - if (!rexNode.isA(SqlKind.INPUT_REF)) { - return false; - } - } - return ImmutableBitSet.range(numLeftFields).contains(RelOptUtil.InputFinder.bits(rexNode)); } @@ -387,19 +388,16 @@ static class ConditionAnalysis */ private final List literalSubConditions; - private final List inputRefs; ConditionAnalysis( int numLeftFields, List> equalitySubConditions, - List literalSubConditions, - List inputRefs + List literalSubConditions ) { this.numLeftFields = numLeftFields; this.equalitySubConditions = equalitySubConditions; this.literalSubConditions = literalSubConditions; - this.inputRefs = inputRefs; } public ConditionAnalysis pushThroughLeftProject(final Project leftProject) @@ -419,8 +417,7 @@ public ConditionAnalysis pushThroughLeftProject(final Project leftProject) ) ) .collect(Collectors.toList()), - literalSubConditions, - inputRefs + literalSubConditions ); } @@ -445,8 +442,7 @@ public ConditionAnalysis pushThroughRightProject(final Project rightProject) ) ) .collect(Collectors.toList()), - literalSubConditions, - inputRefs + literalSubConditions ); } @@ -472,10 +468,6 @@ public RexNode getCondition(final RexBuilder rexBuilder) equalitySubConditions .stream() .map(equality -> rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, equality.lhs, equality.rhs)) - .collect(Collectors.toList()), - inputRefs - .stream() - .map(inputRef -> rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, inputRef, rexBuilder.makeLiteral(true))) .collect(Collectors.toList()) ), false diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java index 41c6895dff25..e531580162ee 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java @@ -84,7 +84,8 @@ public void test_canHandleCondition_leftEqRight() rexBuilder.makeInputRef(joinType, 1) ), leftType, - null + null, + rexBuilder ) ); } @@ -104,7 +105,8 @@ public void test_canHandleCondition_leftFnEqRight() rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1) ), leftType, - null + null, + rexBuilder ) ); } @@ -124,7 +126,8 @@ public void test_canHandleCondition_leftEqRightFn() ) ), leftType, - null + null, + rexBuilder ) ); } @@ -140,7 +143,8 @@ public void test_canHandleCondition_leftEqLeft() rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0) ), leftType, - null + null, + rexBuilder ) ); } @@ -156,7 +160,8 @@ public void test_canHandleCondition_rightEqRight() rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1) ), leftType, - null + null, + rexBuilder ) ); } @@ -168,7 +173,8 @@ public void test_canHandleCondition_true() druidJoinRule.canHandleCondition( rexBuilder.makeLiteral(true), leftType, - null + null, + rexBuilder ) ); } @@ -180,7 +186,8 @@ public void test_canHandleCondition_false() druidJoinRule.canHandleCondition( rexBuilder.makeLiteral(false), leftType, - null + null, + rexBuilder ) ); } From aef958748c16cb9b783045cb1c443cccc9396c1a Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 3 Jul 2023 02:18:58 +0530 Subject: [PATCH 03/12] use algorithm as a hint 2 --- .../apache/druid/msq/querykit/DataSourcePlan.java | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index b339f33bc59d..dbd825df72bb 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -25,6 +25,7 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSets; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.KeyColumn; import org.apache.druid.java.util.common.IAE; @@ -205,12 +206,15 @@ public Optional getSubQueryDefBuilder() private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgorithm, JoinDataSource joinDataSource) { + if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) { return JoinAlgorithm.BROADCAST; - } else { - if (isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis())) { - return JoinAlgorithm.SORT_MERGE; - } + } + + // preferredJoinAlgorithm would only be sortMerge now + + if (isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis())) { + return JoinAlgorithm.SORT_MERGE; } return JoinAlgorithm.BROADCAST; } From ff2f6a47e9a7e3b3f4cb7062719cb49e637e5ea2 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 3 Jul 2023 02:36:44 +0530 Subject: [PATCH 04/12] error message changes --- .../druid/msq/querykit/DataSourcePlan.java | 1 - .../druid/sql/calcite/rule/DruidJoinRule.java | 26 ++++++++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index dbd825df72bb..ef3b1616ec52 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -25,7 +25,6 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSets; -import org.apache.druid.error.DruidException; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.KeyColumn; import org.apache.druid.java.util.common.IAE; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java index add568608360..881e511948e3 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java @@ -39,6 +39,7 @@ import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.ImmutableBitSet; import org.apache.druid.java.util.common.Pair; @@ -268,28 +269,33 @@ private Optional analyzeCondition( continue; } - - if (!subCondition.isA(SqlKind.EQUALS)) { - // If it's not EQUALS, it's not supported. - plannerContext.setPlanningError( - "SQL requires a join with '%s' condition that is not supported.", - subCondition.getKind() - ); - return Optional.empty(); - } - RexNode firstOperand; RexNode secondOperand; if (subCondition.isA(SqlKind.INPUT_REF)) { firstOperand = rexBuilder.makeLiteral(true); secondOperand = (RexInputRef) subCondition; + + if (!SqlTypeName.BOOLEAN_TYPES.contains(secondOperand.getType().getSqlTypeName())) { + plannerContext.setPlanningError( + "SQL requires a join with '%s' condition where the column is of the type %s, that is not supported", + subCondition.getKind(), + secondOperand.getType().getSqlTypeName() + ); + return Optional.empty(); + + } } else if (subCondition.isA(SqlKind.EQUALS)) { final List operands = ((RexCall) subCondition).getOperands(); Preconditions.checkState(operands.size() == 2, "Expected 2 operands, got[%,d]", operands.size()); firstOperand = operands.get(0); secondOperand = operands.get(1); } else { + // If it's not EQUALS or a BOOLEAN input ref, it's not supported. + plannerContext.setPlanningError( + "SQL requires a join with '%s' condition that is not supported.", + subCondition.getKind() + ); return Optional.empty(); } From f51f9d18a891e02fbbc5907081e77c8fcc60e716 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 3 Jul 2023 10:17:25 +0530 Subject: [PATCH 05/12] add test --- .../druid/sql/calcite/rule/DruidJoinRule.java | 2 +- .../sql/calcite/CalciteJoinQueryTest.java | 48 +++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java index 881e511948e3..2ccb6c9eb275 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java @@ -274,7 +274,7 @@ private Optional analyzeCondition( if (subCondition.isA(SqlKind.INPUT_REF)) { firstOperand = rexBuilder.makeLiteral(true); - secondOperand = (RexInputRef) subCondition; + secondOperand = subCondition; if (!SqlTypeName.BOOLEAN_TYPES.contains(secondOperand.getType().getSqlTypeName())) { plannerContext.setPlanningError( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index 337926d462ae..2599b78279ae 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -5652,4 +5652,52 @@ public void testJoinsWithThreeConditions() ) ); } + + @Test + public void testJoinWithInputRefCondition() + { + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + testQuery( + "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) " + + "FROM foo", + context, + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + GroupByQuery.builder() + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDataSource(new TableDataSource(CalciteTests.DATASOURCE1)) + .setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG)) + .setDimensions( + new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + .build() + ), + "j0.", + "(floor(100) == \"j0.d0\")", + JoinType.LEFT + ) + ) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + new SelectorDimFilter("j0.d1", null, null) + ) + )) + .context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0")) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(context) + .build() + ), + ImmutableList.of( + new Object[]{6L} + ) + ); + } } From 3c6ebabb9be3094f11af547e6a96702d83895a5e Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 3 Jul 2023 10:19:34 +0530 Subject: [PATCH 06/12] cannotVectorize --- .../java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index 2599b78279ae..2556ab7d1574 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -5656,6 +5656,7 @@ public void testJoinsWithThreeConditions() @Test public void testJoinWithInputRefCondition() { + cannotVectorize(); Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); testQuery( "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) " From 61cbecfa9bd3e33b847e1e082864a2a039c067af Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 3 Jul 2023 12:16:18 +0530 Subject: [PATCH 07/12] cleanup --- .../SortMergeJoinFrameProcessorFactory.java | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java index 33734143ccc8..76e05d3ce0cf 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java @@ -180,7 +180,8 @@ public ProcessorsAndChannels, Long> makeProcessors( stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()), rightPrefix, keyColumns, - joinType + joinType, + frameContext.memoryParameters().getSortMergeJoinMemory() ); } ); @@ -204,14 +205,12 @@ public static List> toKeyColumns(final JoinConditionAnalysis con retVal.add(new ArrayList<>()); // Right-side key columns for (final Equality equiCondition : condition.getEquiConditions()) { + final String leftColumn = Preconditions.checkNotNull( + equiCondition.getLeftExpr().getBindingIfIdentifier(), + "leftExpr#getBindingIfIdentifier" + ); - if (!equiCondition.getLeftExpr().isLiteral()) { - final String leftColumn = Preconditions.checkNotNull( - equiCondition.getLeftExpr().getBindingIfIdentifier(), - "leftExpr#getBindingIfIdentifier" - ); - retVal.get(0).add(new KeyColumn(leftColumn, KeyOrder.ASCENDING)); - } + retVal.get(0).add(new KeyColumn(leftColumn, KeyOrder.ASCENDING)); retVal.get(1).add(new KeyColumn(equiCondition.getRightColumn(), KeyOrder.ASCENDING)); } @@ -236,7 +235,7 @@ public static JoinConditionAnalysis validateCondition(final JoinConditionAnalysi throw new IAE("Cannot handle non-equijoin condition: %s", condition.getOriginalExpression()); } - if (condition.getEquiConditions().stream().anyMatch(c -> !c.getLeftExpr().isIdentifier() && !c.getLeftExpr().isLiteral())) { + if (condition.getEquiConditions().stream().anyMatch(c -> !c.getLeftExpr().isIdentifier())) { throw new IAE( "Cannot handle equality condition involving left-hand expression: %s", condition.getOriginalExpression() From 49b599f61be77d84c0fd4cb1ad93f716b93aa247 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 3 Jul 2023 16:02:43 +0530 Subject: [PATCH 08/12] test fix for sqlcompatible case --- .../sql/calcite/CalciteJoinQueryTest.java | 143 +++++++++++++----- 1 file changed, 109 insertions(+), 34 deletions(-) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index 2556ab7d1574..104297b4a763 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -5658,44 +5658,119 @@ public void testJoinWithInputRefCondition() { cannotVectorize(); Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + + Query expectedQuery; + + if (!NullHandling.sqlCompatible()) { + expectedQuery = Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + GroupByQuery.builder() + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDataSource(new TableDataSource(CalciteTests.DATASOURCE1)) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "1", + ColumnType.LONG + )) + .setDimensions( + new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + .build() + ), + "j0.", + "(floor(100) == \"j0.d0\")", + JoinType.LEFT + ) + ) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + new SelectorDimFilter("j0.d1", null, null) + ) + )) + .context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0")) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(context) + .build(); + + } else { + expectedQuery = Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + join( + new TableDataSource("foo"), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource("foo") + .aggregators( + new CountAggregatorFactory("a0"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a1"), + not(selector("m1", null, null)), + "a1" + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(context) + .build() + ), + "j0.", + "1", + JoinType.INNER + ), + new QueryDataSource( + GroupByQuery.builder() + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDataSource(new TableDataSource(CalciteTests.DATASOURCE1)) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "1", + ColumnType.LONG + )) + .setDimensions( + new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + .build() + ), + "_j0.", + "(floor(100) == \"_j0.d0\")", + JoinType.LEFT + ) + ) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + or( + new SelectorDimFilter("j0.a0", "0", null), + and( + selector("_j0.d1", null, null), + expressionFilter("(\"j0.a1\" >= \"j0.a0\")") + ) + + ) + ) + )) + .context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0")) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(context) + .build(); + + } + testQuery( "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) " + "FROM foo", context, - ImmutableList.of( - Druids.newTimeseriesQueryBuilder() - .dataSource( - join( - new TableDataSource(CalciteTests.DATASOURCE1), - new QueryDataSource( - GroupByQuery.builder() - .setInterval(querySegmentSpec(Filtration.eternity())) - .setGranularity(Granularities.ALL) - .setDataSource(new TableDataSource(CalciteTests.DATASOURCE1)) - .setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG)) - .setDimensions( - new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT), - new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) - ) - .build() - ), - "j0.", - "(floor(100) == \"j0.d0\")", - JoinType.LEFT - ) - ) - .granularity(Granularities.ALL) - .aggregators(aggregators( - new FilteredAggregatorFactory( - new CountAggregatorFactory("a0"), - new SelectorDimFilter("j0.d1", null, null) - ) - )) - .context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0")) - .intervals(querySegmentSpec(Filtration.eternity())) - .context(context) - .build() - ), + ImmutableList.of(expectedQuery), ImmutableList.of( new Object[]{6L} ) From 9309720f8658d4cbbe9e09c6696efb31c02a5599 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 10 Jul 2023 10:58:48 +0530 Subject: [PATCH 09/12] review --- docs/multi-stage-query/reference.md | 10 ++- .../druid/msq/querykit/DataSourcePlan.java | 37 ++++++-- .../apache/druid/msq/exec/MSQSelectTest.java | 87 +++++++++++++++++++ .../apache/druid/msq/test/MSQTestBase.java | 17 ++++ .../sql/calcite/CalciteJoinQueryTest.java | 2 +- 5 files changed, 143 insertions(+), 10 deletions(-) diff --git a/docs/multi-stage-query/reference.md b/docs/multi-stage-query/reference.md index 5bbe935f1eef..08335ff1143f 100644 --- a/docs/multi-stage-query/reference.md +++ b/docs/multi-stage-query/reference.md @@ -234,7 +234,7 @@ The following table lists the context parameters for the MSQ task engine: | `maxNumTasks` | SELECT, INSERT, REPLACE

The maximum total number of tasks to launch, including the controller task. The lowest possible value for this setting is 2: one controller and one worker. All tasks must be able to launch simultaneously. If they cannot, the query returns a `TaskStartTimeout` error code after approximately 10 minutes.

May also be provided as `numTasks`. If both are present, `maxNumTasks` takes priority. | 2 | | `taskAssignment` | SELECT, INSERT, REPLACE

Determines how many tasks to use. Possible values include:
  • `max`: Uses as many tasks as possible, up to `maxNumTasks`.
  • `auto`: When file sizes can be determined through directory listing (for example: local files, S3, GCS, HDFS) uses as few tasks as possible without exceeding 512 MiB or 10,000 files per task, unless exceeding these limits is necessary to stay within `maxNumTasks`. When calculating the size of files, the weighted size is used, which considers the file format and compression format used if any. When file sizes cannot be determined through directory listing (for example: http), behaves the same as `max`.
| `max` | | `finalizeAggregations` | SELECT, INSERT, REPLACE

Determines the type of aggregation to return. If true, Druid finalizes the results of complex aggregations that directly appear in query results. If false, Druid returns the aggregation's intermediate type rather than finalized type. This parameter is useful during ingestion, where it enables storing sketches directly in Druid tables. For more information about aggregations, see [SQL aggregation functions](../querying/sql-aggregations.md). | true | -| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE

Algorithm to use for JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for sort-merge join. Affects all JOIN operations in the query. See [Joins](#joins) for more details. | `broadcast` | +| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE

Algorithm to use for JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for sort-merge join. Affects all JOIN operations in the query. This is a hint to the MSQ engine and the actual joins in the query may proceed in a different way than specified. See [Joins](#joins) for more details. | `broadcast` | | `rowsInMemory` | INSERT or REPLACE

Maximum number of rows to store in memory at once before flushing to disk during the segment generation process. Ignored for non-INSERT queries. In most cases, use the default value. You may need to override the default if you run into one of the [known issues](./known-issues.md) around memory usage. | 100,000 | | `segmentSortOrder` | INSERT or REPLACE

Normally, Druid sorts rows in individual segments using `__time` first, followed by the [CLUSTERED BY](#clustered-by) clause. When you set `segmentSortOrder`, Druid sorts rows in segments using this column list first, followed by the CLUSTERED BY order.

You provide the column list as comma-separated values or as a JSON array in string form. If your query includes `__time`, then this list must begin with `__time`. For example, consider an INSERT query that uses `CLUSTERED BY country` and has `segmentSortOrder` set to `__time,city`. Within each time chunk, Druid assigns rows to segments based on `country`, and then within each of those segments, Druid sorts those rows by `__time` first, then `city`, then `country`. | empty list | | `maxParseExceptions`| SELECT, INSERT, REPLACE

Maximum number of parse exceptions that are ignored while executing the query before it stops with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value to -1. | 0 | @@ -253,6 +253,12 @@ Joins in multi-stage queries use one of two algorithms based on what you set the If you omit this context parameter, the MSQ task engine uses broadcast since it's the default join algorithm. The context parameter applies to the entire SQL statement, so you can't mix different join algorithms in the same query. +`sqlJoinAlgorithm` is a hint to the planner to execute the join in the specified manner. The planner can decide to ignore +the hint if it deduces that the specified algorithm can be detrimental to the performance of the join beforehand. This intelligence +is very limited as of now, and the `sqlJoinAlgorithm` set would be respected in most cases, therefore the user should set it +appropriately. See the advantages and the drawbacks for the [broadcast](#broadcast) and the [sort-merge](#sort-merge) join to +determine which join to use beforehand. + ### Broadcast The default join algorithm for multi-stage queries is a broadcast hash join, which is similar to how @@ -439,7 +445,7 @@ The following table describes error codes you may encounter in the `multiStageQu | `TooManyInputFiles` | Exceeded the maximum number of input files or segments per worker (10,000 files or segments).

If you encounter this limit, consider adding more workers, or breaking up your query into smaller queries that process fewer files or segments per query. | `numInputFiles`: The total number of input files/segments for the stage.

`maxInputFiles`: The maximum number of input files/segments per worker per stage.

`minNumWorker`: The minimum number of workers required for a successful run. | | `TooManyPartitions` | Exceeded the maximum number of partitions for a stage (25,000 partitions).

This can occur with INSERT or REPLACE statements that generate large numbers of segments, since each segment is associated with a partition. If you encounter this limit, consider breaking up your INSERT or REPLACE statement into smaller statements that process less data per statement. | `maxPartitions`: The limit on partitions which was exceeded | | `TooManyClusteredByColumns` | Exceeded the maximum number of clustering columns for a stage (1,500 columns).

This can occur with `CLUSTERED BY`, `ORDER BY`, or `GROUP BY` with a large number of columns. | `numColumns`: The number of columns requested.

`maxColumns`: The limit on columns which was exceeded.`stage`: The stage number exceeding the limit

| -| `TooManyRowsWithSameKey` | The number of rows for a given key exceeded the maximum number of buffered bytes on both sides of a join. See the [Limits](#limits) table for the specific limit. Only occurs when `sqlJoinAlgorithm` is `sortMerge`. | `key`: The key that had a large number of rows.

`numBytes`: Number of bytes buffered, which may include other keys.

`maxBytes`: Maximum number of bytes buffered. | +| `TooManyRowsWithSameKey` | The number of rows for a given key exceeded the maximum number of buffered bytes on both sides of a join. See the [Limits](#limits) table for the specific limit. Only occurs when join is executed via the sort-merge join algorithm. | `key`: The key that had a large number of rows.

`numBytes`: Number of bytes buffered, which may include other keys.

`maxBytes`: Maximum number of bytes buffered. | | `TooManyColumns` | Exceeded the maximum number of columns for a stage (2,000 columns). | `numColumns`: The number of columns requested.

`maxColumns`: The limit on columns which was exceeded. | | `TooManyWarnings` | Exceeded the maximum allowed number of warnings of a particular type. | `rootErrorCode`: The error code corresponding to the exception that exceeded the required limit.

`maxWarnings`: Maximum number of warnings that are allowed for the corresponding `rootErrorCode`. | | `TooManyWorkers` | Exceeded the maximum number of simultaneously-running workers. See the [Limits](#limits) table for more details. | `workers`: The number of simultaneously running workers that exceeded a hard or soft limit. This may be larger than the number of workers in any one stage if multiple stages are running simultaneously.

`maxWorkers`: The hard or soft limit on workers that was exceeded. If this is lower than the hard limit (1,000 workers), then you can increase the limit by adding more memory to each task. | diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index bada81f9c21e..477c3e0e1982 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -30,6 +30,7 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.UOE; +import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.external.ExternalInputSpec; import org.apache.druid.msq.input.inline.InlineInputSpec; @@ -80,6 +81,8 @@ public class DataSourcePlan */ private static final Map CONTEXT_MAP_NO_SEGMENT_GRANULARITY = new HashMap<>(); + private static final Logger log = new Logger(DataSourcePlan.class); + static { CONTEXT_MAP_NO_SEGMENT_GRANULARITY.put(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY, null); } @@ -203,21 +206,41 @@ public Optional getSubQueryDefBuilder() return Optional.ofNullable(subQueryDefBuilder); } + /** + * Contains the logic that deduces the join algorithm to be used. Ideally, this should reside while planning the + * native query, however we don't have the resources and the structure in place (when adding this function) to do so. + * Therefore, this is done while planning the MSQ query + * It takes into account the algorithm specified by "sqlJoinAlgorithm" in the query context and the join condition + * that is present in the query. + */ private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgorithm, JoinDataSource joinDataSource) { - + JoinAlgorithm deducedJoinAlgorithm; if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) { - return JoinAlgorithm.BROADCAST; + deducedJoinAlgorithm = JoinAlgorithm.BROADCAST; + } else if (isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis())) { + deducedJoinAlgorithm = JoinAlgorithm.SORT_MERGE; + } else { + deducedJoinAlgorithm = JoinAlgorithm.BROADCAST; } - // preferredJoinAlgorithm would only be sortMerge now - - if (isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis())) { - return JoinAlgorithm.SORT_MERGE; + if (deducedJoinAlgorithm != preferredJoinAlgorithm) { + log.debug( + "User wanted to plan join [%s] as [%s], however the join will be executed as [%s]", + joinDataSource, + preferredJoinAlgorithm.toString(), + deducedJoinAlgorithm.toString() + ); } - return JoinAlgorithm.BROADCAST; + + return deducedJoinAlgorithm; } + /** + * Checks if the join condition on two tables "table1" and "table2" is of the form + * table1.columnA = table2.columnA && table1.columnB = table2.columnB && .... + * sortMerge algorithm can help these types of join conditions + */ private static boolean isConditionEqualityOnLeftAndRightColumns(JoinConditionAnalysis joinConditionAnalysis) { return joinConditionAnalysis.getEquiConditions() diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index 0f4210e7f59a..1434768ab2b3 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -39,12 +39,14 @@ import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQTuningConfig; import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory; import org.apache.druid.msq.test.CounterSnapshotMatcher; import org.apache.druid.msq.test.MSQTestBase; import org.apache.druid.msq.test.MSQTestFileUtils; import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.LookupDataSource; +import org.apache.druid.query.Query; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; @@ -1964,6 +1966,91 @@ public void testSelectRowsGetUntruncatedInReportsByDefault() throws IOException .verifyResults(); } + @Test + public void testJoinUsesDifferentAlgorithm() + { + RowSignature rowSignature = RowSignature.builder().add("cnt", ColumnType.LONG).build(); + + Map queryContext = new HashMap<>(context); + queryContext.put(PlannerContext.CTX_SQL_JOIN_ALGORITHM, JoinAlgorithm.SORT_MERGE.toString()); + + Query expectedQuery; + + expectedQuery = GroupByQuery + .builder() + .setDataSource( + join( + new QueryDataSource( + newScanQueryBuilder() + .dataSource("foo") + .virtualColumns(expressionVirtualColumn("v0", "0", ColumnType.LONG)) + .columns("v0") + .context(defaultScanQueryContext( + queryContext, + RowSignature.builder().add("v0", ColumnType.LONG).build() + )) + .intervals(querySegmentSpec(Intervals.ETERNITY)) + .build() + ), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource("foo") + .setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG)) + .setDimensions( + new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + .setContext(queryContext) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setGranularity(Granularities.ALL) + .build() + + ), + "j0.", + "(floor(100) == \"j0.d0\")", + JoinType.LEFT + ) + ) + .setAggregatorSpecs( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + new SelectorDimFilter("j0.d1", null, null), + "a0" + ) + ) + .setContext(queryContext) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setGranularity(Granularities.ALL) + .build(); + + testSelectQuery() + .setSql( + "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) AS cnt " + + "FROM foo" + ) + .setExpectedRowSignature(rowSignature) + .setExpectedMSQSpec( + MSQSpec + .builder() + .query(expectedQuery) + .columnMappings(new ColumnMappings( + ImmutableList.of( + new ColumnMapping("a0", "cnt") + ) + )) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .build()) + .setQueryContext(queryContext) + .addAdhocReportAssertions( + msqTaskReportPayload -> msqTaskReportPayload.getStages().getStages().stream().noneMatch( + stage -> stage.getStageDefinition().getProcessorFactory().getClass().equals(SortMergeJoinFrameProcessorFactory.class) + ), + "assert the query didn't use sort merge" + ) + .setExpectedResultRows(ImmutableList.of(new Object[]{6L})) + .verifyResults(); + } + @Nonnull private List expectedMultiValueFooRowsGroup() { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 736ec2f430da..79afe97f74d4 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -193,6 +193,7 @@ import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -798,6 +799,7 @@ public abstract class MSQTester> protected Set expectedTombstoneIntervals = null; protected List expectedResultRows = null; protected Matcher expectedValidationErrorMatcher = null; + protected List, String>> adhocReportAssertionAndReasons = new ArrayList<>(); protected Matcher expectedExecutionErrorMatcher = null; protected MSQFault expectedMSQFault = null; protected Class expectedMSQFaultClass = null; @@ -859,6 +861,12 @@ public Builder setExpectedMSQSpec(MSQSpec expectedMSQSpec) return asBuilder(); } + public Builder addAdhocReportAssertions(Predicate predicate, String reason) + { + this.adhocReportAssertionAndReasons.add(Pair.of(predicate, reason)); + return asBuilder(); + } + public Builder setExpectedValidationErrorMatcher(Matcher expectedValidationErrorMatcher) { this.expectedValidationErrorMatcher = expectedValidationErrorMatcher; @@ -1221,6 +1229,11 @@ public void verifyResults() } Assert.assertEquals(expectedTombstoneSegmentIds, tombstoneSegmentIds); } + + for (Pair, String> adhocReportAssertionAndReason : adhocReportAssertionAndReasons) { + Assert.assertTrue(adhocReportAssertionAndReason.rhs, adhocReportAssertionAndReason.lhs.test(reportPayload)); + } + // assert results assertResultsEquals(sql, expectedResultRows, transformedOutputRows); } @@ -1302,6 +1315,10 @@ public Pair, List>> log.info("found row signature %s", payload.getResults().getSignature()); log.info(rows.stream().map(Arrays::toString).collect(Collectors.joining("\n"))); + for (Pair, String> adhocReportAssertionAndReason : adhocReportAssertionAndReasons) { + Assert.assertTrue(adhocReportAssertionAndReason.rhs, adhocReportAssertionAndReason.lhs.test(payload)); + } + final MSQSpec spec = indexingServiceClient.getMSQControllerTask(controllerId).getQuerySpec(); log.info("Found spec: %s", objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(spec)); return new Pair<>(spec, Pair.of(payload.getResults().getSignature(), rows)); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index 104297b4a763..d0c70e7901f7 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -5659,7 +5659,7 @@ public void testJoinWithInputRefCondition() cannotVectorize(); Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); - Query expectedQuery; + Query expectedQuery; if (!NullHandling.sqlCompatible()) { expectedQuery = Druids.newTimeseriesQueryBuilder() From 501265665372d2f217df950efef71dc5b28be487 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 10 Jul 2023 11:08:24 +0530 Subject: [PATCH 10/12] import --- .../src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index 6f51cbe456fa..e5ec0e79e779 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -95,6 +95,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; From d429be7a02fa17ea42c8ae6938a4b9c706b5bba9 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Mon, 10 Jul 2023 11:36:03 +0530 Subject: [PATCH 11/12] select destination --- .../src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index e5ec0e79e779..f1e4565e74e7 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -2088,6 +2088,9 @@ public void testJoinUsesDifferentAlgorithm() new ColumnMapping("a0", "cnt") ) )) + .destination(isDurableStorageDestination() + ? DurableStorageMSQDestination.INSTANCE + : TaskReportMSQDestination.INSTANCE) .tuningConfig(MSQTuningConfig.defaultConfig()) .build()) .setQueryContext(queryContext) From dce72e79bc7e9c858bdb1663edd715327ff3e6c6 Mon Sep 17 00:00:00 2001 From: Laksh Singla Date: Tue, 11 Jul 2023 09:26:57 +0530 Subject: [PATCH 12/12] test --- .../org/apache/druid/msq/exec/MSQSelectTest.java | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index f1e4565e74e7..0d4b3aff2f2d 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -2019,6 +2019,15 @@ public void testSelectRowsGetUntruncatedByDefault() throws IOException @Test public void testJoinUsesDifferentAlgorithm() { + + // This test asserts that the join algorithnm used is a different one from that supplied. In sqlCompatible() mode + // the query gets planned differently, therefore we do use the sortMerge processor. Instead of having separate + // handling, a similar test has been described in CalciteJoinQueryMSQTest, therefore we don't want to repeat that + // here, hence ignoring in sqlCompatible() mode + if (NullHandling.sqlCompatible()) { + return; + } + RowSignature rowSignature = RowSignature.builder().add("cnt", ColumnType.LONG).build(); Map queryContext = new HashMap<>(context); @@ -2096,7 +2105,10 @@ public void testJoinUsesDifferentAlgorithm() .setQueryContext(queryContext) .addAdhocReportAssertions( msqTaskReportPayload -> msqTaskReportPayload.getStages().getStages().stream().noneMatch( - stage -> stage.getStageDefinition().getProcessorFactory().getClass().equals(SortMergeJoinFrameProcessorFactory.class) + stage -> stage.getStageDefinition() + .getProcessorFactory() + .getClass() + .equals(SortMergeJoinFrameProcessorFactory.class) ), "assert the query didn't use sort merge" )