diff --git a/docs/multi-stage-query/reference.md b/docs/multi-stage-query/reference.md index 5bbe935f1eef..08335ff1143f 100644 --- a/docs/multi-stage-query/reference.md +++ b/docs/multi-stage-query/reference.md @@ -234,7 +234,7 @@ The following table lists the context parameters for the MSQ task engine: | `maxNumTasks` | SELECT, INSERT, REPLACE

The maximum total number of tasks to launch, including the controller task. The lowest possible value for this setting is 2: one controller and one worker. All tasks must be able to launch simultaneously. If they cannot, the query returns a `TaskStartTimeout` error code after approximately 10 minutes.

May also be provided as `numTasks`. If both are present, `maxNumTasks` takes priority. | 2 | | `taskAssignment` | SELECT, INSERT, REPLACE

Determines how many tasks to use. Possible values include: | `max` | | `finalizeAggregations` | SELECT, INSERT, REPLACE

Determines the type of aggregation to return. If true, Druid finalizes the results of complex aggregations that directly appear in query results. If false, Druid returns the aggregation's intermediate type rather than finalized type. This parameter is useful during ingestion, where it enables storing sketches directly in Druid tables. For more information about aggregations, see [SQL aggregation functions](../querying/sql-aggregations.md). | true | -| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE

