diff --git a/docs/multi-stage-query/reference.md b/docs/multi-stage-query/reference.md
index 5bbe935f1eef..08335ff1143f 100644
--- a/docs/multi-stage-query/reference.md
+++ b/docs/multi-stage-query/reference.md
@@ -234,7 +234,7 @@ The following table lists the context parameters for the MSQ task engine:
| `maxNumTasks` | SELECT, INSERT, REPLACE
The maximum total number of tasks to launch, including the controller task. The lowest possible value for this setting is 2: one controller and one worker. All tasks must be able to launch simultaneously. If they cannot, the query returns a `TaskStartTimeout` error code after approximately 10 minutes.
May also be provided as `numTasks`. If both are present, `maxNumTasks` takes priority. | 2 |
| `taskAssignment` | SELECT, INSERT, REPLACE
Determines how many tasks to use. Possible values include:
`max`: Uses as many tasks as possible, up to `maxNumTasks`.
`auto`: When file sizes can be determined through directory listing (for example: local files, S3, GCS, HDFS) uses as few tasks as possible without exceeding 512 MiB or 10,000 files per task, unless exceeding these limits is necessary to stay within `maxNumTasks`. When calculating the size of files, the weighted size is used, which considers the file format and compression format used if any. When file sizes cannot be determined through directory listing (for example: http), behaves the same as `max`.
Determines the type of aggregation to return. If true, Druid finalizes the results of complex aggregations that directly appear in query results. If false, Druid returns the aggregation's intermediate type rather than finalized type. This parameter is useful during ingestion, where it enables storing sketches directly in Druid tables. For more information about aggregations, see [SQL aggregation functions](../querying/sql-aggregations.md). | true |
-| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE
Algorithm to use for JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for sort-merge join. Affects all JOIN operations in the query. See [Joins](#joins) for more details. | `broadcast` |
+| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE
Algorithm to use for JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for sort-merge join. Affects all JOIN operations in the query. This is a hint to the MSQ engine and the actual joins in the query may proceed in a different way than specified. See [Joins](#joins) for more details. | `broadcast` |
| `rowsInMemory` | INSERT or REPLACE
Maximum number of rows to store in memory at once before flushing to disk during the segment generation process. Ignored for non-INSERT queries. In most cases, use the default value. You may need to override the default if you run into one of the [known issues](./known-issues.md) around memory usage. | 100,000 |
| `segmentSortOrder` | INSERT or REPLACE
Normally, Druid sorts rows in individual segments using `__time` first, followed by the [CLUSTERED BY](#clustered-by) clause. When you set `segmentSortOrder`, Druid sorts rows in segments using this column list first, followed by the CLUSTERED BY order.
You provide the column list as comma-separated values or as a JSON array in string form. If your query includes `__time`, then this list must begin with `__time`. For example, consider an INSERT query that uses `CLUSTERED BY country` and has `segmentSortOrder` set to `__time,city`. Within each time chunk, Druid assigns rows to segments based on `country`, and then within each of those segments, Druid sorts those rows by `__time` first, then `city`, then `country`. | empty list |
| `maxParseExceptions`| SELECT, INSERT, REPLACE
Maximum number of parse exceptions that are ignored while executing the query before it stops with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value to -1. | 0 |
@@ -253,6 +253,12 @@ Joins in multi-stage queries use one of two algorithms based on what you set the
If you omit this context parameter, the MSQ task engine uses broadcast since it's the default join algorithm. The context parameter applies to the entire SQL statement, so you can't mix different
join algorithms in the same query.
+`sqlJoinAlgorithm` is a hint to the planner to execute the join in the specified manner. The planner can decide to ignore
+the hint if it deduces that the specified algorithm can be detrimental to the performance of the join beforehand. This intelligence
+is very limited as of now, and the `sqlJoinAlgorithm` set would be respected in most cases, therefore the user should set it
+appropriately. See the advantages and the drawbacks for the [broadcast](#broadcast) and the [sort-merge](#sort-merge) join to
+determine which join to use beforehand.
+
### Broadcast
The default join algorithm for multi-stage queries is a broadcast hash join, which is similar to how
@@ -439,7 +445,7 @@ The following table describes error codes you may encounter in the `multiStageQu
| `TooManyInputFiles` | Exceeded the maximum number of input files or segments per worker (10,000 files or segments).
If you encounter this limit, consider adding more workers, or breaking up your query into smaller queries that process fewer files or segments per query. | `numInputFiles`: The total number of input files/segments for the stage.
`maxInputFiles`: The maximum number of input files/segments per worker per stage.
`minNumWorker`: The minimum number of workers required for a successful run. |
| `TooManyPartitions` | Exceeded the maximum number of partitions for a stage (25,000 partitions).
This can occur with INSERT or REPLACE statements that generate large numbers of segments, since each segment is associated with a partition. If you encounter this limit, consider breaking up your INSERT or REPLACE statement into smaller statements that process less data per statement. | `maxPartitions`: The limit on partitions which was exceeded |
| `TooManyClusteredByColumns` | Exceeded the maximum number of clustering columns for a stage (1,500 columns).
This can occur with `CLUSTERED BY`, `ORDER BY`, or `GROUP BY` with a large number of columns. | `numColumns`: The number of columns requested.
`maxColumns`: The limit on columns which was exceeded.`stage`: The stage number exceeding the limit
|
-| `TooManyRowsWithSameKey` | The number of rows for a given key exceeded the maximum number of buffered bytes on both sides of a join. See the [Limits](#limits) table for the specific limit. Only occurs when `sqlJoinAlgorithm` is `sortMerge`. | `key`: The key that had a large number of rows.
`numBytes`: Number of bytes buffered, which may include other keys.
`maxBytes`: Maximum number of bytes buffered. |
+| `TooManyRowsWithSameKey` | The number of rows for a given key exceeded the maximum number of buffered bytes on both sides of a join. See the [Limits](#limits) table for the specific limit. Only occurs when join is executed via the sort-merge join algorithm. | `key`: The key that had a large number of rows.
`numBytes`: Number of bytes buffered, which may include other keys.
`maxBytes`: Maximum number of bytes buffered. |
| `TooManyColumns` | Exceeded the maximum number of columns for a stage (2,000 columns). | `numColumns`: The number of columns requested.
`maxColumns`: The limit on columns which was exceeded. |
| `TooManyWarnings` | Exceeded the maximum allowed number of warnings of a particular type. | `rootErrorCode`: The error code corresponding to the exception that exceeded the required limit.
`maxWarnings`: Maximum number of warnings that are allowed for the corresponding `rootErrorCode`. |
| `TooManyWorkers` | Exceeded the maximum number of simultaneously-running workers. See the [Limits](#limits) table for more details. | `workers`: The number of simultaneously running workers that exceeded a hard or soft limit. This may be larger than the number of workers in any one stage if multiple stages are running simultaneously.
`maxWorkers`: The hard or soft limit on workers that was exceeded. If this is lower than the hard limit (1,000 workers), then you can increase the limit by adding more memory to each task. |
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
index 95f5eae7bb4d..477c3e0e1982 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
@@ -30,6 +30,7 @@
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.UOE;
+import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.input.InputSpec;
import org.apache.druid.msq.input.external.ExternalInputSpec;
import org.apache.druid.msq.input.inline.InlineInputSpec;
@@ -56,6 +57,7 @@
import org.apache.druid.query.spec.QuerySegmentSpec;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.sql.calcite.external.ExternalDataSource;
import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
@@ -79,6 +81,8 @@ public class DataSourcePlan
*/
private static final Map CONTEXT_MAP_NO_SEGMENT_GRANULARITY = new HashMap<>();
+ private static final Logger log = new Logger(DataSourcePlan.class);
+
static {
CONTEXT_MAP_NO_SEGMENT_GRANULARITY.put(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY, null);
}
@@ -144,9 +148,13 @@ public static DataSourcePlan forDataSource(
broadcast
);
} else if (dataSource instanceof JoinDataSource) {
- final JoinAlgorithm joinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext);
+ final JoinAlgorithm preferredJoinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext);
+ final JoinAlgorithm deducedJoinAlgorithm = deduceJoinAlgorithm(
+ preferredJoinAlgorithm,
+ ((JoinDataSource) dataSource)
+ );
- switch (joinAlgorithm) {
+ switch (deducedJoinAlgorithm) {
case BROADCAST:
return forBroadcastHashJoin(
queryKit,
@@ -171,7 +179,7 @@ public static DataSourcePlan forDataSource(
);
default:
- throw new UOE("Cannot handle join algorithm [%s]", joinAlgorithm);
+ throw new UOE("Cannot handle join algorithm [%s]", deducedJoinAlgorithm);
}
} else {
throw new UOE("Cannot handle dataSource [%s]", dataSource);
@@ -198,6 +206,48 @@ public Optional getSubQueryDefBuilder()
return Optional.ofNullable(subQueryDefBuilder);
}
+ /**
+ * Contains the logic that deduces the join algorithm to be used. Ideally, this should reside while planning the
+ * native query, however we don't have the resources and the structure in place (when adding this function) to do so.
+ * Therefore, this is done while planning the MSQ query
+ * It takes into account the algorithm specified by "sqlJoinAlgorithm" in the query context and the join condition
+ * that is present in the query.
+ */
+ private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgorithm, JoinDataSource joinDataSource)
+ {
+ JoinAlgorithm deducedJoinAlgorithm;
+ if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) {
+ deducedJoinAlgorithm = JoinAlgorithm.BROADCAST;
+ } else if (isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis())) {
+ deducedJoinAlgorithm = JoinAlgorithm.SORT_MERGE;
+ } else {
+ deducedJoinAlgorithm = JoinAlgorithm.BROADCAST;
+ }
+
+ if (deducedJoinAlgorithm != preferredJoinAlgorithm) {
+ log.debug(
+ "User wanted to plan join [%s] as [%s], however the join will be executed as [%s]",
+ joinDataSource,
+ preferredJoinAlgorithm.toString(),
+ deducedJoinAlgorithm.toString()
+ );
+ }
+
+ return deducedJoinAlgorithm;
+ }
+
+ /**
+ * Checks if the join condition on two tables "table1" and "table2" is of the form
+ * table1.columnA = table2.columnA && table1.columnB = table2.columnB && ....
+ * sortMerge algorithm can help these types of join conditions
+ */
+ private static boolean isConditionEqualityOnLeftAndRightColumns(JoinConditionAnalysis joinConditionAnalysis)
+ {
+ return joinConditionAnalysis.getEquiConditions()
+ .stream()
+ .allMatch(equality -> equality.getLeftExpr().isIdentifier());
+ }
+
/**
* Whether this datasource must be processed by a single worker. True if, and only if, all inputs are broadcast.
*/
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
index d751946f24a5..0d4b3aff2f2d 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
@@ -41,12 +41,14 @@
import org.apache.druid.msq.indexing.destination.MSQSelectDestination;
import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
import org.apache.druid.msq.indexing.report.MSQResultsReport;
+import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory;
import org.apache.druid.msq.test.CounterSnapshotMatcher;
import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.msq.test.MSQTestFileUtils;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.LookupDataSource;
+import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
@@ -93,6 +95,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@@ -2013,6 +2016,106 @@ public void testSelectRowsGetUntruncatedByDefault() throws IOException
.verifyResults();
}
+ @Test
+ public void testJoinUsesDifferentAlgorithm()
+ {
+
+ // This test asserts that the join algorithnm used is a different one from that supplied. In sqlCompatible() mode
+ // the query gets planned differently, therefore we do use the sortMerge processor. Instead of having separate
+ // handling, a similar test has been described in CalciteJoinQueryMSQTest, therefore we don't want to repeat that
+ // here, hence ignoring in sqlCompatible() mode
+ if (NullHandling.sqlCompatible()) {
+ return;
+ }
+
+ RowSignature rowSignature = RowSignature.builder().add("cnt", ColumnType.LONG).build();
+
+ Map queryContext = new HashMap<>(context);
+ queryContext.put(PlannerContext.CTX_SQL_JOIN_ALGORITHM, JoinAlgorithm.SORT_MERGE.toString());
+
+ Query> expectedQuery;
+
+ expectedQuery = GroupByQuery
+ .builder()
+ .setDataSource(
+ join(
+ new QueryDataSource(
+ newScanQueryBuilder()
+ .dataSource("foo")
+ .virtualColumns(expressionVirtualColumn("v0", "0", ColumnType.LONG))
+ .columns("v0")
+ .context(defaultScanQueryContext(
+ queryContext,
+ RowSignature.builder().add("v0", ColumnType.LONG).build()
+ ))
+ .intervals(querySegmentSpec(Intervals.ETERNITY))
+ .build()
+ ),
+ new QueryDataSource(
+ GroupByQuery.builder()
+ .setDataSource("foo")
+ .setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
+ .setDimensions(
+ new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+ new DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+ )
+ .setContext(queryContext)
+ .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
+ .setGranularity(Granularities.ALL)
+ .build()
+
+ ),
+ "j0.",
+ "(floor(100) == \"j0.d0\")",
+ JoinType.LEFT
+ )
+ )
+ .setAggregatorSpecs(
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0"),
+ new SelectorDimFilter("j0.d1", null, null),
+ "a0"
+ )
+ )
+ .setContext(queryContext)
+ .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
+ .setGranularity(Granularities.ALL)
+ .build();
+
+ testSelectQuery()
+ .setSql(
+ "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) AS cnt "
+ + "FROM foo"
+ )
+ .setExpectedRowSignature(rowSignature)
+ .setExpectedMSQSpec(
+ MSQSpec
+ .builder()
+ .query(expectedQuery)
+ .columnMappings(new ColumnMappings(
+ ImmutableList.of(
+ new ColumnMapping("a0", "cnt")
+ )
+ ))
+ .destination(isDurableStorageDestination()
+ ? DurableStorageMSQDestination.INSTANCE
+ : TaskReportMSQDestination.INSTANCE)
+ .tuningConfig(MSQTuningConfig.defaultConfig())
+ .build())
+ .setQueryContext(queryContext)
+ .addAdhocReportAssertions(
+ msqTaskReportPayload -> msqTaskReportPayload.getStages().getStages().stream().noneMatch(
+ stage -> stage.getStageDefinition()
+ .getProcessorFactory()
+ .getClass()
+ .equals(SortMergeJoinFrameProcessorFactory.class)
+ ),
+ "assert the query didn't use sort merge"
+ )
+ .setExpectedResultRows(ImmutableList.of(new Object[]{6L}))
+ .verifyResults();
+ }
+
@Nonnull
private List