diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index 831c9b139d3d..e6ddb4d723dc 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -43,7 +43,6 @@ import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageDefinitionBuilder; import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory; -import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.DataSource; import org.apache.druid.query.FilteredDataSource; import org.apache.druid.query.InlineDataSource; @@ -424,21 +423,11 @@ private static DataSourcePlan forQuery( @Nullable final QueryContext parentContext ) { - // check if parentContext has a window operator - final Map windowShuffleMap = new HashMap<>(); - if (parentContext != null && parentContext.containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)) { - windowShuffleMap.put(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL, parentContext.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)); - } final QueryDefinition subQueryDef = queryKit.makeQueryDefinition( queryId, // Subqueries ignore SQL_INSERT_SEGMENT_GRANULARITY, even if set in the context. It's only used for the // outermost query, and setting it for the subquery makes us erroneously add bucketing where it doesn't belong. - windowShuffleMap.isEmpty() - ? dataSource.getQuery() - .withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY) - : dataSource.getQuery() - .withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY) - .withOverriddenContext(windowShuffleMap), + dataSource.getQuery().withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY), queryKit, ShuffleSpecFactories.globalSortWithMaxPartitionCount(maxWorkerCount), maxWorkerCount, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java index 23e13f176d7b..a814640f7042 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryKit.java @@ -20,7 +20,6 @@ package org.apache.druid.msq.querykit; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.collect.ImmutableMap; import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.KeyColumn; import org.apache.druid.frame.key.KeyOrder; @@ -88,17 +87,6 @@ public QueryDefinition makeQueryDefinition( List> operatorList = getOperatorListFromQuery(originalQuery); log.info("Created operatorList with operator factories: [%s]", operatorList); - ShuffleSpec nextShuffleSpec = findShuffleSpecForNextWindow(operatorList.get(0), maxWorkerCount); - // add this shuffle spec to the last stage of the inner query - - final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId); - if (nextShuffleSpec != null) { - final ClusterBy windowClusterBy = nextShuffleSpec.clusterBy(); - originalQuery = (WindowOperatorQuery) originalQuery.withOverriddenContext(ImmutableMap.of( - MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL, - windowClusterBy - )); - } final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource( queryKit, queryId, @@ -112,7 +100,8 @@ public QueryDefinition makeQueryDefinition( false ); - dataSourcePlan.getSubQueryDefBuilder().ifPresent(queryDefBuilder::addAll); + ShuffleSpec nextShuffleSpec = findShuffleSpecForNextWindow(operatorList.get(0), maxWorkerCount); + final QueryDefinitionBuilder queryDefBuilder = makeQueryDefinitionBuilder(queryId, dataSourcePlan, nextShuffleSpec); final int firstStageNumber = Math.max(minStageNumber, queryDefBuilder.getNextStageNumber()); final WindowOperatorQuery queryToRun = (WindowOperatorQuery) originalQuery.withDataSource(dataSourcePlan.getNewDataSource()); @@ -309,12 +298,16 @@ private ShuffleSpec findShuffleSpecForNextWindow(List operatorF } } - if (partition == null || partition.getPartitionColumns().isEmpty()) { + if (partition == null) { // If operatorFactories doesn't have any partitioning factory, then we should keep the shuffle spec from previous stage. // This indicates that we already have the data partitioned correctly, and hence we don't need to do any shuffling. return null; } + if (partition.getPartitionColumns().isEmpty()) { + return MixShuffleSpec.instance(); + } + List keyColsOfWindow = new ArrayList<>(); for (String partitionColumn : partition.getPartitionColumns()) { KeyColumn kc; @@ -328,4 +321,29 @@ private ShuffleSpec findShuffleSpecForNextWindow(List operatorF return new HashShuffleSpec(new ClusterBy(keyColsOfWindow, 0), maxWorkerCount); } + + /** + * Override the shuffle spec of the last stage based on the shuffling required by the first window stage. + * @param queryId + * @param dataSourcePlan + * @param shuffleSpec + * @return + */ + private QueryDefinitionBuilder makeQueryDefinitionBuilder(String queryId, DataSourcePlan dataSourcePlan, ShuffleSpec shuffleSpec) + { + final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId); + int previousStageNumber = dataSourcePlan.getSubQueryDefBuilder().get().build().getFinalStageDefinition().getStageNumber(); + for (final StageDefinition stageDef : dataSourcePlan.getSubQueryDefBuilder().get().build().getStageDefinitions()) { + if (stageDef.getStageNumber() == previousStageNumber) { + RowSignature rowSignature = QueryKitUtils.sortableSignature( + stageDef.getSignature(), + shuffleSpec.clusterBy().getColumns() + ); + queryDefBuilder.add(StageDefinition.builder(stageDef).shuffleSpec(shuffleSpec).signature(rowSignature)); + } else { + queryDefBuilder.add(StageDefinition.builder(stageDef)); + } + } + return queryDefBuilder; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java index eb9953402bad..2bf77fd8d0cf 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java @@ -28,7 +28,6 @@ import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.msq.input.stage.StageInputSpec; -import org.apache.druid.msq.kernel.HashShuffleSpec; import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinitionBuilder; import org.apache.druid.msq.kernel.ShuffleSpec; @@ -39,7 +38,6 @@ import org.apache.druid.msq.querykit.ShuffleSpecFactories; import org.apache.druid.msq.querykit.ShuffleSpecFactory; import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory; -import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.DimensionComparisonUtils; import org.apache.druid.query.Query; import org.apache.druid.query.dimension.DimensionSpec; @@ -168,104 +166,40 @@ public QueryDefinition makeQueryDefinition( partitionBoost ); - final ShuffleSpec nextShuffleWindowSpec = getShuffleSpecForNextWindow(originalQuery, maxWorkerCount); + queryDefBuilder.add( + StageDefinition.builder(firstStageNumber + 1) + .inputs(new StageInputSpec(firstStageNumber)) + .signature(resultSignature) + .maxWorkerCount(maxWorkerCount) + .shuffleSpec( + shuffleSpecFactoryPostAggregation != null + ? shuffleSpecFactoryPostAggregation.build(resultClusterBy, false) + : null + ) + .processorFactory(new GroupByPostShuffleFrameProcessorFactory(queryToRun)) + ); - if (nextShuffleWindowSpec == null) { + if (doLimitOrOffset) { + final ShuffleSpec finalShuffleSpec = resultShuffleSpecFactory.build(resultClusterBy, false); + final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec(); queryDefBuilder.add( - StageDefinition.builder(firstStageNumber + 1) - .inputs(new StageInputSpec(firstStageNumber)) + StageDefinition.builder(firstStageNumber + 2) + .inputs(new StageInputSpec(firstStageNumber + 1)) .signature(resultSignature) - .maxWorkerCount(maxWorkerCount) - .shuffleSpec( - shuffleSpecFactoryPostAggregation != null - ? shuffleSpecFactoryPostAggregation.build(resultClusterBy, false) - : null - ) - .processorFactory(new GroupByPostShuffleFrameProcessorFactory(queryToRun)) - ); - - if (doLimitOrOffset) { - final ShuffleSpec finalShuffleSpec = resultShuffleSpecFactory.build(resultClusterBy, false); - final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec(); - queryDefBuilder.add( - StageDefinition.builder(firstStageNumber + 2) - .inputs(new StageInputSpec(firstStageNumber + 1)) - .signature(resultSignature) - .maxWorkerCount(1) - .shuffleSpec(finalShuffleSpec) - .processorFactory( - new OffsetLimitFrameProcessorFactory( - limitSpec.getOffset(), - limitSpec.isLimited() ? (long) limitSpec.getLimit() : null - ) - ) - ); - } - } else { - final RowSignature stageSignature; - // sort the signature to make sure the prefix is aligned - stageSignature = QueryKitUtils.sortableSignature( - resultSignature, - nextShuffleWindowSpec.clusterBy().getColumns() - ); - - - queryDefBuilder.add( - StageDefinition.builder(firstStageNumber + 1) - .inputs(new StageInputSpec(firstStageNumber)) - .signature(stageSignature) - .maxWorkerCount(maxWorkerCount) - .shuffleSpec(doLimitOrOffset ? (shuffleSpecFactoryPostAggregation != null - ? shuffleSpecFactoryPostAggregation.build( - resultClusterBy, - false + .maxWorkerCount(1) + .shuffleSpec(finalShuffleSpec) + .processorFactory( + new OffsetLimitFrameProcessorFactory( + limitSpec.getOffset(), + limitSpec.isLimited() ? (long) limitSpec.getLimit() : null + ) ) - : null) : nextShuffleWindowSpec) - .processorFactory(new GroupByPostShuffleFrameProcessorFactory(queryToRun)) ); - if (doLimitOrOffset) { - final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec(); - final ShuffleSpec finalShuffleSpec = resultShuffleSpecFactory.build(resultClusterBy, false); - queryDefBuilder.add( - StageDefinition.builder(firstStageNumber + 2) - .inputs(new StageInputSpec(firstStageNumber + 1)) - .signature(resultSignature) - .maxWorkerCount(1) - .shuffleSpec(finalShuffleSpec) - .processorFactory( - new OffsetLimitFrameProcessorFactory( - limitSpec.getOffset(), - limitSpec.isLimited() ? (long) limitSpec.getLimit() : null - ) - ) - ); - } } return queryDefBuilder.build(); } - /** - * @param originalQuery which has the context for the next shuffle if that's present in the next window - * @param maxWorkerCount max worker count - * @return shuffle spec without partition boosting for next stage, null if there is no partition by for next window - */ - private ShuffleSpec getShuffleSpecForNextWindow(GroupByQuery originalQuery, int maxWorkerCount) - { - final ShuffleSpec nextShuffleWindowSpec; - if (originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)) { - final ClusterBy windowClusterBy = (ClusterBy) originalQuery.getContext() - .get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL); - nextShuffleWindowSpec = new HashShuffleSpec( - windowClusterBy, - maxWorkerCount - ); - } else { - nextShuffleWindowSpec = null; - } - return nextShuffleWindowSpec; - } - /** * Intermediate signature of a particular {@link GroupByQuery}. Does not include post-aggregators, and all * aggregations are nonfinalized. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java index 48a17a9e84e2..2a90616fe1db 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java @@ -37,7 +37,6 @@ import org.apache.druid.msq.querykit.ShuffleSpecFactories; import org.apache.druid.msq.querykit.ShuffleSpecFactory; import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory; -import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.Query; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.segment.column.ColumnType; @@ -129,26 +128,8 @@ public QueryDefinition makeQueryDefinition( ); } - // Update partition by of next window - final RowSignature signatureSoFar = signatureBuilder.build(); - boolean addShuffle = true; - if (originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)) { - final ClusterBy windowClusterBy = (ClusterBy) originalQuery.getContext() - .get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL); - for (KeyColumn c : windowClusterBy.getColumns()) { - if (!signatureSoFar.contains(c.columnName())) { - addShuffle = false; - break; - } - } - if (addShuffle) { - clusterByColumns.addAll(windowClusterBy.getColumns()); - } - } else { - // Add partition boosting column. - clusterByColumns.add(new KeyColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, KeyOrder.ASCENDING)); - signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG); - } + clusterByColumns.add(new KeyColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, KeyOrder.ASCENDING)); + signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG); final ClusterBy clusterBy = QueryKitUtils.clusterByWithSegmentGranularity(new ClusterBy(clusterByColumns, 0), segmentGranularity); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java index cbf9a1a905f0..aa12e0f093b3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java @@ -167,8 +167,6 @@ public class MultiStageQueryContext public static final String CTX_ARRAY_INGEST_MODE = "arrayIngestMode"; public static final ArrayIngestMode DEFAULT_ARRAY_INGEST_MODE = ArrayIngestMode.MVD; - public static final String NEXT_WINDOW_SHUFFLE_COL = "__windowShuffleCol"; - public static final String MAX_ROWS_MATERIALIZED_IN_WINDOW = "maxRowsMaterializedInWindow"; public static final String CTX_SKIP_TYPE_VERIFICATION = "skipTypeVerification"; diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQWindowTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQWindowTest.java index 5cc84ac6ee61..32b840bb64e2 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQWindowTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQWindowTest.java @@ -37,7 +37,9 @@ import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnnestDataSource; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.groupby.GroupByQuery; @@ -48,6 +50,7 @@ import org.apache.druid.query.operator.window.WindowFrame; import org.apache.druid.query.operator.window.WindowFramedAggregateProcessor; import org.apache.druid.query.operator.window.WindowOperatorFactory; +import org.apache.druid.query.operator.window.ranking.WindowRowNumberProcessor; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.spec.LegacySegmentSpec; import org.apache.druid.segment.column.ColumnType; @@ -65,6 +68,8 @@ import java.util.Arrays; import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -1842,7 +1847,7 @@ public void testSelectWithWikipediaEmptyOverWithCustomContext(String contextName .setSql( "select cityName, added, SUM(added) OVER () cc from wikipedia") .setQueryContext(customContext) - .setExpectedMSQFault(new TooManyRowsInAWindowFault(15676, 200)) + .setExpectedMSQFault(new TooManyRowsInAWindowFault(15921, 200)) .verifyResults(); } @@ -2048,4 +2053,235 @@ public void testReplaceGroupByOnWikipedia(String contextName, Map context) + { + final Map multipleWorkerContext = new HashMap<>(context); + multipleWorkerContext.put(MultiStageQueryContext.CTX_MAX_NUM_TASKS, 5); + + final RowSignature rowSignature = RowSignature.builder() + .add("countryName", ColumnType.STRING) + .add("cityName", ColumnType.STRING) + .add("channel", ColumnType.STRING) + .add("c1", ColumnType.LONG) + .add("c2", ColumnType.LONG) + .build(); + + final Map contextWithRowSignature = + ImmutableMap.builder() + .putAll(multipleWorkerContext) + .put( + DruidQuery.CTX_SCAN_SIGNATURE, + "[{\"name\":\"d0\",\"type\":\"STRING\"},{\"name\":\"d1\",\"type\":\"STRING\"},{\"name\":\"d2\",\"type\":\"STRING\"},{\"name\":\"w0\",\"type\":\"LONG\"},{\"name\":\"w1\",\"type\":\"LONG\"}]" + ) + .build(); + + final GroupByQuery groupByQuery = GroupByQuery.builder() + .setDataSource(CalciteTests.WIKIPEDIA) + .setInterval(querySegmentSpec(Filtration + .eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions( + new DefaultDimensionSpec( + "countryName", + "d0", + ColumnType.STRING + ), + new DefaultDimensionSpec( + "cityName", + "d1", + ColumnType.STRING + ), + new DefaultDimensionSpec( + "channel", + "d2", + ColumnType.STRING + ) + )) + .setDimFilter(in("countryName", ImmutableList.of("Austria", "Republic of Korea"))) + .setContext(multipleWorkerContext) + .build(); + + final AggregatorFactory[] aggs = { + new FilteredAggregatorFactory(new CountAggregatorFactory("w1"), notNull("d2"), "w1") + }; + + final WindowOperatorQuery windowQuery = new WindowOperatorQuery( + new QueryDataSource(groupByQuery), + new LegacySegmentSpec(Intervals.ETERNITY), + multipleWorkerContext, + RowSignature.builder() + .add("d0", ColumnType.STRING) + .add("d1", ColumnType.STRING) + .add("d2", ColumnType.STRING) + .add("w0", ColumnType.LONG) + .add("w1", ColumnType.LONG).build(), + ImmutableList.of( + new NaiveSortOperatorFactory(ImmutableList.of(ColumnWithDirection.ascending("d0"), ColumnWithDirection.ascending("d1"), ColumnWithDirection.ascending("d2"))), + new NaivePartitioningOperatorFactory(Collections.emptyList()), + new WindowOperatorFactory(new WindowRowNumberProcessor("w0")), + new NaiveSortOperatorFactory(ImmutableList.of(ColumnWithDirection.ascending("d1"), ColumnWithDirection.ascending("d0"), ColumnWithDirection.ascending("d2"))), + new NaivePartitioningOperatorFactory(Collections.singletonList("d1")), + new WindowOperatorFactory(new WindowFramedAggregateProcessor(WindowFrame.forOrderBy("d0", "d1", "d2"), aggs)) + ), + ImmutableList.of() + ); + + final ScanQuery scanQuery = Druids.newScanQueryBuilder() + .dataSource(new QueryDataSource(windowQuery)) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("d0", "d1", "d2", "w0", "w1") + .orderBy( + ImmutableList.of( + new ScanQuery.OrderBy("d0", ScanQuery.Order.ASCENDING), + new ScanQuery.OrderBy("d1", ScanQuery.Order.ASCENDING), + new ScanQuery.OrderBy("d2", ScanQuery.Order.ASCENDING) + ) + ) + .columnTypes(ColumnType.STRING, ColumnType.STRING, ColumnType.STRING, ColumnType.LONG, ColumnType.LONG) + .limit(Long.MAX_VALUE) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(contextWithRowSignature) + .build(); + + final String sql = "select countryName, cityName, channel, \n" + + "row_number() over (order by countryName, cityName, channel) as c1, \n" + + "count(channel) over (partition by cityName order by countryName, cityName, channel) as c2\n" + + "from wikipedia\n" + + "where countryName in ('Austria', 'Republic of Korea')\n" + + "group by countryName, cityName, channel " + + "order by countryName, cityName, channel"; + + final String nullValue = NullHandling.sqlCompatible() ? null : ""; + + testSelectQuery() + .setSql(sql) + .setExpectedMSQSpec(MSQSpec.builder() + .query(scanQuery) + .columnMappings( + new ColumnMappings(ImmutableList.of( + new ColumnMapping("d0", "countryName"), + new ColumnMapping("d1", "cityName"), + new ColumnMapping("d2", "channel"), + new ColumnMapping("w0", "c1"), + new ColumnMapping("w1", "c2") + ) + )) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .build()) + .setExpectedRowSignature(rowSignature) + .setExpectedResultRows( + ImmutableList.of( + new Object[]{"Austria", nullValue, "#de.wikipedia", 1L, 1L}, + new Object[]{"Austria", "Horsching", "#de.wikipedia", 2L, 1L}, + new Object[]{"Austria", "Vienna", "#de.wikipedia", 3L, 1L}, + new Object[]{"Austria", "Vienna", "#es.wikipedia", 4L, 2L}, + new Object[]{"Austria", "Vienna", "#tr.wikipedia", 5L, 3L}, + new Object[]{"Republic of Korea", nullValue, "#en.wikipedia", 6L, 2L}, + new Object[]{"Republic of Korea", nullValue, "#ja.wikipedia", 7L, 3L}, + new Object[]{"Republic of Korea", nullValue, "#ko.wikipedia", 8L, 4L}, + new Object[]{"Republic of Korea", "Jeonju", "#ko.wikipedia", 9L, 1L}, + new Object[]{"Republic of Korea", "Seongnam-si", "#ko.wikipedia", 10L, 1L}, + new Object[]{"Republic of Korea", "Seoul", "#ko.wikipedia", 11L, 1L}, + new Object[]{"Republic of Korea", "Suwon-si", "#ko.wikipedia", 12L, 1L}, + new Object[]{"Republic of Korea", "Yongsan-dong", "#ko.wikipedia", 13L, 1L} + ) + ) + .setQueryContext(multipleWorkerContext) + // Stage 0 + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().totalFiles(1), + 0, 0, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(13).bytes(872).frames(1), + 0, 0, "output" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(4, 4, 4, 1).bytes(251, 266, 300, 105).frames(1, 1, 1, 1), + 0, 0, "shuffle" + ) + // Stage 1, Worker 0 + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(4).bytes(251).frames(1), + 1, 0, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(4).bytes(251).frames(1), + 1, 0, "output" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(4).bytes(251).frames(1), + 1, 0, "shuffle" + ) + + // Stage 1, Worker 1 + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(0, 4).bytes(0, 266).frames(0, 1), + 1, 1, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(0, 4).bytes(0, 266).frames(0, 1), + 1, 1, "output" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(4).bytes(266).frames(1), + 1, 1, "shuffle" + ) + + // Stage 1, Worker 2 + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(0, 0, 4).bytes(0, 0, 300).frames(0, 0, 1), + 1, 2, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(0, 0, 4).bytes(0, 0, 300).frames(0, 0, 1), + 1, 2, "output" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(4).bytes(300).frames(1), + 1, 2, "shuffle" + ) + + // Stage 1, Worker 3 + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(0, 0, 0, 1).bytes(0, 0, 0, 105).frames(0, 0, 0, 1), + 1, 3, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(0, 0, 0, 1).bytes(0, 0, 0, 105).frames(0, 0, 0, 1), + 1, 3, "output" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(1).bytes(105).frames(1), + 1, 3, "shuffle" + ) + + // Stage 2 (window stage) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(13).bytes(922).frames(4), + 2, 0, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(13).bytes(1158).frames(1), + 2, 0, "output" + ) + + // Stage 3, Worker 0 + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(13).bytes(1158).frames(1), + 3, 0, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(13).bytes(1379).frames(1), + 3, 0, "output" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher.with().rows(13).bytes(1327).frames(1), + 3, 0, "shuffle" + ) + .verifyResults(); + } } diff --git a/processing/src/main/java/org/apache/druid/query/operator/window/ranking/WindowRowNumberProcessor.java b/processing/src/main/java/org/apache/druid/query/operator/window/ranking/WindowRowNumberProcessor.java index 98b09b6f80d1..e920a9a6fd3c 100644 --- a/processing/src/main/java/org/apache/druid/query/operator/window/ranking/WindowRowNumberProcessor.java +++ b/processing/src/main/java/org/apache/druid/query/operator/window/ranking/WindowRowNumberProcessor.java @@ -30,6 +30,7 @@ import java.util.Collections; import java.util.List; +import java.util.Objects; public class WindowRowNumberProcessor implements Processor { @@ -137,4 +138,23 @@ public List getOutputColumnNames() { return Collections.singletonList(outputColumn); } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WindowRowNumberProcessor that = (WindowRowNumberProcessor) o; + return Objects.equals(outputColumn, that.outputColumn); + } + + @Override + public int hashCode() + { + return Objects.hashCode(outputColumn); + } } diff --git a/processing/src/test/java/org/apache/druid/query/operator/window/ranking/WindowRowNumberProcessorTest.java b/processing/src/test/java/org/apache/druid/query/operator/window/ranking/WindowRowNumberProcessorTest.java index f4f9b5bfeee4..4c0864b287d2 100644 --- a/processing/src/test/java/org/apache/druid/query/operator/window/ranking/WindowRowNumberProcessorTest.java +++ b/processing/src/test/java/org/apache/druid/query/operator/window/ranking/WindowRowNumberProcessorTest.java @@ -19,6 +19,7 @@ package org.apache.druid.query.operator.window.ranking; +import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.druid.query.operator.window.Processor; import org.apache.druid.query.operator.window.RowsAndColumnsHelper; import org.apache.druid.query.rowsandcols.MapOfColumnsRowsAndColumns; @@ -61,4 +62,13 @@ public void testRowNumberProcessing() final RowsAndColumns results = processor.process(rac); expectations.validate(results); } + + @Test + public void testEqualsAndHashcode() + { + EqualsVerifier.forClass(WindowRowNumberProcessor.class) + .withNonnullFields("outputColumn") + .usingGetClass() + .verify(); + } }