Algorithm to use for JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for sort-merge join. Affects all JOIN operations in the query. See [Joins](#joins) for more details. | `broadcast` | +| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE

Algorithm to use for JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for sort-merge join. Affects all JOIN operations in the query. This is a hint to the MSQ engine and the actual joins in the query may proceed in a different way than specified. See [Joins](#joins) for more details. | `broadcast` | | `rowsInMemory` | INSERT or REPLACE

Maximum number of rows to store in memory at once before flushing to disk during the segment generation process. Ignored for non-INSERT queries. In most cases, use the default value. You may need to override the default if you run into one of the [known issues](./known-issues.md) around memory usage. | 100,000 | | `segmentSortOrder` | INSERT or REPLACE

Normally, Druid sorts rows in individual segments using `__time` first, followed by the [CLUSTERED BY](#clustered-by) clause. When you set `segmentSortOrder`, Druid sorts rows in segments using this column list first, followed by the CLUSTERED BY order.

You provide the column list as comma-separated values or as a JSON array in string form. If your query includes `__time`, then this list must begin with `__time`. For example, consider an INSERT query that uses `CLUSTERED BY country` and has `segmentSortOrder` set to `__time,city`. Within each time chunk, Druid assigns rows to segments based on `country`, and then within each of those segments, Druid sorts those rows by `__time` first, then `city`, then `country`. | empty list | | `maxParseExceptions`| SELECT, INSERT, REPLACE

Maximum number of parse exceptions that are ignored while executing the query before it stops with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value to -1. | 0 | @@ -253,6 +253,12 @@ Joins in multi-stage queries use one of two algorithms based on what you set the If you omit this context parameter, the MSQ task engine uses broadcast since it's the default join algorithm. The context parameter applies to the entire SQL statement, so you can't mix different join algorithms in the same query. +`sqlJoinAlgorithm` is a hint to the planner to execute the join in the specified manner. The planner can decide to ignore +the hint if it deduces that the specified algorithm can be detrimental to the performance of the join beforehand. This intelligence +is very limited as of now, and the `sqlJoinAlgorithm` set would be respected in most cases, therefore the user should set it +appropriately. See the advantages and the drawbacks for the [broadcast](#broadcast) and the [sort-merge](#sort-merge) join to +determine which join to use beforehand. + ### Broadcast The default join algorithm for multi-stage queries is a broadcast hash join, which is similar to how @@ -439,7 +445,7 @@ The following table describes error codes you may encounter in the `multiStageQu | `TooManyInputFiles` | Exceeded the maximum number of input files or segments per worker (10,000 files or segments).

If you encounter this limit, consider adding more workers, or breaking up your query into smaller queries that process fewer files or segments per query. | `numInputFiles`: The total number of input files/segments for the stage.

`maxInputFiles`: The maximum number of input files/segments per worker per stage.

`minNumWorker`: The minimum number of workers required for a successful run. | | `TooManyPartitions` | Exceeded the maximum number of partitions for a stage (25,000 partitions).

This can occur with INSERT or REPLACE statements that generate large numbers of segments, since each segment is associated with a partition. If you encounter this limit, consider breaking up your INSERT or REPLACE statement into smaller statements that process less data per statement. | `maxPartitions`: The limit on partitions which was exceeded | | `TooManyClusteredByColumns` | Exceeded the maximum number of clustering columns for a stage (1,500 columns).

This can occur with `CLUSTERED BY`, `ORDER BY`, or `GROUP BY` with a large number of columns. | `numColumns`: The number of columns requested.

`maxColumns`: The limit on columns which was exceeded.`stage`: The stage number exceeding the limit

| -| `TooManyRowsWithSameKey` | The number of rows for a given key exceeded the maximum number of buffered bytes on both sides of a join. See the [Limits](#limits) table for the specific limit. Only occurs when `sqlJoinAlgorithm` is `sortMerge`. | `key`: The key that had a large number of rows.

`numBytes`: Number of bytes buffered, which may include other keys.

`maxBytes`: Maximum number of bytes buffered. | +| `TooManyRowsWithSameKey` | The number of rows for a given key exceeded the maximum number of buffered bytes on both sides of a join. See the [Limits](#limits) table for the specific limit. Only occurs when join is executed via the sort-merge join algorithm. | `key`: The key that had a large number of rows.

`numBytes`: Number of bytes buffered, which may include other keys.

`maxBytes`: Maximum number of bytes buffered. | | `TooManyColumns` | Exceeded the maximum number of columns for a stage (2,000 columns). | `numColumns`: The number of columns requested.

`maxColumns`: The limit on columns which was exceeded. | | `TooManyWarnings` | Exceeded the maximum allowed number of warnings of a particular type. | `rootErrorCode`: The error code corresponding to the exception that exceeded the required limit.

`maxWarnings`: Maximum number of warnings that are allowed for the corresponding `rootErrorCode`. | | `TooManyWorkers` | Exceeded the maximum number of simultaneously-running workers. See the [Limits](#limits) table for more details. | `workers`: The number of simultaneously running workers that exceeded a hard or soft limit. This may be larger than the number of workers in any one stage if multiple stages are running simultaneously.

`maxWorkers`: The hard or soft limit on workers that was exceeded. If this is lower than the hard limit (1,000 workers), then you can increase the limit by adding more memory to each task. | diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index 95f5eae7bb4d..477c3e0e1982 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -30,6 +30,7 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.UOE; +import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.external.ExternalInputSpec; import org.apache.druid.msq.input.inline.InlineInputSpec; @@ -56,6 +57,7 @@ import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.sql.calcite.external.ExternalDataSource; import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.planner.JoinAlgorithm; @@ -79,6 +81,8 @@ public class DataSourcePlan */ private static final Map CONTEXT_MAP_NO_SEGMENT_GRANULARITY = new HashMap<>(); + private static final Logger log = new Logger(DataSourcePlan.class); + static { CONTEXT_MAP_NO_SEGMENT_GRANULARITY.put(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY, null); } @@ -144,9 +148,13 @@ public static DataSourcePlan forDataSource( broadcast ); } else if (dataSource instanceof JoinDataSource) { - final JoinAlgorithm joinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext); + final JoinAlgorithm preferredJoinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext); + final JoinAlgorithm deducedJoinAlgorithm = deduceJoinAlgorithm( + preferredJoinAlgorithm, + ((JoinDataSource) dataSource) + ); - switch (joinAlgorithm) { + switch (deducedJoinAlgorithm) { case BROADCAST: return forBroadcastHashJoin( queryKit, @@ -171,7 +179,7 @@ public static DataSourcePlan forDataSource( ); default: - throw new UOE("Cannot handle join algorithm [%s]", joinAlgorithm); + throw new UOE("Cannot handle join algorithm [%s]", deducedJoinAlgorithm); } } else { throw new UOE("Cannot handle dataSource [%s]", dataSource); @@ -198,6 +206,48 @@ public Optional getSubQueryDefBuilder() return Optional.ofNullable(subQueryDefBuilder); } + /** + * Contains the logic that deduces the join algorithm to be used. Ideally, this should reside while planning the + * native query, however we don't have the resources and the structure in place (when adding this function) to do so. + * Therefore, this is done while planning the MSQ query + * It takes into account the algorithm specified by "sqlJoinAlgorithm" in the query context and the join condition + * that is present in the query. + */ + private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgorithm, JoinDataSource joinDataSource) + { + JoinAlgorithm deducedJoinAlgorithm; + if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) { + deducedJoinAlgorithm = JoinAlgorithm.BROADCAST; + } else if (isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis())) { + deducedJoinAlgorithm = JoinAlgorithm.SORT_MERGE; + } else { + deducedJoinAlgorithm = JoinAlgorithm.BROADCAST; + } + + if (deducedJoinAlgorithm != preferredJoinAlgorithm) { + log.debug( + "User wanted to plan join [%s] as [%s], however the join will be executed as [%s]", + joinDataSource, + preferredJoinAlgorithm.toString(), + deducedJoinAlgorithm.toString() + ); + } + + return deducedJoinAlgorithm; + } + + /** + * Checks if the join condition on two tables "table1" and "table2" is of the form + * table1.columnA = table2.columnA && table1.columnB = table2.columnB && .... + * sortMerge algorithm can help these types of join conditions + */ + private static boolean isConditionEqualityOnLeftAndRightColumns(JoinConditionAnalysis joinConditionAnalysis) + { + return joinConditionAnalysis.getEquiConditions() + .stream() + .allMatch(equality -> equality.getLeftExpr().isIdentifier()); + } + /** * Whether this datasource must be processed by a single worker. True if, and only if, all inputs are broadcast. */ diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index d751946f24a5..0d4b3aff2f2d 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -41,12 +41,14 @@ import org.apache.druid.msq.indexing.destination.MSQSelectDestination; import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory; import org.apache.druid.msq.test.CounterSnapshotMatcher; import org.apache.druid.msq.test.MSQTestBase; import org.apache.druid.msq.test.MSQTestFileUtils; import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.LookupDataSource; +import org.apache.druid.query.Query; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; @@ -93,6 +95,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -2013,6 +2016,106 @@ public void testSelectRowsGetUntruncatedByDefault() throws IOException .verifyResults(); } + @Test + public void testJoinUsesDifferentAlgorithm() + { + + // This test asserts that the join algorithnm used is a different one from that supplied. In sqlCompatible() mode + // the query gets planned differently, therefore we do use the sortMerge processor. Instead of having separate + // handling, a similar test has been described in CalciteJoinQueryMSQTest, therefore we don't want to repeat that + // here, hence ignoring in sqlCompatible() mode + if (NullHandling.sqlCompatible()) { + return; + } + + RowSignature rowSignature = RowSignature.builder().add("cnt", ColumnType.LONG).build(); + + Map queryContext = new HashMap<>(context); + queryContext.put(PlannerContext.CTX_SQL_JOIN_ALGORITHM, JoinAlgorithm.SORT_MERGE.toString()); + + Query expectedQuery; + + expectedQuery = GroupByQuery + .builder() + .setDataSource( + join( + new QueryDataSource( + newScanQueryBuilder() + .dataSource("foo") + .virtualColumns(expressionVirtualColumn("v0", "0", ColumnType.LONG)) + .columns("v0") + .context(defaultScanQueryContext( + queryContext, + RowSignature.builder().add("v0", ColumnType.LONG).build() + )) + .intervals(querySegmentSpec(Intervals.ETERNITY)) + .build() + ), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource("foo") + .setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG)) + .setDimensions( + new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + .setContext(queryContext) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setGranularity(Granularities.ALL) + .build() + + ), + "j0.", + "(floor(100) == \"j0.d0\")", + JoinType.LEFT + ) + ) + .setAggregatorSpecs( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + new SelectorDimFilter("j0.d1", null, null), + "a0" + ) + ) + .setContext(queryContext) + .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY)) + .setGranularity(Granularities.ALL) + .build(); + + testSelectQuery() + .setSql( + "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) AS cnt " + + "FROM foo" + ) + .setExpectedRowSignature(rowSignature) + .setExpectedMSQSpec( + MSQSpec + .builder() + .query(expectedQuery) + .columnMappings(new ColumnMappings( + ImmutableList.of( + new ColumnMapping("a0", "cnt") + ) + )) + .destination(isDurableStorageDestination() + ? DurableStorageMSQDestination.INSTANCE + : TaskReportMSQDestination.INSTANCE) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .build()) + .setQueryContext(queryContext) + .addAdhocReportAssertions( + msqTaskReportPayload -> msqTaskReportPayload.getStages().getStages().stream().noneMatch( + stage -> stage.getStageDefinition() + .getProcessorFactory() + .getClass() + .equals(SortMergeJoinFrameProcessorFactory.class) + ), + "assert the query didn't use sort merge" + ) + .setExpectedResultRows(ImmutableList.of(new Object[]{6L})) + .verifyResults(); + } + @Nonnull private List expectedMultiValueFooRowsGroup() { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 68965f40bfed..4e00fd657ac5 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -202,6 +202,7 @@ import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -807,6 +808,7 @@ public abstract class MSQTester> protected Set expectedTombstoneIntervals = null; protected List expectedResultRows = null; protected Matcher expectedValidationErrorMatcher = null; + protected List, String>> adhocReportAssertionAndReasons = new ArrayList<>(); protected Matcher expectedExecutionErrorMatcher = null; protected MSQFault expectedMSQFault = null; protected Class expectedMSQFaultClass = null; @@ -868,6 +870,12 @@ public Builder setExpectedMSQSpec(MSQSpec expectedMSQSpec) return asBuilder(); } + public Builder addAdhocReportAssertions(Predicate predicate, String reason) + { + this.adhocReportAssertionAndReasons.add(Pair.of(predicate, reason)); + return asBuilder(); + } + public Builder setExpectedValidationErrorMatcher(Matcher expectedValidationErrorMatcher) { this.expectedValidationErrorMatcher = expectedValidationErrorMatcher; @@ -1230,6 +1238,11 @@ public void verifyResults() } Assert.assertEquals(expectedTombstoneSegmentIds, tombstoneSegmentIds); } + + for (Pair, String> adhocReportAssertionAndReason : adhocReportAssertionAndReasons) { + Assert.assertTrue(adhocReportAssertionAndReason.rhs, adhocReportAssertionAndReason.lhs.test(reportPayload)); + } + // assert results assertResultsEquals(sql, expectedResultRows, transformedOutputRows); } @@ -1340,6 +1353,9 @@ public Pair, List>> log.info("found row signature %s", payload.getResults().getSignature()); log.info(rows.stream().map(Arrays::toString).collect(Collectors.joining("\n"))); + for (Pair, String> adhocReportAssertionAndReason : adhocReportAssertionAndReasons) { + Assert.assertTrue(adhocReportAssertionAndReason.rhs, adhocReportAssertionAndReason.lhs.test(payload)); + } log.info("Found spec: %s", objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(spec)); return new Pair<>(spec, Pair.of(payload.getResults().getSignature(), rows)); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java index c6ba49992163..7bbcc44799bc 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java @@ -39,6 +39,7 @@ import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.ImmutableBitSet; import org.apache.druid.java.util.common.Pair; @@ -92,7 +93,7 @@ public boolean matches(RelOptRuleCall call) // 1) Can handle the join condition as a native join. // 2) Left has a PartialDruidQuery (i.e., is a real query, not top-level UNION ALL). // 3) Right has a PartialDruidQuery (i.e., is a real query, not top-level UNION ALL). - return canHandleCondition(join.getCondition(), join.getLeft().getRowType(), right) + return canHandleCondition(join.getCondition(), join.getLeft().getRowType(), right, join.getCluster().getRexBuilder()) && left.getPartialDruidQuery() != null && right.getPartialDruidQuery() != null; } @@ -116,7 +117,8 @@ public void onMatch(RelOptRuleCall call) ConditionAnalysis conditionAnalysis = analyzeCondition( join.getCondition(), join.getLeft().getRowType(), - right + right, + rexBuilder ).get(); final boolean isLeftDirectAccessPossible = enableLeftScanDirect && (left instanceof DruidQueryRel); @@ -223,9 +225,9 @@ private static RexNode makeNullableIfLiteral(final RexNode rexNode, final RexBui * Returns whether {@link #analyzeCondition} would return something. */ @VisibleForTesting - boolean canHandleCondition(final RexNode condition, final RelDataType leftRowType, DruidRel right) + boolean canHandleCondition(final RexNode condition, final RelDataType leftRowType, DruidRel right, final RexBuilder rexBuilder) { - return analyzeCondition(condition, leftRowType, right).isPresent(); + return analyzeCondition(condition, leftRowType, right, rexBuilder).isPresent(); } /** @@ -235,7 +237,8 @@ boolean canHandleCondition(final RexNode condition, final RelDataType leftRowTyp private Optional analyzeCondition( final RexNode condition, final RelDataType leftRowType, - final DruidRel right + final DruidRel right, + final RexBuilder rexBuilder ) { final List subConditions = decomposeAnd(condition); @@ -266,8 +269,29 @@ private Optional analyzeCondition( continue; } - if (!subCondition.isA(SqlKind.EQUALS)) { - // If it's not EQUALS, it's not supported. + RexNode firstOperand; + RexNode secondOperand; + + if (subCondition.isA(SqlKind.INPUT_REF)) { + firstOperand = rexBuilder.makeLiteral(true); + secondOperand = subCondition; + + if (!SqlTypeName.BOOLEAN_TYPES.contains(secondOperand.getType().getSqlTypeName())) { + plannerContext.setPlanningError( + "SQL requires a join with '%s' condition where the column is of the type %s, that is not supported", + subCondition.getKind(), + secondOperand.getType().getSqlTypeName() + ); + return Optional.empty(); + + } + } else if (subCondition.isA(SqlKind.EQUALS)) { + final List operands = ((RexCall) subCondition).getOperands(); + Preconditions.checkState(operands.size() == 2, "Expected 2 operands, got[%s]", operands.size()); + firstOperand = operands.get(0); + secondOperand = operands.get(1); + } else { + // If it's not EQUALS or a BOOLEAN input ref, it's not supported. plannerContext.setPlanningError( "SQL requires a join with '%s' condition that is not supported.", subCondition.getKind() @@ -275,16 +299,13 @@ private Optional analyzeCondition( return Optional.empty(); } - final List operands = ((RexCall) subCondition).getOperands(); - Preconditions.checkState(operands.size() == 2, "Expected 2 operands, got[%s]", operands.size()); - - if (isLeftExpression(operands.get(0), numLeftFields) && isRightInputRef(operands.get(1), numLeftFields)) { - equalitySubConditions.add(Pair.of(operands.get(0), (RexInputRef) operands.get(1))); - rightColumns.add((RexInputRef) operands.get(1)); - } else if (isRightInputRef(operands.get(0), numLeftFields) - && isLeftExpression(operands.get(1), numLeftFields)) { - equalitySubConditions.add(Pair.of(operands.get(1), (RexInputRef) operands.get(0))); - rightColumns.add((RexInputRef) operands.get(0)); + if (isLeftExpression(firstOperand, numLeftFields) && isRightInputRef(secondOperand, numLeftFields)) { + equalitySubConditions.add(Pair.of(firstOperand, (RexInputRef) secondOperand)); + rightColumns.add((RexInputRef) secondOperand); + } else if (isRightInputRef(firstOperand, numLeftFields) + && isLeftExpression(secondOperand, numLeftFields)) { + equalitySubConditions.add(Pair.of(secondOperand, (RexInputRef) firstOperand)); + rightColumns.add((RexInputRef) firstOperand); } else { // Cannot handle this condition. plannerContext.setPlanningError("SQL is resulting in a join that has unsupported operand types."); @@ -310,7 +331,12 @@ && isLeftExpression(operands.get(1), numLeftFields)) { } } - return Optional.of(new ConditionAnalysis(numLeftFields, equalitySubConditions, literalSubConditions)); + return Optional.of( + new ConditionAnalysis( + numLeftFields, + equalitySubConditions, + literalSubConditions + )); } @VisibleForTesting @@ -341,13 +367,6 @@ static List decomposeAnd(final RexNode condition) private boolean isLeftExpression(final RexNode rexNode, final int numLeftFields) { - if (!plannerContext.getJoinAlgorithm().canHandleLeftExpressions()) { - // Must be INPUT_REF. - if (!rexNode.isA(SqlKind.INPUT_REF)) { - return false; - } - } - return ImmutableBitSet.range(numLeftFields).contains(RelOptUtil.InputFinder.bits(rexNode)); } @@ -375,6 +394,7 @@ static class ConditionAnalysis */ private final List literalSubConditions; + ConditionAnalysis( int numLeftFields, List> equalitySubConditions, diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index db631fe67439..c4ff4a17a3f4 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -5654,4 +5654,128 @@ public void testJoinsWithThreeConditions() ) ); } + + @Test + public void testJoinWithInputRefCondition() + { + cannotVectorize(); + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + + Query expectedQuery; + + if (!NullHandling.sqlCompatible()) { + expectedQuery = Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + GroupByQuery.builder() + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDataSource(new TableDataSource(CalciteTests.DATASOURCE1)) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "1", + ColumnType.LONG + )) + .setDimensions( + new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + .build() + ), + "j0.", + "(floor(100) == \"j0.d0\")", + JoinType.LEFT + ) + ) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + new SelectorDimFilter("j0.d1", null, null) + ) + )) + .context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0")) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(context) + .build(); + + } else { + expectedQuery = Druids.newTimeseriesQueryBuilder() + .dataSource( + join( + join( + new TableDataSource("foo"), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource("foo") + .aggregators( + new CountAggregatorFactory("a0"), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a1"), + not(selector("m1", null, null)), + "a1" + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(context) + .build() + ), + "j0.", + "1", + JoinType.INNER + ), + new QueryDataSource( + GroupByQuery.builder() + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDataSource(new TableDataSource(CalciteTests.DATASOURCE1)) + .setVirtualColumns(expressionVirtualColumn( + "v0", + "1", + ColumnType.LONG + )) + .setDimensions( + new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT), + new DefaultDimensionSpec("v0", "d1", ColumnType.LONG) + ) + .build() + ), + "_j0.", + "(floor(100) == \"_j0.d0\")", + JoinType.LEFT + ) + ) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + or( + new SelectorDimFilter("j0.a0", "0", null), + and( + selector("_j0.d1", null, null), + expressionFilter("(\"j0.a1\" >= \"j0.a0\")") + ) + + ) + ) + )) + .context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0")) + .intervals(querySegmentSpec(Filtration.eternity())) + .context(context) + .build(); + + } + + testQuery( + "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) " + + "FROM foo", + context, + ImmutableList.of(expectedQuery), + ImmutableList.of( + new Object[]{6L} + ) + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java index 41c6895dff25..e531580162ee 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java @@ -84,7 +84,8 @@ public void test_canHandleCondition_leftEqRight() rexBuilder.makeInputRef(joinType, 1) ), leftType, - null + null, + rexBuilder ) ); } @@ -104,7 +105,8 @@ public void test_canHandleCondition_leftFnEqRight() rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1) ), leftType, - null + null, + rexBuilder ) ); } @@ -124,7 +126,8 @@ public void test_canHandleCondition_leftEqRightFn() ) ), leftType, - null + null, + rexBuilder ) ); } @@ -140,7 +143,8 @@ public void test_canHandleCondition_leftEqLeft() rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0) ), leftType, - null + null, + rexBuilder ) ); } @@ -156,7 +160,8 @@ public void test_canHandleCondition_rightEqRight() rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1) ), leftType, - null + null, + rexBuilder ) ); } @@ -168,7 +173,8 @@ public void test_canHandleCondition_true() druidJoinRule.canHandleCondition( rexBuilder.makeLiteral(true), leftType, - null + null, + rexBuilder ) ); } @@ -180,7 +186,8 @@ public void test_canHandleCondition_false() druidJoinRule.canHandleCondition( rexBuilder.makeLiteral(false), leftType, - null + null, + rexBuilder ) ); }