From 28acd0839ceeab058b6b9fbdbdc497a713faeb37 Mon Sep 17 00:00:00 2001 From: Paul Rogers Date: Thu, 8 Sep 2022 00:46:45 +0200 Subject: [PATCH 1/9] Replace QueryContext with a facade * Removes the QueryContext class and its attendant race conditions. * Replaces it with a light-weight facade to access the context. * Uses a simpler, alternative way to authorize user context values. * Moved query-related methods to Queries * Moved more context value methods to QueryContext * Revised the non-default "get" methods to return null if the value is not set. * Various build fixes --- .../druid/benchmark/query/SqlBenchmark.java | 5 +- .../query/SqlExpressionBenchmark.java | 3 +- .../query/SqlNestedDataBenchmark.java | 3 +- .../benchmark/query/SqlVsNativeBenchmark.java | 4 +- .../MaterializedViewQuery.java | 8 +- .../MaterializedViewQueryTest.java | 3 +- .../movingaverage/MovingAverageQuery.java | 2 +- .../MovingAverageQueryRunner.java | 6 +- ...blesSketchApproxQuantileSqlAggregator.java | 7 +- .../sql/DoublesSketchObjectSqlAggregator.java | 6 +- .../apache/druid/msq/exec/ControllerImpl.java | 14 +- .../org/apache/druid/msq/exec/WorkerImpl.java | 5 +- .../druid/msq/indexing/MSQControllerTask.java | 2 +- .../druid/msq/querykit/QueryKitUtils.java | 2 +- .../msq/querykit/groupby/GroupByQueryKit.java | 3 +- .../druid/msq/querykit/scan/ScanQueryKit.java | 4 +- .../org/apache/druid/msq/sql/MSQMode.java | 10 +- .../druid/msq/sql/MSQTaskQueryMaker.java | 21 +- .../druid/msq/sql/MSQTaskSqlEngine.java | 6 +- .../msq/util/MultiStageQueryContext.java | 76 +-- .../org/apache/druid/msq/sql/MSQModeTest.java | 14 +- .../apache/druid/msq/test/MSQTestBase.java | 4 +- .../msq/util/MultiStageQueryContextTest.java | 39 +- .../tools/ServerManagerForQueryErrorTest.java | 14 +- .../ServerManagerForQueryErrorTest.java | 14 +- .../druid/query/BadQueryContextException.java | 5 + .../org/apache/druid/query/BaseQuery.java | 33 +- .../druid/query/BySegmentQueryRunner.java | 4 +- .../query/BySegmentSkippingQueryRunner.java | 2 +- .../query/ChainedExecutionQueryRunner.java | 7 +- .../query/FinalizeResultsQueryRunner.java | 5 +- .../druid/query/GroupByMergedQueryRunner.java | 10 +- .../java/org/apache/druid/query/Queries.java | 22 + .../java/org/apache/druid/query/Query.java | 85 ++- .../org/apache/druid/query/QueryContext.java | 560 ++++++++++++++---- .../org/apache/druid/query/QueryContexts.java | 478 ++++----------- .../druid/query/SubqueryQueryRunner.java | 2 +- .../druid/query/groupby/GroupByQuery.java | 8 +- .../query/groupby/GroupByQueryConfig.java | 36 +- .../query/groupby/GroupByQueryEngine.java | 2 +- .../query/groupby/GroupByQueryHelper.java | 4 +- .../groupby/GroupByQueryQueryToolChest.java | 9 +- .../GroupByMergingQueryRunnerV2.java | 13 +- .../epinephelinae/GroupByQueryEngineV2.java | 8 +- .../vector/VectorGroupByEngine.java | 4 +- .../groupby/orderby/DefaultLimitSpec.java | 7 +- .../groupby/strategy/GroupByStrategyV1.java | 2 +- .../groupby/strategy/GroupByStrategyV2.java | 21 +- .../SegmentMetadataQueryRunnerFactory.java | 9 +- .../apache/druid/query/scan/ScanQuery.java | 4 +- .../druid/query/scan/ScanQueryEngine.java | 4 +- .../query/scan/ScanQueryLimitRowIterator.java | 2 +- .../query/scan/ScanQueryRunnerFactory.java | 3 +- .../druid/query/search/SearchQueryConfig.java | 2 +- .../search/SearchQueryQueryToolChest.java | 4 +- .../druid/query/select/SelectQuery.java | 8 +- .../spec/SpecificSegmentQueryRunner.java | 2 +- .../TimeBoundaryQueryQueryToolChest.java | 6 +- .../query/timeseries/TimeseriesQuery.java | 6 +- .../timeseries/TimeseriesQueryEngine.java | 6 +- .../TimeseriesQueryQueryToolChest.java | 3 +- .../druid/query/topn/TopNQueryEngine.java | 2 +- .../query/topn/TopNQueryQueryToolChest.java | 4 +- .../apache/druid/segment/VirtualColumns.java | 4 +- .../apache/druid/segment/filter/Filters.java | 2 +- .../rewrite/JoinFilterRewriteConfig.java | 13 +- .../apache/druid/query/QueryContextTest.java | 284 ++------- .../apache/druid/query/QueryContextsTest.java | 97 ++- .../DataSourceMetadataQueryTest.java | 16 +- .../VectorGroupByEngineIteratorTest.java | 3 +- .../timeboundary/TimeBoundaryQueryTest.java | 19 +- .../org/apache/druid/client/CacheUtil.java | 10 +- .../druid/client/CachingClusteredClient.java | 34 +- .../druid/client/DirectDruidClient.java | 17 +- .../druid/client/JsonParserIterator.java | 2 +- .../apache/druid/query/RetryQueryRunner.java | 6 +- .../server/ClientQuerySegmentWalker.java | 6 +- .../apache/druid/server/QueryLifecycle.java | 40 +- .../apache/druid/server/QueryResource.java | 11 +- .../apache/druid/server/QueryScheduler.java | 3 +- .../SetAndVerifyContextQueryRunner.java | 17 +- .../scheduling/HiLoQueryLaningStrategy.java | 8 +- .../scheduling/ManualQueryLaningStrategy.java | 4 +- .../scheduling/NoQueryLaningStrategy.java | 3 +- ...sholdBasedQueryPrioritizationStrategy.java | 4 +- ...ingClusteredClientCacheKeyManagerTest.java | 9 +- .../client/CachingClusteredClientTest.java | 8 +- .../druid/client/JsonParserIteratorTest.java | 10 +- ...nifiedIndexerAppenderatorsManagerTest.java | 2 - .../druid/server/QueryLifecycleTest.java | 27 +- .../druid/server/QuerySchedulerTest.java | 7 +- .../SetAndVerifyContextQueryRunnerTest.java | 9 +- .../ManualTieredBrokerSelectorStrategy.java | 12 +- .../PriorityTieredBrokerSelectorStrategy.java | 3 +- .../router/TieredBrokerHostSelector.java | 3 +- .../router/TieredBrokerSelectorStrategy.java | 3 +- .../apache/druid/sql/AbstractStatement.java | 53 +- .../org/apache/druid/sql/DirectStatement.java | 6 +- .../apache/druid/sql/PreparedStatement.java | 3 +- .../druid/sql/SqlExecutionReporter.java | 15 +- .../org/apache/druid/sql/SqlQueryPlus.java | 37 +- .../druid/sql/avatica/DruidConnection.java | 15 +- .../druid/sql/avatica/DruidJdbcStatement.java | 7 +- .../apache/druid/sql/avatica/DruidMeta.java | 13 +- .../sql/calcite/planner/DruidPlanner.java | 13 +- .../sql/calcite/planner/IngestHandler.java | 6 +- .../sql/calcite/planner/PlannerConfig.java | 28 +- .../sql/calcite/planner/PlannerContext.java | 39 +- .../sql/calcite/planner/PlannerFactory.java | 20 +- .../sql/calcite/planner/QueryHandler.java | 4 +- .../calcite/planner/SqlStatementHandler.java | 2 + .../druid/sql/calcite/rel/DruidQuery.java | 69 ++- .../druid/sql/calcite/rule/DruidJoinRule.java | 2 +- .../sql/calcite/run/NativeSqlEngine.java | 7 +- .../druid/sql/calcite/run/SqlEngine.java | 5 +- .../druid/sql/calcite/run/SqlEngines.java | 6 +- .../sql/calcite/view/DruidViewMacro.java | 8 +- .../druid/sql/calcite/view/ViewSqlEngine.java | 5 +- .../org/apache/druid/sql/http/SqlQuery.java | 6 + .../apache/druid/sql/SqlStatementTest.java | 4 +- .../druid/sql/avatica/DruidStatementTest.java | 3 +- .../sql/calcite/CalciteJoinQueryTest.java | 4 +- .../druid/sql/calcite/CalciteQueryTest.java | 33 +- .../sql/calcite/CalciteScanSignatureTest.java | 3 +- .../sql/calcite/IngestionTestSqlEngine.java | 5 +- .../SqlVectorizedExpressionSanityTest.java | 11 +- .../expression/ExpressionTestHelper.java | 5 +- .../external/ExternalTableScanRuleTest.java | 6 +- .../planner/CalcitePlannerModuleTest.java | 6 +- .../calcite/planner/DruidRexExecutorTest.java | 4 +- .../sql/calcite/rule/DruidJoinRuleTest.java | 6 +- 131 files changed, 1372 insertions(+), 1450 deletions(-) diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBenchmark.java index 35974e4430df..09b208182263 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBenchmark.java @@ -29,7 +29,6 @@ import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.aggregation.datasketches.hll.sql.HllSketchApproxCountDistinctSqlAggregator; @@ -516,7 +515,7 @@ public void querySql(Blackhole blackhole) throws Exception QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize ); final String sql = QUERIES.get(Integer.parseInt(query)); - try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) { + try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) { final PlannerResult plannerResult = planner.plan(); final Sequence resultSequence = plannerResult.run().getResults(); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); @@ -534,7 +533,7 @@ public void planSql(Blackhole blackhole) throws Exception QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize ); final String sql = QUERIES.get(Integer.parseInt(query)); - try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) { + try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) { final PlannerResult plannerResult = planner.plan(); blackhole.consume(plannerResult); } diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlExpressionBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlExpressionBenchmark.java index e1c27afc9889..0e4ba5b4f009 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlExpressionBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlExpressionBenchmark.java @@ -29,7 +29,6 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.math.expr.ExpressionProcessing; import org.apache.druid.query.DruidProcessingConfig; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.segment.QueryableIndex; @@ -352,7 +351,7 @@ public void querySql(Blackhole blackhole) throws Exception QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize ); final String sql = QUERIES.get(Integer.parseInt(query)); - try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) { + try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) { final PlannerResult plannerResult = planner.plan(); final Sequence resultSequence = plannerResult.run().getResults(); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlNestedDataBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlNestedDataBenchmark.java index ab3f5de9cef0..ed7ad8f214d7 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlNestedDataBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlNestedDataBenchmark.java @@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.math.expr.ExpressionProcessing; import org.apache.druid.query.DruidProcessingConfig; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.expression.TestExprMacroTable; @@ -318,7 +317,7 @@ public void querySql(Blackhole blackhole) throws Exception QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize ); final String sql = QUERIES.get(Integer.parseInt(query)); - try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, new QueryContext(context))) { + try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, context)) { final PlannerResult plannerResult = planner.plan(); final Sequence resultSequence = plannerResult.run().getResults(); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java index b11188eb98c9..8b0ed6c96c28 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java @@ -26,7 +26,6 @@ import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.aggregation.CountAggregatorFactory; @@ -66,6 +65,7 @@ import org.openjdk.jmh.annotations.Warmup; import org.openjdk.jmh.infra.Blackhole; +import java.util.Collections; import java.util.concurrent.TimeUnit; /** @@ -167,7 +167,7 @@ public void queryNative(Blackhole blackhole) @OutputTimeUnit(TimeUnit.MILLISECONDS) public void queryPlanner(Blackhole blackhole) throws Exception { - try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sqlQuery, new QueryContext())) { + try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sqlQuery, Collections.emptyMap())) { final PlannerResult plannerResult = planner.plan(); final Sequence resultSequence = plannerResult.run().getResults(); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); diff --git a/extensions-contrib/materialized-view-selection/src/main/java/org/apache/druid/query/materializedview/MaterializedViewQuery.java b/extensions-contrib/materialized-view-selection/src/main/java/org/apache/druid/query/materializedview/MaterializedViewQuery.java index 73c74dba68ce..d771667f6eaa 100644 --- a/extensions-contrib/materialized-view-selection/src/main/java/org/apache/druid/query/materializedview/MaterializedViewQuery.java +++ b/extensions-contrib/materialized-view-selection/src/main/java/org/apache/druid/query/materializedview/MaterializedViewQuery.java @@ -28,7 +28,6 @@ import org.apache.druid.query.BaseQuery; import org.apache.druid.query.DataSource; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.filter.DimFilter; @@ -41,6 +40,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.util.List; import java.util.Map; import java.util.Objects; @@ -146,12 +146,6 @@ public Map getContext() return query.getContext(); } - @Override - public QueryContext getQueryContext() - { - return query.getQueryContext(); - } - @Override public boolean isDescending() { diff --git a/extensions-contrib/materialized-view-selection/src/test/java/org/apache/druid/query/materializedview/MaterializedViewQueryTest.java b/extensions-contrib/materialized-view-selection/src/test/java/org/apache/druid/query/materializedview/MaterializedViewQueryTest.java index 13dfe567cd22..ad9913eca555 100644 --- a/extensions-contrib/materialized-view-selection/src/test/java/org/apache/druid/query/materializedview/MaterializedViewQueryTest.java +++ b/extensions-contrib/materialized-view-selection/src/test/java/org/apache/druid/query/materializedview/MaterializedViewQueryTest.java @@ -121,7 +121,6 @@ public void testGetContextHumanReadableBytes() .postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT) .build(); MaterializedViewQuery query = new MaterializedViewQuery(topNQuery, optimizer); - Assert.assertEquals(20_000_000, query.getContextAsHumanReadableBytes("maxOnDiskStorage", HumanReadableBytes.ZERO).getBytes()); - + Assert.assertEquals(20_000_000, query.context().getHumanReadableBytes("maxOnDiskStorage", HumanReadableBytes.ZERO).getBytes()); } } diff --git a/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQuery.java b/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQuery.java index 5ac36de51047..280bc8ccceb2 100644 --- a/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQuery.java +++ b/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQuery.java @@ -237,7 +237,7 @@ public String getType() @JsonIgnore public boolean getContextSortByDimsFirst() { - return getContextBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false); + return context().getBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false); } @Override diff --git a/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQueryRunner.java b/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQueryRunner.java index 80cc45fbc29a..7753d55c0632 100644 --- a/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQueryRunner.java +++ b/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQueryRunner.java @@ -30,7 +30,6 @@ import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.query.DataSource; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; @@ -52,6 +51,7 @@ import org.joda.time.Period; import javax.annotation.Nullable; + import java.util.List; import java.util.Map; import java.util.Optional; @@ -124,7 +124,7 @@ public Sequence run(QueryPlus query, ResponseContext responseContext) ResponseContext gbqResponseContext = ResponseContext.createEmpty(); gbqResponseContext.merge(responseContext); gbqResponseContext.putQueryFailDeadlineMs( - System.currentTimeMillis() + QueryContexts.getTimeout(gbq) + System.currentTimeMillis() + gbq.context().getTimeout() ); Sequence results = gbq.getRunner(walker).run(QueryPlus.wrap(gbq), gbqResponseContext); @@ -164,7 +164,7 @@ public Sequence run(QueryPlus query, ResponseContext responseContext) ResponseContext tsqResponseContext = ResponseContext.createEmpty(); tsqResponseContext.merge(responseContext); tsqResponseContext.putQueryFailDeadlineMs( - System.currentTimeMillis() + QueryContexts.getTimeout(tsq) + System.currentTimeMillis() + tsq.context().getTimeout() ); Sequence> results = tsq.getRunner(walker).run(QueryPlus.wrap(tsq), tsqResponseContext); diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java index abba4616c656..c6729e3036fd 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchApproxQuantileSqlAggregator.java @@ -49,6 +49,7 @@ import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; + import java.util.List; public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator @@ -171,7 +172,7 @@ public Aggregation toDruidAggregation( histogramName, input.getDirectColumn(), k, - getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext()) + getMaxStreamLengthFromQueryContext(plannerContext.queryContext()) ); } else { String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression( @@ -182,7 +183,7 @@ public Aggregation toDruidAggregation( histogramName, virtualColumnName, k, - getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext()) + getMaxStreamLengthFromQueryContext(plannerContext.queryContext()) ); } @@ -201,7 +202,7 @@ public Aggregation toDruidAggregation( static long getMaxStreamLengthFromQueryContext(QueryContext queryContext) { - return queryContext.getAsLong( + return queryContext.getLong( CTX_APPROX_QUANTILE_DS_MAX_STREAM_LENGTH, DoublesSketchAggregatorFactory.DEFAULT_MAX_STREAM_LENGTH ); diff --git a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java index abe516fa1f3a..04654daaf238 100644 --- a/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java +++ b/extensions-core/datasketches/src/main/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchObjectSqlAggregator.java @@ -46,6 +46,7 @@ import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; + import java.util.List; public class DoublesSketchObjectSqlAggregator implements SqlAggregator @@ -113,7 +114,7 @@ public Aggregation toDruidAggregation( histogramName, input.getDirectColumn(), k, - DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext()) + DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.queryContext()) ); } else { String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression( @@ -124,7 +125,7 @@ public Aggregation toDruidAggregation( histogramName, virtualColumnName, k, - DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.getQueryContext()) + DoublesSketchApproxQuantileSqlAggregator.getMaxStreamLengthFromQueryContext(plannerContext.queryContext()) ); } @@ -136,7 +137,6 @@ public Aggregation toDruidAggregation( private static class DoublesSketchSqlAggFunction extends SqlAggFunction { - private static final String SIGNATURE1 = "'" + NAME + "(column)'\n"; private static final String SIGNATURE2 = "'" + NAME + "(column, k)'\n"; DoublesSketchSqlAggFunction() diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index b6414c3db6e8..b5bb6bdd2bfb 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -518,7 +518,7 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer) closer.register(netClient::close); final boolean isDurableStorageEnabled = - MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext()); + MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context()); final QueryDefinition queryDef = makeQueryDefinition( id(), @@ -1191,7 +1191,7 @@ private Yielder getFinalResultsYielder( final InputChannelFactory inputChannelFactory; - if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext())) { + if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context())) { inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation( id(), () -> taskIds, @@ -1294,7 +1294,7 @@ private void publishSegmentsIfNeeded( */ private void cleanUpDurableStorageIfNeeded() { - if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext())) { + if (MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().context())) { final String controllerDirName = DurableStorageOutputChannelFactory.getControllerDirectory(task.getId()); try { // Delete all temporary files as a failsafe @@ -1454,7 +1454,7 @@ private static GranularitySpec makeGranularitySpecForIngestion( ) { if (isRollupQuery) { - final String queryGranularity = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY, ""); + final String queryGranularity = query.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY, ""); if (timeIsGroupByDimension((GroupByQuery) query, columnMappings) && !queryGranularity.isEmpty()) { return new ArbitraryGranularitySpec( @@ -1483,7 +1483,7 @@ private static boolean timeIsGroupByDimension(GroupByQuery groupByQuery, ColumnM { if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) { final String queryTimeColumn = columnMappings.getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME); - return queryTimeColumn.equals(groupByQuery.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD)); + return queryTimeColumn.equals(groupByQuery.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD)); } else { return false; } @@ -1505,8 +1505,8 @@ private static boolean timeIsGroupByDimension(GroupByQuery groupByQuery, ColumnM private static boolean isRollupQuery(Query query) { return query instanceof GroupByQuery - && !MultiStageQueryContext.isFinalizeAggregations(query.getQueryContext()) - && !query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true); + && !MultiStageQueryContext.isFinalizeAggregations(query.context()) + && !query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true); } private static boolean isInlineResults(final MSQSpec querySpec) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java index 275965abbe40..37be81674979 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java @@ -106,6 +106,7 @@ import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.PrioritizedCallable; import org.apache.druid.query.PrioritizedRunnable; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.server.DruidNode; @@ -177,7 +178,9 @@ public WorkerImpl(MSQWorkerTask task, WorkerContext context) this.context = context; this.selfDruidNode = context.selfNode(); this.processorBouncer = context.processorBouncer(); - this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(task.getContext()); + this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled( + QueryContext.of(task.getContext()) + ); } @Override diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java index ff4c8c19ed05..3e733168170b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java @@ -106,7 +106,7 @@ public MSQControllerTask( this.sqlQueryContext = sqlQueryContext; this.sqlTypeNames = sqlTypeNames; - if (MultiStageQueryContext.isDurableStorageEnabled(querySpec.getQuery().getContext())) { + if (MultiStageQueryContext.isDurableStorageEnabled(querySpec.getQuery().context())) { this.remoteFetchExecutorService = Executors.newCachedThreadPool(Execs.makeThreadFactory(getId() + "-remote-fetcher-%d")); } else { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java index 8630fec754d2..fcd723291650 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java @@ -191,7 +191,7 @@ public static RowSignature sortableSignature( public static VirtualColumn makeSegmentGranularityVirtualColumn(final Query query) { final Granularity segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext(query.getContext()); - final String timeColumnName = query.getQueryContext().getAsString(QueryKitUtils.CTX_TIME_COLUMN_NAME); + final String timeColumnName = query.context().getString(QueryKitUtils.CTX_TIME_COLUMN_NAME); if (timeColumnName == null || Granularities.ALL.equals(segmentGranularity)) { return null; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java index 3e494c805b25..411fe118a29e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java @@ -37,7 +37,6 @@ import org.apache.druid.msq.querykit.ShuffleSpecFactory; import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.having.AlwaysHavingSpec; @@ -205,7 +204,7 @@ static RowSignature computeResultSignature(final GroupByQuery query) */ static boolean isFinalize(final GroupByQuery query) { - return QueryContexts.isFinalize(query, true); + return query.context().isFinalize(true); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java index edd553b2b046..5bfb70b52c91 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java @@ -57,7 +57,7 @@ public static RowSignature getAndValidateSignature(final ScanQuery scanQuery, fi { RowSignature scanSignature; try { - final String s = scanQuery.getQueryContext().getAsString(DruidQuery.CTX_SCAN_SIGNATURE); + final String s = scanQuery.context().getString(DruidQuery.CTX_SCAN_SIGNATURE); scanSignature = jsonMapper.readValue(s, RowSignature.class); } catch (JsonProcessingException e) { @@ -74,7 +74,7 @@ public static RowSignature getAndValidateSignature(final ScanQuery scanQuery, fi * 2. This is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec */ // No ordering, but there is a limit or an offset. These work by funneling everything through a single partition. - // So there is no point in forcing any particular partitioning. Since everything is funnelled into a single + // So there is no point in forcing any particular partitioning. Since everything is funneled into a single // partition without a ClusterBy, we don't need to necessarily create it via the resultShuffleSpecFactory provided @Override public QueryDefinition makeQueryDefinition( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQMode.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQMode.java index e1daafadf298..6485f3ab7005 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQMode.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQMode.java @@ -23,9 +23,10 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.indexing.error.MSQWarnings; -import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryContexts; import javax.annotation.Nullable; + import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -64,7 +65,7 @@ public String toString() return value; } - public static void populateDefaultQueryContext(final String modeStr, final QueryContext originalQueryContext) + public static void populateDefaultQueryContext(final String modeStr, final Map originalQueryContext) { MSQMode mode = MSQMode.fromString(modeStr); if (mode == null) { @@ -74,8 +75,7 @@ public static void populateDefaultQueryContext(final String modeStr, final Query Arrays.stream(MSQMode.values()).map(m -> m.value).collect(Collectors.toList()) ); } - Map defaultQueryContext = mode.defaultQueryContext; - log.debug("Populating default query context with %s for the %s multi stage query mode", defaultQueryContext, mode); - originalQueryContext.addDefaultParams(defaultQueryContext); + log.debug("Populating default query context with %s for the %s multi stage query mode", mode.defaultQueryContext, mode); + QueryContexts.addDefaults(originalQueryContext, mode.defaultQueryContext); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java index c1611f52db86..d1a7a80991f3 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java @@ -42,6 +42,7 @@ import org.apache.druid.msq.indexing.MSQTuningConfig; import org.apache.druid.msq.indexing.TaskReportMSQDestination; import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.rpc.indexing.OverlordClient; @@ -59,6 +60,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -109,13 +111,14 @@ public QueryResponse runQuery(final DruidQuery druidQuery) { String taskId = MSQTasks.controllerTaskId(plannerContext.getSqlQueryId()); - String msqMode = MultiStageQueryContext.getMSQMode(plannerContext.getQueryContext()); + QueryContext queryContext = plannerContext.queryContext(); + String msqMode = MultiStageQueryContext.getMSQMode(queryContext); if (msqMode != null) { MSQMode.populateDefaultQueryContext(msqMode, plannerContext.getQueryContext()); } final String ctxDestination = - DimensionHandlerUtils.convertObjectToString(MultiStageQueryContext.getDestination(plannerContext.getQueryContext())); + DimensionHandlerUtils.convertObjectToString(MultiStageQueryContext.getDestination(queryContext)); Object segmentGranularity; try { @@ -128,7 +131,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) + "segment graularity"); } - final int maxNumTasks = MultiStageQueryContext.getMaxNumTasks(plannerContext.getQueryContext()); + final int maxNumTasks = MultiStageQueryContext.getMaxNumTasks(queryContext); if (maxNumTasks < 2) { throw new IAE(MultiStageQueryContext.CTX_MAX_NUM_TASKS @@ -139,16 +142,16 @@ public QueryResponse runQuery(final DruidQuery druidQuery) final int maxNumWorkers = maxNumTasks - 1; final int rowsPerSegment = MultiStageQueryContext.getRowsPerSegment( - plannerContext.getQueryContext(), + queryContext, DEFAULT_ROWS_PER_SEGMENT ); final int maxRowsInMemory = MultiStageQueryContext.getRowsInMemory( - plannerContext.getQueryContext(), + queryContext, DEFAULT_ROWS_IN_MEMORY ); - final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(plannerContext.getQueryContext()); + final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(queryContext); final List replaceTimeChunks = Optional.ofNullable(plannerContext.getQueryContext().get(DruidSqlReplace.SQL_REPLACE_TIME_CHUNKS)) @@ -213,7 +216,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) } final List segmentSortOrder = MultiStageQueryContext.decodeSortOrder( - MultiStageQueryContext.getSortOrder(plannerContext.getQueryContext()) + MultiStageQueryContext.getSortOrder(queryContext) ); validateSegmentSortOrder( @@ -245,7 +248,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) .query(druidQuery.getQuery().withOverriddenContext(nativeQueryContextOverrides)) .columnMappings(new ColumnMappings(columnMappings)) .destination(destination) - .assignmentStrategy(MultiStageQueryContext.getAssignmentStrategy(plannerContext.getQueryContext())) + .assignmentStrategy(MultiStageQueryContext.getAssignmentStrategy(queryContext)) .tuningConfig(new MSQTuningConfig(maxNumWorkers, maxRowsInMemory, rowsPerSegment)) .build(); @@ -253,7 +256,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) taskId, querySpec, plannerContext.getSql(), - plannerContext.getQueryContext().getMergedParams(), + plannerContext.getQueryContext(), sqlTypeNames, null ); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java index 02563f27e506..8c9caee43359 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java @@ -38,7 +38,6 @@ import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.msq.querykit.QueryKitUtils; import org.apache.druid.msq.util.MultiStageQueryContext; -import org.apache.druid.query.QueryContext; import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.sql.calcite.parser.DruidSqlInsert; @@ -52,6 +51,7 @@ import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; public class MSQTaskSqlEngine implements SqlEngine @@ -86,7 +86,7 @@ public String name() } @Override - public void validateContext(QueryContext queryContext) throws ValidationException + public void validateContext(Map queryContext) throws ValidationException { SqlEngines.validateNoSpecialContextKeys(queryContext, SYSTEM_CONTEXT_PARAMETERS); } @@ -207,7 +207,7 @@ private static void validateInsert( try { segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext( - plannerContext.getQueryContext().getMergedParams() + plannerContext.getQueryContext() ); } catch (Exception e) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java index 7a6b576e68d9..97eb5d2a1941 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java @@ -26,12 +26,12 @@ import com.opencsv.RFC4180Parser; import com.opencsv.RFC4180ParserBuilder; import org.apache.druid.java.util.common.IAE; -import org.apache.druid.java.util.common.Numbers; import org.apache.druid.msq.kernel.WorkerAssignmentStrategy; import org.apache.druid.msq.sql.MSQMode; import org.apache.druid.query.QueryContext; import javax.annotation.Nullable; + import java.io.IOException; import java.util.Arrays; import java.util.Collections; @@ -59,7 +59,7 @@ public class MultiStageQueryContext private static final boolean DEFAULT_FINALIZE_AGGREGATIONS = true; public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage"; - private static final String DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = "false"; + private static final boolean DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = false; public static final String CTX_DESTINATION = "destination"; private static final String DEFAULT_DESTINATION = null; @@ -77,48 +77,34 @@ public class MultiStageQueryContext private static final Pattern LOOKS_LIKE_JSON_ARRAY = Pattern.compile("^\\s*\\[.*", Pattern.DOTALL); - public static String getMSQMode(QueryContext queryContext) + public static String getMSQMode(final QueryContext queryContext) { - return (String) MultiStageQueryContext.getValueFromPropertyMap( - queryContext.getMergedParams(), + return queryContext.getString( CTX_MSQ_MODE, - null, DEFAULT_MSQ_MODE ); } - public static boolean isDurableStorageEnabled(Map propertyMap) + public static boolean isDurableStorageEnabled(final QueryContext queryContext) { - return Boolean.parseBoolean( - String.valueOf( - getValueFromPropertyMap( - propertyMap, - CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, - null, - DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE - ) - ) + return queryContext.getBoolean( + CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, + DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE ); } public static boolean isFinalizeAggregations(final QueryContext queryContext) { - return Numbers.parseBoolean( - getValueFromPropertyMap( - queryContext.getMergedParams(), - CTX_FINALIZE_AGGREGATIONS, - null, - DEFAULT_FINALIZE_AGGREGATIONS - ) + return queryContext.getBoolean( + CTX_FINALIZE_AGGREGATIONS, + DEFAULT_FINALIZE_AGGREGATIONS ); } public static WorkerAssignmentStrategy getAssignmentStrategy(final QueryContext queryContext) { - String assignmentStrategyString = (String) getValueFromPropertyMap( - queryContext.getMergedParams(), + String assignmentStrategyString = queryContext.getString( CTX_TASK_ASSIGNMENT_STRATEGY, - null, DEFAULT_TASK_ASSIGNMENT_STRATEGY ); @@ -127,47 +113,33 @@ public static WorkerAssignmentStrategy getAssignmentStrategy(final QueryContext public static int getMaxNumTasks(final QueryContext queryContext) { - return Numbers.parseInt( - getValueFromPropertyMap( - queryContext.getMergedParams(), - CTX_MAX_NUM_TASKS, - null, - DEFAULT_MAX_NUM_TASKS - ) + return queryContext.getInt( + CTX_MAX_NUM_TASKS, + DEFAULT_MAX_NUM_TASKS ); } public static Object getDestination(final QueryContext queryContext) { - return getValueFromPropertyMap( - queryContext.getMergedParams(), + return queryContext.get( CTX_DESTINATION, - null, DEFAULT_DESTINATION ); } public static int getRowsPerSegment(final QueryContext queryContext, int defaultRowsPerSegment) { - return Numbers.parseInt( - getValueFromPropertyMap( - queryContext.getMergedParams(), - CTX_ROWS_PER_SEGMENT, - null, - defaultRowsPerSegment - ) + return queryContext.getInt( + CTX_ROWS_PER_SEGMENT, + defaultRowsPerSegment ); } public static int getRowsInMemory(final QueryContext queryContext, int defaultRowsInMemory) { - return Numbers.parseInt( - getValueFromPropertyMap( - queryContext.getMergedParams(), - CTX_ROWS_IN_MEMORY, - null, - defaultRowsInMemory - ) + return queryContext.getInt( + CTX_ROWS_IN_MEMORY, + defaultRowsInMemory ); } @@ -196,10 +168,8 @@ public static Object getValueFromPropertyMap( public static String getSortOrder(final QueryContext queryContext) { - return (String) getValueFromPropertyMap( - queryContext.getMergedParams(), + return queryContext.getString( CTX_SORT_ORDER, - null, DEFAULT_SORT_ORDER ); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQModeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQModeTest.java index ea9bb4ff5b3d..fa0dac1cf518 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQModeTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQModeTest.java @@ -22,34 +22,36 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.java.util.common.ISE; import org.apache.druid.msq.indexing.error.MSQWarnings; -import org.apache.druid.query.QueryContext; import org.junit.Assert; import org.junit.Test; +import java.util.Collections; +import java.util.Map; + public class MSQModeTest { @Test public void testPopulateQueryContextWhenNoSupercedingValuePresent() { - QueryContext originalQueryContext = new QueryContext(); + Map originalQueryContext = Collections.emptyMap(); MSQMode.populateDefaultQueryContext("strict", originalQueryContext); - Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 0), originalQueryContext.getMergedParams()); + Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 0), originalQueryContext); } @Test public void testPopulateQueryContextWhenSupercedingValuePresent() { - QueryContext originalQueryContext = new QueryContext(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10)); + Map originalQueryContext = Collections.emptyMap(); MSQMode.populateDefaultQueryContext("strict", originalQueryContext); - Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10), originalQueryContext.getMergedParams()); + Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10), originalQueryContext); } @Test public void testPopulateQueryContextWhenInvalidMode() { - QueryContext originalQueryContext = new QueryContext(); + Map originalQueryContext = Collections.emptyMap(); Assert.assertThrows(ISE.class, () -> { MSQMode.populateDefaultQueryContext("fake_mode", originalQueryContext); }); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 880c6f4adead..a08cf0ca3ece 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -89,7 +89,6 @@ import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.ForwardingQueryProcessingPool; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -162,6 +161,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; + import java.io.Closeable; import java.io.File; import java.io.IOException; @@ -587,7 +587,7 @@ private String runMultiStageQuery(String query, Map context) final DirectStatement stmt = sqlStatementFactory.directStatement( new SqlQueryPlus( query, - new QueryContext(context), + context, Collections.emptyList(), CalciteTests.REGULAR_USER_AUTH_RESULT ) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/MultiStageQueryContextTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/MultiStageQueryContextTest.java index 23beeebd8f59..8f0876f1b146 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/MultiStageQueryContextTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/MultiStageQueryContextTest.java @@ -27,6 +27,7 @@ import org.junit.Test; import javax.annotation.Nullable; + import java.util.List; import java.util.Map; @@ -46,33 +47,33 @@ public class MultiStageQueryContextTest @Test public void isDurableStorageEnabled_noParameterSetReturnsDefaultValue() { - Assert.assertFalse(MultiStageQueryContext.isDurableStorageEnabled(ImmutableMap.of())); + Assert.assertFalse(MultiStageQueryContext.isDurableStorageEnabled(QueryContext.empty())); } @Test public void isDurableStorageEnabled_parameterSetReturnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, "true"); - Assert.assertTrue(MultiStageQueryContext.isDurableStorageEnabled(propertyMap)); + Assert.assertTrue(MultiStageQueryContext.isDurableStorageEnabled(QueryContext.of(propertyMap))); } @Test public void isFinalizeAggregations_noParameterSetReturnsDefaultValue() { - Assert.assertTrue(MultiStageQueryContext.isFinalizeAggregations(new QueryContext())); + Assert.assertTrue(MultiStageQueryContext.isFinalizeAggregations(QueryContext.empty())); } @Test public void isFinalizeAggregations_parameterSetReturnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_FINALIZE_AGGREGATIONS, "false"); - Assert.assertFalse(MultiStageQueryContext.isFinalizeAggregations(new QueryContext(propertyMap))); + Assert.assertFalse(MultiStageQueryContext.isFinalizeAggregations(QueryContext.of(propertyMap))); } @Test public void getAssignmentStrategy_noParameterSetReturnsDefaultValue() { - Assert.assertEquals(WorkerAssignmentStrategy.MAX, MultiStageQueryContext.getAssignmentStrategy(new QueryContext())); + Assert.assertEquals(WorkerAssignmentStrategy.MAX, MultiStageQueryContext.getAssignmentStrategy(QueryContext.empty())); } @Test @@ -81,67 +82,67 @@ public void getAssignmentStrategy_parameterSetReturnsCorrectValue() Map propertyMap = ImmutableMap.of(CTX_TASK_ASSIGNMENT_STRATEGY, "AUTO"); Assert.assertEquals( WorkerAssignmentStrategy.AUTO, - MultiStageQueryContext.getAssignmentStrategy(new QueryContext(propertyMap)) + MultiStageQueryContext.getAssignmentStrategy(QueryContext.of(propertyMap)) ); } @Test public void getMaxNumTasks_noParameterSetReturnsDefaultValue() { - Assert.assertEquals(DEFAULT_MAX_NUM_TASKS, MultiStageQueryContext.getMaxNumTasks(new QueryContext())); + Assert.assertEquals(DEFAULT_MAX_NUM_TASKS, MultiStageQueryContext.getMaxNumTasks(QueryContext.empty())); } @Test public void getMaxNumTasks_parameterSetReturnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101); - Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(new QueryContext(propertyMap))); + Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(QueryContext.of(propertyMap))); } @Test public void getMaxNumTasks_legacyParameterSetReturnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_MAX_NUM_TASKS, 101); - Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(new QueryContext(propertyMap))); + Assert.assertEquals(101, MultiStageQueryContext.getMaxNumTasks(QueryContext.of(propertyMap))); } @Test public void getDestination_noParameterSetReturnsDefaultValue() { - Assert.assertNull(MultiStageQueryContext.getDestination(new QueryContext())); + Assert.assertNull(MultiStageQueryContext.getDestination(QueryContext.empty())); } @Test public void getDestination_parameterSetReturnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_DESTINATION, "dataSource"); - Assert.assertEquals("dataSource", MultiStageQueryContext.getDestination(new QueryContext(propertyMap))); + Assert.assertEquals("dataSource", MultiStageQueryContext.getDestination(QueryContext.of(propertyMap))); } @Test public void getRowsPerSegment_noParameterSetReturnsDefaultValue() { - Assert.assertEquals(1000, MultiStageQueryContext.getRowsPerSegment(new QueryContext(), 1000)); + Assert.assertEquals(1000, MultiStageQueryContext.getRowsPerSegment(QueryContext.empty(), 1000)); } @Test public void getRowsPerSegment_parameterSetReturnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_ROWS_PER_SEGMENT, 10); - Assert.assertEquals(10, MultiStageQueryContext.getRowsPerSegment(new QueryContext(propertyMap), 1000)); + Assert.assertEquals(10, MultiStageQueryContext.getRowsPerSegment(QueryContext.of(propertyMap), 1000)); } @Test public void getRowsInMemory_noParameterSetReturnsDefaultValue() { - Assert.assertEquals(1000, MultiStageQueryContext.getRowsInMemory(new QueryContext(), 1000)); + Assert.assertEquals(1000, MultiStageQueryContext.getRowsInMemory(QueryContext.empty(), 1000)); } @Test public void getRowsInMemory_parameterSetReturnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_ROWS_IN_MEMORY, 10); - Assert.assertEquals(10, MultiStageQueryContext.getRowsInMemory(new QueryContext(propertyMap), 1000)); + Assert.assertEquals(10, MultiStageQueryContext.getRowsInMemory(QueryContext.of(propertyMap), 1000)); } @Test @@ -161,27 +162,27 @@ public void testDecodeSortOrder() @Test public void getSortOrderNoParameterSetReturnsDefaultValue() { - Assert.assertNull(MultiStageQueryContext.getSortOrder(new QueryContext())); + Assert.assertNull(MultiStageQueryContext.getSortOrder(QueryContext.empty())); } @Test public void getSortOrderParameterSetReturnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_SORT_ORDER, "a, b,\"c,d\""); - Assert.assertEquals("a, b,\"c,d\"", MultiStageQueryContext.getSortOrder(new QueryContext(propertyMap))); + Assert.assertEquals("a, b,\"c,d\"", MultiStageQueryContext.getSortOrder(QueryContext.of(propertyMap))); } @Test public void getMSQModeNoParameterSetReturnsDefaultValue() { - Assert.assertEquals("strict", MultiStageQueryContext.getMSQMode(new QueryContext())); + Assert.assertEquals("strict", MultiStageQueryContext.getMSQMode(QueryContext.empty())); } @Test public void getMSQModeParameterSetReturnsCorrectValue() { Map propertyMap = ImmutableMap.of(CTX_MSQ_MODE, "nonStrict"); - Assert.assertEquals("nonStrict", MultiStageQueryContext.getMSQMode(new QueryContext(propertyMap))); + Assert.assertEquals("nonStrict", MultiStageQueryContext.getMSQMode(QueryContext.of(propertyMap))); } private static List decodeSortOrder(@Nullable final String input) diff --git a/integration-tests-ex/tools/src/main/java/org/apache/druid/testing/tools/ServerManagerForQueryErrorTest.java b/integration-tests-ex/tools/src/main/java/org/apache/druid/testing/tools/ServerManagerForQueryErrorTest.java index 160cd9db3a8c..60a057ece799 100644 --- a/integration-tests-ex/tools/src/main/java/org/apache/druid/testing/tools/ServerManagerForQueryErrorTest.java +++ b/integration-tests-ex/tools/src/main/java/org/apache/druid/testing/tools/ServerManagerForQueryErrorTest.java @@ -34,6 +34,7 @@ import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunnerFactory; @@ -127,7 +128,8 @@ protected QueryRunner buildQueryRunnerForSegment( Optional cacheKeyPrefix ) { - if (query.getContextBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) { + final QueryContext queryContext = query.context(); + if (queryContext.getBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) { final MutableBoolean isIgnoreSegment = new MutableBoolean(false); queryToIgnoredSegments.compute( query.getMostSpecificId(), @@ -147,7 +149,7 @@ protected QueryRunner buildQueryRunnerForSegment( LOG.info("Pretending I don't have segment [%s]", descriptor); return new ReportTimelineMissingSegmentQueryRunner<>(descriptor); } - } else if (query.getContextBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override @@ -162,7 +164,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat throw new QueryTimeoutException("query timeout test"); } }; - } else if (query.getContextBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override @@ -177,7 +179,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat throw QueryCapacityExceededException.withErrorMessageAndResolvedHost("query capacity exceeded test"); } }; - } else if (query.getContextBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override @@ -192,7 +194,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat throw new QueryUnsupportedException("query unsupported test"); } }; - } else if (query.getContextBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override @@ -207,7 +209,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat throw new ResourceLimitExceededException("resource limit exceeded test"); } }; - } else if (query.getContextBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override diff --git a/integration-tests/src/main/java/org/apache/druid/server/coordination/ServerManagerForQueryErrorTest.java b/integration-tests/src/main/java/org/apache/druid/server/coordination/ServerManagerForQueryErrorTest.java index ec3ad43a73de..7b434667fa94 100644 --- a/integration-tests/src/main/java/org/apache/druid/server/coordination/ServerManagerForQueryErrorTest.java +++ b/integration-tests/src/main/java/org/apache/druid/server/coordination/ServerManagerForQueryErrorTest.java @@ -34,6 +34,7 @@ import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunnerFactory; @@ -125,7 +126,8 @@ protected QueryRunner buildQueryRunnerForSegment( Optional cacheKeyPrefix ) { - if (query.getContextBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) { + final QueryContext queryContext = query.context(); + if (queryContext.getBoolean(QUERY_RETRY_TEST_CONTEXT_KEY, false)) { final MutableBoolean isIgnoreSegment = new MutableBoolean(false); queryToIgnoredSegments.compute( query.getMostSpecificId(), @@ -145,7 +147,7 @@ protected QueryRunner buildQueryRunnerForSegment( LOG.info("Pretending I don't have segment[%s]", descriptor); return new ReportTimelineMissingSegmentQueryRunner<>(descriptor); } - } else if (query.getContextBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(QUERY_TIMEOUT_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override @@ -160,7 +162,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat throw new QueryTimeoutException("query timeout test"); } }; - } else if (query.getContextBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(QUERY_CAPACITY_EXCEEDED_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override @@ -175,7 +177,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat throw QueryCapacityExceededException.withErrorMessageAndResolvedHost("query capacity exceeded test"); } }; - } else if (query.getContextBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(QUERY_UNSUPPORTED_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override @@ -190,7 +192,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat throw new QueryUnsupportedException("query unsupported test"); } }; - } else if (query.getContextBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(RESOURCE_LIMIT_EXCEEDED_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override @@ -205,7 +207,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat throw new ResourceLimitExceededException("resource limit exceeded test"); } }; - } else if (query.getContextBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) { + } else if (queryContext.getBoolean(QUERY_FAILURE_TEST_CONTEXT_KEY, false)) { return (queryPlus, responseContext) -> new Sequence() { @Override diff --git a/processing/src/main/java/org/apache/druid/query/BadQueryContextException.java b/processing/src/main/java/org/apache/druid/query/BadQueryContextException.java index 1991656332c0..29f63b1f40ee 100644 --- a/processing/src/main/java/org/apache/druid/query/BadQueryContextException.java +++ b/processing/src/main/java/org/apache/druid/query/BadQueryContextException.java @@ -32,6 +32,11 @@ public BadQueryContextException(Exception e) this(ERROR_CODE, e.getMessage(), ERROR_CLASS); } + public BadQueryContextException(String msg) + { + this(ERROR_CODE, msg, ERROR_CLASS); + } + @JsonCreator private BadQueryContextException( @JsonProperty("error") String errorCode, diff --git a/processing/src/main/java/org/apache/druid/query/BaseQuery.java b/processing/src/main/java/org/apache/druid/query/BaseQuery.java index a4c1a999a812..88a59781d946 100644 --- a/processing/src/main/java/org/apache/druid/query/BaseQuery.java +++ b/processing/src/main/java/org/apache/druid/query/BaseQuery.java @@ -27,7 +27,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Ordering; import org.apache.druid.guice.annotations.ExtensionPoint; -import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.java.util.common.granularity.PeriodGranularity; @@ -39,6 +38,7 @@ import javax.annotation.Nullable; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; @@ -61,7 +61,7 @@ public static void checkInterrupted() public static final String SQL_QUERY_ID = "sqlQueryId"; private final DataSource dataSource; private final boolean descending; - private final QueryContext context; + private final Map context; private final QuerySegmentSpec querySegmentSpec; private volatile Duration duration; private final Granularity granularity; @@ -89,7 +89,10 @@ public BaseQuery( Preconditions.checkNotNull(granularity, "Must specify a granularity"); this.dataSource = dataSource; - this.context = new QueryContext(context); + // There is no semantic difference between an empty and a null context. + // Ensure that a context always exists to avoid the need to check for + // a null context. Jackson serialization will omit empty contexts. + this.context = context == null ? Collections.emptyMap() : context; this.querySegmentSpec = querySegmentSpec; this.descending = descending; this.granularity = granularity; @@ -172,25 +175,7 @@ public DateTimeZone getTimezone() @JsonInclude(Include.NON_DEFAULT) public Map getContext() { - return context.getMergedParams(); - } - - @Override - public QueryContext getQueryContext() - { - return context; - } - - @Override - public boolean getContextBoolean(String key, boolean defaultValue) - { - return context.getAsBoolean(key, defaultValue); - } - - @Override - public HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue) - { - return context.getAsHumanReadableBytes(key, defaultValue); + return context == null ? Collections.emptyMap() : context; } /** @@ -228,7 +213,7 @@ public Ordering getResultOrdering() @Override public String getId() { - return context.getAsString(QUERY_ID); + return context().getString(QUERY_ID); } @Override @@ -241,7 +226,7 @@ public Query withSubQueryId(String subQueryId) @Override public String getSubQueryId() { - return context.getAsString(SUB_QUERY_ID); + return context().getString(SUB_QUERY_ID); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/BySegmentQueryRunner.java b/processing/src/main/java/org/apache/druid/query/BySegmentQueryRunner.java index cd386c7a6ba8..c0a6f55832c9 100644 --- a/processing/src/main/java/org/apache/druid/query/BySegmentQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/BySegmentQueryRunner.java @@ -35,7 +35,7 @@ * * Note that despite the type parameter "T", this runner may not actually return sequences with type T. They * may really be of type {@code Result>}, if "bySegment" is set. Downstream consumers - * of the returned sequence must be aware of this, and can use {@link QueryContexts#isBySegment(Query)} to + * of the returned sequence must be aware of this, and can use {@link QueryContext#isBySegment()} to * know what to expect. */ public class BySegmentQueryRunner implements QueryRunner @@ -55,7 +55,7 @@ public BySegmentQueryRunner(SegmentId segmentId, DateTime timestamp, QueryRunner @SuppressWarnings("unchecked") public Sequence run(final QueryPlus queryPlus, ResponseContext responseContext) { - if (QueryContexts.isBySegment(queryPlus.getQuery())) { + if (queryPlus.getQuery().context().isBySegment()) { final Sequence baseSequence = base.run(queryPlus, responseContext); final List results = baseSequence.toList(); return Sequences.simple( diff --git a/processing/src/main/java/org/apache/druid/query/BySegmentSkippingQueryRunner.java b/processing/src/main/java/org/apache/druid/query/BySegmentSkippingQueryRunner.java index 7061d47b5483..a0dd96a31ac9 100644 --- a/processing/src/main/java/org/apache/druid/query/BySegmentSkippingQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/BySegmentSkippingQueryRunner.java @@ -39,7 +39,7 @@ public BySegmentSkippingQueryRunner( @Override public Sequence run(QueryPlus queryPlus, ResponseContext responseContext) { - if (QueryContexts.isBySegment(queryPlus.getQuery())) { + if (queryPlus.getQuery().context().isBySegment()) { return baseRunner.run(queryPlus, responseContext); } diff --git a/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java b/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java index 995cd4235937..5ff044f65641 100644 --- a/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java @@ -78,7 +78,7 @@ public ChainedExecutionQueryRunner( public Sequence run(final QueryPlus queryPlus, final ResponseContext responseContext) { Query query = queryPlus.getQuery(); - final int priority = QueryContexts.getPriority(query); + final int priority = query.context().getPriority(); final Ordering ordering = query.getResultOrdering(); final QueryPlus threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState(); return new BaseSequence>( @@ -137,9 +137,10 @@ public Iterable call() queryWatcher.registerQueryFuture(query, future); try { + final QueryContext context = query.context(); return new MergeIterable<>( - QueryContexts.hasTimeout(query) ? - future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS) : + context.hasTimeout() ? + future.get(context.getTimeout(), TimeUnit.MILLISECONDS) : future.get(), ordering.nullsFirst() ).iterator(); diff --git a/processing/src/main/java/org/apache/druid/query/FinalizeResultsQueryRunner.java b/processing/src/main/java/org/apache/druid/query/FinalizeResultsQueryRunner.java index e73b9cc4619d..a8b590b5dbfa 100644 --- a/processing/src/main/java/org/apache/druid/query/FinalizeResultsQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/FinalizeResultsQueryRunner.java @@ -56,8 +56,9 @@ public FinalizeResultsQueryRunner( public Sequence run(final QueryPlus queryPlus, ResponseContext responseContext) { final Query query = queryPlus.getQuery(); - final boolean isBySegment = QueryContexts.isBySegment(query); - final boolean shouldFinalize = QueryContexts.isFinalize(query, true); + final QueryContext queryContext = query.context(); + final boolean isBySegment = queryContext.isBySegment(); + final boolean shouldFinalize = queryContext.isFinalize(true); final Query queryToRun; final Function finalizerFn; diff --git a/processing/src/main/java/org/apache/druid/query/GroupByMergedQueryRunner.java b/processing/src/main/java/org/apache/druid/query/GroupByMergedQueryRunner.java index 26eca1340a69..d0d231b0c3df 100644 --- a/processing/src/main/java/org/apache/druid/query/GroupByMergedQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/GroupByMergedQueryRunner.java @@ -84,8 +84,9 @@ public Sequence run(final QueryPlus queryPlus, final ResponseContext respo querySpecificConfig ); final Pair> bySegmentAccumulatorPair = GroupByQueryHelper.createBySegmentAccumulatorPair(); - final boolean bySegment = QueryContexts.isBySegment(query); - final int priority = QueryContexts.getPriority(query); + final QueryContext queryContext = query.context(); + final boolean bySegment = queryContext.isBySegment(); + final int priority = queryContext.getPriority(); final QueryPlus threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState(); final List> futures = Lists.newArrayList( @@ -173,8 +174,9 @@ private void waitForFutureCompletion( ListenableFuture> future = Futures.allAsList(futures); try { queryWatcher.registerQueryFuture(query, future); - if (QueryContexts.hasTimeout(query)) { - future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS); + final QueryContext context = query.context(); + if (context.hasTimeout()) { + future.get(context.getTimeout(), TimeUnit.MILLISECONDS); } else { future.get(); } diff --git a/processing/src/main/java/org/apache/druid/query/Queries.java b/processing/src/main/java/org/apache/druid/query/Queries.java index 58de4695faf4..9a758b351126 100644 --- a/processing/src/main/java/org/apache/druid/query/Queries.java +++ b/processing/src/main/java/org/apache/druid/query/Queries.java @@ -20,6 +20,7 @@ package org.apache.druid.query; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.druid.guice.annotations.PublicApi; @@ -36,6 +37,7 @@ import org.apache.druid.segment.column.ColumnHolder; import javax.annotation.Nullable; + import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -293,4 +295,24 @@ public static Set computeRequiredColumns( return requiredColumns; } + + public static Query withMaxScatterGatherBytes(Query query, long maxScatterGatherBytesLimit) + { + QueryContext context = query.context(); + if (!context.containsKey(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY)) { + return query.withOverriddenContext(ImmutableMap.of(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, maxScatterGatherBytesLimit)); + } + context.verifyMaxScatterGatherBytes(maxScatterGatherBytesLimit); + return query; + } + + public static Query withTimeout(Query query, long timeout) + { + return query.withOverriddenContext(ImmutableMap.of(QueryContexts.TIMEOUT_KEY, timeout)); + } + + public static Query withDefaultTimeout(Query query, long defaultTimeout) + { + return query.withOverriddenContext(ImmutableMap.of(QueryContexts.DEFAULT_TIMEOUT_KEY, defaultTimeout)); + } } diff --git a/processing/src/main/java/org/apache/druid/query/Query.java b/processing/src/main/java/org/apache/druid/query/Query.java index 5662b988c238..9e20ee653828 100644 --- a/processing/src/main/java/org/apache/druid/query/Query.java +++ b/processing/src/main/java/org/apache/druid/query/Query.java @@ -45,6 +45,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.util.List; import java.util.Map; import java.util.Set; @@ -96,64 +97,60 @@ public interface Query DateTimeZone getTimezone(); /** - * Use {@link #getQueryContext()} instead. + * Returns the context as an (immutable) map. */ - @Deprecated Map getContext(); /** - * Returns QueryContext for this query. This type distinguishes between user provided, system default, and system - * generated query context keys so that authorization may be employed directly against the user supplied context - * values. - * - * This method is marked @Nullable, but is only so for backwards compatibility with Druid versions older than 0.23. - * Callers should check if the result of this method is null, and if so, they are dealing with a legacy query - * implementation, and should fall back to using {@link #getContext()} and {@link #withOverriddenContext(Map)} to - * manipulate the query context. - * - * Note for query context serialization and deserialization. - * Currently, once a query is serialized, its queryContext can be different from the original queryContext - * after the query is deserialized back. If the queryContext has any {@link QueryContext#defaultParams} or - * {@link QueryContext#systemParams} in it, those will be found in {@link QueryContext#userParams} - * after it is deserialized. This is because {@link BaseQuery#getContext()} uses - * {@link QueryContext#getMergedParams()} for serialization, and queries accept a map for deserialization. + * Returns the query context as a {@link QueryContext}, which provides + * convenience methods for accessing typed context values. The returned + * instance is a view on top of the context provided by {@link #getContext()}. */ - @Nullable - default QueryContext getQueryContext() + default QueryContext context() { - return null; + return QueryContext.of(getContext()); } /** * Get context value and cast to ContextType in an unsafe way. * - * For safe conversion, it's recommended to use following methods instead + * For safe conversion, it's recommended to use following methods instead: + *

+ * {@link QueryContext#getBoolean(String)}
+ * {@link QueryContext#getString(String)}
+ * {@link QueryContext#getInt(String)}
+ * {@link QueryContext#getLong(String)}
+ * {@link QueryContext#getFloat(String)}
+ * {@link QueryContext#getEnum(String, Class, Enum)}
+ * {@link QueryContext#getHumanReadableBytes(String, HumanReadableBytes)} * - * {@link QueryContext#getAsBoolean(String)} - * {@link QueryContext#getAsString(String)} - * {@link QueryContext#getAsInt(String)} - * {@link QueryContext#getAsLong(String)} - * {@link QueryContext#getAsFloat(String, float)} - * {@link QueryContext#getAsEnum(String, Class, Enum)} - * {@link QueryContext#getAsHumanReadableBytes(String, HumanReadableBytes)} + * @deprecated use {@code queryContext().get()} instead */ + @Deprecated + @SuppressWarnings("unchecked") @Nullable default ContextType getContextValue(String key) { - if (getQueryContext() == null) { - return null; - } else { - return (ContextType) getQueryContext().get(key); - } + return (ContextType) context().get(key); } + /** + * @deprecated use {@code queryContext().get(defaultValue)} instead + */ + @SuppressWarnings("unchecked") + @Deprecated + default ContextType getContextValue(String key, ContextType defaultValue) + { + return (ContextType) context().get(key, defaultValue); + } + + /** + * @deprecated use {@code queryContext().getBoolean()} instead. + */ + @Deprecated default boolean getContextBoolean(String key, boolean defaultValue) { - if (getQueryContext() == null) { - return defaultValue; - } else { - return getQueryContext().getAsBoolean(key, defaultValue); - } + return context().getBoolean(key, defaultValue); } /** @@ -164,14 +161,12 @@ default boolean getContextBoolean(String key, boolean defaultValue) * @param key The context key value being looked up * @param defaultValue The default to return if the key value doesn't exist or the context is null. * @return {@link HumanReadableBytes} + * @deprecated use {@code queryContext().getContextHumanReadableBytes()} instead. */ - default HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue) + @Deprecated + default HumanReadableBytes getContextHumanReadableBytes(String key, HumanReadableBytes defaultValue) { - if (getQueryContext() == null) { - return defaultValue; - } else { - return getQueryContext().getAsHumanReadableBytes(key, defaultValue); - } + return context().getHumanReadableBytes(key, defaultValue); } boolean isDescending(); @@ -230,7 +225,7 @@ default Query withSqlQueryId(String sqlQueryId) @Nullable default String getSqlQueryId() { - return getQueryContext().getAsString(BaseQuery.SQL_QUERY_ID); + return context().getString(BaseQuery.SQL_QUERY_ID); } /** diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java b/processing/src/main/java/org/apache/druid/query/QueryContext.java index f902bddb2d97..93e5dcf23c77 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContext.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java @@ -20,236 +20,540 @@ package org.apache.druid.query; import org.apache.druid.java.util.common.HumanReadableBytes; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.granularity.Granularity; +import org.apache.druid.query.QueryContexts.Vectorize; +import org.apache.druid.segment.QueryableIndexStorageAdapter; import javax.annotation.Nullable; import java.util.Collections; import java.util.Map; import java.util.Objects; -import java.util.TreeMap; /** - * Holder for query context parameters. There are 3 ways to set context params today. - * - * - Default parameters. These are set mostly via {@link DefaultQueryConfig#context}. - * Auto-generated queryId or sqlQueryId are also set as default parameters. These default parameters can - * be overridden by user or system parameters. - * - User parameters. These are the params set by the user. User params override default parameters but - * are overridden by system parameters. - * - System parameters. These are the params set by the Druid query engine for internal use only. - * - * You can use {@code getX} methods or {@link #getMergedParams()} to compute the context params - * merging 3 types of params above. - * - * Currently, this class is mainly used for query context parameter authorization, - * such as HTTP query endpoints or JDBC endpoint. Its usage can be expanded in the future if we - * want to track user parameters and separate them from others during query processing. + * Immutable holder for query context parameters with typed access methods. */ public class QueryContext { - private final Map defaultParams; - private final Map userParams; - private final Map systemParams; + private static final QueryContext EMPTY = new QueryContext(null); + + private final Map context; + + public QueryContext(Map context) + { + this.context = context == null ? Collections.emptyMap() : Collections.unmodifiableMap(context); + } + + public static QueryContext empty() + { + return EMPTY; + } + + public static QueryContext of(Map context) + { + return new QueryContext(context == null ? Collections.emptyMap() : context); + } + + public boolean isEmpty() + { + return context.isEmpty(); + } + + public Map getContext() + { + return context; + } /** - * Cache of params merged. + * Check if the given key is set. If the client will then fetch the value, + * consider using one of the {@code get(String key)} methods instead: + * they each return {@code null} if the value is not set. + */ + public boolean containsKey(String key) + { + return context.containsKey(key); + } + + /** + * Return a value as a generic {@code Object}, returning {@code null} if the + * context value is not set. */ @Nullable - private Map mergedParams; + public Object get(String key) + { + return context.get(key); + } - public QueryContext() + /** + * Return a value as a generic {@code Object}, returning the default value if the + * context value is not set. + */ + public Object get(String key, Object defaultValue) { - this(null); + final Object val = get(key); + return val == null ? defaultValue : val; } - public QueryContext(@Nullable Map userParams) + /** + * Return a value as an {@code String}, returning {@link null} if the + * context value is not set. + * + * @throws BadQueryContextException for an invalid value + */ + @Nullable + public String getString(String key) { - this( - new TreeMap<>(), - userParams == null ? new TreeMap<>() : new TreeMap<>(userParams), - new TreeMap<>() - ); + return getString(key, null); } - private QueryContext( - final Map defaultParams, - final Map userParams, - final Map systemParams - ) + public String getString(String key, String defaultValue) { - this.defaultParams = defaultParams; - this.userParams = userParams; - this.systemParams = systemParams; - this.mergedParams = null; + return QueryContexts.parseString(context, key, defaultValue); } - private void invalidateMergedParams() + /** + * Return a value as an {@code Boolean}, returning {@link null} if the + * context value is not set. + * + * @throws BadQueryContextException for an invalid value + */ + public Boolean getBoolean(final String key) { - this.mergedParams = null; + return QueryContexts.getAsBoolean(key, get(key)); } - public boolean isEmpty() + /** + * Return a value as an {@code boolean}, returning the default value if the + * context value is not set. + * + * @throws BadQueryContextException for an invalid value + */ + public boolean getBoolean(final String key, final boolean defaultValue) { - return defaultParams.isEmpty() && userParams.isEmpty() && systemParams.isEmpty(); + return QueryContexts.parseBoolean(context, key, defaultValue); } - public void addDefaultParam(String key, Object val) + /** + * Return a value as an {@code Integer}, returning {@link null} if the + * context value is not set. + * + * @throws BadQueryContextException for an invalid value + */ + public Integer getInt(final String key) + { + return QueryContexts.getAsInt(key, get(key)); + } + + /** + * Return a value as an {@code int}, returning the default value if the + * context value is not set. + * + * @throws BadQueryContextException for an invalid value + */ + public int getInt(final String key, final int defaultValue) { - invalidateMergedParams(); - defaultParams.put(key, val); + return QueryContexts.parseInt(context, key, defaultValue); } - public void addDefaultParams(Map defaultParams) + /** + * Return a value as an {@code Long}, returning {@link null} if the + * context value is not set. + * + * @throws BadQueryContextException for an invalid value + */ + public Long getLong(final String key) { - invalidateMergedParams(); - this.defaultParams.putAll(defaultParams); + return QueryContexts.getAsLong(key, get(key)); } - public void addSystemParam(String key, Object val) + /** + * Return a value as an {@code long}, returning the default value if the + * context value is not set. + * + * @throws BadQueryContextException for an invalid value + */ + public long getLong(final String key, final long defaultValue) { - invalidateMergedParams(); - this.systemParams.put(key, val); + return QueryContexts.parseLong(context, key, defaultValue); } - public Object removeUserParam(String key) + /** + * Return a value as an {@code Float}, returning {@link null} if the + * context value is not set. + * + * @throws BadQueryContextException for an invalid value + */ + @SuppressWarnings("unused") + public Float getFloat(final String key) { - invalidateMergedParams(); - return userParams.remove(key); + return QueryContexts.getAsFloat(key, get(key)); } /** - * Returns only the context parameters the user sets. - * The returned map does not include the parameters that have been removed via {@link #removeUserParam}. + * Return a value as an {@code long}, returning the default value if the + * context value is not set. * - * Callers should use {@code getX} methods or {@link #getMergedParams()} instead to use the whole context params. + * @throws BadQueryContextException for an invalid value */ - public Map getUserParams() + public float getFloat(final String key, final float defaultValue) { - return userParams; + return QueryContexts.getAsFloat(key, get(key), defaultValue); + } + + public HumanReadableBytes getHumanReadableBytes(final String key, final HumanReadableBytes defaultValue) + { + return QueryContexts.getAsHumanReadableBytes(key, get(key), defaultValue); + } + + public > E getEnum(String key, Class clazz, E defaultValue) + { + return QueryContexts.getAsEnum(key, get(key), clazz, defaultValue); + } + + public Granularity getGranularity(String key) + { + final Object value = get(key); + if (value == null) { + return null; + } + if (value instanceof Granularity) { + return (Granularity) value; + } else { + throw QueryContexts.badTypeException(key, "a Granularity", value); + } } public boolean isDebug() { - return getAsBoolean(QueryContexts.ENABLE_DEBUG, QueryContexts.DEFAULT_ENABLE_DEBUG); + return getBoolean(QueryContexts.ENABLE_DEBUG, QueryContexts.DEFAULT_ENABLE_DEBUG); } - public boolean isEnableJoinLeftScanDirect() + public boolean isBySegment() { - return getAsBoolean( - QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT, - QueryContexts.DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT + return isBySegment(QueryContexts.DEFAULT_BY_SEGMENT); + } + + public boolean isBySegment(boolean defaultValue) + { + return getBoolean(QueryContexts.BY_SEGMENT_KEY, defaultValue); + } + + public boolean isPopulateCache() + { + return isPopulateCache(QueryContexts.DEFAULT_POPULATE_CACHE); + } + + public boolean isPopulateCache(boolean defaultValue) + { + return getBoolean(QueryContexts.POPULATE_CACHE_KEY, defaultValue); + } + + public boolean isUseCache() + { + return isUseCache(QueryContexts.DEFAULT_USE_CACHE); + } + + public boolean isUseCache(boolean defaultValue) + { + return getBoolean(QueryContexts.USE_CACHE_KEY, defaultValue); + } + + public boolean isPopulateResultLevelCache() + { + return isPopulateResultLevelCache(QueryContexts.DEFAULT_POPULATE_RESULTLEVEL_CACHE); + } + + public boolean isPopulateResultLevelCache(boolean defaultValue) + { + return getBoolean(QueryContexts.POPULATE_RESULT_LEVEL_CACHE_KEY, defaultValue); + } + + public boolean isUseResultLevelCache() + { + return isUseResultLevelCache(QueryContexts.DEFAULT_USE_RESULTLEVEL_CACHE); + } + + public boolean isUseResultLevelCache(boolean defaultValue) + { + return getBoolean(QueryContexts.USE_RESULT_LEVEL_CACHE_KEY, defaultValue); + } + + public boolean isFinalize(boolean defaultValue) + + { + return getBoolean(QueryContexts.FINALIZE_KEY, defaultValue); + } + + public boolean isSerializeDateTimeAsLong(boolean defaultValue) + { + return getBoolean(QueryContexts.SERIALIZE_DATE_TIME_AS_LONG_KEY, defaultValue); + } + + public boolean isSerializeDateTimeAsLongInner(boolean defaultValue) + { + return getBoolean(QueryContexts.SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY, defaultValue); + } + + public Vectorize getVectorize() + { + return getVectorize(QueryContexts.DEFAULT_VECTORIZE); + } + + public Vectorize getVectorize(Vectorize defaultValue) + { + return getEnum(QueryContexts.VECTORIZE_KEY, Vectorize.class, defaultValue); + } + + public Vectorize getVectorizeVirtualColumns() + { + return getVectorizeVirtualColumns(QueryContexts.DEFAULT_VECTORIZE_VIRTUAL_COLUMN); + } + + public Vectorize getVectorizeVirtualColumns(Vectorize defaultValue) + { + return getEnum( + QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, + Vectorize.class, + defaultValue ); } - @SuppressWarnings("unused") - public boolean containsKey(String key) + public int getVectorSize() { - return get(key) != null; + return getVectorSize(QueryableIndexStorageAdapter.DEFAULT_VECTOR_SIZE); } - @Nullable - public Object get(String key) + public int getVectorSize(int defaultSize) { - Object val = systemParams.get(key); - if (val != null) { - return val; - } - val = userParams.get(key); - return val == null ? defaultParams.get(key) : val; + return getInt(QueryContexts.VECTOR_SIZE_KEY, defaultSize); } - @SuppressWarnings("unused") - public Object getOrDefault(String key, Object defaultValue) + public int getMaxSubqueryRows(int defaultSize) { - final Object val = get(key); - return val == null ? defaultValue : val; + return getInt(QueryContexts.MAX_SUBQUERY_ROWS_KEY, defaultSize); } - @Nullable - public String getAsString(String key) + public int getUncoveredIntervalsLimit() { - Object val = get(key); - return val == null ? null : val.toString(); + return getUncoveredIntervalsLimit(QueryContexts.DEFAULT_UNCOVERED_INTERVALS_LIMIT); } - public String getAsString(String key, String defaultValue) + public int getUncoveredIntervalsLimit(int defaultValue) { - Object val = get(key); - return val == null ? defaultValue : val.toString(); + return getInt(QueryContexts.UNCOVERED_INTERVALS_LIMIT_KEY, defaultValue); } - @Nullable - public Boolean getAsBoolean(String key) + public int getPriority() { - return QueryContexts.getAsBoolean(key, get(key)); + return getPriority(QueryContexts.DEFAULT_PRIORITY); } - public boolean getAsBoolean( - final String key, - final boolean defaultValue - ) + public int getPriority(int defaultValue) { - return QueryContexts.getAsBoolean(key, get(key), defaultValue); + return getInt(QueryContexts.PRIORITY_KEY, defaultValue); } - public Integer getAsInt(final String key) + public String getLane() { - return QueryContexts.getAsInt(key, get(key)); + return getString(QueryContexts.LANE_KEY); } - public int getAsInt( - final String key, - final int defaultValue - ) + public boolean getEnableParallelMerges() { - return QueryContexts.getAsInt(key, get(key), defaultValue); + return getBoolean( + QueryContexts.BROKER_PARALLEL_MERGE_KEY, + QueryContexts.DEFAULT_ENABLE_PARALLEL_MERGE + ); } - public Long getAsLong(final String key) + public int getParallelMergeInitialYieldRows(int defaultValue) { - return QueryContexts.getAsLong(key, get(key)); + return getInt(QueryContexts.BROKER_PARALLEL_MERGE_INITIAL_YIELD_ROWS_KEY, defaultValue); } - public long getAsLong(final String key, final long defaultValue) + public int getParallelMergeSmallBatchRows(int defaultValue) { - return QueryContexts.getAsLong(key, get(key), defaultValue); + return getInt(QueryContexts.BROKER_PARALLEL_MERGE_SMALL_BATCH_ROWS_KEY, defaultValue); } - public HumanReadableBytes getAsHumanReadableBytes(final String key, final HumanReadableBytes defaultValue) + public int getParallelMergeParallelism(int defaultValue) { - return QueryContexts.getAsHumanReadableBytes(key, get(key), defaultValue); + return getInt(QueryContexts.BROKER_PARALLELISM, defaultValue); } - public float getAsFloat(final String key, final float defaultValue) + public long getJoinFilterRewriteMaxSize() { - return QueryContexts.getAsFloat(key, get(key), defaultValue); + return getLong( + QueryContexts.JOIN_FILTER_REWRITE_MAX_SIZE_KEY, + QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE + ); } - public > E getAsEnum(String key, Class clazz, E defaultValue) + public boolean getEnableJoinFilterPushDown() { - return QueryContexts.getAsEnum(key, get(key), clazz, defaultValue); + return getBoolean( + QueryContexts.JOIN_FILTER_PUSH_DOWN_KEY, + QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN + ); + } + + public boolean getEnableJoinFilterRewrite() + { + return getBoolean( + QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, + QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE + ); + } + + public boolean isSecondaryPartitionPruningEnabled() + { + return getBoolean( + QueryContexts.SECONDARY_PARTITION_PRUNING_KEY, + QueryContexts.DEFAULT_SECONDARY_PARTITION_PRUNING + ); + } + + public long getMaxQueuedBytes(long defaultValue) + { + return getLong(QueryContexts.MAX_QUEUED_BYTES_KEY, defaultValue); } - public Map getMergedParams() + public long getMaxScatterGatherBytes() { - if (mergedParams == null) { - final Map merged = new TreeMap<>(defaultParams); - merged.putAll(userParams); - merged.putAll(systemParams); - mergedParams = Collections.unmodifiableMap(merged); + return getLong(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE); + } + + public boolean hasTimeout() + { + return getTimeout() != QueryContexts.NO_TIMEOUT; + } + + public long getTimeout() + { + return getTimeout(getDefaultTimeout()); + } + + public long getTimeout(long defaultTimeout) + { + final long timeout = getLong(QueryContexts.TIMEOUT_KEY, defaultTimeout); + if (timeout >= 0) { + return timeout; } - return mergedParams; + throw new BadQueryContextException( + StringUtils.format( + "Timeout [%s] must be a non negative value, but was %d", + QueryContexts.TIMEOUT_KEY, + timeout + ) + ); } - public QueryContext copy() + public long getDefaultTimeout() { - return new QueryContext( - new TreeMap<>(defaultParams), - new TreeMap<>(userParams), - new TreeMap<>(systemParams) + final long defaultTimeout = getLong(QueryContexts.DEFAULT_TIMEOUT_KEY, QueryContexts.DEFAULT_TIMEOUT_MILLIS); + if (defaultTimeout >= 0) { + return defaultTimeout; + } + throw new BadQueryContextException( + StringUtils.format( + "Timeout [%s] must be a non negative value, but was %d", + QueryContexts.DEFAULT_TIMEOUT_KEY, + defaultTimeout + ) ); } + public void verifyMaxQueryTimeout(long maxQueryTimeout) + { + long timeout = getTimeout(); + if (timeout > maxQueryTimeout) { + throw new BadQueryContextException( + StringUtils.format( + "Configured %s = %d is more than enforced limit of %d.", + QueryContexts.TIMEOUT_KEY, + timeout, + maxQueryTimeout + ) + ); + } + } + + public void verifyMaxScatterGatherBytes(long maxScatterGatherBytesLimit) + { + long curr = getLong(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, 0); + if (curr > maxScatterGatherBytesLimit) { + throw new BadQueryContextException( + StringUtils.format( + "Configured %s = %d is more than enforced limit of %d.", + QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, + curr, + maxScatterGatherBytesLimit + ) + ); + } + } + + public int getNumRetriesOnMissingSegments(int defaultValue) + { + return getInt(QueryContexts.NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, defaultValue); + } + + public boolean allowReturnPartialResults(boolean defaultValue) + { + return getBoolean(QueryContexts.RETURN_PARTIAL_RESULTS_KEY, defaultValue); + } + + public boolean getEnableJoinFilterRewriteValueColumnFilters() + { + return getBoolean( + QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, + QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS + ); + } + + public boolean getEnableRewriteJoinToFilter() + { + return getBoolean( + QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, + QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER + ); + } + + public boolean getEnableJoinLeftScanDirect() + { + return getBoolean( + QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT, + QueryContexts.DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT + ); + } + + public int getInSubQueryThreshold() + { + return getInSubQueryThreshold(QueryContexts.DEFAULT_IN_SUB_QUERY_THRESHOLD); + } + + public int getInSubQueryThreshold(int defaultValue) + { + return getInt( + QueryContexts.IN_SUB_QUERY_THRESHOLD_KEY, + defaultValue + ); + } + + public boolean isTimeBoundaryPlanningEnabled() + { + return getBoolean( + QueryContexts.TIME_BOUNDARY_PLANNING_KEY, + QueryContexts.DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING + ); + } + + public String getBrokerServiceName() + { + return getString(QueryContexts.BROKER_SERVICE_NAME); + } + @Override public boolean equals(Object o) { @@ -259,23 +563,21 @@ public boolean equals(Object o) if (o == null || getClass() != o.getClass()) { return false; } - QueryContext context = (QueryContext) o; - return getMergedParams().equals(context.getMergedParams()); + QueryContext other = (QueryContext) o; + return context.equals(other.context); } @Override public int hashCode() { - return Objects.hash(getMergedParams()); + return Objects.hash(context); } @Override public String toString() { return "QueryContext{" + - "defaultParams=" + defaultParams + - ", userParams=" + userParams + - ", systemParams=" + systemParams + + "context=" + context + '}'; } } diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java index 6138979faca7..2aec4ea04484 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java @@ -21,18 +21,17 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableMap; import org.apache.druid.guice.annotations.PublicApi; import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Numbers; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.segment.QueryableIndexStorageAdapter; import javax.annotation.Nullable; + import java.util.Map; +import java.util.Map.Entry; import java.util.TreeMap; import java.util.concurrent.TimeUnit; @@ -80,6 +79,8 @@ public class QueryContexts public static final String SERIALIZE_DATE_TIME_AS_LONG_KEY = "serializeDateTimeAsLong"; public static final String SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY = "serializeDateTimeAsLongInner"; public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit"; + public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold"; + public static final boolean DEFAULT_BY_SEGMENT = false; public static final boolean DEFAULT_POPULATE_CACHE = true; @@ -150,330 +151,40 @@ public String toString() } } - public static boolean isBySegment(Query query) - { - return isBySegment(query, DEFAULT_BY_SEGMENT); - } - - public static boolean isBySegment(Query query, boolean defaultValue) - { - return query.getContextBoolean(BY_SEGMENT_KEY, defaultValue); - } - - public static boolean isPopulateCache(Query query) - { - return isPopulateCache(query, DEFAULT_POPULATE_CACHE); - } - - public static boolean isPopulateCache(Query query, boolean defaultValue) - { - return query.getContextBoolean(POPULATE_CACHE_KEY, defaultValue); - } - - public static boolean isUseCache(Query query) - { - return isUseCache(query, DEFAULT_USE_CACHE); - } - - public static boolean isUseCache(Query query, boolean defaultValue) - { - return query.getContextBoolean(USE_CACHE_KEY, defaultValue); - } - - public static boolean isPopulateResultLevelCache(Query query) - { - return isPopulateResultLevelCache(query, DEFAULT_POPULATE_RESULTLEVEL_CACHE); - } - - public static boolean isPopulateResultLevelCache(Query query, boolean defaultValue) - { - return query.getContextBoolean(POPULATE_RESULT_LEVEL_CACHE_KEY, defaultValue); - } - - public static boolean isUseResultLevelCache(Query query) - { - return isUseResultLevelCache(query, DEFAULT_USE_RESULTLEVEL_CACHE); - } - - public static boolean isUseResultLevelCache(Query query, boolean defaultValue) - { - return query.getContextBoolean(USE_RESULT_LEVEL_CACHE_KEY, defaultValue); - } - - public static boolean isFinalize(Query query, boolean defaultValue) - - { - return query.getContextBoolean(FINALIZE_KEY, defaultValue); - } - - public static boolean isSerializeDateTimeAsLong(Query query, boolean defaultValue) - { - return query.getContextBoolean(SERIALIZE_DATE_TIME_AS_LONG_KEY, defaultValue); - } - - public static boolean isSerializeDateTimeAsLongInner(Query query, boolean defaultValue) - { - return query.getContextBoolean(SERIALIZE_DATE_TIME_AS_LONG_INNER_KEY, defaultValue); - } - - public static Vectorize getVectorize(Query query) - { - return getVectorize(query, QueryContexts.DEFAULT_VECTORIZE); - } - - public static Vectorize getVectorize(Query query, Vectorize defaultValue) - { - return query.getQueryContext().getAsEnum(VECTORIZE_KEY, Vectorize.class, defaultValue); - } - - public static Vectorize getVectorizeVirtualColumns(Query query) - { - return getVectorizeVirtualColumns(query, QueryContexts.DEFAULT_VECTORIZE_VIRTUAL_COLUMN); - } - - public static Vectorize getVectorizeVirtualColumns(Query query, Vectorize defaultValue) - { - return query.getQueryContext().getAsEnum(VECTORIZE_VIRTUAL_COLUMNS_KEY, Vectorize.class, defaultValue); - } - - public static int getVectorSize(Query query) - { - return getVectorSize(query, QueryableIndexStorageAdapter.DEFAULT_VECTOR_SIZE); - } - - public static int getVectorSize(Query query, int defaultSize) - { - return query.getQueryContext().getAsInt(VECTOR_SIZE_KEY, defaultSize); - } - - public static int getMaxSubqueryRows(Query query, int defaultSize) - { - return query.getQueryContext().getAsInt(MAX_SUBQUERY_ROWS_KEY, defaultSize); - } - - public static int getUncoveredIntervalsLimit(Query query) - { - return getUncoveredIntervalsLimit(query, DEFAULT_UNCOVERED_INTERVALS_LIMIT); - } - - public static int getUncoveredIntervalsLimit(Query query, int defaultValue) - { - return query.getQueryContext().getAsInt(UNCOVERED_INTERVALS_LIMIT_KEY, defaultValue); - } - - public static int getPriority(Query query) - { - return getPriority(query, DEFAULT_PRIORITY); - } - - public static int getPriority(Query query, int defaultValue) - { - return query.getQueryContext().getAsInt(PRIORITY_KEY, defaultValue); - } - - public static String getLane(Query query) - { - return query.getQueryContext().getAsString(LANE_KEY); - } - - public static boolean getEnableParallelMerges(Query query) - { - return query.getContextBoolean(BROKER_PARALLEL_MERGE_KEY, DEFAULT_ENABLE_PARALLEL_MERGE); - } - - public static int getParallelMergeInitialYieldRows(Query query, int defaultValue) - { - return query.getQueryContext().getAsInt(BROKER_PARALLEL_MERGE_INITIAL_YIELD_ROWS_KEY, defaultValue); - } - - public static int getParallelMergeSmallBatchRows(Query query, int defaultValue) - { - return query.getQueryContext().getAsInt(BROKER_PARALLEL_MERGE_SMALL_BATCH_ROWS_KEY, defaultValue); - } - - public static int getParallelMergeParallelism(Query query, int defaultValue) - { - return query.getQueryContext().getAsInt(BROKER_PARALLELISM, defaultValue); - } - - public static boolean getEnableJoinFilterRewriteValueColumnFilters(Query query) - { - return query.getContextBoolean( - JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, - DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS - ); - } - - public static boolean getEnableRewriteJoinToFilter(Query query) - { - return query.getContextBoolean( - REWRITE_JOIN_TO_FILTER_ENABLE_KEY, - DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER - ); - } - - public static long getJoinFilterRewriteMaxSize(Query query) - { - return query.getQueryContext().getAsLong(JOIN_FILTER_REWRITE_MAX_SIZE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE); - } - - public static boolean getEnableJoinFilterPushDown(Query query) - { - return query.getContextBoolean(JOIN_FILTER_PUSH_DOWN_KEY, DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN); - } - - public static boolean getEnableJoinFilterRewrite(Query query) - { - return query.getContextBoolean(JOIN_FILTER_REWRITE_ENABLE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE); - } - - public static boolean getEnableJoinLeftScanDirect(Map context) - { - return parseBoolean(context, SQL_JOIN_LEFT_SCAN_DIRECT, DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT); - } - - public static boolean isSecondaryPartitionPruningEnabled(Query query) - { - return query.getContextBoolean(SECONDARY_PARTITION_PRUNING_KEY, DEFAULT_SECONDARY_PARTITION_PRUNING); - } - - public static boolean isDebug(Query query) - { - return query.getContextBoolean(ENABLE_DEBUG, DEFAULT_ENABLE_DEBUG); - } - - public static boolean isDebug(Map queryContext) - { - return parseBoolean(queryContext, ENABLE_DEBUG, DEFAULT_ENABLE_DEBUG); - } - - public static int getInSubQueryThreshold(Map context) - { - return getInSubQueryThreshold(context, DEFAULT_IN_SUB_QUERY_THRESHOLD); - } - - public static int getInSubQueryThreshold(Map context, int defaultValue) - { - return parseInt(context, IN_SUB_QUERY_THRESHOLD_KEY, defaultValue); - } - - public static boolean isTimeBoundaryPlanningEnabled(Map queryContext) - { - return parseBoolean(queryContext, TIME_BOUNDARY_PLANNING_KEY, DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING); - } - - public static Query withMaxScatterGatherBytes(Query query, long maxScatterGatherBytesLimit) - { - Long curr = query.getQueryContext().getAsLong(MAX_SCATTER_GATHER_BYTES_KEY); - if (curr == null) { - return query.withOverriddenContext(ImmutableMap.of(MAX_SCATTER_GATHER_BYTES_KEY, maxScatterGatherBytesLimit)); - } else { - if (curr > maxScatterGatherBytesLimit) { - throw new IAE( - "configured [%s = %s] is more than enforced limit of [%s].", - MAX_SCATTER_GATHER_BYTES_KEY, - curr, - maxScatterGatherBytesLimit - ); - } else { - return query; - } - } - } - - public static Query verifyMaxQueryTimeout(Query query, long maxQueryTimeout) - { - long timeout = getTimeout(query); - if (timeout > maxQueryTimeout) { - throw new IAE( - "configured [%s = %s] is more than enforced limit of maxQueryTimeout [%s].", - TIMEOUT_KEY, - timeout, - maxQueryTimeout - ); - } else { - return query; - } - } - - public static long getMaxQueuedBytes(Query query, long defaultValue) - { - return query.getQueryContext().getAsLong(MAX_QUEUED_BYTES_KEY, defaultValue); - } - - public static long getMaxScatterGatherBytes(Query query) - { - return query.getQueryContext().getAsLong(MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE); - } - - public static boolean hasTimeout(Query query) - { - return getTimeout(query) != NO_TIMEOUT; - } - - public static long getTimeout(Query query) - { - return getTimeout(query, getDefaultTimeout(query)); - } - - public static long getTimeout(Query query, long defaultTimeout) - { - try { - final long timeout = query.getQueryContext().getAsLong(TIMEOUT_KEY, defaultTimeout); - Preconditions.checkState(timeout >= 0, "Timeout must be a non negative value, but was [%s]", timeout); - return timeout; - } - catch (IAE e) { - throw new BadQueryContextException(e); - } - } - - public static Query withTimeout(Query query, long timeout) - { - return query.withOverriddenContext(ImmutableMap.of(TIMEOUT_KEY, timeout)); - } - - public static Query withDefaultTimeout(Query query, long defaultTimeout) - { - return query.withOverriddenContext(ImmutableMap.of(QueryContexts.DEFAULT_TIMEOUT_KEY, defaultTimeout)); - } - - static long getDefaultTimeout(Query query) + private QueryContexts() { - final long defaultTimeout = query.getQueryContext().getAsLong(DEFAULT_TIMEOUT_KEY, DEFAULT_TIMEOUT_MILLIS); - Preconditions.checkState(defaultTimeout >= 0, "Timeout must be a non negative value, but was [%s]", defaultTimeout); - return defaultTimeout; } - public static int getNumRetriesOnMissingSegments(Query query, int defaultValue) + public static long parseLong(Map context, String key, long defaultValue) { - return query.getQueryContext().getAsInt(NUM_RETRIES_ON_MISSING_SEGMENTS_KEY, defaultValue); + return getAsLong(key, context.get(key), defaultValue); } - public static boolean allowReturnPartialResults(Query query, boolean defaultValue) + public static int parseInt(Map context, String key, int defaultValue) { - return query.getContextBoolean(RETURN_PARTIAL_RESULTS_KEY, defaultValue); + return getAsInt(key, context.get(key), defaultValue); } - public static String getBrokerServiceName(Map queryContext) + @Nullable + public static String parseString(Map context, String key) { - return queryContext == null ? null : (String) queryContext.get(BROKER_SERVICE_NAME); + return parseString(context, key, null); } - @SuppressWarnings("unused") - static long parseLong(Map context, String key, long defaultValue) + public static boolean parseBoolean(Map context, String key, boolean defaultValue) { - return getAsLong(key, context.get(key), defaultValue); + return getAsBoolean(key, context.get(key), defaultValue); } - static int parseInt(Map context, String key, int defaultValue) + public static String parseString(Map context, String key, String defaultValue) { - return getAsInt(key, context.get(key), defaultValue); + return getAsString(key, context.get(key), defaultValue); } - static boolean parseBoolean(Map context, String key, boolean defaultValue) + @SuppressWarnings("unused") // To keep IntelliJ inspections happy + public static float parseFloat(Map context, String key, float defaultValue) { - return getAsBoolean(key, context.get(key), defaultValue); + return getAsFloat(key, context.get(key), defaultValue); } public static String getAsString( @@ -486,14 +197,13 @@ public static String getAsString( return defaultValue; } else if (value instanceof String) { return (String) value; - } else { - throw new IAE("Expected key [%s] to be a String, but got [%s]", key, value.getClass().getName()); } + throw badTypeException(key, "a String", value); } @Nullable public static Boolean getAsBoolean( - final String parameter, + final String key, final Object value ) { @@ -503,13 +213,12 @@ public static Boolean getAsBoolean( return Boolean.parseBoolean((String) value); } else if (value instanceof Boolean) { return (Boolean) value; - } else { - throw new IAE("Expected parameter [%s] to be a Boolean, but got [%s]", parameter, value.getClass().getName()); } + throw badTypeException(key, "a Boolean", value); } /** - * Get the value of a parameter as a {@code boolean}. The parameter is expected + * Get the value of a context value as a {@code boolean}. The value is expected * to be {@code null}, a string or a {@code Boolean} object. */ public static boolean getAsBoolean( @@ -534,24 +243,24 @@ public static Integer getAsInt(String key, Object value) return Numbers.parseInt(value); } catch (NumberFormatException ignored) { - throw new IAE("Expected key [%s] in integer format, but got [%s]", key, value); + throw badValueException(key, "in integer format", value); } } - throw new IAE("Expected key [%s] to be an Integer, but got [%s]", key, value.getClass().getName()); + throw badTypeException(key, "an Integer", value); } /** - * Get the value of a parameter as an {@code int}. The parameter is expected + * Get the value of a context value as an {@code int}. The value is expected * to be {@code null}, a string or a {@code Number} object. */ public static int getAsInt( - final String ke, + final String key, final Object value, final int defaultValue ) { - Integer val = getAsInt(ke, value); + Integer val = getAsInt(key, value); return val == null ? defaultValue : val; } @@ -567,14 +276,14 @@ public static Long getAsLong(String key, Object value) return Numbers.parseLong(value); } catch (NumberFormatException ignored) { - throw new IAE("Expected key [%s] in long format, but got [%s]", key, value); + throw badValueException(key, "in long format", value); } } - throw new IAE("Expected key [%s] to be a Long, but got [%s]", key, value.getClass().getName()); + throw badTypeException(key, "a Long", value); } /** - * Get the value of a parameter as an {@code long}. The parameter is expected + * Get the value of a context value as an {@code long}. The value is expected * to be {@code null}, a string or a {@code Number} object. */ public static long getAsLong( @@ -587,43 +296,57 @@ public static long getAsLong( return val == null ? defaultValue : val; } - public static HumanReadableBytes getAsHumanReadableBytes( - final String parameter, - final Object value, - final HumanReadableBytes defaultValue - ) + /** + * Get the value of a context value as an {@code Float}. The value is expected + * to be {@code null}, a string or a {@code Number} object. + */ + public static Float getAsFloat(final String key, final Object value) { - if (null == value) { - return defaultValue; + if (value == null) { + return null; } else if (value instanceof Number) { - return HumanReadableBytes.valueOf(Numbers.parseLong(value)); + return ((Number) value).floatValue(); } else if (value instanceof String) { try { - return HumanReadableBytes.valueOf(HumanReadableBytes.parse((String) value)); + return Float.parseFloat((String) value); } - catch (IAE e) { - throw new IAE("Expected key [%s] in human readable format, but got [%s]", parameter, value); + catch (NumberFormatException ignored) { + throw badValueException(key, "in float format", value); } } + throw badTypeException(key, "a Float", value); + } - throw new IAE("Expected key [%s] to be a human readable number, but got [%s]", parameter, value.getClass().getName()); + public static float getAsFloat( + final String key, + final Object value, + final float defaultValue + ) + { + Float val = getAsFloat(key, value); + return val == null ? defaultValue : val; } - public static float getAsFloat(String key, Object value, float defaultValue) + public static HumanReadableBytes getAsHumanReadableBytes( + final String key, + final Object value, + final HumanReadableBytes defaultValue + ) { if (null == value) { return defaultValue; } else if (value instanceof Number) { - return ((Number) value).floatValue(); + return HumanReadableBytes.valueOf(Numbers.parseLong(value)); } else if (value instanceof String) { try { - return Float.parseFloat((String) value); + return HumanReadableBytes.valueOf(HumanReadableBytes.parse((String) value)); } - catch (NumberFormatException ignored) { - throw new IAE("Expected key [%s] in float format, but got [%s]", key, value); + catch (IAE e) { + throw badValueException(key, "a human readable number", value); } } - throw new IAE("Expected key [%s] to be a Float, but got [%s]", key, value.getClass().getName()); + + throw badTypeException(key, "a human readable number", value); } public static Map override( @@ -635,40 +358,77 @@ public static Map override( if (context != null) { overridden.putAll(context); } - overridden.putAll(overrides); + if (overrides != null) { + overridden.putAll(overrides); + } return overridden; } - private QueryContexts() - { - } - - public static > E getAsEnum(String key, Object val, Class clazz, E defaultValue) + public static > E getAsEnum(String key, Object value, Class clazz, E defaultValue) { - if (val == null) { + if (value == null) { return defaultValue; } try { - if (val instanceof String) { - return Enum.valueOf(clazz, StringUtils.toUpperCase((String) val)); - } else if (val instanceof Boolean) { - return Enum.valueOf(clazz, StringUtils.toUpperCase(String.valueOf(val))); + if (value instanceof String) { + return Enum.valueOf(clazz, StringUtils.toUpperCase((String) value)); + } else if (value instanceof Boolean) { + return Enum.valueOf(clazz, StringUtils.toUpperCase(String.valueOf(value))); } } catch (IllegalArgumentException e) { - throw new IAE("Expected key [%s] must be value of enum [%s], but got [%s].", - key, - clazz.getName(), - val.toString()); + throw badValueException( + key, + StringUtils.format("a value of enum [%s]", clazz.getSimpleName()), + value + ); } - throw new ISE( - "Expected key [%s] must be type of [%s], actual type is [%s].", + throw badTypeException( key, - clazz.getName(), - val.getClass() + StringUtils.format("of type [%s]", clazz.getSimpleName()), + value + ); + } + + public static BadQueryContextException badValueException( + final String key, + final String expected, + final Object actual + ) + { + return new BadQueryContextException( + StringUtils.format( + "Expected key [%s] to be in %s, but got [%s]", + key, + expected, + actual + ) + ); + } + + public static BadQueryContextException badTypeException( + final String key, + final String expected, + final Object actual + ) + { + return new BadQueryContextException( + StringUtils.format( + "Expected key [%s] to be %s, but got [%s]", + key, + expected, + actual.getClass().getName() + ) ); } + + public static void addDefaults(Map context, Map defaults) + { + for (Entry entry : defaults.entrySet()) { + context.putIfAbsent(entry.getKey(), entry.getValue()); + } + } } diff --git a/processing/src/main/java/org/apache/druid/query/SubqueryQueryRunner.java b/processing/src/main/java/org/apache/druid/query/SubqueryQueryRunner.java index a6bcbb8e1364..cd17cc13e4de 100644 --- a/processing/src/main/java/org/apache/druid/query/SubqueryQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/SubqueryQueryRunner.java @@ -41,7 +41,7 @@ public Sequence run(final QueryPlus queryPlus, ResponseContext responseCon { DataSource dataSource = queryPlus.getQuery().getDataSource(); boolean forcePushDownNestedQuery = queryPlus.getQuery() - .getContextBoolean( + .context().getBoolean( GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, false ); diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java index 31d4b031e664..9a63a796d971 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java @@ -450,7 +450,7 @@ public String getType() @JsonIgnore public boolean getContextSortByDimsFirst() { - return getContextBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false); + return context().getBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false); } @JsonIgnore @@ -465,7 +465,7 @@ public boolean isApplyLimitPushDown() @JsonIgnore public boolean getApplyLimitPushDownFromContext() { - return getContextBoolean(GroupByQueryConfig.CTX_KEY_APPLY_LIMIT_PUSH_DOWN, true); + return context().getBoolean(GroupByQueryConfig.CTX_KEY_APPLY_LIMIT_PUSH_DOWN, true); } @Override @@ -487,7 +487,7 @@ public Ordering getResultOrdering() private boolean validateAndGetForceLimitPushDown() { - final boolean forcePushDown = getContextBoolean(GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, false); + final boolean forcePushDown = context().getBoolean(GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, false); if (forcePushDown) { if (!(limitSpec instanceof DefaultLimitSpec)) { throw new IAE("When forcing limit push down, a limit spec must be provided."); @@ -748,7 +748,7 @@ private int compareDims(List dimensions, ResultRow lhs, ResultRow @Nullable private DateTime computeUniversalTimestamp() { - final String timestampStringFromContext = getQueryContext().getAsString(CTX_KEY_FUDGE_TIMESTAMP, ""); + final String timestampStringFromContext = context().getString(CTX_KEY_FUDGE_TIMESTAMP, ""); final Granularity granularity = getGranularity(); if (!timestampStringFromContext.isEmpty()) { diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryConfig.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryConfig.java index ac018b942c0b..380cf825f603 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryConfig.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryConfig.java @@ -23,6 +23,7 @@ import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.groupby.strategy.GroupByStrategySelector; import org.apache.druid.utils.JvmUtils; @@ -335,25 +336,26 @@ public boolean isMultiValueUnnestingEnabled() public GroupByQueryConfig withOverrides(final GroupByQuery query) { final GroupByQueryConfig newConfig = new GroupByQueryConfig(); - newConfig.defaultStrategy = query.getQueryContext().getAsString(CTX_KEY_STRATEGY, getDefaultStrategy()); - newConfig.singleThreaded = query.getQueryContext().getAsBoolean(CTX_KEY_IS_SINGLE_THREADED, isSingleThreaded()); + final QueryContext queryContext = query.context(); + newConfig.defaultStrategy = queryContext.getString(CTX_KEY_STRATEGY, getDefaultStrategy()); + newConfig.singleThreaded = queryContext.getBoolean(CTX_KEY_IS_SINGLE_THREADED, isSingleThreaded()); newConfig.maxIntermediateRows = Math.min( - query.getQueryContext().getAsInt(CTX_KEY_MAX_INTERMEDIATE_ROWS, getMaxIntermediateRows()), + queryContext.getInt(CTX_KEY_MAX_INTERMEDIATE_ROWS, getMaxIntermediateRows()), getMaxIntermediateRows() ); newConfig.maxResults = Math.min( - query.getQueryContext().getAsInt(CTX_KEY_MAX_RESULTS, getMaxResults()), + queryContext.getInt(CTX_KEY_MAX_RESULTS, getMaxResults()), getMaxResults() ); newConfig.bufferGrouperMaxSize = Math.min( - query.getQueryContext().getAsInt(CTX_KEY_BUFFER_GROUPER_MAX_SIZE, getBufferGrouperMaxSize()), + queryContext.getInt(CTX_KEY_BUFFER_GROUPER_MAX_SIZE, getBufferGrouperMaxSize()), getBufferGrouperMaxSize() ); - newConfig.bufferGrouperMaxLoadFactor = query.getQueryContext().getAsFloat( + newConfig.bufferGrouperMaxLoadFactor = queryContext.getFloat( CTX_KEY_BUFFER_GROUPER_MAX_LOAD_FACTOR, getBufferGrouperMaxLoadFactor() ); - newConfig.bufferGrouperInitialBuckets = query.getQueryContext().getAsInt( + newConfig.bufferGrouperInitialBuckets = queryContext.getInt( CTX_KEY_BUFFER_GROUPER_INITIAL_BUCKETS, getBufferGrouperInitialBuckets() ); @@ -362,33 +364,33 @@ public GroupByQueryConfig withOverrides(final GroupByQuery query) // choose a default value lower than the max allowed when the context key is missing in the client query. newConfig.maxOnDiskStorage = HumanReadableBytes.valueOf( Math.min( - query.getContextAsHumanReadableBytes(CTX_KEY_MAX_ON_DISK_STORAGE, getDefaultOnDiskStorage()).getBytes(), + queryContext.getHumanReadableBytes(CTX_KEY_MAX_ON_DISK_STORAGE, getDefaultOnDiskStorage()).getBytes(), getMaxOnDiskStorage().getBytes() ) ); newConfig.maxSelectorDictionarySize = maxSelectorDictionarySize; // No overrides newConfig.maxMergingDictionarySize = maxMergingDictionarySize; // No overrides - newConfig.forcePushDownLimit = query.getContextBoolean(CTX_KEY_FORCE_LIMIT_PUSH_DOWN, isForcePushDownLimit()); - newConfig.applyLimitPushDownToSegment = query.getContextBoolean( + newConfig.forcePushDownLimit = queryContext.getBoolean(CTX_KEY_FORCE_LIMIT_PUSH_DOWN, isForcePushDownLimit()); + newConfig.applyLimitPushDownToSegment = queryContext.getBoolean( CTX_KEY_APPLY_LIMIT_PUSH_DOWN_TO_SEGMENT, isApplyLimitPushDownToSegment() ); - newConfig.forceHashAggregation = query.getContextBoolean(CTX_KEY_FORCE_HASH_AGGREGATION, isForceHashAggregation()); - newConfig.forcePushDownNestedQuery = query.getContextBoolean( + newConfig.forceHashAggregation = queryContext.getBoolean(CTX_KEY_FORCE_HASH_AGGREGATION, isForceHashAggregation()); + newConfig.forcePushDownNestedQuery = queryContext.getBoolean( CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, isForcePushDownNestedQuery() ); - newConfig.intermediateCombineDegree = query.getQueryContext().getAsInt( + newConfig.intermediateCombineDegree = queryContext.getInt( CTX_KEY_INTERMEDIATE_COMBINE_DEGREE, getIntermediateCombineDegree() ); - newConfig.numParallelCombineThreads = query.getQueryContext().getAsInt( + newConfig.numParallelCombineThreads = queryContext.getInt( CTX_KEY_NUM_PARALLEL_COMBINE_THREADS, getNumParallelCombineThreads() ); - newConfig.mergeThreadLocal = query.getContextBoolean(CTX_KEY_MERGE_THREAD_LOCAL, isMergeThreadLocal()); - newConfig.vectorize = query.getContextBoolean(QueryContexts.VECTORIZE_KEY, isVectorize()); - newConfig.enableMultiValueUnnesting = query.getContextBoolean( + newConfig.mergeThreadLocal = queryContext.getBoolean(CTX_KEY_MERGE_THREAD_LOCAL, isMergeThreadLocal()); + newConfig.vectorize = queryContext.getBoolean(QueryContexts.VECTORIZE_KEY, isVectorize()); + newConfig.enableMultiValueUnnesting = queryContext.getBoolean( CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, isMultiValueUnnestingEnabled() ); diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryEngine.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryEngine.java index d971e1bcce2c..e4236f411d02 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryEngine.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryEngine.java @@ -96,7 +96,7 @@ public Sequence process( "Null storage adapter found. Probably trying to issue a query against a segment being memory unmapped." ); } - if (!query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true)) { + if (!query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true)) { throw new UOE( "GroupBy v1 does not support %s as false. Set %s to true or use groupBy v2", GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryHelper.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryHelper.java index 4fd84c9b62b3..9f65fac85fba 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryHelper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryHelper.java @@ -100,7 +100,7 @@ public String apply(DimensionSpec input) ); final IncrementalIndex index; - final boolean sortResults = query.getContextBoolean(CTX_KEY_SORT_RESULTS, true); + final boolean sortResults = query.context().getBoolean(CTX_KEY_SORT_RESULTS, true); // All groupBy dimensions are strings, for now. final List dimensionSchemas = new ArrayList<>(); @@ -118,7 +118,7 @@ public String apply(DimensionSpec input) final AppendableIndexBuilder indexBuilder; - if (query.getContextBoolean("useOffheap", false)) { + if (query.context().getBoolean("useOffheap", false)) { throw new UnsupportedOperationException( "The 'useOffheap' option is no longer available for groupBy v1. Please move to the newer groupBy engine, " + "which always operates off-heap, by removing any custom 'druid.query.groupBy.defaultStrategy' runtime " diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChest.java index 0ae0a67d4591..173e6babd0d4 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQueryQueryToolChest.java @@ -45,7 +45,6 @@ import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.DataSource; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; @@ -118,7 +117,7 @@ public GroupByQueryQueryToolChest( public QueryRunner mergeResults(final QueryRunner runner) { return (queryPlus, responseContext) -> { - if (QueryContexts.isBySegment(queryPlus.getQuery())) { + if (queryPlus.getQuery().context().isBySegment()) { return runner.run(queryPlus, responseContext); } @@ -304,7 +303,7 @@ GroupByQuery rewriteNestedQueryForPushDown(GroupByQuery query) private Sequence finalizeSubqueryResults(Sequence subqueryResult, GroupByQuery subquery) { final Sequence finalizingResults; - if (QueryContexts.isFinalize(subquery, false)) { + if (subquery.context().isFinalize(false)) { finalizingResults = new MappedSequence<>( subqueryResult, makePreComputeManipulatorFn( @@ -321,7 +320,7 @@ private Sequence finalizeSubqueryResults(Sequence subquery public static boolean isNestedQueryPushDown(GroupByQuery q, GroupByStrategy strategy) { return q.getDataSource() instanceof QueryDataSource - && q.getContextBoolean(GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, false) + && q.context().getBoolean(GroupByQueryConfig.CTX_KEY_FORCE_PUSH_DOWN_NESTED_QUERY, false) && q.getSubtotalsSpec() == null && strategy.supportsNestedQueryPushDown(); } @@ -418,7 +417,7 @@ public TypeReference getResultTypeReference() @Override public ObjectMapper decorateObjectMapper(final ObjectMapper objectMapper, final GroupByQuery query) { - final boolean resultAsArray = query.getContextBoolean(GroupByQueryConfig.CTX_KEY_ARRAY_RESULT_ROWS, false); + final boolean resultAsArray = query.context().getBoolean(GroupByQueryConfig.CTX_KEY_ARRAY_RESULT_ROWS, false); if (resultAsArray && !queryConfig.isIntermediateResultAsMapCompat()) { // We can assume ResultRow are serialized and deserialized as arrays. No need for special decoration, diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java index 59642bb91974..6718dff9f834 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java @@ -45,7 +45,7 @@ import org.apache.druid.query.AbstractPrioritizedQueryRunnerCallable; import org.apache.druid.query.ChainedExecutionQueryRunner; import org.apache.druid.query.DruidProcessingConfig; -import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryProcessingPool; @@ -134,7 +134,7 @@ public Sequence run(final QueryPlus queryPlus, final Respo // merge buffer, otherwise the query will allocate too many merge buffers. This is potentially sub-optimal as it // will involve materializing the results for each sink before starting to feed them into the outer merge buffer. // I'm not sure of a better way to do this without tweaking how realtime servers do queries. - final boolean forceChainedExecution = query.getContextBoolean( + final boolean forceChainedExecution = query.context().getBoolean( CTX_KEY_MERGE_RUNNERS_USING_CHAINED_EXECUTION, false ); @@ -144,7 +144,8 @@ public Sequence run(final QueryPlus queryPlus, final Respo ) .withoutThreadUnsafeState(); - if (QueryContexts.isBySegment(query) || forceChainedExecution) { + final QueryContext queryContext = query.context(); + if (queryContext.isBySegment() || forceChainedExecution) { ChainedExecutionQueryRunner runner = new ChainedExecutionQueryRunner<>(queryProcessingPool, queryWatcher, queryables); return runner.run(queryPlusForRunners, responseContext); } @@ -156,12 +157,12 @@ public Sequence run(final QueryPlus queryPlus, final Respo StringUtils.format("druid-groupBy-%s_%s", UUID.randomUUID(), query.getId()) ); - final int priority = QueryContexts.getPriority(query); + final int priority = queryContext.getPriority(); // Figure out timeoutAt time now, so we can apply the timeout to both the mergeBufferPool.take and the actual // query processing together. - final long queryTimeout = QueryContexts.getTimeout(query); - final boolean hasTimeout = QueryContexts.hasTimeout(query); + final long queryTimeout = queryContext.getTimeout(); + final boolean hasTimeout = queryContext.hasTimeout(); final long timeoutAt = System.currentTimeMillis() + queryTimeout; return new BaseSequence<>( diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java index 96055d521d50..d5141ff415ab 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java @@ -34,7 +34,6 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.ColumnSelectorPlus; import org.apache.druid.query.DruidProcessingConfig; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.aggregation.AggregatorAdapters; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.dimension.ColumnSelectorStrategyFactory; @@ -77,6 +76,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.io.Closeable; import java.nio.ByteBuffer; import java.util.Iterator; @@ -141,7 +141,7 @@ public static Sequence process( try { final String fudgeTimestampString = NullHandling.emptyToNullIfNeeded( - query.getQueryContext().getAsString(GroupByStrategyV2.CTX_KEY_FUDGE_TIMESTAMP) + query.context().getString(GroupByStrategyV2.CTX_KEY_FUDGE_TIMESTAMP) ); final DateTime fudgeTimestamp = fudgeTimestampString == null @@ -151,7 +151,7 @@ public static Sequence process( final Filter filter = Filters.convertToCNFFromQueryContext(query, Filters.toFilter(query.getFilter())); final Interval interval = Iterables.getOnlyElement(query.getIntervals()); - final boolean doVectorize = QueryContexts.getVectorize(query).shouldVectorize( + final boolean doVectorize = query.context().getVectorize().shouldVectorize( VectorGroupByEngine.canVectorize(query, storageAdapter, filter) ); @@ -496,7 +496,7 @@ public GroupByEngineIterator( // Time is the same for every row in the cursor this.timestamp = fudgeTimestamp != null ? fudgeTimestamp : cursor.getTime(); this.allSingleValueDims = allSingleValueDims; - this.allowMultiValueGrouping = query.getContextBoolean( + this.allowMultiValueGrouping = query.context().getBoolean( GroupByQueryConfig.CTX_KEY_ENABLE_MULTI_VALUE_UNNESTING, true ); diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngine.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngine.java index ceaa75f85d42..137f7587b8ce 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngine.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngine.java @@ -28,7 +28,6 @@ import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.parsers.CloseableIterator; import org.apache.druid.query.DruidProcessingConfig; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.aggregation.AggregatorAdapters; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.filter.Filter; @@ -56,6 +55,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.io.IOException; import java.nio.ByteBuffer; import java.util.Collections; @@ -150,7 +150,7 @@ public CloseableIterator make() interval, query.getVirtualColumns(), false, - QueryContexts.getVectorSize(query), + query.context().getVectorSize(), groupByQueryMetrics ); diff --git a/processing/src/main/java/org/apache/druid/query/groupby/orderby/DefaultLimitSpec.java b/processing/src/main/java/org/apache/druid/query/groupby/orderby/DefaultLimitSpec.java index 33650cd5d9d1..b49fa0d2a8cc 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/orderby/DefaultLimitSpec.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/orderby/DefaultLimitSpec.java @@ -37,6 +37,7 @@ import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.TopNSequence; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.dimension.DimensionSpec; @@ -232,9 +233,11 @@ public Function, Sequence> build(final GroupByQue } if (!sortingNeeded) { - String timestampField = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD); + final QueryContext queryContext = query.context(); + String timestampField = queryContext.getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD); if (timestampField != null && !timestampField.isEmpty()) { - int timestampResultFieldIndex = query.getQueryContext().getAsInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX); + // Will NPE if the key is not set + int timestampResultFieldIndex = queryContext.getInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX); sortingNeeded = query.getContextSortByDimsFirst() ? timestampResultFieldIndex != query.getDimensions().size() - 1 : timestampResultFieldIndex != 0; diff --git a/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV1.java b/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV1.java index bc475514f61b..8c119d59dd43 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV1.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV1.java @@ -91,7 +91,7 @@ public boolean isCacheable(boolean willMergeRunners) @Override public boolean doMergeResults(final GroupByQuery query) { - return query.getContextBoolean(GroupByQueryQueryToolChest.GROUP_BY_MERGE_KEY, true); + return query.context().getBoolean(GroupByQueryQueryToolChest.GROUP_BY_MERGE_KEY, true); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java b/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java index 060127880122..0ee5078efd81 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/strategy/GroupByStrategyV2.java @@ -44,6 +44,7 @@ import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryPlus; @@ -132,8 +133,9 @@ public GroupByQueryResource prepareResource(GroupByQuery query) return new GroupByQueryResource(); } else { final List> mergeBufferHolders; - if (QueryContexts.hasTimeout(query)) { - mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum, QueryContexts.getTimeout(query)); + final QueryContext context = query.context(); + if (context.hasTimeout()) { + mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum, context.getTimeout()); } else { mergeBufferHolders = mergeBufferPool.takeBatch(requiredMergeBufferNum); } @@ -221,9 +223,10 @@ public Sequence mergeResults( Granularity granularity = query.getGranularity(); List dimensionSpecs = query.getDimensions(); // the CTX_TIMESTAMP_RESULT_FIELD is set in DruidQuery.java - final String timestampResultField = query.getQueryContext().getAsString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD); + final QueryContext queryContext = query.context(); + final String timestampResultField = queryContext.getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD); final boolean hasTimestampResultField = (timestampResultField != null && !timestampResultField.isEmpty()) - && query.getContextBoolean(CTX_KEY_OUTERMOST, true) + && queryContext.getBoolean(CTX_KEY_OUTERMOST, true) && !query.isApplyLimitPushDown(); int timestampResultFieldIndex = 0; if (hasTimestampResultField) { @@ -249,7 +252,7 @@ public Sequence mergeResults( // the granularity and dimensions are slightly different. // now, part of the query plan logic is handled in GroupByStrategyV2, not only in DruidQuery.toGroupByQuery() final Granularity timestampResultFieldGranularity - = query.getContextValue(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY); + = queryContext.getGranularity(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_GRANULARITY); dimensionSpecs = query.getDimensions() .stream() @@ -258,7 +261,7 @@ public Sequence mergeResults( granularity = timestampResultFieldGranularity; // when timestampResultField is the last dimension, should set sortByDimsFirst=true, // otherwise the downstream is sorted by row's timestamp first which makes the final ordering not as expected - timestampResultFieldIndex = query.getQueryContext().getAsInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX); + timestampResultFieldIndex = queryContext.getInt(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD_INDEX); if (!query.getContextSortByDimsFirst() && timestampResultFieldIndex == query.getDimensions().size() - 1) { context.put(GroupByQuery.CTX_KEY_SORT_BY_DIMS_FIRST, true); } @@ -312,8 +315,8 @@ public Sequence mergeResults( // Apply postaggregators if this is the outermost mergeResults (CTX_KEY_OUTERMOST) and we are not executing a // pushed-down subquery (CTX_KEY_EXECUTING_NESTED_QUERY). - if (!query.getContextBoolean(CTX_KEY_OUTERMOST, true) - || query.getContextBoolean(GroupByQueryConfig.CTX_KEY_EXECUTING_NESTED_QUERY, false)) { + if (!queryContext.getBoolean(CTX_KEY_OUTERMOST, true) + || queryContext.getBoolean(GroupByQueryConfig.CTX_KEY_EXECUTING_NESTED_QUERY, false)) { return mergedResults; } else if (query.getPostAggregatorSpecs().isEmpty()) { if (!hasTimestampResultField) { @@ -405,7 +408,7 @@ private void moveOrReplicateTimestampInRow( public Sequence applyPostProcessing(Sequence results, GroupByQuery query) { // Don't apply limit here for inner results, that will be pushed down to the BufferHashGrouper - if (query.getContextBoolean(CTX_KEY_OUTERMOST, true)) { + if (query.context().getBoolean(CTX_KEY_OUTERMOST, true)) { return query.postProcess(results); } else { return results; diff --git a/processing/src/main/java/org/apache/druid/query/metadata/SegmentMetadataQueryRunnerFactory.java b/processing/src/main/java/org/apache/druid/query/metadata/SegmentMetadataQueryRunnerFactory.java index bc7dc9339b9e..827c3a86b07f 100644 --- a/processing/src/main/java/org/apache/druid/query/metadata/SegmentMetadataQueryRunnerFactory.java +++ b/processing/src/main/java/org/apache/druid/query/metadata/SegmentMetadataQueryRunnerFactory.java @@ -31,7 +31,7 @@ import org.apache.druid.query.AbstractPrioritizedQueryRunnerCallable; import org.apache.druid.query.ConcatQueryRunner; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryProcessingPool; @@ -205,7 +205,7 @@ public Sequence run( ) { final Query query = queryPlus.getQuery(); - final int priority = QueryContexts.getPriority(query); + final int priority = query.context().getPriority(); final QueryPlus threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState(); ListenableFuture> future = queryProcessingPool.submitRunnerTask( new AbstractPrioritizedQueryRunnerCallable, SegmentAnalysis>(priority, input) @@ -219,8 +219,9 @@ public Sequence call() ); try { queryWatcher.registerQueryFuture(query, future); - if (QueryContexts.hasTimeout(query)) { - return future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS); + final QueryContext context = query.context(); + if (context.hasTimeout()) { + return future.get(context.getTimeout(), TimeUnit.MILLISECONDS); } else { return future.get(); } diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java index 41273907623e..57f32bc4398e 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java @@ -264,7 +264,7 @@ public static void verifyOrderByForNativeExecution(final ScanQuery query) private Integer validateAndGetMaxRowsQueuedForOrdering() { final Integer maxRowsQueuedForOrdering = - getQueryContext().getAsInt(ScanQueryConfig.CTX_KEY_MAX_ROWS_QUEUED_FOR_ORDERING); + context().getInt(ScanQueryConfig.CTX_KEY_MAX_ROWS_QUEUED_FOR_ORDERING); Preconditions.checkArgument( maxRowsQueuedForOrdering == null || maxRowsQueuedForOrdering > 0, "maxRowsQueuedForOrdering must be greater than 0" @@ -275,7 +275,7 @@ private Integer validateAndGetMaxRowsQueuedForOrdering() private Integer validateAndGetMaxSegmentPartitionsOrderedInMemory() { final Integer maxSegmentPartitionsOrderedInMemory = - getQueryContext().getAsInt(ScanQueryConfig.CTX_KEY_MAX_SEGMENT_PARTITIONS_FOR_ORDERING); + context().getInt(ScanQueryConfig.CTX_KEY_MAX_SEGMENT_PARTITIONS_FOR_ORDERING); Preconditions.checkArgument( maxSegmentPartitionsOrderedInMemory == null || maxSegmentPartitionsOrderedInMemory > 0, "maxRowsQueuedForOrdering must be greater than 0" diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java index b61ffa4bf4a0..56b9087793c7 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java @@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.guava.BaseSequence; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequences; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.query.context.ResponseContext; @@ -46,6 +45,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -78,7 +78,7 @@ public Sequence process( if (numScannedRows != null && numScannedRows >= query.getScanRowsLimit() && query.getTimeOrder().equals(ScanQuery.Order.NONE)) { return Sequences.empty(); } - final boolean hasTimeout = QueryContexts.hasTimeout(query); + final boolean hasTimeout = query.context().hasTimeout(); final Long timeoutAt = responseContext.getTimeoutTime(); final StorageAdapter adapter = segment.asStorageAdapter(); diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryLimitRowIterator.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryLimitRowIterator.java index 6a081b654865..ee90ca17a38b 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryLimitRowIterator.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryLimitRowIterator.java @@ -99,7 +99,7 @@ public ScanResultValue next() // We want to perform multi-event ScanResultValue limiting if we are not time-ordering or are at the // inner-level if we are time-ordering if (query.getTimeOrder() == ScanQuery.Order.NONE || - !query.getContextBoolean(ScanQuery.CTX_KEY_OUTERMOST, true)) { + !query.context().getBoolean(ScanQuery.CTX_KEY_OUTERMOST, true)) { ScanResultValue batch = yielder.get(); List events = (List) batch.getEvents(); if (events.size() <= limit - count) { diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryRunnerFactory.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryRunnerFactory.java index 8aec07679b3a..0013c4f84fcd 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryRunnerFactory.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryRunnerFactory.java @@ -33,7 +33,6 @@ import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryRunner; @@ -94,7 +93,7 @@ public QueryRunner mergeRunners( // Note: this variable is effective only when queryContext has a timeout. // See the comment of ResponseContext.Key.TIMEOUT_AT. - final long timeoutAt = System.currentTimeMillis() + QueryContexts.getTimeout(queryPlus.getQuery()); + final long timeoutAt = System.currentTimeMillis() + queryPlus.getQuery().context().getTimeout(); responseContext.putTimeoutTime(timeoutAt); if (query.getTimeOrder().equals(ScanQuery.Order.NONE)) { diff --git a/processing/src/main/java/org/apache/druid/query/search/SearchQueryConfig.java b/processing/src/main/java/org/apache/druid/query/search/SearchQueryConfig.java index c7b6f201f88c..74a3c1465d91 100644 --- a/processing/src/main/java/org/apache/druid/query/search/SearchQueryConfig.java +++ b/processing/src/main/java/org/apache/druid/query/search/SearchQueryConfig.java @@ -55,7 +55,7 @@ public SearchQueryConfig withOverrides(final SearchQuery query) { final SearchQueryConfig newConfig = new SearchQueryConfig(); newConfig.maxSearchLimit = query.getLimit(); - newConfig.searchStrategy = query.getQueryContext().getAsString(CTX_KEY_STRATEGY, searchStrategy); + newConfig.searchStrategy = query.context().getString(CTX_KEY_STRATEGY, searchStrategy); return newConfig; } } diff --git a/processing/src/main/java/org/apache/druid/query/search/SearchQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/search/SearchQueryQueryToolChest.java index ff3c5b8e0166..06459d7073a2 100644 --- a/processing/src/main/java/org/apache/druid/query/search/SearchQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/search/SearchQueryQueryToolChest.java @@ -34,7 +34,6 @@ import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryToolChest; @@ -46,6 +45,7 @@ import org.apache.druid.query.dimension.DimensionSpec; import javax.annotation.Nullable; + import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -329,7 +329,7 @@ public Sequence> run( return runner.run(queryPlus, responseContext); } - final boolean isBySegment = QueryContexts.isBySegment(query); + final boolean isBySegment = query.context().isBySegment(); return Sequences.map( runner.run(queryPlus.withQuery(query.withLimit(config.getMaxSearchLimit())), responseContext), diff --git a/processing/src/main/java/org/apache/druid/query/select/SelectQuery.java b/processing/src/main/java/org/apache/druid/query/select/SelectQuery.java index e6a9ec197fd3..606ccbeb3573 100644 --- a/processing/src/main/java/org/apache/druid/query/select/SelectQuery.java +++ b/processing/src/main/java/org/apache/druid/query/select/SelectQuery.java @@ -24,7 +24,6 @@ import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.query.DataSource; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.filter.DimFilter; @@ -34,6 +33,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.util.List; import java.util.Map; @@ -110,12 +110,6 @@ public Map getContext() throw new RuntimeException(REMOVED_ERROR_MESSAGE); } - @Override - public QueryContext getQueryContext() - { - throw new RuntimeException(REMOVED_ERROR_MESSAGE); - } - @Override public boolean isDescending() { diff --git a/processing/src/main/java/org/apache/druid/query/spec/SpecificSegmentQueryRunner.java b/processing/src/main/java/org/apache/druid/query/spec/SpecificSegmentQueryRunner.java index 95217cc6dfc3..94f13eb58c74 100644 --- a/processing/src/main/java/org/apache/druid/query/spec/SpecificSegmentQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/spec/SpecificSegmentQueryRunner.java @@ -68,7 +68,7 @@ public Sequence run(final QueryPlus input, final ResponseContext responseC ) ); - final boolean setName = input.getQuery().getContextBoolean(CTX_SET_THREAD_NAME, true); + final boolean setName = input.getQuery().context().getBoolean(CTX_SET_THREAD_NAME, true); final Query query = queryPlus.getQuery(); diff --git a/processing/src/main/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryQueryToolChest.java index a5e58d094616..1bb109ef04f7 100644 --- a/processing/src/main/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryQueryToolChest.java @@ -35,6 +35,7 @@ import org.apache.druid.query.DefaultGenericQueryMetricsFactory; import org.apache.druid.query.GenericQueryMetricsFactory; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; @@ -232,9 +233,10 @@ public RowSignature resultArraySignature(TimeBoundaryQuery query) { if (query.isMinTime() || query.isMaxTime()) { RowSignature.Builder builder = RowSignature.builder(); + final QueryContext queryContext = query.context(); String outputName = query.isMinTime() ? - query.getQueryContext().getAsString(TimeBoundaryQuery.MIN_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MIN_TIME) : - query.getQueryContext().getAsString(TimeBoundaryQuery.MAX_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MAX_TIME); + queryContext.getString(TimeBoundaryQuery.MIN_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MIN_TIME) : + queryContext.getString(TimeBoundaryQuery.MAX_TIME_ARRAY_OUTPUT_NAME, TimeBoundaryQuery.MAX_TIME); return builder.add(outputName, ColumnType.LONG).build(); } return super.resultArraySignature(query); diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java index 291428f5c2a3..52fddccdf7b5 100644 --- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java +++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java @@ -154,17 +154,17 @@ public int getLimit() public boolean isGrandTotal() { - return getContextBoolean(CTX_GRAND_TOTAL, false); + return context().getBoolean(CTX_GRAND_TOTAL, false); } public String getTimestampResultField() { - return getQueryContext().getAsString(CTX_TIMESTAMP_RESULT_FIELD); + return context().getString(CTX_TIMESTAMP_RESULT_FIELD); } public boolean isSkipEmptyBuckets() { - return getContextBoolean(SKIP_EMPTY_BUCKETS, false); + return context().getBoolean(SKIP_EMPTY_BUCKETS, false); } @Nullable diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java index f65020a8b926..afb23523e5b4 100644 --- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java +++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java @@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.io.Closer; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryRunnerHelper; import org.apache.druid.query.Result; import org.apache.druid.query.aggregation.Aggregator; @@ -49,6 +48,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.nio.ByteBuffer; import java.util.Collections; import java.util.List; @@ -101,7 +101,7 @@ public Sequence> process( final ColumnInspector inspector = query.getVirtualColumns().wrapInspector(adapter); - final boolean doVectorize = QueryContexts.getVectorize(query).shouldVectorize( + final boolean doVectorize = query.context().getVectorize().shouldVectorize( adapter.canVectorize(filter, query.getVirtualColumns(), descending) && VirtualColumns.shouldVectorize(query, query.getVirtualColumns(), adapter) && query.getAggregatorSpecs().stream().allMatch(aggregatorFactory -> aggregatorFactory.canVectorize(inspector)) @@ -141,7 +141,7 @@ private Sequence> processVectorized( queryInterval, query.getVirtualColumns(), descending, - QueryContexts.getVectorSize(query), + query.context().getVectorSize(), timeseriesQueryMetrics ); diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java index 82f802bf988d..5a4417aa719f 100644 --- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryQueryToolChest.java @@ -37,7 +37,6 @@ import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryToolChest; @@ -147,7 +146,7 @@ public Sequence> doRun( !query.isSkipEmptyBuckets() && // Returns empty sequence if bySegment is set because bySegment results are mostly used for // caching in historicals or debugging where the exact results are preferred. - !QueryContexts.isBySegment(query)) { + !query.context().isBySegment()) { // Usally it is NOT Okay to materialize results via toList(), but Granularity is ALL thus // we have only one record. final List> val = baseResults.toList(); diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java index 657bb8931fb9..50b8a30d1028 100644 --- a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java +++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryEngine.java @@ -138,7 +138,7 @@ private TopNMapFn getMapFn( // if sorted by dimension we should aggregate all metrics in a single pass, use the regular pooled algorithm for // this topNAlgorithm = new PooledTopNAlgorithm(adapter, query, bufferPool); - } else if (selector.isAggregateTopNMetricFirst() || query.getContextBoolean("doAggregateTopNMetricFirst", false)) { + } else if (selector.isAggregateTopNMetricFirst() || query.context().getBoolean("doAggregateTopNMetricFirst", false)) { // for high cardinality dimensions with larger result sets we aggregate with only the ordering aggregation to // compute the first 'n' values, and then for the rest of the metrics but for only the 'n' values topNAlgorithm = new AggregateTopNMetricFirstAlgorithm(adapter, query, bufferPool); diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java index f6fa421719da..a7785a01be1a 100644 --- a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryQueryToolChest.java @@ -574,12 +574,12 @@ public Sequence> run( } final TopNQuery query = (TopNQuery) input; - final int minTopNThreshold = query.getQueryContext().getAsInt("minTopNThreshold", config.getMinTopNThreshold()); + final int minTopNThreshold = query.context().getInt(QueryContexts.MIN_TOP_N_THRESHOLD, config.getMinTopNThreshold()); if (query.getThreshold() > minTopNThreshold) { return runner.run(queryPlus, responseContext); } - final boolean isBySegment = QueryContexts.isBySegment(query); + final boolean isBySegment = query.context().isBySegment(); return Sequences.map( runner.run(queryPlus.withQuery(query.withThreshold(minTopNThreshold)), responseContext), diff --git a/processing/src/main/java/org/apache/druid/segment/VirtualColumns.java b/processing/src/main/java/org/apache/druid/segment/VirtualColumns.java index c3815286d90f..82f46370fb66 100644 --- a/processing/src/main/java/org/apache/druid/segment/VirtualColumns.java +++ b/processing/src/main/java/org/apache/druid/segment/VirtualColumns.java @@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Pair; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.column.ColumnCapabilities; @@ -48,6 +47,7 @@ import org.apache.druid.segment.virtual.VirtualizedColumnSelectorFactory; import javax.annotation.Nullable; + import java.util.HashMap; import java.util.List; import java.util.Map; @@ -120,7 +120,7 @@ public static VirtualColumns nullToEmpty(@Nullable VirtualColumns virtualColumns public static boolean shouldVectorize(Query query, VirtualColumns virtualColumns, ColumnInspector inspector) { if (virtualColumns.getVirtualColumns().length > 0) { - return QueryContexts.getVectorizeVirtualColumns(query).shouldVectorize(virtualColumns.canVectorize(inspector)); + return query.context().getVectorizeVirtualColumns().shouldVectorize(virtualColumns.canVectorize(inspector)); } else { return true; } diff --git a/processing/src/main/java/org/apache/druid/segment/filter/Filters.java b/processing/src/main/java/org/apache/druid/segment/filter/Filters.java index 9c91cd491c4c..7f9abfae2fa0 100644 --- a/processing/src/main/java/org/apache/druid/segment/filter/Filters.java +++ b/processing/src/main/java/org/apache/druid/segment/filter/Filters.java @@ -215,7 +215,7 @@ public static Filter convertToCNFFromQueryContext(Query query, @Nullable Filter if (filter == null) { return null; } - boolean useCNF = query.getContextBoolean(QueryContexts.USE_FILTER_CNF_KEY, QueryContexts.DEFAULT_USE_FILTER_CNF); + boolean useCNF = query.context().getBoolean(QueryContexts.USE_FILTER_CNF_KEY, QueryContexts.DEFAULT_USE_FILTER_CNF); try { return useCNF ? Filters.toCnf(filter) : filter; } diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java b/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java index 88bf00bf4e4f..7e5eca79b846 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java @@ -20,7 +20,7 @@ package org.apache.druid.segment.join.filter.rewrite; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.QueryContext; import java.util.Objects; @@ -76,12 +76,13 @@ public JoinFilterRewriteConfig( public static JoinFilterRewriteConfig forQuery(final Query query) { + QueryContext context = query.context(); return new JoinFilterRewriteConfig( - QueryContexts.getEnableJoinFilterPushDown(query), - QueryContexts.getEnableJoinFilterRewrite(query), - QueryContexts.getEnableJoinFilterRewriteValueColumnFilters(query), - QueryContexts.getEnableRewriteJoinToFilter(query), - QueryContexts.getJoinFilterRewriteMaxSize(query) + context.getEnableJoinFilterPushDown(), + context.getEnableJoinFilterRewrite(), + context.getEnableJoinFilterRewriteValueColumnFilters(), + context.getEnableRewriteJoinToFilter(), + context.getJoinFilterRewriteMaxSize() ); } diff --git a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java index 77d70c66412f..74be113e1b7a 100644 --- a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java +++ b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java @@ -24,7 +24,6 @@ import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.Warning; import org.apache.druid.java.util.common.HumanReadableBytes; -import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; @@ -38,6 +37,7 @@ import org.junit.Test; import javax.annotation.Nullable; + import java.util.Collections; import java.util.List; import java.util.Map; @@ -51,63 +51,61 @@ public void testEquals() .suppress(Warning.NONFINAL_FIELDS, Warning.ALL_FIELDS_SHOULD_BE_USED) .usingGetClass() .forClass(QueryContext.class) - .withNonnullFields("defaultParams", "userParams", "systemParams") + .withNonnullFields("context") .verify(); } @Test public void testEmptyParam() { - final QueryContext context = new QueryContext(); - Assert.assertEquals(ImmutableMap.of(), context.getMergedParams()); + final QueryContext context = QueryContext.empty(); + Assert.assertEquals(ImmutableMap.of(), context.getContext()); } @Test public void testIsEmpty() { - Assert.assertTrue(new QueryContext().isEmpty()); - Assert.assertFalse(new QueryContext(ImmutableMap.of("k", "v")).isEmpty()); - QueryContext context = new QueryContext(); - context.addDefaultParam("k", "v"); - Assert.assertFalse(context.isEmpty()); - context = new QueryContext(); - context.addSystemParam("k", "v"); - Assert.assertFalse(context.isEmpty()); + Assert.assertTrue(QueryContext.empty().isEmpty()); + Assert.assertFalse(QueryContext.of(ImmutableMap.of("k", "v")).isEmpty()); } @Test public void testGetString() { - final QueryContext context = new QueryContext( + final QueryContext context = QueryContext.of( ImmutableMap.of("key", "val", "key2", 2) ); Assert.assertEquals("val", context.get("key")); - Assert.assertEquals("val", context.getAsString("key")); - Assert.assertEquals("2", context.getAsString("key2")); - Assert.assertNull(context.getAsString("non-exist")); + Assert.assertEquals("val", context.getString("key")); + Assert.assertNull(context.getString("non-exist")); + Assert.assertEquals("foo", context.getString("non-exist", "foo")); + + Assert.assertThrows(BadQueryContextException.class, () -> context.getString("key2")); } @Test public void testGetBoolean() { - final QueryContext context = new QueryContext( + final QueryContext context = QueryContext.of( ImmutableMap.of( "key1", "true", "key2", true ) ); - Assert.assertTrue(context.getAsBoolean("key1", false)); - Assert.assertTrue(context.getAsBoolean("key2", false)); - Assert.assertFalse(context.getAsBoolean("non-exist", false)); + Assert.assertTrue(context.getBoolean("key1", false)); + Assert.assertTrue(context.getBoolean("key2", false)); + Assert.assertTrue(context.getBoolean("key1")); + Assert.assertFalse(context.getBoolean("non-exist", false)); + Assert.assertNull(context.getBoolean("non-exist")); } @Test public void testGetInt() { - final QueryContext context = new QueryContext( + final QueryContext context = QueryContext.of( ImmutableMap.of( "key1", "100", "key2", 100, @@ -115,17 +113,17 @@ public void testGetInt() ) ); - Assert.assertEquals(100, context.getAsInt("key1", 0)); - Assert.assertEquals(100, context.getAsInt("key2", 0)); - Assert.assertEquals(0, context.getAsInt("non-exist", 0)); + Assert.assertEquals(100, context.getInt("key1", 0)); + Assert.assertEquals(100, context.getInt("key2", 0)); + Assert.assertEquals(0, context.getInt("non-exist", 0)); - Assert.assertThrows(IAE.class, () -> context.getAsInt("key3", 5)); + Assert.assertThrows(BadQueryContextException.class, () -> context.getInt("key3", 5)); } @Test public void testGetLong() { - final QueryContext context = new QueryContext( + final QueryContext context = QueryContext.of( ImmutableMap.of( "key1", "100", "key2", 100, @@ -133,17 +131,17 @@ public void testGetLong() ) ); - Assert.assertEquals(100L, context.getAsLong("key1", 0)); - Assert.assertEquals(100L, context.getAsLong("key2", 0)); - Assert.assertEquals(0L, context.getAsLong("non-exist", 0)); + Assert.assertEquals(100L, context.getLong("key1", 0)); + Assert.assertEquals(100L, context.getLong("key2", 0)); + Assert.assertEquals(0L, context.getLong("non-exist", 0)); - Assert.assertThrows(IAE.class, () -> context.getAsLong("key3", 5)); + Assert.assertThrows(BadQueryContextException.class, () -> context.getLong("key3", 5)); } @Test public void testGetFloat() { - final QueryContext context = new QueryContext( + final QueryContext context = QueryContext.of( ImmutableMap.of( "f1", "500", "f2", 500, @@ -152,11 +150,11 @@ public void testGetFloat() ) ); - Assert.assertEquals(0, Float.compare(500, context.getAsFloat("f1", 100))); - Assert.assertEquals(0, Float.compare(500, context.getAsFloat("f2", 100))); - Assert.assertEquals(0, Float.compare(500.1f, context.getAsFloat("f3", 100))); + Assert.assertEquals(0, Float.compare(500, context.getFloat("f1", 100))); + Assert.assertEquals(0, Float.compare(500, context.getFloat("f2", 100))); + Assert.assertEquals(0, Float.compare(500.1f, context.getFloat("f3", 100))); - Assert.assertThrows(IAE.class, () -> context.getAsLong("f4", 5)); + Assert.assertThrows(BadQueryContextException.class, () -> context.getFloat("f4", 5)); } @Test @@ -172,167 +170,30 @@ public void testGetHumanReadableBytes() .put("m6", "abc") .build() ); - Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m1", HumanReadableBytes.ZERO).getBytes()); - Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m2", HumanReadableBytes.ZERO).getBytes()); - Assert.assertEquals(500 * 1024 * 1024L, context.getAsHumanReadableBytes("m3", HumanReadableBytes.ZERO).getBytes()); - Assert.assertEquals(500 * 1024 * 1024L, context.getAsHumanReadableBytes("m4", HumanReadableBytes.ZERO).getBytes()); - Assert.assertEquals(500_000_000, context.getAsHumanReadableBytes("m5", HumanReadableBytes.ZERO).getBytes()); - - Assert.assertThrows(IAE.class, () -> context.getAsHumanReadableBytes("m6", HumanReadableBytes.ZERO)); - } - - @Test - public void testAddSystemParamOverrideUserParam() - { - final QueryContext context = new QueryContext( - ImmutableMap.of( - "user1", "userVal1", - "conflict", "userVal2" - ) - ); - context.addSystemParam("sys1", "sysVal1"); - context.addSystemParam("conflict", "sysVal2"); - - Assert.assertEquals( - ImmutableMap.of( - "user1", "userVal1", - "conflict", "userVal2" - ), - context.getUserParams() - ); - - Assert.assertEquals( - ImmutableMap.of( - "user1", "userVal1", - "sys1", "sysVal1", - "conflict", "sysVal2" - ), - context.getMergedParams() - ); - } - - @Test - public void testUserParamOverrideDefaultParam() - { - final QueryContext context = new QueryContext( - ImmutableMap.of( - "user1", "userVal1", - "conflict", "userVal2" - ) - ); - context.addDefaultParams( - ImmutableMap.of( - "default1", "defaultVal1" - ) - ); - context.addDefaultParam("conflict", "defaultVal2"); - - Assert.assertEquals( - ImmutableMap.of( - "user1", "userVal1", - "conflict", "userVal2" - ), - context.getUserParams() - ); - - Assert.assertEquals( - ImmutableMap.of( - "user1", "userVal1", - "default1", "defaultVal1", - "conflict", "userVal2" - ), - context.getMergedParams() - ); - } - - @Test - public void testRemoveUserParam() - { - final QueryContext context = new QueryContext( - ImmutableMap.of( - "user1", "userVal1", - "conflict", "userVal2" - ) - ); - context.addDefaultParams( - ImmutableMap.of( - "default1", "defaultVal1", - "conflict", "defaultVal2" - ) - ); + Assert.assertEquals(500_000_000, context.getHumanReadableBytes("m1", HumanReadableBytes.ZERO).getBytes()); + Assert.assertEquals(500_000_000, context.getHumanReadableBytes("m2", HumanReadableBytes.ZERO).getBytes()); + Assert.assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m3", HumanReadableBytes.ZERO).getBytes()); + Assert.assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m4", HumanReadableBytes.ZERO).getBytes()); + Assert.assertEquals(500_000_000, context.getHumanReadableBytes("m5", HumanReadableBytes.ZERO).getBytes()); - Assert.assertEquals( - ImmutableMap.of( - "user1", "userVal1", - "default1", "defaultVal1", - "conflict", "userVal2" - ), - context.getMergedParams() - ); - Assert.assertEquals("userVal2", context.removeUserParam("conflict")); - Assert.assertEquals( - ImmutableMap.of( - "user1", "userVal1", - "default1", "defaultVal1", - "conflict", "defaultVal2" - ), - context.getMergedParams() - ); + Assert.assertThrows(BadQueryContextException.class, () -> context.getHumanReadableBytes("m6", HumanReadableBytes.ZERO)); } @Test - public void testGetMergedParams() + public void testDefaultEnableQueryDebugging() { - final QueryContext context = new QueryContext( - ImmutableMap.of( - "user1", "userVal1", - "conflict", "userVal2" - ) - ); - context.addDefaultParams( - ImmutableMap.of( - "default1", "defaultVal1", - "conflict", "defaultVal2" - ) - ); - - Assert.assertSame(context.getMergedParams(), context.getMergedParams()); - } - - @Test - public void testCopy() - { - final QueryContext context = new QueryContext( - ImmutableMap.of( - "user1", "userVal1", - "conflict", "userVal2" - ) - ); - - context.addDefaultParams( - ImmutableMap.of( - "default1", "defaultVal1", - "conflict", "defaultVal2" - ) - ); - - context.addSystemParam("sys1", "val1"); - - final Map merged = ImmutableMap.copyOf(context.getMergedParams()); - - final QueryContext context2 = context.copy(); - context2.removeUserParam("conflict"); - context2.addSystemParam("sys2", "val2"); - context2.addDefaultParam("default3", "defaultVal3"); - - Assert.assertEquals(merged, context.getMergedParams()); + Assert.assertFalse(QueryContext.empty().isDebug()); + Assert.assertTrue(QueryContext.of(ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true)).isDebug()); } + // This test is a bit silly. It is retained because another test uses the + // LegacyContextQuery test. @Test public void testLegacyReturnsLegacy() { - Query legacy = new LegacyContextQuery(ImmutableMap.of("foo", "bar")); - Assert.assertNull(legacy.getQueryContext()); + Map context = ImmutableMap.of("foo", "bar"); + Query legacy = new LegacyContextQuery(context); + Assert.assertEquals(context, legacy.getContext()); } @Test @@ -345,10 +206,10 @@ public void testNonLegacyIsNotLegacyContext() .aggregators(Collections.singletonList(new CountAggregatorFactory("theCount"))) .context(ImmutableMap.of("foo", "bar")) .build(); - Assert.assertNotNull(timeseries.getQueryContext()); + Assert.assertNotNull(timeseries.getContext()); } - public static class LegacyContextQuery implements Query + public static class LegacyContextQuery implements Query { private final Map context; @@ -382,9 +243,9 @@ public String getType() } @Override - public QueryRunner getRunner(QuerySegmentWalker walker) + public QueryRunner getRunner(QuerySegmentWalker walker) { - return new NoopQueryRunner(); + return new NoopQueryRunner<>(); } @Override @@ -417,31 +278,6 @@ public Map getContext() return context; } - @Override - public boolean getContextBoolean(String key, boolean defaultValue) - { - if (context == null || !context.containsKey(key)) { - return defaultValue; - } - return (boolean) context.get(key); - } - - @Override - public HumanReadableBytes getContextAsHumanReadableBytes(String key, HumanReadableBytes defaultValue) - { - if (null == context || !context.containsKey(key)) { - return defaultValue; - } - Object value = context.get(key); - if (value instanceof Number) { - return HumanReadableBytes.valueOf(((Number) value).longValue()); - } else if (value instanceof String) { - return new HumanReadableBytes((String) value); - } else { - throw new IAE("Expected parameter [%s] to be in human readable format", key); - } - } - @Override public boolean isDescending() { @@ -449,19 +285,19 @@ public boolean isDescending() } @Override - public Ordering getResultOrdering() + public Ordering getResultOrdering() { return Ordering.natural(); } @Override - public Query withQuerySegmentSpec(QuerySegmentSpec spec) + public Query withQuerySegmentSpec(QuerySegmentSpec spec) { return new LegacyContextQuery(context); } @Override - public Query withId(String id) + public Query withId(String id) { context.put(BaseQuery.QUERY_ID, id); return this; @@ -475,7 +311,7 @@ public String getId() } @Override - public Query withSubQueryId(String subQueryId) + public Query withSubQueryId(String subQueryId) { context.put(BaseQuery.SUB_QUERY_ID, subQueryId); return this; @@ -489,21 +325,15 @@ public String getSubQueryId() } @Override - public Query withDataSource(DataSource dataSource) + public Query withDataSource(DataSource dataSource) { return this; } @Override - public Query withOverriddenContext(Map contextOverride) + public Query withOverriddenContext(Map contextOverride) { return new LegacyContextQuery(contextOverride); } - - @Override - public Object getContextValue(String key) - { - return context.get(key); - } } } diff --git a/processing/src/test/java/org/apache/druid/query/QueryContextsTest.java b/processing/src/test/java/org/apache/druid/query/QueryContextsTest.java index 7431022240d3..d47bb558fe97 100644 --- a/processing/src/test/java/org/apache/druid/query/QueryContextsTest.java +++ b/processing/src/test/java/org/apache/druid/query/QueryContextsTest.java @@ -22,7 +22,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.druid.java.util.common.HumanReadableBytes; -import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.junit.Assert; @@ -47,7 +46,7 @@ public void testDefaultQueryTimeout() false, new HashMap<>() ); - Assert.assertEquals(300_000, QueryContexts.getDefaultTimeout(query)); + Assert.assertEquals(300_000, query.context().getDefaultTimeout()); } @Test @@ -59,10 +58,10 @@ public void testEmptyQueryTimeout() false, new HashMap<>() ); - Assert.assertEquals(300_000, QueryContexts.getTimeout(query)); + Assert.assertEquals(300_000, query.context().getTimeout()); - query = QueryContexts.withDefaultTimeout(query, 60_000); - Assert.assertEquals(60_000, QueryContexts.getTimeout(query)); + query = Queries.withDefaultTimeout(query, 60_000); + Assert.assertEquals(60_000, query.context().getTimeout()); } @Test @@ -74,17 +73,17 @@ public void testQueryTimeout() false, ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1000) ); - Assert.assertEquals(1000, QueryContexts.getTimeout(query)); + Assert.assertEquals(1000, query.context().getTimeout()); - query = QueryContexts.withDefaultTimeout(query, 1_000_000); - Assert.assertEquals(1000, QueryContexts.getTimeout(query)); + query = Queries.withDefaultTimeout(query, 1_000_000); + Assert.assertEquals(1000, query.context().getTimeout()); } @Test public void testQueryMaxTimeout() { - exception.expect(IAE.class); - exception.expectMessage("configured [timeout = 1000] is more than enforced limit of maxQueryTimeout [100]."); + exception.expect(BadQueryContextException.class); + exception.expectMessage("Configured timeout = 1000 is more than enforced limit of 100."); Query query = new TestQuery( new TableDataSource("test"), new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))), @@ -92,14 +91,14 @@ public void testQueryMaxTimeout() ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1000) ); - QueryContexts.verifyMaxQueryTimeout(query, 100); + query.context().verifyMaxQueryTimeout(100); } @Test public void testMaxScatterGatherBytes() { - exception.expect(IAE.class); - exception.expectMessage("configured [maxScatterGatherBytes = 1000] is more than enforced limit of [100]."); + exception.expect(BadQueryContextException.class); + exception.expectMessage("Configured maxScatterGatherBytes = 1000 is more than enforced limit of 100."); Query query = new TestQuery( new TableDataSource("test"), new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))), @@ -107,7 +106,7 @@ public void testMaxScatterGatherBytes() ImmutableMap.of(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, 1000) ); - QueryContexts.withMaxScatterGatherBytes(query, 100); + Queries.withMaxScatterGatherBytes(query, 100); } @Test @@ -119,7 +118,7 @@ public void testDisableSegmentPruning() false, ImmutableMap.of(QueryContexts.SECONDARY_PARTITION_PRUNING_KEY, false) ); - Assert.assertFalse(QueryContexts.isSecondaryPartitionPruningEnabled(query)); + Assert.assertFalse(query.context().isSecondaryPartitionPruningEnabled()); } @Test @@ -131,7 +130,7 @@ public void testDefaultSegmentPruning() false, ImmutableMap.of() ); - Assert.assertTrue(QueryContexts.isSecondaryPartitionPruningEnabled(query)); + Assert.assertTrue(query.context().isSecondaryPartitionPruningEnabled()); } @Test @@ -139,7 +138,7 @@ public void testDefaultInSubQueryThreshold() { Assert.assertEquals( QueryContexts.DEFAULT_IN_SUB_QUERY_THRESHOLD, - QueryContexts.getInSubQueryThreshold(ImmutableMap.of()) + QueryContext.empty().getInSubQueryThreshold() ); } @@ -148,32 +147,32 @@ public void testDefaultPlanTimeBoundarySql() { Assert.assertEquals( QueryContexts.DEFAULT_ENABLE_TIME_BOUNDARY_PLANNING, - QueryContexts.isTimeBoundaryPlanningEnabled(ImmutableMap.of()) + QueryContext.empty().isTimeBoundaryPlanningEnabled() ); } @Test public void testGetEnableJoinLeftScanDirect() { - Assert.assertFalse(QueryContexts.getEnableJoinLeftScanDirect(ImmutableMap.of())); - Assert.assertTrue(QueryContexts.getEnableJoinLeftScanDirect(ImmutableMap.of( + Assert.assertFalse(QueryContext.empty().getEnableJoinLeftScanDirect()); + Assert.assertTrue(QueryContext.of(ImmutableMap.of( QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT, true - ))); - Assert.assertFalse(QueryContexts.getEnableJoinLeftScanDirect(ImmutableMap.of( + )).getEnableJoinLeftScanDirect()); + Assert.assertFalse(QueryContext.of(ImmutableMap.of( QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT, false - ))); + )).getEnableJoinLeftScanDirect()); } @Test public void testGetBrokerServiceName() { Map queryContext = new HashMap<>(); - Assert.assertNull(QueryContexts.getBrokerServiceName(queryContext)); + Assert.assertNull(QueryContext.of(queryContext).getBrokerServiceName()); queryContext.put(QueryContexts.BROKER_SERVICE_NAME, "hotBroker"); - Assert.assertEquals("hotBroker", QueryContexts.getBrokerServiceName(queryContext)); + Assert.assertEquals("hotBroker", QueryContext.of(queryContext).getBrokerServiceName()); } @Test @@ -182,8 +181,8 @@ public void testGetBrokerServiceName_withNonStringValue() Map queryContext = new HashMap<>(); queryContext.put(QueryContexts.BROKER_SERVICE_NAME, 100); - exception.expect(ClassCastException.class); - QueryContexts.getBrokerServiceName(queryContext); + exception.expect(BadQueryContextException.class); + QueryContext.of(queryContext).getBrokerServiceName(); } @Test @@ -193,38 +192,12 @@ public void testGetTimeout_withNonNumericValue() queryContext.put(QueryContexts.TIMEOUT_KEY, "2000'"); exception.expect(BadQueryContextException.class); - QueryContexts.getTimeout(new TestQuery( + new TestQuery( new TableDataSource("test"), new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))), false, queryContext - )); - } - - @Test - public void testDefaultEnableQueryDebugging() - { - Query query = new TestQuery( - new TableDataSource("test"), - new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))), - false, - ImmutableMap.of() - ); - Assert.assertFalse(QueryContexts.isDebug(query)); - Assert.assertFalse(QueryContexts.isDebug(query.getContext())); - } - - @Test - public void testEnableQueryDebuggingSetToTrue() - { - Query query = new TestQuery( - new TableDataSource("test"), - new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))), - false, - ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true) - ); - Assert.assertTrue(QueryContexts.isDebug(query)); - Assert.assertTrue(QueryContexts.isDebug(query.getContext())); + ).context().getTimeout(); } @Test @@ -237,7 +210,7 @@ public void testGetAs() QueryContexts.getAsString("foo", 10, null); Assert.fail(); } - catch (IAE e) { + catch (BadQueryContextException e) { // Expected } @@ -249,7 +222,7 @@ public void testGetAs() QueryContexts.getAsBoolean("foo", 10, false); Assert.fail(); } - catch (IAE e) { + catch (BadQueryContextException e) { // Expected } @@ -262,7 +235,7 @@ public void testGetAs() QueryContexts.getAsInt("foo", true, 20); Assert.fail(); } - catch (IAE e) { + catch (BadQueryContextException e) { // Expected } @@ -275,7 +248,7 @@ public void testGetAs() QueryContexts.getAsLong("foo", true, 20); Assert.fail(); } - catch (IAE e) { + catch (BadQueryContextException e) { // Expected } } @@ -314,12 +287,12 @@ public void testGetEnum() Assert.assertEquals( QueryContexts.Vectorize.FORCE, - query.getQueryContext().getAsEnum("e1", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE) + query.context().getEnum("e1", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE) ); Assert.assertThrows( - IAE.class, - () -> query.getQueryContext().getAsEnum("e2", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE) + BadQueryContextException.class, + () -> query.context().getEnum("e2", QueryContexts.Vectorize.class, QueryContexts.Vectorize.FALSE) ); } } diff --git a/processing/src/test/java/org/apache/druid/query/datasourcemetadata/DataSourceMetadataQueryTest.java b/processing/src/test/java/org/apache/druid/query/datasourcemetadata/DataSourceMetadataQueryTest.java index 0b9717919d82..3f0c0401e9a9 100644 --- a/processing/src/test/java/org/apache/druid/query/datasourcemetadata/DataSourceMetadataQueryTest.java +++ b/processing/src/test/java/org/apache/druid/query/datasourcemetadata/DataSourceMetadataQueryTest.java @@ -31,6 +31,7 @@ import org.apache.druid.query.Druids; import org.apache.druid.query.GenericQueryMetricsFactory; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; @@ -102,13 +103,14 @@ public void testContextSerde() throws Exception ), Query.class ); - Assert.assertEquals((Integer) 1, serdeQuery.getQueryContext().getAsInt(QueryContexts.PRIORITY_KEY)); - Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.USE_CACHE_KEY)); - Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.POPULATE_CACHE_KEY)); - Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.FINALIZE_KEY)); - Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.USE_CACHE_KEY, false)); - Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.POPULATE_CACHE_KEY, false)); - Assert.assertEquals(true, serdeQuery.getContextBoolean(QueryContexts.FINALIZE_KEY, false)); + final QueryContext queryContext = serdeQuery.context(); + Assert.assertEquals(1, (int) queryContext.getInt(QueryContexts.PRIORITY_KEY)); + Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.USE_CACHE_KEY)); + Assert.assertEquals("true", queryContext.getString(QueryContexts.POPULATE_CACHE_KEY)); + Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.FINALIZE_KEY)); + Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.USE_CACHE_KEY, false)); + Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.POPULATE_CACHE_KEY, false)); + Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.FINALIZE_KEY, false)); } @Test diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngineIteratorTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngineIteratorTest.java index 1c36cdfc9fff..8bc83277c65b 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngineIteratorTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/vector/VectorGroupByEngineIteratorTest.java @@ -20,7 +20,6 @@ package org.apache.druid.query.groupby.epinephelinae.vector; import org.apache.commons.lang3.mutable.MutableObject; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryRunnerTestHelper; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; @@ -68,7 +67,7 @@ public void testCreateOneGrouperAndCloseItWhenClose() throws IOException interval, query.getVirtualColumns(), false, - QueryContexts.getVectorSize(query), + query.context().getVectorSize(), null ); final List dimensions = query.getDimensions().stream().map( diff --git a/processing/src/test/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryTest.java b/processing/src/test/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryTest.java index aaa293f4355c..d062be439a07 100644 --- a/processing/src/test/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryTest.java +++ b/processing/src/test/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryTest.java @@ -24,6 +24,7 @@ import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.query.Druids; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.junit.Assert; import org.junit.Test; @@ -78,10 +79,11 @@ public void testContextSerde() throws Exception ), TimeBoundaryQuery.class ); - Assert.assertEquals(new Integer(1), serdeQuery.getQueryContext().getAsInt(QueryContexts.PRIORITY_KEY)); - Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.USE_CACHE_KEY)); - Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.POPULATE_CACHE_KEY)); - Assert.assertEquals(true, serdeQuery.getQueryContext().getAsBoolean(QueryContexts.FINALIZE_KEY)); + final QueryContext queryContext = query.context(); + Assert.assertEquals(1, (int) queryContext.getInt(QueryContexts.PRIORITY_KEY)); + Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.USE_CACHE_KEY)); + Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.POPULATE_CACHE_KEY)); + Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.FINALIZE_KEY)); } @Test @@ -116,9 +118,10 @@ public void testContextSerde2() throws Exception ); - Assert.assertEquals("1", serdeQuery.getQueryContext().getAsString(QueryContexts.PRIORITY_KEY)); - Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.USE_CACHE_KEY)); - Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.POPULATE_CACHE_KEY)); - Assert.assertEquals("true", serdeQuery.getQueryContext().getAsString(QueryContexts.FINALIZE_KEY)); + final QueryContext queryContext = query.context(); + Assert.assertEquals("1", queryContext.get(QueryContexts.PRIORITY_KEY)); + Assert.assertEquals("true", queryContext.get(QueryContexts.USE_CACHE_KEY)); + Assert.assertEquals("true", queryContext.get(QueryContexts.POPULATE_CACHE_KEY)); + Assert.assertEquals("true", queryContext.get(QueryContexts.FINALIZE_KEY)); } } diff --git a/server/src/main/java/org/apache/druid/client/CacheUtil.java b/server/src/main/java/org/apache/druid/client/CacheUtil.java index 88d713c19aba..a5741ebd0db2 100644 --- a/server/src/main/java/org/apache/druid/client/CacheUtil.java +++ b/server/src/main/java/org/apache/druid/client/CacheUtil.java @@ -24,12 +24,12 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.SegmentDescriptor; import org.joda.time.Interval; import javax.annotation.Nullable; + import java.nio.ByteBuffer; public class CacheUtil @@ -109,7 +109,7 @@ public static boolean isUseSegmentCache( ) { return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType) - && QueryContexts.isUseCache(query) + && query.context().isUseCache() && cacheConfig.isUseCache(); } @@ -129,7 +129,7 @@ public static boolean isPopulateSegmentCache( ) { return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType) - && QueryContexts.isPopulateCache(query) + && query.context().isPopulateCache() && cacheConfig.isPopulateCache(); } @@ -149,7 +149,7 @@ public static boolean isUseResultCache( ) { return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType) - && QueryContexts.isUseResultLevelCache(query) + && query.context().isUseResultLevelCache() && cacheConfig.isUseResultLevelCache(); } @@ -169,7 +169,7 @@ public static boolean isPopulateResultCache( ) { return isQueryCacheable(query, cacheStrategy, cacheConfig, serverType) - && QueryContexts.isPopulateResultLevelCache(query) + && query.context().isPopulateResultLevelCache() && cacheConfig.isPopulateResultLevelCache(); } diff --git a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java index 69952e22ce3a..025d079ad33e 100644 --- a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java +++ b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java @@ -60,6 +60,7 @@ import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.Queries; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryPlus; @@ -282,10 +283,11 @@ private class SpecificQueryRunnable this.useCache = CacheUtil.isUseSegmentCache(query, strategy, cacheConfig, CacheUtil.ServerType.BROKER); this.populateCache = CacheUtil.isPopulateSegmentCache(query, strategy, cacheConfig, CacheUtil.ServerType.BROKER); - this.isBySegment = QueryContexts.isBySegment(query); + final QueryContext queryContext = query.context(); + this.isBySegment = queryContext.isBySegment(); // Note that enabling this leads to putting uncovered intervals information in the response headers // and might blow up in some cases https://github.com/apache/druid/issues/2108 - this.uncoveredIntervalsLimit = QueryContexts.getUncoveredIntervalsLimit(query); + this.uncoveredIntervalsLimit = queryContext.getUncoveredIntervalsLimit(); // For nested queries, we need to look at the intervals of the inner most query. this.intervals = dataSourceAnalysis.getBaseQuerySegmentSpec() .map(QuerySegmentSpec::getIntervals) @@ -304,9 +306,10 @@ private ImmutableMap makeDownstreamQueryContext() { final ImmutableMap.Builder contextBuilder = new ImmutableMap.Builder<>(); - final int priority = QueryContexts.getPriority(query); + final QueryContext queryContext = query.context(); + final int priority = queryContext.getPriority(); contextBuilder.put(QueryContexts.PRIORITY_KEY, priority); - final String lane = QueryContexts.getLane(query); + final String lane = queryContext.getLane(); if (lane != null) { contextBuilder.put(QueryContexts.LANE_KEY, lane); } @@ -384,18 +387,19 @@ ClusterQueryResult run( private Sequence merge(List> sequencesByInterval) { BinaryOperator mergeFn = toolChest.createMergeFn(query); - if (processingConfig.useParallelMergePool() && QueryContexts.getEnableParallelMerges(query) && mergeFn != null) { + final QueryContext queryContext = query.context(); + if (processingConfig.useParallelMergePool() && queryContext.getEnableParallelMerges() && mergeFn != null) { return new ParallelMergeCombiningSequence<>( pool, sequencesByInterval, query.getResultOrdering(), mergeFn, - QueryContexts.hasTimeout(query), - QueryContexts.getTimeout(query), - QueryContexts.getPriority(query), - QueryContexts.getParallelMergeParallelism(query, processingConfig.getMergePoolDefaultMaxQueryParallelism()), - QueryContexts.getParallelMergeInitialYieldRows(query, processingConfig.getMergePoolTaskInitialYieldRows()), - QueryContexts.getParallelMergeSmallBatchRows(query, processingConfig.getMergePoolSmallBatchRows()), + queryContext.hasTimeout(), + queryContext.getTimeout(), + queryContext.getPriority(), + queryContext.getParallelMergeParallelism(processingConfig.getMergePoolDefaultMaxQueryParallelism()), + queryContext.getParallelMergeInitialYieldRows(processingConfig.getMergePoolTaskInitialYieldRows()), + queryContext.getParallelMergeSmallBatchRows(processingConfig.getMergePoolSmallBatchRows()), processingConfig.getMergePoolTargetTaskRunTimeMillis(), reportMetrics -> { QueryMetrics queryMetrics = queryPlus.getQueryMetrics(); @@ -437,7 +441,7 @@ private Set computeSegmentsToQuery( // Filter unneeded chunks based on partition dimension for (TimelineObjectHolder holder : serversLookup) { final Set> filteredChunks; - if (QueryContexts.isSecondaryPartitionPruningEnabled(query)) { + if (query.context().isSecondaryPartitionPruningEnabled()) { filteredChunks = DimFilterUtils.filterShards( query.getFilter(), holder.getObject(), @@ -652,12 +656,12 @@ private void addSequencesFromServer( final QueryRunner serverRunner = serverView.getQueryRunner(server); if (serverRunner == null) { - log.error("Server[%s] doesn't have a query runner", server.getName()); + log.error("Server [%s] doesn't have a query runner", server.getName()); return; } // Divide user-provided maxQueuedBytes by the number of servers, and limit each server to that much. - final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, httpClientConfig.getMaxQueuedBytes()); + final long maxQueuedBytes = query.context().getMaxQueuedBytes(httpClientConfig.getMaxQueuedBytes()); final long maxQueuedBytesPerServer = maxQueuedBytes / segmentsByServer.size(); final Sequence serverResults; @@ -776,7 +780,7 @@ static class CacheKeyManager this.dataSourceAnalysis = dataSourceAnalysis; this.joinableFactoryWrapper = joinableFactoryWrapper; this.isSegmentLevelCachingEnable = ((populateCache || useCache) - && !QueryContexts.isBySegment(query)); // explicit bySegment queries are never cached + && !query.context().isBySegment()); // explicit bySegment queries are never cached } diff --git a/server/src/main/java/org/apache/druid/client/DirectDruidClient.java b/server/src/main/java/org/apache/druid/client/DirectDruidClient.java index 9c244bbdc731..f8257d443674 100644 --- a/server/src/main/java/org/apache/druid/client/DirectDruidClient.java +++ b/server/src/main/java/org/apache/druid/client/DirectDruidClient.java @@ -41,8 +41,9 @@ import org.apache.druid.java.util.http.client.response.HttpResponseHandler; import org.apache.druid.java.util.http.client.response.StatusResponseHandler; import org.apache.druid.java.util.http.client.response.StatusResponseHolder; +import org.apache.druid.query.Queries; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; @@ -152,7 +153,7 @@ public Sequence run(final QueryPlus queryPlus, final ResponseContext conte { final Query query = queryPlus.getQuery(); QueryToolChest> toolChest = warehouse.getToolChest(query); - boolean isBySegment = QueryContexts.isBySegment(query); + boolean isBySegment = query.context().isBySegment(); final JavaType queryResultType = isBySegment ? toolChest.getBySegmentResultType() : toolChest.getBaseResultType(); final ListenableFuture future; @@ -160,13 +161,15 @@ public Sequence run(final QueryPlus queryPlus, final ResponseContext conte final String cancelUrl = url + query.getId(); try { - log.debug("Querying queryId[%s] url[%s]", query.getId(), url); + log.debug("Querying queryId [%s] url [%s]", query.getId(), url); final long requestStartTimeNs = System.nanoTime(); - final long timeoutAt = query.getQueryContext().getAsLong(QUERY_FAIL_TIME); - final long maxScatterGatherBytes = QueryContexts.getMaxScatterGatherBytes(query); + final QueryContext queryContext = query.context(); + // Will NPE if the value is not set. + final long timeoutAt = queryContext.getLong(QUERY_FAIL_TIME); + final long maxScatterGatherBytes = queryContext.getMaxScatterGatherBytes(); final AtomicLong totalBytesGathered = context.getTotalBytes(); - final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, 0); + final long maxQueuedBytes = queryContext.getMaxQueuedBytes(0); final boolean usingBackpressure = maxQueuedBytes > 0; final HttpResponseHandler responseHandler = new HttpResponseHandler() @@ -454,7 +457,7 @@ private void checkTotalBytesLimit(long bytes) new Request( HttpMethod.POST, new URL(url) - ).setContent(objectMapper.writeValueAsBytes(QueryContexts.withTimeout(query, timeLeft))) + ).setContent(objectMapper.writeValueAsBytes(Queries.withTimeout(query, timeLeft))) .setHeader( HttpHeaders.Names.CONTENT_TYPE, isSmile ? SmileMediaTypes.APPLICATION_JACKSON_SMILE : MediaType.APPLICATION_JSON diff --git a/server/src/main/java/org/apache/druid/client/JsonParserIterator.java b/server/src/main/java/org/apache/druid/client/JsonParserIterator.java index 42834b0fbaf6..97c772ed1928 100644 --- a/server/src/main/java/org/apache/druid/client/JsonParserIterator.java +++ b/server/src/main/java/org/apache/druid/client/JsonParserIterator.java @@ -75,7 +75,7 @@ public JsonParserIterator( this.future = future; this.url = url; if (query != null) { - this.timeoutAt = query.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME, -1L); + this.timeoutAt = query.context().getLong(DirectDruidClient.QUERY_FAIL_TIME, -1L); this.queryId = query.getId(); } else { this.timeoutAt = -1; diff --git a/server/src/main/java/org/apache/druid/query/RetryQueryRunner.java b/server/src/main/java/org/apache/druid/query/RetryQueryRunner.java index 9302763d3e5f..00b9dd030c32 100644 --- a/server/src/main/java/org/apache/druid/query/RetryQueryRunner.java +++ b/server/src/main/java/org/apache/druid/query/RetryQueryRunner.java @@ -215,15 +215,15 @@ public boolean hasNext() if (sequence != null) { return true; } else { + final QueryContext queryContext = queryPlus.getQuery().context(); final List missingSegments = getMissingSegments(queryPlus, context); - final int maxNumRetries = QueryContexts.getNumRetriesOnMissingSegments( - queryPlus.getQuery(), + final int maxNumRetries = queryContext.getNumRetriesOnMissingSegments( config.getNumTries() ); if (missingSegments.isEmpty()) { return false; } else if (retryCount >= maxNumRetries) { - if (!QueryContexts.allowReturnPartialResults(queryPlus.getQuery(), config.isReturnPartialResults())) { + if (!queryContext.allowReturnPartialResults(config.isReturnPartialResults())) { throw new SegmentMissingException("No results found for segments[%s]", missingSegments); } else { return false; diff --git a/server/src/main/java/org/apache/druid/server/ClientQuerySegmentWalker.java b/server/src/main/java/org/apache/druid/server/ClientQuerySegmentWalker.java index 05cc77f1cd75..bc825d422202 100644 --- a/server/src/main/java/org/apache/druid/server/ClientQuerySegmentWalker.java +++ b/server/src/main/java/org/apache/druid/server/ClientQuerySegmentWalker.java @@ -39,7 +39,6 @@ import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.PostProcessingOperator; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; @@ -60,6 +59,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; @@ -163,7 +163,7 @@ public QueryRunner getQueryRunnerForIntervals(Query query, Iterable QueryRunner decorateClusterRunner(Query query, QueryRunner .emitCPUTimeMetric(emitter) .postProcess( objectMapper.convertValue( - query.getQueryContext().getAsString("postProcessing"), + query.context().getString("postProcessing"), new TypeReference>() {} ) ) diff --git a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java index 40e5267b80b3..6de72014307a 100644 --- a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java +++ b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.databind.ObjectWriter; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.common.collect.Iterables; import org.apache.druid.client.DirectDruidClient; import org.apache.druid.java.util.common.DateTimes; @@ -61,10 +62,10 @@ import javax.annotation.Nullable; import javax.servlet.http.HttpServletRequest; + import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; -import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -102,6 +103,8 @@ public class QueryLifecycle @MonotonicNonNull private Query baseQuery; + @MonotonicNonNull + private Map userContext; public QueryLifecycle( final QueryToolChestWarehouse warehouse, @@ -195,17 +198,15 @@ public void initialize(final Query baseQuery) { transition(State.NEW, State.INITIALIZED); - if (baseQuery.getQueryContext() == null) { - QueryContext context = new QueryContext(baseQuery.getContext()); - context.addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString()); - context.addDefaultParams(defaultQueryConfig.getContext()); - - this.baseQuery = baseQuery.withOverriddenContext(context.getMergedParams()); - } else { - baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString()); - baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext()); - this.baseQuery = baseQuery; + userContext = baseQuery.getContext(); + String queryId = baseQuery.getId(); + if (Strings.isNullOrEmpty(queryId)) { + queryId = UUID.randomUUID().toString(); } + + Map mergedUserAndConfigContext = QueryContexts.override(defaultQueryConfig.getContext(), baseQuery.getContext()); + mergedUserAndConfigContext.put(BaseQuery.QUERY_ID, queryId); + this.baseQuery = baseQuery.withOverriddenContext(mergedUserAndConfigContext); this.toolChest = warehouse.getToolChest(this.baseQuery); } @@ -220,12 +221,6 @@ public void initialize(final Query baseQuery) public Access authorize(HttpServletRequest req) { transition(State.INITIALIZED, State.AUTHORIZING); - final Set contextKeys; - if (baseQuery.getQueryContext() == null) { - contextKeys = baseQuery.getContext().keySet(); - } else { - contextKeys = baseQuery.getQueryContext().getUserParams().keySet(); - } final Iterable resourcesToAuthorize = Iterables.concat( Iterables.transform( baseQuery.getDataSource().getTableNames(), @@ -233,7 +228,7 @@ public Access authorize(HttpServletRequest req) ), authConfig.authorizeQueryContextParams() ? Iterables.transform( - contextKeys, + userContext.keySet(), contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE) ) : Collections.emptyList() @@ -353,7 +348,7 @@ public void emitLogsAndMetrics( if (e != null) { statsMap.put("exception", e.toString()); - if (QueryContexts.isDebug(baseQuery)) { + if (baseQuery.context().isDebug()) { log.warn(e, "Exception while processing queryId [%s]", baseQuery.getId()); } else { log.noStackTrace().warn(e, "Exception while processing queryId [%s]", baseQuery.getId()); @@ -403,9 +398,10 @@ public String threadName(String currThreadName) private boolean isSerializeDateTimeAsLong() { - final boolean shouldFinalize = QueryContexts.isFinalize(baseQuery, true); - return QueryContexts.isSerializeDateTimeAsLong(baseQuery, false) - || (!shouldFinalize && QueryContexts.isSerializeDateTimeAsLongInner(baseQuery, false)); + final QueryContext queryContext = baseQuery.context(); + final boolean shouldFinalize = queryContext.isFinalize(true); + return queryContext.isSerializeDateTimeAsLong(false) + || (!shouldFinalize && queryContext.isSerializeDateTimeAsLongInner(false)); } public ObjectWriter newOutputWriter(ResourceIOReaderWriter ioReaderWriter) diff --git a/server/src/main/java/org/apache/druid/server/QueryResource.java b/server/src/main/java/org/apache/druid/server/QueryResource.java index 1a72cfc3b8ec..a9178429cac4 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResource.java +++ b/server/src/main/java/org/apache/druid/server/QueryResource.java @@ -46,7 +46,6 @@ import org.apache.druid.query.BadQueryException; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryTimeoutException; @@ -78,6 +77,7 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.StreamingOutput; + import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -386,14 +386,7 @@ private Query readQuery( String prevEtag = getPreviousEtag(req); if (prevEtag != null) { - if (baseQuery.getQueryContext() == null) { - QueryContext context = new QueryContext(baseQuery.getContext()); - context.addSystemParam(HEADER_IF_NONE_MATCH, prevEtag); - - return baseQuery.withOverriddenContext(context.getMergedParams()); - } else { - baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag); - } + baseQuery.getContext().put(HEADER_IF_NONE_MATCH, prevEtag); } return baseQuery; diff --git a/server/src/main/java/org/apache/druid/server/QueryScheduler.java b/server/src/main/java/org/apache/druid/server/QueryScheduler.java index cc306e95fe78..762647aa86e3 100644 --- a/server/src/main/java/org/apache/druid/server/QueryScheduler.java +++ b/server/src/main/java/org/apache/druid/server/QueryScheduler.java @@ -38,7 +38,6 @@ import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryWatcher; @@ -254,7 +253,7 @@ int getLaneAvailableCapacity(String lane) @VisibleForTesting List acquireLanes(Query query) { - final String lane = QueryContexts.getLane(query); + final String lane = query.context().getLane(); final Optional laneConfig = lane == null ? Optional.empty() : laneRegistry.getConfiguration(lane); final Optional totalConfig = laneRegistry.getConfiguration(TOTAL); List hallPasses = new ArrayList<>(2); diff --git a/server/src/main/java/org/apache/druid/server/SetAndVerifyContextQueryRunner.java b/server/src/main/java/org/apache/druid/server/SetAndVerifyContextQueryRunner.java index 579484ddab65..a44aa2eb8348 100644 --- a/server/src/main/java/org/apache/druid/server/SetAndVerifyContextQueryRunner.java +++ b/server/src/main/java/org/apache/druid/server/SetAndVerifyContextQueryRunner.java @@ -22,8 +22,9 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.client.DirectDruidClient; import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.query.Queries; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.context.ResponseContext; @@ -56,21 +57,23 @@ public Sequence run(QueryPlus queryPlus, ResponseContext responseContext) public Query withTimeoutAndMaxScatterGatherBytes(Query query, ServerConfig serverConfig) { - Query newQuery = QueryContexts.verifyMaxQueryTimeout( - QueryContexts.withMaxScatterGatherBytes( - QueryContexts.withDefaultTimeout( + Query newQuery = + Queries.withMaxScatterGatherBytes( + Queries.withDefaultTimeout( query, Math.min(serverConfig.getDefaultQueryTimeout(), serverConfig.getMaxQueryTimeout()) ), serverConfig.getMaxScatterGatherBytes() - ), + ); + newQuery.context().verifyMaxQueryTimeout( serverConfig.getMaxQueryTimeout() ); // DirectDruidClient.QUERY_FAIL_TIME is used by DirectDruidClient and JsonParserIterator to determine when to // fail with a timeout exception final long failTime; - if (QueryContexts.hasTimeout(newQuery)) { - failTime = this.startTimeMillis + QueryContexts.getTimeout(newQuery); + final QueryContext context = newQuery.context(); + if (context.hasTimeout()) { + failTime = this.startTimeMillis + context.getTimeout(); } else { failTime = this.startTimeMillis + serverConfig.getMaxQueryTimeout(); } diff --git a/server/src/main/java/org/apache/druid/server/scheduling/HiLoQueryLaningStrategy.java b/server/src/main/java/org/apache/druid/server/scheduling/HiLoQueryLaningStrategy.java index cb365d82c598..150b93268050 100644 --- a/server/src/main/java/org/apache/druid/server/scheduling/HiLoQueryLaningStrategy.java +++ b/server/src/main/java/org/apache/druid/server/scheduling/HiLoQueryLaningStrategy.java @@ -26,6 +26,7 @@ import it.unimi.dsi.fastutil.objects.Object2IntMap; import org.apache.druid.client.SegmentServerSelector; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.server.QueryLaningStrategy; @@ -70,10 +71,11 @@ public Optional computeLane(QueryPlus query, Set getLaneLimits(int totalLimit) @Override public Optional computeLane(QueryPlus query, Set segments) { - return Optional.ofNullable(QueryContexts.getLane(query.getQuery())); + return Optional.ofNullable(query.getQuery().context().getLane()); } } diff --git a/server/src/main/java/org/apache/druid/server/scheduling/NoQueryLaningStrategy.java b/server/src/main/java/org/apache/druid/server/scheduling/NoQueryLaningStrategy.java index 8f830d6b555d..2ae1ec33a6c1 100644 --- a/server/src/main/java/org/apache/druid/server/scheduling/NoQueryLaningStrategy.java +++ b/server/src/main/java/org/apache/druid/server/scheduling/NoQueryLaningStrategy.java @@ -22,7 +22,6 @@ import it.unimi.dsi.fastutil.objects.Object2IntArrayMap; import it.unimi.dsi.fastutil.objects.Object2IntMap; import org.apache.druid.client.SegmentServerSelector; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.server.QueryLaningStrategy; @@ -47,6 +46,6 @@ public Object2IntMap getLaneLimits(int totalLimit) @Override public Optional computeLane(QueryPlus query, Set segments) { - return Optional.ofNullable(QueryContexts.getLane(query.getQuery())); + return Optional.ofNullable(query.getQuery().context().getLane()); } } diff --git a/server/src/main/java/org/apache/druid/server/scheduling/ThresholdBasedQueryPrioritizationStrategy.java b/server/src/main/java/org/apache/druid/server/scheduling/ThresholdBasedQueryPrioritizationStrategy.java index 9469dcdc59d7..3f76951352c7 100644 --- a/server/src/main/java/org/apache/druid/server/scheduling/ThresholdBasedQueryPrioritizationStrategy.java +++ b/server/src/main/java/org/apache/druid/server/scheduling/ThresholdBasedQueryPrioritizationStrategy.java @@ -25,7 +25,6 @@ import org.apache.druid.client.SegmentServerSelector; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.server.QueryPrioritizationStrategy; import org.joda.time.DateTime; @@ -33,6 +32,7 @@ import org.joda.time.Period; import javax.annotation.Nullable; + import java.util.Optional; import java.util.Set; @@ -87,7 +87,7 @@ public Optional computePriority(QueryPlus query, Set segmentCountThreshold; if (violatesPeriodThreshold || violatesDurationThreshold || violatesSegmentThreshold) { - final int adjustedPriority = QueryContexts.getPriority(theQuery) - adjustment; + final int adjustedPriority = theQuery.context().getPriority() - adjustment; return Optional.of(adjustedPriority); } return Optional.empty(); diff --git a/server/src/test/java/org/apache/druid/client/CachingClusteredClientCacheKeyManagerTest.java b/server/src/test/java/org/apache/druid/client/CachingClusteredClientCacheKeyManagerTest.java index 4c43a009384c..14b757bfaf67 100644 --- a/server/src/test/java/org/apache/druid/client/CachingClusteredClientCacheKeyManagerTest.java +++ b/server/src/test/java/org/apache/druid/client/CachingClusteredClientCacheKeyManagerTest.java @@ -19,12 +19,14 @@ package org.apache.druid.client; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Bytes; import org.apache.druid.client.selector.QueryableDruidServer; import org.apache.druid.client.selector.ServerSelector; import org.apache.druid.query.CacheStrategy; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.segment.join.JoinableFactoryWrapper; @@ -43,7 +45,6 @@ import java.util.Optional; import java.util.Set; -import static org.apache.druid.query.QueryContexts.DEFAULT_BY_SEGMENT; import static org.easymock.EasyMock.expect; import static org.easymock.EasyMock.replay; import static org.easymock.EasyMock.reset; @@ -67,7 +68,7 @@ public class CachingClusteredClientCacheKeyManagerTest extends EasyMockSupport public void setup() { expect(strategy.computeCacheKey(query)).andReturn(QUERY_CACHE_KEY).anyTimes(); - expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(false).anyTimes(); + expect(query.context()).andReturn(QueryContext.of(ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, false))).anyTimes(); } @After @@ -203,7 +204,7 @@ public void testComputeEtag_noEffectifBySegment() { expect(dataSourceAnalysis.isJoin()).andReturn(false); reset(query); - expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(true).anyTimes(); + expect(query.context()).andReturn(QueryContext.of(ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, true))).anyTimes(); replayAll(); CachingClusteredClient.CacheKeyManager keyManager = makeKeyManager(); Set selectors = ImmutableSet.of( @@ -272,7 +273,7 @@ public void testSegmentQueryCacheKey_joinWithSupportedCaching() public void testSegmentQueryCacheKey_noCachingIfBySegment() { reset(query); - expect(query.getContextBoolean(QueryContexts.BY_SEGMENT_KEY, DEFAULT_BY_SEGMENT)).andReturn(true).anyTimes(); + expect(query.context()).andReturn(QueryContext.of(ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, true))).anyTimes(); replayAll(); byte[] cacheKey = makeKeyManager().computeSegmentLevelQueryCacheKey(); Assert.assertNull(cacheKey); diff --git a/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java b/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java index a9ee86f1ccce..ca4547858948 100644 --- a/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java +++ b/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java @@ -72,6 +72,7 @@ import org.apache.druid.query.Druids; import org.apache.druid.query.FinalizeResultsQueryRunner; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; @@ -2297,12 +2298,13 @@ public Iterable>> apply(@Nullable Integer input) for (Capture queryCapture : queryCaptures) { QueryPlus capturedQueryPlus = (QueryPlus) queryCapture.getValue(); Query capturedQuery = capturedQueryPlus.getQuery(); + final QueryContext queryContext = capturedQuery.context(); if (expectBySegment) { - Assert.assertEquals(true, capturedQuery.getQueryContext().getAsBoolean(QueryContexts.BY_SEGMENT_KEY)); + Assert.assertEquals(true, queryContext.getBoolean(QueryContexts.BY_SEGMENT_KEY)); } else { Assert.assertTrue( - capturedQuery.getContextValue(QueryContexts.BY_SEGMENT_KEY) == null || - capturedQuery.getQueryContext().getAsBoolean(QueryContexts.BY_SEGMENT_KEY).equals(false) + queryContext.get(QueryContexts.BY_SEGMENT_KEY) == null || + !queryContext.getBoolean(QueryContexts.BY_SEGMENT_KEY) ); } } diff --git a/server/src/test/java/org/apache/druid/client/JsonParserIteratorTest.java b/server/src/test/java/org/apache/druid/client/JsonParserIteratorTest.java index bd39d8f6add7..b6010c7b94ee 100644 --- a/server/src/test/java/org/apache/druid/client/JsonParserIteratorTest.java +++ b/server/src/test/java/org/apache/druid/client/JsonParserIteratorTest.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.Futures; import org.apache.druid.jackson.DefaultObjectMapper; @@ -309,13 +310,8 @@ private Query mockQuery(String queryId, long timeoutAt) Query query = Mockito.mock(Query.class); QueryContext context = Mockito.mock(QueryContext.class); Mockito.when(query.getId()).thenReturn(queryId); - Mockito.when(query.getQueryContext()).thenReturn(context); - Mockito.when( - context.getAsLong( - ArgumentMatchers.eq(DirectDruidClient.QUERY_FAIL_TIME), - ArgumentMatchers.eq(-1L) - ) - ).thenReturn(timeoutAt); + Mockito.when(query.context()).thenReturn( + QueryContext.of(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, timeoutAt))); return query; } } diff --git a/server/src/test/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManagerTest.java b/server/src/test/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManagerTest.java index b95a00eff942..e27a66a20239 100644 --- a/server/src/test/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManagerTest.java +++ b/server/src/test/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManagerTest.java @@ -119,8 +119,6 @@ public void setup() @Test public void test_getBundle_knownDataSource() { - - final UnifiedIndexerAppenderatorsManager.DatasourceBundle bundle = manager.getBundle( Druids.newScanQueryBuilder() .dataSource(appenderator.getDataSource()) diff --git a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java index 1d1840b6c72e..05efe30e0c28 100644 --- a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java +++ b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java @@ -55,6 +55,9 @@ import javax.servlet.http.HttpServletRequest; +import java.util.HashMap; +import java.util.Map; + public class QueryLifecycleTest { private static final String DATASOURCE = "some_datasource"; @@ -197,21 +200,23 @@ public void testAuthorizeQueryContext_authorized() replayAll(); + final Map userContext = ImmutableMap.of("foo", "bar", "baz", "qux"); final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() .dataSource(DATASOURCE) .intervals(ImmutableList.of(Intervals.ETERNITY)) .aggregators(new CountAggregatorFactory("chocula")) - .context(ImmutableMap.of("foo", "bar", "baz", "qux")) + .context(userContext) .build(); lifecycle.initialize(query); + final Map revisedContext = new HashMap<>(lifecycle.getQuery().getContext()); + Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId")); + revisedContext.remove("queryId"); Assert.assertEquals( - ImmutableMap.of("foo", "bar", "baz", "qux"), - lifecycle.getQuery().getQueryContext().getUserParams() + userContext, + revisedContext ); - Assert.assertTrue(lifecycle.getQuery().getQueryContext().getMergedParams().containsKey("queryId")); - Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId")); Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed()); } @@ -257,9 +262,6 @@ public void testAuthorizeLegacyQueryContext_authorized() EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE)) .andReturn(Access.OK); EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE)).andReturn(Access.OK); - // to use legacy query context with context authorization, even system generated things like queryId need to be explicitly added - EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("queryId", ResourceType.QUERY_CONTEXT), Action.WRITE)) - .andReturn(Access.OK); EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject())) .andReturn(toolChest) @@ -271,10 +273,11 @@ public void testAuthorizeLegacyQueryContext_authorized() lifecycle.initialize(query); - Assert.assertNull(lifecycle.getQuery().getQueryContext()); - Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("foo")); - Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("baz")); - Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId")); + final Map revisedContext = lifecycle.getQuery().getContext(); + Assert.assertNotNull(revisedContext); + Assert.assertTrue(revisedContext.containsKey("foo")); + Assert.assertTrue(revisedContext.containsKey("baz")); + Assert.assertTrue(revisedContext.containsKey("queryId")); Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed()); } diff --git a/server/src/test/java/org/apache/druid/server/QuerySchedulerTest.java b/server/src/test/java/org/apache/druid/server/QuerySchedulerTest.java index 9b61145363ab..571684cacc2f 100644 --- a/server/src/test/java/org/apache/druid/server/QuerySchedulerTest.java +++ b/server/src/test/java/org/apache/druid/server/QuerySchedulerTest.java @@ -48,7 +48,6 @@ import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.topn.TopNQuery; @@ -150,7 +149,7 @@ public void testHiLoLo() throws ExecutionException, InterruptedException try { Query scheduledReport = scheduler.prioritizeAndLaneQuery(QueryPlus.wrap(report), ImmutableSet.of()); Assert.assertNotNull(scheduledReport); - Assert.assertEquals(HiLoQueryLaningStrategy.LOW, QueryContexts.getLane(scheduledReport)); + Assert.assertEquals(HiLoQueryLaningStrategy.LOW, scheduledReport.context().getLane()); Sequence underlyingSequence = makeSequence(10); underlyingSequence = Sequences.wrap(underlyingSequence, new SequenceWrapper() @@ -412,8 +411,8 @@ public void testConfigHiLoWithThreshold() EasyMock.createMock(SegmentServerSelector.class) ) ); - Assert.assertEquals(-5, QueryContexts.getPriority(query)); - Assert.assertEquals(HiLoQueryLaningStrategy.LOW, QueryContexts.getLane(query)); + Assert.assertEquals(-5, query.context().getPriority()); + Assert.assertEquals(HiLoQueryLaningStrategy.LOW, query.context().getLane()); } @Test diff --git a/server/src/test/java/org/apache/druid/server/SetAndVerifyContextQueryRunnerTest.java b/server/src/test/java/org/apache/druid/server/SetAndVerifyContextQueryRunnerTest.java index 59853bc7c5d9..81abd5186cb2 100644 --- a/server/src/test/java/org/apache/druid/server/SetAndVerifyContextQueryRunnerTest.java +++ b/server/src/test/java/org/apache/druid/server/SetAndVerifyContextQueryRunnerTest.java @@ -36,7 +36,6 @@ public class SetAndVerifyContextQueryRunnerTest { - @Test public void testTimeoutIsUsedIfTimeoutIsNonZero() throws InterruptedException { @@ -58,7 +57,7 @@ public void testTimeoutIsUsedIfTimeoutIsNonZero() throws InterruptedException // time + 1 at the time the method was called // this means that after sleeping for 1 millis, the fail time should be less than the current time when checking Assert.assertTrue( - System.currentTimeMillis() > transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME) + System.currentTimeMillis() > transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME) ); } @@ -85,7 +84,7 @@ public long getDefaultQueryTimeout() Query transformed = queryRunner.withTimeoutAndMaxScatterGatherBytes(query, defaultConfig); // timeout is not set, default timeout has been set to long.max, make sure timeout is still in the future - Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME)); + Assert.assertEquals(Long.MAX_VALUE, (long) transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME)); } @Test @@ -107,7 +106,7 @@ public void testTimeoutZeroIsNotImmediateTimeoutDefaultServersideMax() // timeout is set to 0, so withTimeoutAndMaxScatterGatherBytes should set QUERY_FAIL_TIME to be the current // time + max query timeout at the time the method was called // since default is long max, expect long max since current time would overflow - Assert.assertEquals((Long) Long.MAX_VALUE, transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME)); + Assert.assertEquals(Long.MAX_VALUE, (long) transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME)); } @Test @@ -137,7 +136,7 @@ public long getMaxQueryTimeout() // time + max query timeout at the time the method was called // this means that the fail time should be greater than the current time when checking Assert.assertTrue( - System.currentTimeMillis() < (Long) transformed.getQueryContext().getAsLong(DirectDruidClient.QUERY_FAIL_TIME) + System.currentTimeMillis() < transformed.context().getLong(DirectDruidClient.QUERY_FAIL_TIME) ); } } diff --git a/services/src/main/java/org/apache/druid/server/router/ManualTieredBrokerSelectorStrategy.java b/services/src/main/java/org/apache/druid/server/router/ManualTieredBrokerSelectorStrategy.java index c16ec0c035b6..5569946d0af2 100644 --- a/services/src/main/java/org/apache/druid/server/router/ManualTieredBrokerSelectorStrategy.java +++ b/services/src/main/java/org/apache/druid/server/router/ManualTieredBrokerSelectorStrategy.java @@ -26,11 +26,11 @@ import org.apache.commons.lang.StringUtils; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.sql.http.SqlQuery; import javax.annotation.Nullable; -import java.util.Map; /** * Implementation of {@link TieredBrokerSelectorStrategy} which uses the parameter @@ -58,15 +58,15 @@ public ManualTieredBrokerSelectorStrategy( } @Override - public Optional getBrokerServiceName(TieredBrokerConfig tierConfig, Query query) + public Optional getBrokerServiceName(TieredBrokerConfig tierConfig, Query query) { - return getBrokerServiceName(tierConfig, query.getContext()); + return getBrokerServiceName(tierConfig, query.context()); } @Override public Optional getBrokerServiceName(TieredBrokerConfig config, SqlQuery sqlQuery) { - return getBrokerServiceName(config, sqlQuery.getContext()); + return getBrokerServiceName(config, sqlQuery.queryContext()); } /** @@ -74,11 +74,11 @@ public Optional getBrokerServiceName(TieredBrokerConfig config, SqlQuery */ private Optional getBrokerServiceName( TieredBrokerConfig tierConfig, - Map queryContext + QueryContext queryContext ) { try { - final String contextBrokerService = QueryContexts.getBrokerServiceName(queryContext); + final String contextBrokerService = queryContext.getBrokerServiceName(); if (isValidBrokerService(contextBrokerService, tierConfig)) { // If the broker service in the query context is valid, use that diff --git a/services/src/main/java/org/apache/druid/server/router/PriorityTieredBrokerSelectorStrategy.java b/services/src/main/java/org/apache/druid/server/router/PriorityTieredBrokerSelectorStrategy.java index 5b6a8bb0ed9c..fd921028320d 100644 --- a/services/src/main/java/org/apache/druid/server/router/PriorityTieredBrokerSelectorStrategy.java +++ b/services/src/main/java/org/apache/druid/server/router/PriorityTieredBrokerSelectorStrategy.java @@ -23,7 +23,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Optional; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; /** */ @@ -45,7 +44,7 @@ public PriorityTieredBrokerSelectorStrategy( @Override public Optional getBrokerServiceName(TieredBrokerConfig tierConfig, Query query) { - final int priority = QueryContexts.getPriority(query); + final int priority = query.context().getPriority(); if (priority < minPriority || priority > maxPriority) { return Optional.of( diff --git a/services/src/main/java/org/apache/druid/server/router/TieredBrokerHostSelector.java b/services/src/main/java/org/apache/druid/server/router/TieredBrokerHostSelector.java index ae03665153ac..8819f3ea4a64 100644 --- a/services/src/main/java/org/apache/druid/server/router/TieredBrokerHostSelector.java +++ b/services/src/main/java/org/apache/druid/server/router/TieredBrokerHostSelector.java @@ -36,7 +36,6 @@ import org.apache.druid.java.util.common.lifecycle.LifecycleStop; import org.apache.druid.java.util.emitter.EmittingLogger; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContexts; import org.apache.druid.server.coordinator.rules.LoadRule; import org.apache.druid.server.coordinator.rules.Rule; import org.apache.druid.sql.http.SqlQuery; @@ -291,7 +290,7 @@ public Pair selectForSql(SqlQuery sqlQuery) brokerServiceName = tierConfig.getDefaultBrokerServiceName(); // Log if query debugging is enabled - if (QueryContexts.isDebug(sqlQuery.getContext())) { + if (sqlQuery.queryContext().isDebug()) { log.info( "No brokerServiceName found for SQL Query [%s], Context [%s]. Using default selector for [%s].", sqlQuery.getQuery(), diff --git a/services/src/main/java/org/apache/druid/server/router/TieredBrokerSelectorStrategy.java b/services/src/main/java/org/apache/druid/server/router/TieredBrokerSelectorStrategy.java index aee4ef88c9a5..af7a057a25a0 100644 --- a/services/src/main/java/org/apache/druid/server/router/TieredBrokerSelectorStrategy.java +++ b/services/src/main/java/org/apache/druid/server/router/TieredBrokerSelectorStrategy.java @@ -37,7 +37,6 @@ public interface TieredBrokerSelectorStrategy { - /** * Tries to determine the name of the Broker service to which the given native * query should be routed. @@ -46,7 +45,7 @@ public interface TieredBrokerSelectorStrategy * @param query Native (JSON) query to be routed * @return An empty Optional if the service name could not be determined. */ - Optional getBrokerServiceName(TieredBrokerConfig config, Query query); + Optional getBrokerServiceName(TieredBrokerConfig config, Query query); /** * Tries to determine the name of the Broker service to which the given SqlQuery diff --git a/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java b/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java index 99605918e234..be84970bb533 100644 --- a/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java @@ -22,7 +22,6 @@ import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.tools.ValidationException; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.server.security.Access; import org.apache.druid.server.security.AuthorizationUtils; @@ -32,7 +31,9 @@ import org.apache.druid.sql.calcite.planner.PlannerContext; import java.io.Closeable; +import java.util.Map; import java.util.Set; +import java.util.TreeMap; import java.util.UUID; import java.util.function.Function; @@ -41,6 +42,13 @@ * A statement is given by a lifecycle context and the statement * to execute. See derived classes for actions. Closing the statement * emits logs and metrics for the statement. + *

+ * The query context has a complex lifecycle. It is provided in the SQL query + * request ({@link SqlQueryPlus}), then modified during planning. The + * user-provided copy is immutable: a copy is made in this class, and that + * copy is the one which the planner adjusts as planning proceeds. Context + * authorization, if enabled, is done based on the user-provided context keys, + * not the internally-defined context. */ public abstract class AbstractStatement implements Closeable { @@ -49,6 +57,14 @@ public abstract class AbstractStatement implements Closeable protected final SqlToolbox sqlToolbox; protected final SqlQueryPlus queryPlus; protected final SqlExecutionReporter reporter; + + /** + * Copy of the query context modified during planning. Modifications are + * valid in tasks that run in the request thread. Once the query forks + * child threads, then concurrent modifications to the query context will + * result in an undefined muddle of race conditions. + */ + protected final Map queryContext; protected PlannerContext plannerContext; /** @@ -71,29 +87,36 @@ public AbstractStatement( ) { this.sqlToolbox = sqlToolbox; - this.queryPlus = queryPlus; this.reporter = new SqlExecutionReporter(this, remoteAddress); + this.queryPlus = queryPlus; - // Context is modified, not copied. - contextWithSqlId(queryPlus.context()) - .addDefaultParams(sqlToolbox.defaultQueryConfig.getContext()); - } + // TreeMap is required to get consistent ordering of keys, as needed by tests. + this.queryContext = new TreeMap<>(queryPlus.context()); - private static QueryContext contextWithSqlId(QueryContext queryContext) - { // "bySegment" results are never valid to use with SQL because the result format is incompatible // so, overwrite any user specified context to avoid exceptions down the line - if (queryContext.removeUserParam(QueryContexts.BY_SEGMENT_KEY) != null) { + if (this.queryContext.remove(QueryContexts.BY_SEGMENT_KEY) != null) { log.warn("'bySegment' results are not supported for SQL queries, ignoring query context parameter"); } - queryContext.addDefaultParam(PlannerContext.CTX_SQL_QUERY_ID, UUID.randomUUID().toString()); - return queryContext; + this.queryContext.putIfAbsent(PlannerContext.CTX_SQL_QUERY_ID, UUID.randomUUID().toString()); + for (Map.Entry entry : sqlToolbox.defaultQueryConfig.getContext().entrySet()) { + this.queryContext.putIfAbsent(entry.getKey(), entry.getValue()); + } } public String sqlQueryId() { - return queryPlus.context().getAsString(PlannerContext.CTX_SQL_QUERY_ID); + return QueryContexts.parseString(queryContext, PlannerContext.CTX_SQL_QUERY_ID); + } + + /** + * Returns the context as it evolves during planning. In general, this copy will not + * be the same as the one from {@code getQuery().context()}. + */ + public Map context() + { + return queryContext; } /** @@ -101,7 +124,7 @@ public String sqlQueryId() * will take part in the query. Must be called by the API methods, not * directly. */ - protected void validate(DruidPlanner planner) + protected void validate(final DruidPlanner planner) { plannerContext = planner.getPlannerContext(); plannerContext.setAuthenticationResult(queryPlus.authResult()); @@ -124,8 +147,8 @@ protected void validate(DruidPlanner planner) * context variables as well as query resources. */ protected void authorize( - DruidPlanner planner, - Function, Access> authorizer + final DruidPlanner planner, + final Function, Access> authorizer ) { boolean authorizeContextParams = sqlToolbox.authConfig.authorizeQueryContextParams(); diff --git a/sql/src/main/java/org/apache/druid/sql/DirectStatement.java b/sql/src/main/java/org/apache/druid/sql/DirectStatement.java index 507216c2368b..776dc3091632 100644 --- a/sql/src/main/java/org/apache/druid/sql/DirectStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/DirectStatement.java @@ -206,7 +206,11 @@ public ResultSet plan() try (DruidPlanner planner = sqlToolbox.plannerFactory.createPlanner( sqlToolbox.engine, queryPlus.sql(), - queryPlus.context())) { + queryContext, + // Context keys for authorization. Use the user-provided keys, + // NOT the keys from the query context which, by this point, + // will have been extended with internally-defined values. + queryPlus.context().keySet())) { validate(planner); authorize(planner, authorizer()); diff --git a/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java b/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java index b68d6160d3b6..3ba0a1bca9f1 100644 --- a/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java @@ -68,7 +68,8 @@ public PrepareResult prepare() try (DruidPlanner planner = sqlToolbox.plannerFactory.createPlanner( sqlToolbox.engine, queryPlus.sql(), - queryPlus.context())) { + queryContext, + queryPlus.context().keySet())) { validate(planner); authorize(planner, authorizer()); diff --git a/sql/src/main/java/org/apache/druid/sql/SqlExecutionReporter.java b/sql/src/main/java/org/apache/druid/sql/SqlExecutionReporter.java index c3280266bea7..0d7646d5f076 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlExecutionReporter.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlExecutionReporter.java @@ -24,7 +24,6 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.server.QueryStats; @@ -123,16 +122,12 @@ public void emit() statsMap.put("sqlQuery/planningTimeMs", TimeUnit.NANOSECONDS.toMillis(planningTimeNanos)); statsMap.put("sqlQuery/bytes", bytesWritten); statsMap.put("success", success); - QueryContext queryContext; - if (plannerContext == null) { - queryContext = stmt.queryPlus.context(); - } else { + Map queryContext = stmt.queryContext; + if (plannerContext != null) { statsMap.put("identity", plannerContext.getAuthenticationResult().getIdentity()); - queryContext = stmt.queryPlus.context(); - queryContext.addSystemParam("nativeQueryIds", plannerContext.getNativeQueryIds().toString()); + queryContext.put("nativeQueryIds", plannerContext.getNativeQueryIds().toString()); } - final Map context = queryContext.getMergedParams(); - statsMap.put("context", context); + statsMap.put("context", queryContext); if (e != null) { statsMap.put("exception", e.toString()); @@ -145,7 +140,7 @@ public void emit() stmt.sqlToolbox.requestLogger.logSqlQuery( RequestLogLine.forSql( stmt.queryPlus.sql(), - context, + queryContext, DateTimes.utc(startMs), remoteAddress, new QueryStats(statsMap) diff --git a/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java b/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java index ebd43fb6a37a..b428777b306c 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java @@ -21,12 +21,12 @@ import com.google.common.base.Preconditions; import org.apache.calcite.avatica.remote.TypedValue; -import org.apache.druid.query.QueryContext; import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.sql.http.SqlParameter; import org.apache.druid.sql.http.SqlQuery; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -39,25 +39,31 @@ * SQL requests come from a variety of sources in a variety of formats. Use * the {@link Builder} class to create an instance from the information * available at each point in the code. + *

+ * The query context has a complex lifecycle. The copy here should remain + * unchanged: this is the set of values which the user requested. Planning will + * add (and sometimes remove) values: that work should be done on a copy of the + * context so that we have a clean record of the user's original requested + * values. */ public class SqlQueryPlus { private final String sql; - private final QueryContext queryContext; + private final Map queryContext; private final List parameters; private final AuthenticationResult authResult; public SqlQueryPlus( String sql, - QueryContext queryContext, + Map queryContext, List parameters, AuthenticationResult authResult ) { this.sql = Preconditions.checkNotNull(sql); this.queryContext = queryContext == null - ? new QueryContext() - : queryContext; + ? Collections.emptyMap() + : new HashMap<>(queryContext); this.parameters = parameters == null ? Collections.emptyList() : parameters; @@ -84,7 +90,7 @@ public String sql() return sql; } - public QueryContext context() + public Map context() { return queryContext; } @@ -99,14 +105,9 @@ public AuthenticationResult authResult() return authResult; } - public SqlQueryPlus withContext(QueryContext context) - { - return new SqlQueryPlus(sql, context, parameters, authResult); - } - public SqlQueryPlus withContext(Map context) { - return new SqlQueryPlus(sql, new QueryContext(context), parameters, authResult); + return new SqlQueryPlus(sql, context, parameters, authResult); } public SqlQueryPlus withParameters(List parameters) @@ -117,7 +118,7 @@ public SqlQueryPlus withParameters(List parameters) public static class Builder { private String sql; - private QueryContext queryContext; + private Map queryContext; private List parameters; private AuthenticationResult authResult; @@ -130,20 +131,14 @@ public Builder sql(String sql) public Builder query(SqlQuery sqlQuery) { this.sql = sqlQuery.getQuery(); - this.queryContext = new QueryContext(sqlQuery.getContext()); + this.queryContext = sqlQuery.getContext(); this.parameters = sqlQuery.getParameterList(); return this; } - public Builder context(QueryContext queryContext) - { - this.queryContext = queryContext; - return this; - } - public Builder context(Map queryContext) { - this.queryContext = queryContext == null ? null : new QueryContext(queryContext); + this.queryContext = queryContext; return this; } diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java index a792e28b3a88..7cbeecd344f3 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java @@ -21,15 +21,14 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.concurrent.GuardedBy; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.query.QueryContext; import org.apache.druid.sql.PreparedStatement; import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlStatementFactory; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -46,8 +45,8 @@ public class DruidConnection private final String connectionId; private final int maxStatements; - private final ImmutableMap userSecret; - private final QueryContext context; + private final Map userSecret; + private final Map context; private final AtomicInteger statementCounter = new AtomicInteger(); private final AtomicReference> timeoutFuture = new AtomicReference<>(); @@ -64,12 +63,12 @@ public DruidConnection( final String connectionId, final int maxStatements, final Map userSecret, - final QueryContext context + final Map context ) { this.connectionId = Preconditions.checkNotNull(connectionId); this.maxStatements = maxStatements; - this.userSecret = ImmutableMap.copyOf(userSecret); + this.userSecret = userSecret; this.context = context; } @@ -97,7 +96,7 @@ public DruidJdbcStatement createStatement(SqlStatementFactory sqlStatementFactor final DruidJdbcStatement statement = new DruidJdbcStatement( connectionId, statementId, - context.copy(), + new HashMap(context), sqlStatementFactory ); @@ -127,7 +126,7 @@ public DruidJdbcPreparedStatement createPreparedStatement( @SuppressWarnings("GuardedBy") final PreparedStatement statement = sqlStatementFactory.preparedStatement( - sqlQueryPlus.withContext(context.copy()) + sqlQueryPlus.withContext(new HashMap(context)) ); final DruidJdbcPreparedStatement jdbcStmt = new DruidJdbcPreparedStatement( connectionId, diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java index 4d6fe45207cb..3eda8393154f 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java @@ -21,11 +21,12 @@ import com.google.common.base.Preconditions; import org.apache.calcite.avatica.Meta; -import org.apache.druid.query.QueryContext; import org.apache.druid.sql.DirectStatement; import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlStatementFactory; +import java.util.Map; + /** * Represents Druid's version of the JDBC {@code Statement} class: * can be executed multiple times, one after another, producing a @@ -34,12 +35,12 @@ public class DruidJdbcStatement extends AbstractDruidJdbcStatement { private final SqlStatementFactory lifecycleFactory; - protected final QueryContext queryContext; + protected final Map queryContext; public DruidJdbcStatement( final String connectionId, final int statementId, - final QueryContext queryContext, + final Map queryContext, final SqlStatementFactory lifecycleFactory ) { diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java index 75e4d70c5b70..f26a1f5837e5 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java @@ -42,7 +42,6 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.query.QueryContext; import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.server.security.Authenticator; import org.apache.druid.server.security.AuthenticatorMapper; @@ -55,6 +54,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; + import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; @@ -165,20 +165,19 @@ public void openConnection(final ConnectionHandle ch, final Map try { // Build connection context. final Map secret = new HashMap<>(); - final Map contextMap = new HashMap<>(); + final ImmutableMap.Builder context = ImmutableMap.builder(); if (info != null) { for (Map.Entry entry : info.entrySet()) { if (SENSITIVE_CONTEXT_FIELDS.contains(entry.getKey())) { secret.put(entry.getKey(), entry.getValue()); } else { - contextMap.put(entry.getKey(), entry.getValue()); + context.put(entry.getKey(), entry.getValue()); } } } // we don't want to stringify arrays for JDBC ever because Avatica needs to handle this - final QueryContext context = new QueryContext(contextMap); - context.addSystemParam(PlannerContext.CTX_SQL_STRINGIFY_ARRAYS, false); - openDruidConnection(ch.id, secret, context); + context.put(PlannerContext.CTX_SQL_STRINGIFY_ARRAYS, false); + openDruidConnection(ch.id, secret, context.build()); } catch (NoSuchConnectionException e) { throw e; @@ -776,7 +775,7 @@ private AuthenticationResult authenticateConnection(final DruidConnection connec private DruidConnection openDruidConnection( final String connectionId, final Map userSecret, - final QueryContext context + final Map context ) { if (connectionCount.incrementAndGet() > config.getMaxConnections()) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java index 68db56dd37c0..a31375bad8ae 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java @@ -42,6 +42,7 @@ import java.io.Closeable; import java.util.HashSet; +import java.util.Map; import java.util.Set; import java.util.function.Function; @@ -93,7 +94,7 @@ public void validate() throws SqlParseException, ValidationException Preconditions.checkState(state == State.START); // Validate query context. - engine.validateContext(plannerContext.getQueryContext()); + engine.validateContext(plannerContext.queryContextMap()); // Parse the query string. SqlNode root = planner.parse(plannerContext.getSql()); @@ -189,7 +190,7 @@ public Set resourceActions(boolean includeContext) Set resourceActions = plannerContext.getResourceActions(); if (includeContext) { Set actions = new HashSet<>(resourceActions); - plannerContext.getQueryContext().getUserParams().keySet().forEach(contextParam -> actions.add( + plannerContext.queryContextKeys().forEach(contextParam -> actions.add( new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE) )); return actions; @@ -253,7 +254,13 @@ public CalcitePlanner planner() @Override public QueryContext queryContext() { - return plannerContext.getQueryContext(); + return plannerContext.queryContext(); + } + + @Override + public Map queryContextMap() + { + return plannerContext.queryContextMap(); } @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java index e59b8cf5e7f6..80dad64f89f2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java @@ -122,7 +122,7 @@ public void validate() throws ValidationException try { PlannerContext plannerContext = handlerContext.plannerContext(); if (ingestionGranularity != null) { - plannerContext.getQueryContext().addSystemParam( + plannerContext.queryContextMap().put( DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY, plannerContext.getJsonMapper().writeValueAsString(ingestionGranularity) ); @@ -134,7 +134,7 @@ public void validate() throws ValidationException super.validate(); // Check if CTX_SQL_OUTER_LIMIT is specified and fail the query if it is. CTX_SQL_OUTER_LIMIT being provided causes // the number of rows inserted to be limited which is likely to be confusing and unintended. - if (handlerContext.queryContext().get(PlannerContext.CTX_SQL_OUTER_LIMIT) != null) { + if (handlerContext.queryContextMap().get(PlannerContext.CTX_SQL_OUTER_LIMIT) != null) { throw new ValidationException( StringUtils.format( "%s cannot be provided with %s.", @@ -336,7 +336,7 @@ public void validate() throws ValidationException handlerContext.timeZone()); super.validate(); if (replaceIntervals != null) { - handlerContext.queryContext().addSystemParam( + handlerContext.queryContextMap().put( DruidSqlReplace.SQL_REPLACE_TIME_CHUNKS, String.join(",", replaceIntervals) ); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerConfig.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerConfig.java index 21fe70799bb5..2e6b2d5217f5 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerConfig.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerConfig.java @@ -21,10 +21,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.apache.druid.java.util.common.UOE; -import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryContexts; import org.joda.time.DateTimeZone; import org.joda.time.Period; +import java.util.Map; import java.util.Objects; public class PlannerConfig @@ -169,7 +170,7 @@ public boolean isForceExpressionVirtualColumns() return forceExpressionVirtualColumns; } - public PlannerConfig withOverrides(final QueryContext queryContext) + public PlannerConfig withOverrides(final Map queryContext) { if (queryContext.isEmpty()) { return this; @@ -371,33 +372,40 @@ public Builder metadataRefreshPeriod(String value) return this; } - public Builder withOverrides(final QueryContext queryContext) + public Builder withOverrides(final Map queryContext) { - useApproximateCountDistinct = queryContext.getAsBoolean( + useApproximateCountDistinct = QueryContexts.parseBoolean( + queryContext, CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT, useApproximateCountDistinct ); - useGroupingSetForExactDistinct = queryContext.getAsBoolean( + useGroupingSetForExactDistinct = QueryContexts.parseBoolean( + queryContext, CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT, useGroupingSetForExactDistinct ); - useApproximateTopN = queryContext.getAsBoolean( + useApproximateTopN = QueryContexts.parseBoolean( + queryContext, CTX_KEY_USE_APPROXIMATE_TOPN, useApproximateTopN ); - computeInnerJoinCostAsFilter = queryContext.getAsBoolean( + computeInnerJoinCostAsFilter = QueryContexts.parseBoolean( + queryContext, CTX_COMPUTE_INNER_JOIN_COST_AS_FILTER, computeInnerJoinCostAsFilter ); - useNativeQueryExplain = queryContext.getAsBoolean( + useNativeQueryExplain = QueryContexts.parseBoolean( + queryContext, CTX_KEY_USE_NATIVE_QUERY_EXPLAIN, useNativeQueryExplain ); - forceExpressionVirtualColumns = queryContext.getAsBoolean( + forceExpressionVirtualColumns = QueryContexts.parseBoolean( + queryContext, CTX_KEY_FORCE_EXPRESSION_VIRTUAL_COLUMNS, forceExpressionVirtualColumns ); - final int queryContextMaxNumericInFilters = queryContext.getAsInt( + final int queryContextMaxNumericInFilters = QueryContexts.parseInt( + queryContext, CTX_MAX_NUMERIC_IN_FILTERS, maxNumericInFilters ); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java index 616823b1a66a..4c6bb50424fc 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java @@ -48,6 +48,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; + import java.util.Collections; import java.util.List; import java.util.Map; @@ -82,7 +83,8 @@ public class PlannerContext private final DateTime localNow; private final DruidSchemaCatalog rootSchema; private final SqlEngine engine; - private final QueryContext queryContext; + private final Map queryContext; + private final Set contextKeys; private final String sqlQueryId; private final boolean stringifyArrays; private final CopyOnWriteArrayList nativeQueryIds = new CopyOnWriteArrayList<>(); @@ -110,7 +112,8 @@ private PlannerContext( final boolean stringifyArrays, final DruidSchemaCatalog rootSchema, final SqlEngine engine, - final QueryContext queryContext + final Map queryContext, + final Set contextKeys ) { this.sql = sql; @@ -121,6 +124,7 @@ private PlannerContext( this.rootSchema = rootSchema; this.engine = engine; this.queryContext = queryContext; + this.contextKeys = contextKeys; this.localNow = Preconditions.checkNotNull(localNow, "localNow"); this.stringifyArrays = stringifyArrays; @@ -140,7 +144,8 @@ public static PlannerContext create( final PlannerConfig plannerConfig, final DruidSchemaCatalog rootSchema, final SqlEngine engine, - final QueryContext queryContext + final Map queryContext, + final Set contextKeys ) { final DateTime utcNow; @@ -179,7 +184,8 @@ public static PlannerContext create( stringifyArrays, rootSchema, engine, - queryContext + queryContext, + contextKeys ); } @@ -219,11 +225,34 @@ public String getSchemaResourceType(String schema, String resourceName) return rootSchema.getResourceType(schema, resourceName); } - public QueryContext getQueryContext() + /** + * Return the query context as a mutable map. Use this form when + * modifying the context during planning. + */ + public Map queryContextMap() { return queryContext; } + /** + * Return the query context as an immutable object. Use this form + * when querying the context as it provides type-safe accessors. + */ + public QueryContext queryContext() + { + return QueryContext.of(queryContext); + } + + /** + * Returns the query context keys set by the user. (Actually, set by + * the request made on behalf of the user, which may include options set by + * intermediary services outside of Druid.) + */ + public Set queryContextKeys() + { + return contextKeys; + } + public boolean isStringifyArrays() { return stringifyArrays; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java index d8852a87a439..d0606b2f34e4 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java @@ -38,8 +38,6 @@ import org.apache.calcite.tools.ValidationException; import org.apache.druid.guice.annotations.Json; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.query.QueryContext; -import org.apache.druid.query.QueryContexts; import org.apache.druid.server.security.Access; import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.server.security.NoopEscalator; @@ -49,7 +47,9 @@ import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; import org.apache.druid.sql.calcite.schema.DruidSchemaName; +import java.util.Map; import java.util.Properties; +import java.util.Set; public class PlannerFactory { @@ -97,7 +97,12 @@ public PlannerFactory( /** * Create a Druid query planner from an initial query context */ - public DruidPlanner createPlanner(final SqlEngine engine, final String sql, final QueryContext queryContext) + public DruidPlanner createPlanner( + final SqlEngine engine, + final String sql, + final Map queryContext, + Set contextKeys + ) { final PlannerContext context = PlannerContext.create( sql, @@ -107,7 +112,8 @@ public DruidPlanner createPlanner(final SqlEngine engine, final String sql, fina plannerConfig, rootSchema, engine, - queryContext + queryContext, + contextKeys ); return new DruidPlanner(buildFrameworkConfig(context), context, engine); @@ -118,9 +124,9 @@ public DruidPlanner createPlanner(final SqlEngine engine, final String sql, fina * and ready to go authorization result. */ @VisibleForTesting - public DruidPlanner createPlannerForTesting(final SqlEngine engine, final String sql, final QueryContext queryContext) + public DruidPlanner createPlannerForTesting(final SqlEngine engine, final String sql, final Map queryContext) { - final DruidPlanner thePlanner = createPlanner(engine, sql, queryContext); + final DruidPlanner thePlanner = createPlanner(engine, sql, queryContext, queryContext.keySet()); thePlanner.getPlannerContext() .setAuthenticationResult(NoopEscalator.getInstance().createEscalatedAuthenticationResult()); try { @@ -146,7 +152,7 @@ private FrameworkConfig buildFrameworkConfig(PlannerContext plannerContext) .withDecorrelationEnabled(false) .withTrimUnusedFields(false) .withInSubQueryThreshold( - QueryContexts.getInSubQueryThreshold(plannerContext.getQueryContext().getMergedParams()) + plannerContext.queryContext().getInSubQueryThreshold() ) .build(); return Frameworks diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java index 1d6a71b54271..eb1bd43ff625 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java @@ -60,7 +60,6 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.emitter.EmittingLogger; import org.apache.druid.query.Query; -import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.server.QueryResponse; import org.apache.druid.server.security.Action; import org.apache.druid.server.security.Resource; @@ -537,8 +536,7 @@ protected PlannerResult planWithDruidConvention() throws ValidationException @Nullable private RelRoot possiblyWrapRootWithOuterLimitFromContext(RelRoot root) { - Object outerLimitObj = handlerContext.queryContext().get(PlannerContext.CTX_SQL_OUTER_LIMIT); - Long outerLimit = DimensionHandlerUtils.convertObjectToLong(outerLimitObj, true); + Long outerLimit = handlerContext.queryContext().getLong(PlannerContext.CTX_SQL_OUTER_LIMIT); if (outerLimit == null) { return root; } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlStatementHandler.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlStatementHandler.java index fa8c4fdb17e8..9185b9862f33 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlStatementHandler.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlStatementHandler.java @@ -28,6 +28,7 @@ import org.apache.druid.sql.calcite.run.SqlEngine; import org.joda.time.DateTimeZone; +import java.util.Map; import java.util.Set; /** @@ -52,6 +53,7 @@ interface HandlerContext SqlEngine engine(); CalcitePlanner planner(); QueryContext queryContext(); + Map queryContextMap(); SchemaPlus defaultSchema(); ObjectMapper jsonMapper(); DateTimeZone timeZone(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java index fcf9fb754d8a..44b0786cb69b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java @@ -49,7 +49,6 @@ import org.apache.druid.query.DataSource; import org.apache.druid.query.JoinDataSource; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; @@ -94,6 +93,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; + import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -894,7 +894,7 @@ private TimeBoundaryQuery toTimeBoundaryQuery() final DataSource newDataSource = dataSourceFiltrationPair.lhs; final Filtration filtration = dataSourceFiltrationPair.rhs; String bound = minTime ? TimeBoundaryQuery.MIN_TIME : TimeBoundaryQuery.MAX_TIME; - HashMap context = new HashMap<>(plannerContext.getQueryContext().getMergedParams()); + Map context = new HashMap<>(plannerContext.queryContextMap()); if (minTime) { context.put(TimeBoundaryQuery.MIN_TIME_ARRAY_OUTPUT_NAME, aggregatorFactory.getName()); } else { @@ -994,7 +994,7 @@ private TimeseriesQuery toTimeseriesQuery() if (!Granularities.ALL.equals(queryGranularity) || grouping.hasGroupingDimensionsDropped()) { theContext.put(TimeseriesQuery.SKIP_EMPTY_BUCKETS, true); } - theContext.putAll(plannerContext.getQueryContext().getMergedParams()); + theContext.putAll(plannerContext.queryContextMap()); final Pair dataSourceFiltrationPair = getFiltration( dataSource, @@ -1114,7 +1114,7 @@ private TopNQuery toTopNQuery() Granularities.ALL, grouping.getAggregatorFactories(), postAggregators, - ImmutableSortedMap.copyOf(plannerContext.getQueryContext().getMergedParams()) + ImmutableSortedMap.copyOf(plannerContext.queryContextMap()) ); } @@ -1171,7 +1171,7 @@ private GroupByQuery toGroupByQuery() havingSpec, Optional.ofNullable(sorting).orElse(Sorting.none()).limitSpec(), grouping.getSubtotals().toSubtotalsSpec(grouping.getDimensionSpecs()), - ImmutableSortedMap.copyOf(plannerContext.getQueryContext().getMergedParams()) + ImmutableSortedMap.copyOf(plannerContext.queryContextMap()) ); // We don't apply timestamp computation optimization yet when limit is pushed down. Maybe someday. if (query.getLimitSpec() instanceof DefaultLimitSpec && query.isApplyLimitPushDown()) { @@ -1332,8 +1332,8 @@ private ScanQuery toScanQuery() withScanSignatureIfNeeded( virtualColumns, scanColumnsList, - plannerContext.getQueryContext() - ).getMergedParams() + plannerContext.queryContextMap() + ) ); } @@ -1341,43 +1341,42 @@ private ScanQuery toScanQuery() * Returns a copy of "queryContext" with {@link #CTX_SCAN_SIGNATURE} added if the execution context has the * {@link EngineFeature#SCAN_NEEDS_SIGNATURE} feature. */ - private QueryContext withScanSignatureIfNeeded( + private Map withScanSignatureIfNeeded( final VirtualColumns virtualColumns, final List scanColumns, - final QueryContext queryContext + final Map queryContext ) { - if (plannerContext.engineHasFeature(EngineFeature.SCAN_NEEDS_SIGNATURE)) { - // Compute the signature of the columns that we are selecting. - final RowSignature.Builder scanSignatureBuilder = RowSignature.builder(); + if (!plannerContext.engineHasFeature(EngineFeature.SCAN_NEEDS_SIGNATURE)) { + return queryContext; + } + // Compute the signature of the columns that we are selecting. + final RowSignature.Builder scanSignatureBuilder = RowSignature.builder(); - for (final String columnName : scanColumns) { - final ColumnCapabilities capabilities = - virtualColumns.getColumnCapabilitiesWithFallback(sourceRowSignature, columnName); + for (final String columnName : scanColumns) { + final ColumnCapabilities capabilities = + virtualColumns.getColumnCapabilitiesWithFallback(sourceRowSignature, columnName); - if (capabilities == null) { - // No type for this column. This is a planner bug. - throw new ISE("No type for column [%s]", columnName); - } - - scanSignatureBuilder.add(columnName, capabilities.toColumnType()); + if (capabilities == null) { + // No type for this column. This is a planner bug. + throw new ISE("No type for column [%s]", columnName); } - final RowSignature signature = scanSignatureBuilder.build(); + scanSignatureBuilder.add(columnName, capabilities.toColumnType()); + } - try { - final QueryContext newContext = queryContext.copy(); - newContext.addSystemParam( - CTX_SCAN_SIGNATURE, - plannerContext.getJsonMapper().writeValueAsString(signature) - ); - return newContext; - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } else { - return queryContext; + final RowSignature signature = scanSignatureBuilder.build(); + + try { + Map revised = new HashMap<>(queryContext); + revised.put( + CTX_SCAN_SIGNATURE, + plannerContext.getJsonMapper().writeValueAsString(signature) + ); + return revised; + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); } } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java index c114fc171ea9..2ce9019a5173 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java @@ -73,7 +73,7 @@ private DruidJoinRule(final PlannerContext plannerContext) operand(DruidRel.class, any()) ) ); - this.enableLeftScanDirect = plannerContext.getQueryContext().isEnableJoinLeftScanDirect(); + this.enableLeftScanDirect = plannerContext.queryContext().getEnableJoinLeftScanDirect(); this.plannerContext = plannerContext; } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java index e207a4d3c237..42bd83b0de5b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java @@ -28,8 +28,6 @@ import org.apache.calcite.tools.ValidationException; import org.apache.druid.guice.LazySingleton; import org.apache.druid.java.util.common.IAE; -import org.apache.druid.query.QueryContext; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.timeboundary.TimeBoundaryQuery; import org.apache.druid.server.QueryLifecycleFactory; @@ -38,6 +36,7 @@ import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.DruidQuery; +import java.util.Map; import java.util.Set; @LazySingleton @@ -76,7 +75,7 @@ public String name() } @Override - public void validateContext(QueryContext queryContext) throws ValidationException + public void validateContext(Map queryContext) throws ValidationException { SqlEngines.validateNoSpecialContextKeys(queryContext, SYSTEM_CONTEXT_PARAMETERS); } @@ -103,7 +102,7 @@ public boolean feature(EngineFeature feature, PlannerContext plannerContext) case TOPN_QUERY: return true; case TIME_BOUNDARY_QUERY: - return QueryContexts.isTimeBoundaryPlanningEnabled(plannerContext.getQueryContext().getMergedParams()); + return plannerContext.queryContext().isTimeBoundaryPlanningEnabled(); case CAN_INSERT: case CAN_REPLACE: case READ_EXTERNAL_DATA: diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java index 896cdd8db3e8..2734cd09b7fd 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java @@ -23,9 +23,10 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.tools.ValidationException; -import org.apache.druid.query.QueryContext; import org.apache.druid.sql.calcite.planner.PlannerContext; +import java.util.Map; + /** * Engine for running SQL queries. */ @@ -45,7 +46,7 @@ public interface SqlEngine * Validates a provided query context. Returns quietly if the context is OK; throws {@link ValidationException} * if the context has a problem. */ - void validateContext(QueryContext queryContext) throws ValidationException; + void validateContext(Map queryContext) throws ValidationException; /** * SQL row type that would be emitted by the {@link QueryMaker} from {@link #buildQueryMakerForSelect}. diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngines.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngines.java index f01a4714f437..30dd7926bd20 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngines.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngines.java @@ -21,8 +21,8 @@ import org.apache.calcite.tools.ValidationException; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.query.QueryContext; +import java.util.Map; import java.util.Set; public class SqlEngines @@ -35,10 +35,10 @@ public class SqlEngines * * This is a helper function used by {@link SqlEngine#validateContext} implementations. */ - public static void validateNoSpecialContextKeys(final QueryContext queryContext, final Set specialContextKeys) + public static void validateNoSpecialContextKeys(final Map queryContext, final Set specialContextKeys) throws ValidationException { - for (String contextParameterName : queryContext.getMergedParams().keySet()) { + for (String contextParameterName : queryContext.keySet()) { if (specialContextKeys.contains(contextParameterName)) { throw new ValidationException( StringUtils.format( diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java b/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java index 7703427d3732..86240ee884f7 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java @@ -28,11 +28,11 @@ import org.apache.calcite.schema.TableMacro; import org.apache.calcite.schema.TranslatableTable; import org.apache.calcite.schema.impl.ViewTable; -import org.apache.druid.query.QueryContext; import org.apache.druid.sql.calcite.planner.DruidPlanner; import org.apache.druid.sql.calcite.planner.PlannerFactory; import org.apache.druid.sql.calcite.schema.DruidSchemaName; +import java.util.Collections; import java.util.List; public class DruidViewMacro implements TableMacro @@ -58,7 +58,11 @@ public TranslatableTable apply(final List arguments) { final RelDataType rowType; try (final DruidPlanner planner = - plannerFactory.createPlanner(ViewSqlEngine.INSTANCE, viewSql, new QueryContext())) { + plannerFactory.createPlanner( + ViewSqlEngine.INSTANCE, + viewSql, + Collections.emptyMap(), + Collections.emptySet())) { planner.validate(); rowType = planner.prepare().getValidatedRowType(); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java index 06395c7f5e6e..3bc1acdce8a2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java @@ -23,12 +23,13 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.druid.java.util.common.IAE; -import org.apache.druid.query.QueryContext; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.run.EngineFeature; import org.apache.druid.sql.calcite.run.QueryMaker; import org.apache.druid.sql.calcite.run.SqlEngine; +import java.util.Map; + /** * Engine used for getting the row type of views. Does not do any actual planning or execution of the view. */ @@ -79,7 +80,7 @@ public boolean feature(EngineFeature feature, PlannerContext plannerContext) } @Override - public void validateContext(QueryContext queryContext) + public void validateContext(Map queryContext) { // No query context validation for view row typing. } diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlQuery.java b/sql/src/main/java/org/apache/druid/sql/http/SqlQuery.java index 242df5c68b0d..541f6d9fe65b 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlQuery.java @@ -27,6 +27,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.calcite.avatica.remote.TypedValue; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.query.QueryContext; import java.util.List; import java.util.Map; @@ -133,6 +134,11 @@ public Map getContext() return context; } + public QueryContext queryContext() + { + return QueryContext.of(context); + } + @JsonProperty public List getParameters() { diff --git a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java index 0653c533a14e..f95e9e9326cb 100644 --- a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java @@ -493,7 +493,7 @@ public void testIgnoredQueryContextParametersAreIgnored() .auth(CalciteTests.REGULAR_USER_AUTH_RESULT) .build(); DirectStatement stmt = sqlStatementFactory.directStatement(sqlReq); - Map context = stmt.query().context().getMergedParams(); + Map context = stmt.context(); Assert.assertEquals(2, context.size()); // should contain only query id, not bySegment since it is not valid for SQL Assert.assertTrue(context.containsKey(PlannerContext.CTX_SQL_QUERY_ID)); @@ -508,7 +508,7 @@ public void testDefaultQueryContextIsApplied() .auth(CalciteTests.REGULAR_USER_AUTH_RESULT) .build(); DirectStatement stmt = sqlStatementFactory.directStatement(sqlReq); - Map context = stmt.query().context().getMergedParams(); + Map context = stmt.context(); Assert.assertEquals(2, context.size()); // Statement should contain default query context values for (String defaultContextKey : defaultQueryConfig.getContext().keySet()) { diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java index 2483dd663eca..eba660d3c0c1 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java @@ -29,7 +29,6 @@ import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.server.QueryStackTests; import org.apache.druid.server.security.AllowAllAuthenticator; @@ -139,7 +138,7 @@ private DruidJdbcStatement jdbcStatement() return new DruidJdbcStatement( "", 0, - new QueryContext(), + Collections.emptyMap(), sqlStatementFactory ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index d614c23d5d22..8b4c91dbebd1 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -4915,7 +4915,7 @@ public void testInnerJoinWithFilterPushdownAndManyFiltersEmptyResults(Map queryContext) { // No validation. } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java b/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java index a8c6132ddb70..30692b80c85d 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java @@ -25,7 +25,6 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.java.util.common.IAE; -import org.apache.druid.query.QueryContext; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.run.EngineFeature; @@ -33,6 +32,8 @@ import org.apache.druid.sql.calcite.run.SqlEngine; import org.apache.druid.sql.calcite.table.RowSignatures; +import java.util.Map; + public class IngestionTestSqlEngine implements SqlEngine { public static final IngestionTestSqlEngine INSTANCE = new IngestionTestSqlEngine(); @@ -48,7 +49,7 @@ public String name() } @Override - public void validateContext(QueryContext queryContext) + public void validateContext(Map queryContext) { // No validation. } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/SqlVectorizedExpressionSanityTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/SqlVectorizedExpressionSanityTest.java index f49bb0c40ea9..5124ce64edb0 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/SqlVectorizedExpressionSanityTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/SqlVectorizedExpressionSanityTest.java @@ -31,7 +31,6 @@ import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.math.expr.ExpressionProcessing; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.segment.QueryableIndex; @@ -60,8 +59,10 @@ import org.junit.runners.Parameterized; import javax.annotation.Nullable; + import java.io.IOException; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; @RunWith(Parameterized.class) @@ -181,17 +182,13 @@ public void testQuery() throws ValidationException public static void sanityTestVectorizedSqlQueries(PlannerFactory plannerFactory, String query) throws ValidationException { - final QueryContext vector = new QueryContext( - ImmutableMap.of( + final Map vector = ImmutableMap.of( QueryContexts.VECTORIZE_KEY, "force", QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, "force" - ) ); - final QueryContext nonvector = new QueryContext( - ImmutableMap.of( + final Map nonvector = ImmutableMap.of( QueryContexts.VECTORIZE_KEY, "false", QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, "false" - ) ); try ( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java index 539e585e5fd8..e6f7d1519d13 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java @@ -34,7 +34,6 @@ import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.ValueMatcher; @@ -63,6 +62,7 @@ import org.junit.Assert; import javax.annotation.Nullable; + import java.math.BigDecimal; import java.util.Arrays; import java.util.Collections; @@ -88,7 +88,8 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) ) ), null /* Don't need engine */, - new QueryContext() + Collections.emptyMap(), + Collections.emptySet() ); private final RowSignature rowSignature; diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java index d536cbcbbb14..1a84d2080e51 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java @@ -23,7 +23,6 @@ import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.schema.SchemaPlus; -import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.sql.calcite.planner.PlannerConfig; @@ -39,6 +38,8 @@ import org.junit.Assert; import org.junit.Test; +import java.util.Collections; + public class ExternalTableScanRuleTest { @Test @@ -62,7 +63,8 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) ) ), engine, - new QueryContext() + Collections.emptyMap(), + Collections.emptySet() ); plannerContext.setQueryMaker( engine.buildQueryMakerForSelect(EasyMock.createMock(RelRoot.class), plannerContext) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java index 115b7c0237c2..37410727fea6 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java @@ -34,7 +34,6 @@ import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.jackson.JacksonModule; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.query.QueryContext; import org.apache.druid.server.QueryLifecycleFactory; import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.server.security.ResourceType; @@ -55,6 +54,8 @@ import javax.validation.Validation; import javax.validation.Validator; + +import java.util.Collections; import java.util.Set; import static org.apache.calcite.plan.RelOptRule.any; @@ -174,7 +175,8 @@ public void testExtensionCalciteRule() injector.getInstance(PlannerConfig.class), rootSchema, null, - new QueryContext() + Collections.emptyMap(), + Collections.emptySet() ); boolean containsCustomRule = injector.getInstance(CalciteRulesManager.class) .druidConventionRuleSet(context) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java index 6d910384e040..216634796b68 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java @@ -38,7 +38,6 @@ import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.query.QueryContext; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DirectOperatorConversion; @@ -97,7 +96,8 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) ) ), null /* Don't need an engine */, - new QueryContext() + Collections.emptyMap(), + Collections.emptySet() ); private final RexBuilder rexBuilder = new RexBuilder(new JavaTypeFactoryImpl()); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java index db6436608928..c09b74b5355e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java @@ -39,6 +39,7 @@ import java.math.BigDecimal; import java.util.List; +import java.util.Map; public class DruidJoinRuleTest { @@ -60,14 +61,15 @@ public class DruidJoinRuleTest ), ImmutableList.of("left", "right") ); - + private DruidJoinRule druidJoinRule; @Before public void setup() { PlannerContext plannerContext = Mockito.mock(PlannerContext.class); - Mockito.when(plannerContext.getQueryContext()).thenReturn(Mockito.mock(QueryContext.class)); + Mockito.when(plannerContext.getQueryContext()).thenReturn(Mockito.mock(Map.class)); + Mockito.when(plannerContext.queryContext()).thenReturn(QueryContext.empty()); druidJoinRule = DruidJoinRule.instance(plannerContext); } From 07ad674736f8267e5eef446637b52eb95ffbcca4 Mon Sep 17 00:00:00 2001 From: Paul Rogers Date: Fri, 30 Sep 2022 13:16:30 -0700 Subject: [PATCH 2/9] Revisions from review comments --- .../src/main/java/org/apache/druid/query/Query.java | 10 ---------- .../java/org/apache/druid/server/QueryLifecycle.java | 8 +++++--- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/Query.java b/processing/src/main/java/org/apache/druid/query/Query.java index 9e20ee653828..75c2601aa80c 100644 --- a/processing/src/main/java/org/apache/druid/query/Query.java +++ b/processing/src/main/java/org/apache/druid/query/Query.java @@ -134,16 +134,6 @@ default ContextType getContextValue(String key) return (ContextType) context().get(key); } - /** - * @deprecated use {@code queryContext().get(defaultValue)} instead - */ - @SuppressWarnings("unchecked") - @Deprecated - default ContextType getContextValue(String key, ContextType defaultValue) - { - return (ContextType) context().get(key, defaultValue); - } - /** * @deprecated use {@code queryContext().getBoolean()} instead. */ diff --git a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java index 6de72014307a..bf7460fa4699 100644 --- a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java +++ b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java @@ -64,8 +64,10 @@ import javax.servlet.http.HttpServletRequest; import java.util.Collections; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -104,7 +106,7 @@ public class QueryLifecycle @MonotonicNonNull private Query baseQuery; @MonotonicNonNull - private Map userContext; + private Set userContextKeys; public QueryLifecycle( final QueryToolChestWarehouse warehouse, @@ -198,7 +200,7 @@ public void initialize(final Query baseQuery) { transition(State.NEW, State.INITIALIZED); - userContext = baseQuery.getContext(); + userContextKeys = new HashSet<>(baseQuery.getContext().keySet()); String queryId = baseQuery.getId(); if (Strings.isNullOrEmpty(queryId)) { queryId = UUID.randomUUID().toString(); @@ -228,7 +230,7 @@ public Access authorize(HttpServletRequest req) ), authConfig.authorizeQueryContextParams() ? Iterables.transform( - userContext.keySet(), + userContextKeys, contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE) ) : Collections.emptyList() From 3c600ce880726742d51fab2acbb1cf8ae3b05641 Mon Sep 17 00:00:00 2001 From: Paul Rogers Date: Mon, 3 Oct 2022 16:22:31 -0700 Subject: [PATCH 3/9] Revisions from review comments Decimal parsing for long parameters Streamline context access in queries --- .../druid/msq/sql/MSQTaskQueryMaker.java | 8 +- .../druid/msq/sql/MSQTaskSqlEngine.java | 6 +- .../org/apache/druid/query/BaseQuery.java | 17 +- .../java/org/apache/druid/query/Query.java | 3 + .../org/apache/druid/query/QueryContext.java | 9 +- .../org/apache/druid/query/QueryContexts.java | 57 ++++- .../druid/query/scan/ScanQueryEngine.java | 1 - .../search/SearchQueryQueryToolChest.java | 1 - .../timeseries/TimeseriesQueryEngine.java | 1 - .../apache/druid/query/QueryContextTest.java | 222 +++++++++++++++--- .../server/ClientQuerySegmentWalker.java | 1 - .../apache/druid/server/QueryLifecycle.java | 1 - .../apache/druid/server/QueryResource.java | 16 +- .../sql/calcite/rule/DruidJoinRuleTest.java | 2 - 14 files changed, 274 insertions(+), 71 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java index d1a7a80991f3..523a75d65974 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java @@ -114,7 +114,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) QueryContext queryContext = plannerContext.queryContext(); String msqMode = MultiStageQueryContext.getMSQMode(queryContext); if (msqMode != null) { - MSQMode.populateDefaultQueryContext(msqMode, plannerContext.getQueryContext()); + MSQMode.populateDefaultQueryContext(msqMode, plannerContext.queryContextMap()); } final String ctxDestination = @@ -122,7 +122,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) Object segmentGranularity; try { - segmentGranularity = Optional.ofNullable(plannerContext.getQueryContext() + segmentGranularity = Optional.ofNullable(plannerContext.queryContext() .get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) .orElse(jsonMapper.writeValueAsString(DEFAULT_SEGMENT_GRANULARITY)); } @@ -154,7 +154,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(queryContext); final List replaceTimeChunks = - Optional.ofNullable(plannerContext.getQueryContext().get(DruidSqlReplace.SQL_REPLACE_TIME_CHUNKS)) + Optional.ofNullable(plannerContext.queryContext().get(DruidSqlReplace.SQL_REPLACE_TIME_CHUNKS)) .map( s -> { if (s instanceof String && "all".equals(StringUtils.toLowerCase((String) s))) { @@ -256,7 +256,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) taskId, querySpec, plannerContext.getSql(), - plannerContext.getQueryContext(), + plannerContext.queryContextMap(), sqlTypeNames, null ); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java index 8c9caee43359..a91844114dda 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java @@ -166,7 +166,7 @@ private static void validateSelect( { validateNoDuplicateAliases(fieldMappings); - if (plannerContext.getQueryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) { + if (plannerContext.queryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) { throw new ValidationException( StringUtils.format("Cannot use \"%s\" without INSERT", DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY) ); @@ -207,14 +207,14 @@ private static void validateInsert( try { segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext( - plannerContext.getQueryContext() + plannerContext.queryContextMap() ); } catch (Exception e) { throw new ValidationException( StringUtils.format( "Invalid segmentGranularity: %s", - plannerContext.getQueryContext().get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY) + plannerContext.queryContext().get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY) ), e ); diff --git a/processing/src/main/java/org/apache/druid/query/BaseQuery.java b/processing/src/main/java/org/apache/druid/query/BaseQuery.java index 88a59781d946..6158632a1547 100644 --- a/processing/src/main/java/org/apache/druid/query/BaseQuery.java +++ b/processing/src/main/java/org/apache/druid/query/BaseQuery.java @@ -37,8 +37,6 @@ import org.joda.time.Interval; import javax.annotation.Nullable; - -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; @@ -61,7 +59,7 @@ public static void checkInterrupted() public static final String SQL_QUERY_ID = "sqlQueryId"; private final DataSource dataSource; private final boolean descending; - private final Map context; + private final QueryContext context; private final QuerySegmentSpec querySegmentSpec; private volatile Duration duration; private final Granularity granularity; @@ -89,10 +87,7 @@ public BaseQuery( Preconditions.checkNotNull(granularity, "Must specify a granularity"); this.dataSource = dataSource; - // There is no semantic difference between an empty and a null context. - // Ensure that a context always exists to avoid the need to check for - // a null context. Jackson serialization will omit empty contexts. - this.context = context == null ? Collections.emptyMap() : context; + this.context = QueryContext.of(context); this.querySegmentSpec = querySegmentSpec; this.descending = descending; this.granularity = granularity; @@ -175,7 +170,13 @@ public DateTimeZone getTimezone() @JsonInclude(Include.NON_DEFAULT) public Map getContext() { - return context == null ? Collections.emptyMap() : context; + return context.asMap(); + } + + @Override + public QueryContext context() + { + return context; } /** diff --git a/processing/src/main/java/org/apache/druid/query/Query.java b/processing/src/main/java/org/apache/druid/query/Query.java index 75c2601aa80c..9d38dbe37169 100644 --- a/processing/src/main/java/org/apache/druid/query/Query.java +++ b/processing/src/main/java/org/apache/druid/query/Query.java @@ -105,6 +105,9 @@ public interface Query * Returns the query context as a {@link QueryContext}, which provides * convenience methods for accessing typed context values. The returned * instance is a view on top of the context provided by {@link #getContext()}. + *

+ * The default implementation is for backward compatibility. Derived classes should + * store and return the {@link QueryContext} directly. */ default QueryContext context() { diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java b/processing/src/main/java/org/apache/druid/query/QueryContext.java index 93e5dcf23c77..fffca9d62515 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContext.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java @@ -42,6 +42,9 @@ public class QueryContext public QueryContext(Map context) { + // There is no semantic difference between an empty and a null context. + // Ensure that a context always exists to avoid the need to check for + // a null context. Jackson serialization will omit empty contexts. this.context = context == null ? Collections.emptyMap() : Collections.unmodifiableMap(context); } @@ -52,7 +55,7 @@ public static QueryContext empty() public static QueryContext of(Map context) { - return new QueryContext(context == null ? Collections.emptyMap() : context); + return new QueryContext(context); } public boolean isEmpty() @@ -60,7 +63,7 @@ public boolean isEmpty() return context.isEmpty(); } - public Map getContext() + public Map asMap() { return context; } @@ -191,7 +194,7 @@ public Float getFloat(final String key) } /** - * Return a value as an {@code long}, returning the default value if the + * Return a value as an {@code float}, returning the default value if the * context value is not set. * * @throws BadQueryContextException for an invalid value diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java index 2aec4ea04484..c7e2ed3e8427 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java @@ -29,7 +29,7 @@ import org.apache.druid.java.util.common.StringUtils; import javax.annotation.Nullable; - +import java.math.BigDecimal; import java.util.Map; import java.util.Map.Entry; import java.util.TreeMap; @@ -243,7 +243,16 @@ public static Integer getAsInt(String key, Object value) return Numbers.parseInt(value); } catch (NumberFormatException ignored) { - throw badValueException(key, "in integer format", value); + + // Attempt to handle trivial decimal values: 12.00, etc. + // This mimics how Jackson will convert "12.00" to a Integer on request. + try { + return new BigDecimal((String) value).intValueExact(); + } + catch (Exception nfe) { + // That didn't work either. Give up. + throw badValueException(key, "in integer format", value); + } } } @@ -276,7 +285,16 @@ public static Long getAsLong(String key, Object value) return Numbers.parseLong(value); } catch (NumberFormatException ignored) { - throw badValueException(key, "in long format", value); + + // Attempt to handle trivial decimal values: 12.00, etc. + // This mimics how Jackson will convert "12.00" to a Long on request. + try { + return new BigDecimal((String) value).longValueExact(); + } + catch (Exception nfe) { + // That didn't work either. Give up. + throw badValueException(key, "in long format", value); + } } } throw badTypeException(key, "a Long", value); @@ -349,6 +367,39 @@ public static HumanReadableBytes getAsHumanReadableBytes( throw badTypeException(key, "a human readable number", value); } + /** + * Insert, update or remove a single key to produce an overridden context. + * Leaves the original context unchanged. + * + * @param context context to override + * @param key key to insert, update or remove + * @param value if {@code null}, remove the key. Otherwise, inert or replace + * the key. + * @return a new context map + */ + public static Map override( + final Map context, + final String key, + final Object value + ) + { + Map overridden = new TreeMap<>(); + if (value == null) { + overridden.remove(key); + } else { + overridden.put(key, value); + } + return overridden; + } + + /** + * Insert or replace multiple keys to produce an overridden context. + * Leaves the original context unchanged. + * + * @param context context to override + * @param overrides map of values to insert or replace + * @return a new context map + */ public static Map override( final Map context, final Map overrides diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java index 56b9087793c7..82a33962e7c5 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java @@ -45,7 +45,6 @@ import org.joda.time.Interval; import javax.annotation.Nullable; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; diff --git a/processing/src/main/java/org/apache/druid/query/search/SearchQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/search/SearchQueryQueryToolChest.java index 06459d7073a2..da59cf9c7f59 100644 --- a/processing/src/main/java/org/apache/druid/query/search/SearchQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/search/SearchQueryQueryToolChest.java @@ -45,7 +45,6 @@ import org.apache.druid.query.dimension.DimensionSpec; import javax.annotation.Nullable; - import java.util.Collections; import java.util.Comparator; import java.util.HashMap; diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java index afb23523e5b4..7ae290dd7d48 100644 --- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java +++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQueryEngine.java @@ -48,7 +48,6 @@ import org.joda.time.Interval; import javax.annotation.Nullable; - import java.nio.ByteBuffer; import java.util.Collections; import java.util.List; diff --git a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java index 74be113e1b7a..ebdbded3a724 100644 --- a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java +++ b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java @@ -19,6 +19,10 @@ package org.apache.druid.query; +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.exc.MismatchedInputException; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Ordering; import nl.jqno.equalsverifier.EqualsVerifier; @@ -30,20 +34,30 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.spec.QuerySegmentSpec; +import org.apache.druid.segment.DimensionHandlerUtils; import org.joda.time.DateTimeZone; import org.joda.time.Duration; import org.joda.time.Interval; -import org.junit.Assert; import org.junit.Test; import javax.annotation.Nullable; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + public class QueryContextTest { + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); + @Test public void testEquals() { @@ -55,18 +69,40 @@ public void testEquals() .verify(); } + /** + * Verify that a context with an null map is the same as a context with + * an empty map. + */ @Test - public void testEmptyParam() + public void testEmptyContext() { - final QueryContext context = QueryContext.empty(); - Assert.assertEquals(ImmutableMap.of(), context.getContext()); + { + final QueryContext context = new QueryContext(null); + assertEquals(ImmutableMap.of(), context.asMap()); + } + { + final QueryContext context = new QueryContext(new HashMap<>()); + assertEquals(ImmutableMap.of(), context.asMap()); + } + { + final QueryContext context = QueryContext.of(null); + assertEquals(ImmutableMap.of(), context.asMap()); + } + { + final QueryContext context = QueryContext.of(new HashMap<>()); + assertEquals(ImmutableMap.of(), context.asMap()); + } + { + final QueryContext context = QueryContext.empty(); + assertEquals(ImmutableMap.of(), context.asMap()); + } } @Test public void testIsEmpty() { - Assert.assertTrue(QueryContext.empty().isEmpty()); - Assert.assertFalse(QueryContext.of(ImmutableMap.of("k", "v")).isEmpty()); + assertTrue(QueryContext.empty().isEmpty()); + assertFalse(QueryContext.of(ImmutableMap.of("k", "v")).isEmpty()); } @Test @@ -77,12 +113,12 @@ public void testGetString() "key2", 2) ); - Assert.assertEquals("val", context.get("key")); - Assert.assertEquals("val", context.getString("key")); - Assert.assertNull(context.getString("non-exist")); - Assert.assertEquals("foo", context.getString("non-exist", "foo")); + assertEquals("val", context.get("key")); + assertEquals("val", context.getString("key")); + assertNull(context.getString("non-exist")); + assertEquals("foo", context.getString("non-exist", "foo")); - Assert.assertThrows(BadQueryContextException.class, () -> context.getString("key2")); + assertThrows(BadQueryContextException.class, () -> context.getString("key2")); } @Test @@ -95,11 +131,11 @@ public void testGetBoolean() ) ); - Assert.assertTrue(context.getBoolean("key1", false)); - Assert.assertTrue(context.getBoolean("key2", false)); - Assert.assertTrue(context.getBoolean("key1")); - Assert.assertFalse(context.getBoolean("non-exist", false)); - Assert.assertNull(context.getBoolean("non-exist")); + assertTrue(context.getBoolean("key1", false)); + assertTrue(context.getBoolean("key2", false)); + assertTrue(context.getBoolean("key1")); + assertFalse(context.getBoolean("non-exist", false)); + assertNull(context.getBoolean("non-exist")); } @Test @@ -113,11 +149,11 @@ public void testGetInt() ) ); - Assert.assertEquals(100, context.getInt("key1", 0)); - Assert.assertEquals(100, context.getInt("key2", 0)); - Assert.assertEquals(0, context.getInt("non-exist", 0)); + assertEquals(100, context.getInt("key1", 0)); + assertEquals(100, context.getInt("key2", 0)); + assertEquals(0, context.getInt("non-exist", 0)); - Assert.assertThrows(BadQueryContextException.class, () -> context.getInt("key3", 5)); + assertThrows(BadQueryContextException.class, () -> context.getInt("key3", 5)); } @Test @@ -131,11 +167,121 @@ public void testGetLong() ) ); - Assert.assertEquals(100L, context.getLong("key1", 0)); - Assert.assertEquals(100L, context.getLong("key2", 0)); - Assert.assertEquals(0L, context.getLong("non-exist", 0)); + assertEquals(100L, context.getLong("key1", 0)); + assertEquals(100L, context.getLong("key2", 0)); + assertEquals(0L, context.getLong("non-exist", 0)); + + assertThrows(BadQueryContextException.class, () -> context.getLong("key3", 5)); + } - Assert.assertThrows(BadQueryContextException.class, () -> context.getLong("key3", 5)); + /** + * Tests the several ways that Druid code parses context strings into Long + * values. The desired behavior is that "x" is parsed exactly the same as Jackson + * would parse x (where x is a valid number.) The context methods must emulate + * Jackson. The dimension utility method is included because some code used that + * for long parsing, and we must maintain backward compatibility. + *

+ * The exceptions in the {@code assertThrows} are not critical: the key thing is + * that we're documenting what works and what doesn't. If an exception changes, + * just update the tests. If something no longer throws an exception, we'll want + * to verify that we support the new use case consistently in all three paths. + */ + @Test + public void testGetLongCompatibility() throws JsonProcessingException + { + { + String value = null; + + // Only the context methods allow {"foo": null} to be parsed as a null Long. + assertNull(getContextLong(value)); + // Nulls not legal on this path. + assertThrows(NullPointerException.class, () -> getDimensionLong(value)); + // Nulls not legal on this path. + assertThrows(IllegalArgumentException.class, () -> getJsonLong(value)); + } + + { + String value = ""; + // Blank string not legal on this path. + assertThrows(BadQueryContextException.class, () -> getContextLong(value)); + assertNull(getDimensionLong(value)); + // Blank string not allowed where a value is expected. + assertThrows(MismatchedInputException.class, () -> getJsonLong(value)); + } + + { + String value = "0"; + assertEquals(0L, (long) getContextLong(value)); + assertEquals(0L, (long) getDimensionLong(value)); + assertEquals(0L, (long) getJsonLong(value)); + } + + { + String value = "+1"; + assertEquals(1L, (long) getContextLong(value)); + assertEquals(1L, (long) getDimensionLong(value)); + assertThrows(JsonParseException.class, () -> getJsonLong(value)); + } + + { + String value = "-1"; + assertEquals(-1L, (long) getContextLong(value)); + assertEquals(-1L, (long) getDimensionLong(value)); + assertEquals(-1L, (long) getJsonLong(value)); + } + + { + // Hexadecimal numbers are not supported in JSON. Druid also does not support + // them in strings. + String value = "0xabcd"; + assertThrows(BadQueryContextException.class, () -> getContextLong(value)); + // The dimension utils have a funny way of handling hex: they return null + assertNull(getDimensionLong(value)); + assertThrows(JsonParseException.class, () -> getJsonLong(value)); + } + + { + // Leading zeros supported by Druid parsing, but not by JSON. + String value = "05"; + assertEquals(5L, (long) getContextLong(value)); + assertEquals(5L, (long) getDimensionLong(value)); + assertThrows(JsonParseException.class, () -> getJsonLong(value)); + } + + { + // The dimension utils allow a float where a long is expected. + // Jackson can do this conversion. This test verifies that the context + // functions can handle the same conversion. + String value = "10.00"; + assertEquals(10L, (long) getContextLong(value)); + assertEquals(10L, (long) getDimensionLong(value)); + assertEquals(10L, (long) getJsonLong(value)); + } + + { + // None of the conversion methods allow a (thousands) separator. The comma + // would be ambiguous in JSON. Java allows the underscore, but JSON does + // not support this syntax, and neither does Druid's string-to-long conversion. + String value = "1_234"; + assertThrows(BadQueryContextException.class, () -> getContextLong(value)); + assertNull(getDimensionLong(value)); + assertThrows(JsonParseException.class, () -> getJsonLong(value)); + } + } + + private static Long getContextLong(String value) + { + return QueryContexts.getAsLong("dummy", value); + } + + private static Long getJsonLong(String value) throws JsonProcessingException + { + return JSON_MAPPER.readValue(value, Long.class); + } + + private static Long getDimensionLong(String value) + { + return DimensionHandlerUtils.getExactLongFromDecimalString(value); } @Test @@ -150,11 +296,11 @@ public void testGetFloat() ) ); - Assert.assertEquals(0, Float.compare(500, context.getFloat("f1", 100))); - Assert.assertEquals(0, Float.compare(500, context.getFloat("f2", 100))); - Assert.assertEquals(0, Float.compare(500.1f, context.getFloat("f3", 100))); + assertEquals(0, Float.compare(500, context.getFloat("f1", 100))); + assertEquals(0, Float.compare(500, context.getFloat("f2", 100))); + assertEquals(0, Float.compare(500.1f, context.getFloat("f3", 100))); - Assert.assertThrows(BadQueryContextException.class, () -> context.getFloat("f4", 5)); + assertThrows(BadQueryContextException.class, () -> context.getFloat("f4", 5)); } @Test @@ -170,20 +316,20 @@ public void testGetHumanReadableBytes() .put("m6", "abc") .build() ); - Assert.assertEquals(500_000_000, context.getHumanReadableBytes("m1", HumanReadableBytes.ZERO).getBytes()); - Assert.assertEquals(500_000_000, context.getHumanReadableBytes("m2", HumanReadableBytes.ZERO).getBytes()); - Assert.assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m3", HumanReadableBytes.ZERO).getBytes()); - Assert.assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m4", HumanReadableBytes.ZERO).getBytes()); - Assert.assertEquals(500_000_000, context.getHumanReadableBytes("m5", HumanReadableBytes.ZERO).getBytes()); + assertEquals(500_000_000, context.getHumanReadableBytes("m1", HumanReadableBytes.ZERO).getBytes()); + assertEquals(500_000_000, context.getHumanReadableBytes("m2", HumanReadableBytes.ZERO).getBytes()); + assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m3", HumanReadableBytes.ZERO).getBytes()); + assertEquals(500 * 1024 * 1024L, context.getHumanReadableBytes("m4", HumanReadableBytes.ZERO).getBytes()); + assertEquals(500_000_000, context.getHumanReadableBytes("m5", HumanReadableBytes.ZERO).getBytes()); - Assert.assertThrows(BadQueryContextException.class, () -> context.getHumanReadableBytes("m6", HumanReadableBytes.ZERO)); + assertThrows(BadQueryContextException.class, () -> context.getHumanReadableBytes("m6", HumanReadableBytes.ZERO)); } @Test public void testDefaultEnableQueryDebugging() { - Assert.assertFalse(QueryContext.empty().isDebug()); - Assert.assertTrue(QueryContext.of(ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true)).isDebug()); + assertFalse(QueryContext.empty().isDebug()); + assertTrue(QueryContext.of(ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true)).isDebug()); } // This test is a bit silly. It is retained because another test uses the @@ -193,7 +339,7 @@ public void testLegacyReturnsLegacy() { Map context = ImmutableMap.of("foo", "bar"); Query legacy = new LegacyContextQuery(context); - Assert.assertEquals(context, legacy.getContext()); + assertEquals(context, legacy.getContext()); } @Test @@ -206,7 +352,7 @@ public void testNonLegacyIsNotLegacyContext() .aggregators(Collections.singletonList(new CountAggregatorFactory("theCount"))) .context(ImmutableMap.of("foo", "bar")) .build(); - Assert.assertNotNull(timeseries.getContext()); + assertNotNull(timeseries.getContext()); } public static class LegacyContextQuery implements Query diff --git a/server/src/main/java/org/apache/druid/server/ClientQuerySegmentWalker.java b/server/src/main/java/org/apache/druid/server/ClientQuerySegmentWalker.java index bc825d422202..d6fe03a31d63 100644 --- a/server/src/main/java/org/apache/druid/server/ClientQuerySegmentWalker.java +++ b/server/src/main/java/org/apache/druid/server/ClientQuerySegmentWalker.java @@ -59,7 +59,6 @@ import org.joda.time.Interval; import javax.annotation.Nullable; - import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; diff --git a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java index bf7460fa4699..0dc891b50ece 100644 --- a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java +++ b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java @@ -62,7 +62,6 @@ import javax.annotation.Nullable; import javax.servlet.http.HttpServletRequest; - import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; diff --git a/server/src/main/java/org/apache/druid/server/QueryResource.java b/server/src/main/java/org/apache/druid/server/QueryResource.java index a9178429cac4..743ca9e60ba7 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResource.java +++ b/server/src/main/java/org/apache/druid/server/QueryResource.java @@ -46,6 +46,7 @@ import org.apache.druid.query.BadQueryException; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; +import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryTimeoutException; @@ -77,7 +78,6 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.StreamingOutput; - import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -383,13 +383,19 @@ private Query readQuery( catch (JsonParseException e) { throw new BadJsonQueryException(e); } - String prevEtag = getPreviousEtag(req); - if (prevEtag != null) { - baseQuery.getContext().put(HEADER_IF_NONE_MATCH, prevEtag); + String prevEtag = getPreviousEtag(req); + if (prevEtag == null) { + return baseQuery; } - return baseQuery; + return baseQuery.withOverriddenContext( + QueryContexts.override( + baseQuery.getContext(), + HEADER_IF_NONE_MATCH, + prevEtag + ) + ); } private static String getPreviousEtag(final HttpServletRequest req) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java index c09b74b5355e..e4b037e99f98 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java @@ -39,7 +39,6 @@ import java.math.BigDecimal; import java.util.List; -import java.util.Map; public class DruidJoinRuleTest { @@ -68,7 +67,6 @@ public class DruidJoinRuleTest public void setup() { PlannerContext plannerContext = Mockito.mock(PlannerContext.class); - Mockito.when(plannerContext.getQueryContext()).thenReturn(Mockito.mock(Map.class)); Mockito.when(plannerContext.queryContext()).thenReturn(QueryContext.empty()); druidJoinRule = DruidJoinRule.instance(plannerContext); } From 70458c83b98c2df64311a09b822b309970136c7a Mon Sep 17 00:00:00 2001 From: Paul Rogers Date: Tue, 4 Oct 2022 09:20:58 -0700 Subject: [PATCH 4/9] Revisions from review comments --- .../org/apache/druid/query/QueryContext.java | 17 ++++++++++++++++- .../org/apache/druid/query/QueryContexts.java | 7 ++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java b/processing/src/main/java/org/apache/druid/query/QueryContext.java index fffca9d62515..0296679f1e3e 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContext.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java @@ -30,9 +30,24 @@ import java.util.Collections; import java.util.Map; import java.util.Objects; +import java.util.TreeMap; /** * Immutable holder for query context parameters with typed access methods. + * Code builds up a map of context values from serialization or during + * planning. Once that map is handed to the {@code QueryContext}, that map + * is effectively immutable. + *

+ * The implementation uses a {@link TreeMap} so that the serialized form of a query + * lists context values in a deterministic order. Jackson will call + * {@code getContext()} on the query, which will call {@link asMap()} here, + * which returns the sorted {@code TreeMap}. + *

+ * The {@code TreeMap} is a mutable class. We'd prefer an immutable class, but + * we can choose either ordering or immutability. Since the semantics of the context + * is that it is immutable once it is placed in a query. Code should NEVER get the + * context map from a query and modify it, even if the actual implementation + * allows it. */ public class QueryContext { @@ -45,7 +60,7 @@ public QueryContext(Map context) // There is no semantic difference between an empty and a null context. // Ensure that a context always exists to avoid the need to check for // a null context. Jackson serialization will omit empty contexts. - this.context = context == null ? Collections.emptyMap() : Collections.unmodifiableMap(context); + this.context = context == null ? Collections.emptyMap() : new TreeMap<>(context); } public static QueryContext empty() diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java index c7e2ed3e8427..d0c8c0eea66c 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java @@ -30,6 +30,7 @@ import javax.annotation.Nullable; import java.math.BigDecimal; +import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import java.util.TreeMap; @@ -373,7 +374,7 @@ public static HumanReadableBytes getAsHumanReadableBytes( * * @param context context to override * @param key key to insert, update or remove - * @param value if {@code null}, remove the key. Otherwise, inert or replace + * @param value if {@code null}, remove the key. Otherwise, insert or replace * the key. * @return a new context map */ @@ -383,7 +384,7 @@ public static Map override( final Object value ) { - Map overridden = new TreeMap<>(); + Map overridden = new HashMap<>(context); if (value == null) { overridden.remove(key); } else { @@ -405,7 +406,7 @@ public static Map override( final Map overrides ) { - Map overridden = new TreeMap<>(); + Map overridden = new HashMap<>(); if (context != null) { overridden.putAll(context); } From f95f35e54291eac54eeaed1abca5dc3f0d12f476 Mon Sep 17 00:00:00 2001 From: Paul Rogers Date: Wed, 5 Oct 2022 22:55:46 -0700 Subject: [PATCH 5/9] Build fixes Fix a flaky test (race conditions) --- .../org/apache/druid/query/QueryContext.java | 2 +- .../org/apache/druid/query/QueryContexts.java | 2 +- .../calcite/schema/SegmentMetadataCache.java | 2 +- .../schema/SegmentMetadataCacheTest.java | 139 ++++++++++++++---- 4 files changed, 110 insertions(+), 35 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java b/processing/src/main/java/org/apache/druid/query/QueryContext.java index 0296679f1e3e..6239731a6942 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContext.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java @@ -40,7 +40,7 @@ *

* The implementation uses a {@link TreeMap} so that the serialized form of a query * lists context values in a deterministic order. Jackson will call - * {@code getContext()} on the query, which will call {@link asMap()} here, + * {@code getContext()} on the query, which will call {@link #asMap()} here, * which returns the sorted {@code TreeMap}. *

* The {@code TreeMap} is a mutable class. We'd prefer an immutable class, but diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java index d0c8c0eea66c..976ab43bc074 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java @@ -29,11 +29,11 @@ import org.apache.druid.java.util.common.StringUtils; import javax.annotation.Nullable; + import java.math.BigDecimal; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; -import java.util.TreeMap; import java.util.concurrent.TimeUnit; @PublicApi diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCache.java b/sql/src/main/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCache.java index 5d6d386a61b7..9944c0639c23 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCache.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCache.java @@ -314,7 +314,7 @@ private void startCacheExec() } // lastFailure != 0L means exceptions happened before and there're some refresh work was not completed. - // so that even ServerView is initialized, we can't let broker complete initialization. + // so that even if ServerView is initialized, we can't let broker complete initialization. if (isServerViewInitialized && lastFailure == 0L) { // Server view is initialized, but we don't need to do a refresh. Could happen if there are // no segments in the system yet. Just mark us as initialized, then. diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCacheTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCacheTest.java index 1ba95b8bef0d..784c94fb1928 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCacheTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCacheTest.java @@ -20,6 +20,8 @@ package org.apache.druid.sql.calcite.schema; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -91,15 +93,15 @@ public class SegmentMetadataCacheTest extends SegmentMetadataCacheCommon { // Timeout to allow (rapid) debugging, while not blocking tests with errors. - private static final int WAIT_TIMEOUT_SECS = 60; + private static final int WAIT_TIMEOUT_SECS = 6; private SpecificSegmentsQuerySegmentWalker walker; private TestServerInventoryView serverView; private List druidServers; - private SegmentMetadataCache schema; - private SegmentMetadataCache schema2; + private SegmentMetadataCache runningSchema; private CountDownLatch buildTableLatch = new CountDownLatch(1); private CountDownLatch markDataSourceLatch = new CountDownLatch(1); + private CountDownLatch refreshLatch = new CountDownLatch(1); private static final ObjectMapper MAPPER = TestHelper.makeJsonMapper(); @Before @@ -177,8 +179,12 @@ public void setUp() throws Exception final List realtimeSegments = ImmutableList.of(segment1); serverView = new TestServerInventoryView(walker.getSegments(), realtimeSegments); druidServers = serverView.getDruidServers(); + } - schema = new SegmentMetadataCache( + public SegmentMetadataCache buildSchema1() throws InterruptedException + { + Preconditions.checkState(runningSchema == null); + runningSchema = new SegmentMetadataCache( CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate), serverView, segmentManager, @@ -207,7 +213,15 @@ void markDataSourceAsNeedRebuild(String datasource) } }; - schema2 = new SegmentMetadataCache( + runningSchema.start(); + runningSchema.awaitInitialization(); + return runningSchema; + } + + public SegmentMetadataCache buildSchema2() throws InterruptedException + { + Preconditions.checkState(runningSchema == null); + runningSchema = new SegmentMetadataCache( CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate), serverView, segmentManager, @@ -249,20 +263,61 @@ void markDataSourceAsNeedRebuild(String datasource) } }; - schema.start(); - schema.awaitInitialization(); + runningSchema.start(); + runningSchema.awaitInitialization(); + return runningSchema; + } + + public SegmentMetadataCache buildSchema3() throws InterruptedException + { + Preconditions.checkState(runningSchema == null); + runningSchema = new SegmentMetadataCache( + CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate), + serverView, + segmentManager, + new MapJoinableFactory( + ImmutableSet.of(globalTableJoinable), + ImmutableMap.of(globalTableJoinable.getClass(), GlobalTableDataSource.class) + ), + PLANNER_CONFIG_DEFAULT, + new NoopEscalator(), + new BrokerInternalQueryConfig() + ) + { + @Override + void markDataSourceAsNeedRebuild(String datasource) + { + super.markDataSourceAsNeedRebuild(datasource); + markDataSourceLatch.countDown(); + } + + @Override + @VisibleForTesting + void refresh(final Set segmentsToRefresh, final Set dataSourcesToRebuild) throws IOException + { + super.refresh(segmentsToRefresh, dataSourcesToRebuild); + refreshLatch.countDown(); + } + }; + + runningSchema.start(); + runningSchema.awaitInitialization(); + return runningSchema; } @After public void tearDown() throws Exception { - schema.stop(); + if (runningSchema != null) { + runningSchema.stop(); + } walker.close(); } @Test - public void testGetTableMap() + public void testGetTableMap() throws InterruptedException { + SegmentMetadataCache schema = buildSchema1(); Assert.assertEquals(ImmutableSet.of("foo", "foo2"), schema.getDatasourceNames()); final Set tableNames = schema.getDatasourceNames(); @@ -272,15 +327,14 @@ public void testGetTableMap() @Test public void testSchemaInit() throws InterruptedException { - schema2.start(); - schema2.awaitInitialization(); + SegmentMetadataCache schema2 = buildSchema1(); Assert.assertEquals(ImmutableSet.of("foo", "foo2"), schema2.getDatasourceNames()); - schema2.stop(); } @Test - public void testGetTableMapFoo() + public void testGetTableMapFoo() throws InterruptedException { + SegmentMetadataCache schema = buildSchema1(); final DatasourceTable.PhysicalDatasourceMetadata fooDs = schema.getDatasource("foo"); final DruidTable fooTable = new DatasourceTable(fooDs); final RelDataType rowType = fooTable.getRowType(new JavaTypeFactoryImpl()); @@ -308,8 +362,9 @@ public void testGetTableMapFoo() } @Test - public void testGetTableMapFoo2() + public void testGetTableMapFoo2() throws InterruptedException { + SegmentMetadataCache schema = buildSchema1(); final DatasourceTable.PhysicalDatasourceMetadata fooDs = schema.getDatasource("foo2"); final DruidTable fooTable = new DatasourceTable(fooDs); final RelDataType rowType = fooTable.getRowType(new JavaTypeFactoryImpl()); @@ -331,10 +386,12 @@ public void testGetTableMapFoo2() * This tests that {@link AvailableSegmentMetadata#getNumRows()} is correct in case * of multiple replicas i.e. when {@link SegmentMetadataCache#addSegment(DruidServerMetadata, DataSegment)} * is called more than once for same segment + * @throws InterruptedException */ @Test - public void testAvailableSegmentMetadataNumRows() + public void testAvailableSegmentMetadataNumRows() throws InterruptedException { + SegmentMetadataCache schema = buildSchema1(); Map segmentsMetadata = schema.getSegmentMetadataSnapshot(); final List segments = segmentsMetadata.values() .stream() @@ -382,8 +439,9 @@ public void testAvailableSegmentMetadataNumRows() } @Test - public void testNullDatasource() throws IOException + public void testNullDatasource() throws IOException, InterruptedException { + SegmentMetadataCache schema = buildSchema1(); final Map segmentMetadatas = schema.getSegmentMetadataSnapshot(); final List segments = segmentMetadatas.values() .stream() @@ -406,8 +464,9 @@ public void testNullDatasource() throws IOException } @Test - public void testNullAvailableSegmentMetadata() throws IOException + public void testNullAvailableSegmentMetadata() throws IOException, InterruptedException { + SegmentMetadataCache schema = buildSchema1(); final Map segmentMetadatas = schema.getSegmentMetadataSnapshot(); final List segments = segmentMetadatas.values() .stream() @@ -429,8 +488,9 @@ public void testNullAvailableSegmentMetadata() throws IOException } @Test - public void testAvailableSegmentMetadataIsRealtime() + public void testAvailableSegmentMetadataIsRealtime() throws InterruptedException { + SegmentMetadataCache schema = buildSchema1(); Map segmentsMetadata = schema.getSegmentMetadataSnapshot(); final List segments = segmentsMetadata.values() .stream() @@ -907,19 +967,31 @@ void removeServerSegment(final DruidServerMetadata server, final DataSegment seg Assert.assertEquals(0, metadata.getNumReplicas()); // brokers are not counted as replicas yet } + /** + * Test actions on the cache. The current design of the cache makes testing far harder + * than it should be. + * + * - The cache is refreshed on a schedule. + * - Datasources are added to the refresh queue via an unsynchronized thread. + * - The refresh loop always refreshes since one of the segments is dynamic. + * + * The use of latches tries to keep things synchronized, but there are many + * moving parts. A simpler technique is sorely needed. + */ @Test public void testLocalSegmentCacheSetsDataSourceAsGlobalAndJoinable() throws InterruptedException { - DatasourceTable.PhysicalDatasourceMetadata fooTable = schema.getDatasource("foo"); + SegmentMetadataCache schema3 = buildSchema3(); + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); + DatasourceTable.PhysicalDatasourceMetadata fooTable = schema3.getDatasource("foo"); Assert.assertNotNull(fooTable); Assert.assertTrue(fooTable.dataSource() instanceof TableDataSource); Assert.assertFalse(fooTable.dataSource() instanceof GlobalTableDataSource); Assert.assertFalse(fooTable.isJoinable()); Assert.assertFalse(fooTable.isBroadcast()); - Assert.assertTrue(buildTableLatch.await(1, TimeUnit.SECONDS)); - - buildTableLatch = new CountDownLatch(1); + markDataSourceLatch = new CountDownLatch(1); + refreshLatch = new CountDownLatch(1); final DataSegment someNewBrokerSegment = new DataSegment( "foo", Intervals.of("2012/2013"), @@ -938,11 +1010,12 @@ public void testLocalSegmentCacheSetsDataSourceAsGlobalAndJoinable() throws Inte serverView.addSegment(someNewBrokerSegment, ServerType.BROKER); Assert.assertTrue(markDataSourceLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); // wait for build twice - Assert.assertTrue(buildTableLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); // wait for get again, just to make sure table has been updated (latch counts down just before tables are updated) - Assert.assertTrue(getDatasourcesLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); + refreshLatch = new CountDownLatch(1); + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); - fooTable = schema.getDatasource("foo"); + fooTable = schema3.getDatasource("foo"); Assert.assertNotNull(fooTable); Assert.assertTrue(fooTable.dataSource() instanceof TableDataSource); Assert.assertTrue(fooTable.dataSource() instanceof GlobalTableDataSource); @@ -951,19 +1024,19 @@ public void testLocalSegmentCacheSetsDataSourceAsGlobalAndJoinable() throws Inte // now remove it markDataSourceLatch = new CountDownLatch(1); - buildTableLatch = new CountDownLatch(1); - getDatasourcesLatch = new CountDownLatch(1); + refreshLatch = new CountDownLatch(1); joinableDataSourceNames.remove("foo"); segmentDataSourceNames.remove("foo"); serverView.removeSegment(someNewBrokerSegment, ServerType.BROKER); Assert.assertTrue(markDataSourceLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); - // wait for build - Assert.assertTrue(buildTableLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); + // wait for build twice + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); // wait for get again, just to make sure table has been updated (latch counts down just before tables are updated) - Assert.assertTrue(getDatasourcesLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); + refreshLatch = new CountDownLatch(1); + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); - fooTable = schema.getDatasource("foo"); + fooTable = schema3.getDatasource("foo"); Assert.assertNotNull(fooTable); Assert.assertTrue(fooTable.dataSource() instanceof TableDataSource); Assert.assertFalse(fooTable.dataSource() instanceof GlobalTableDataSource); @@ -974,6 +1047,7 @@ public void testLocalSegmentCacheSetsDataSourceAsGlobalAndJoinable() throws Inte @Test public void testLocalSegmentCacheSetsDataSourceAsBroadcastButNotJoinable() throws InterruptedException { + SegmentMetadataCache schema = buildSchema1(); DatasourceTable.PhysicalDatasourceMetadata fooTable = schema.getDatasource("foo"); Assert.assertNotNull(fooTable); Assert.assertTrue(fooTable.dataSource() instanceof TableDataSource); @@ -1196,8 +1270,9 @@ public void testSegmentMetadataFallbackType() } @Test - public void testStaleDatasourceRefresh() throws IOException + public void testStaleDatasourceRefresh() throws IOException, InterruptedException { + SegmentMetadataCache schema = buildSchema1(); Set segments = new HashSet<>(); Set datasources = new HashSet<>(); datasources.add("wat"); From 0525651589c1d6a670d6717ad34e8d09002b95e4 Mon Sep 17 00:00:00 2001 From: Paul Rogers Date: Thu, 6 Oct 2022 08:17:17 -0700 Subject: [PATCH 6/9] Fix another flaky test --- .../schema/SegmentMetadataCacheTest.java | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCacheTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCacheTest.java index 784c94fb1928..7732a1ad48ac 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCacheTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/schema/SegmentMetadataCacheTest.java @@ -1047,7 +1047,8 @@ public void testLocalSegmentCacheSetsDataSourceAsGlobalAndJoinable() throws Inte @Test public void testLocalSegmentCacheSetsDataSourceAsBroadcastButNotJoinable() throws InterruptedException { - SegmentMetadataCache schema = buildSchema1(); + SegmentMetadataCache schema = buildSchema3(); + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); DatasourceTable.PhysicalDatasourceMetadata fooTable = schema.getDatasource("foo"); Assert.assertNotNull(fooTable); Assert.assertTrue(fooTable.dataSource() instanceof TableDataSource); @@ -1055,10 +1056,8 @@ public void testLocalSegmentCacheSetsDataSourceAsBroadcastButNotJoinable() throw Assert.assertFalse(fooTable.isJoinable()); Assert.assertFalse(fooTable.isBroadcast()); - // wait for build twice - Assert.assertTrue(buildTableLatch.await(1, TimeUnit.SECONDS)); - - buildTableLatch = new CountDownLatch(1); + markDataSourceLatch = new CountDownLatch(1); + refreshLatch = new CountDownLatch(1); final DataSegment someNewBrokerSegment = new DataSegment( "foo", Intervals.of("2012/2013"), @@ -1076,9 +1075,11 @@ public void testLocalSegmentCacheSetsDataSourceAsBroadcastButNotJoinable() throw serverView.addSegment(someNewBrokerSegment, ServerType.BROKER); Assert.assertTrue(markDataSourceLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); - Assert.assertTrue(buildTableLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); + // wait for build twice + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); // wait for get again, just to make sure table has been updated (latch counts down just before tables are updated) - Assert.assertTrue(getDatasourcesLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); + refreshLatch = new CountDownLatch(1); + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); fooTable = schema.getDatasource("foo"); Assert.assertNotNull(fooTable); @@ -1091,16 +1092,16 @@ public void testLocalSegmentCacheSetsDataSourceAsBroadcastButNotJoinable() throw // now remove it markDataSourceLatch = new CountDownLatch(1); - buildTableLatch = new CountDownLatch(1); - getDatasourcesLatch = new CountDownLatch(1); + refreshLatch = new CountDownLatch(1); segmentDataSourceNames.remove("foo"); serverView.removeSegment(someNewBrokerSegment, ServerType.BROKER); Assert.assertTrue(markDataSourceLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); - // wait for build - Assert.assertTrue(buildTableLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); + // wait for build twice + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); // wait for get again, just to make sure table has been updated (latch counts down just before tables are updated) - Assert.assertTrue(getDatasourcesLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); + refreshLatch = new CountDownLatch(1); + Assert.assertTrue(refreshLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); fooTable = schema.getDatasource("foo"); Assert.assertNotNull(fooTable); From f8763ca5575b996b87df49b5859da528cd292a57 Mon Sep 17 00:00:00 2001 From: Paul Rogers Date: Mon, 10 Oct 2022 12:41:08 -0700 Subject: [PATCH 7/9] Revisions from review comments. --- .../org/apache/druid/msq/sql/MSQModeTest.java | 10 +-- .../org/apache/druid/query/QueryContext.java | 4 +- .../appenderator/SinkQuerySegmentWalker.java | 2 +- .../apache/druid/sql/AbstractStatement.java | 24 +++--- .../org/apache/druid/sql/SqlQueryPlus.java | 9 ++- .../avatica/AbstractDruidJdbcStatement.java | 10 ++- .../druid/sql/avatica/DruidConnection.java | 54 +++++++------ .../druid/sql/avatica/DruidJdbcResultSet.java | 6 +- .../druid/sql/avatica/DruidJdbcStatement.java | 11 +-- .../apache/druid/sql/avatica/DruidMeta.java | 78 +++++++++++-------- .../sql/avatica/DruidAvaticaHandlerTest.java | 2 +- .../druid/sql/avatica/DruidStatementTest.java | 1 - 12 files changed, 119 insertions(+), 92 deletions(-) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQModeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQModeTest.java index fa0dac1cf518..7abc83ceec70 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQModeTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQModeTest.java @@ -25,16 +25,15 @@ import org.junit.Assert; import org.junit.Test; -import java.util.Collections; +import java.util.HashMap; import java.util.Map; public class MSQModeTest { - @Test public void testPopulateQueryContextWhenNoSupercedingValuePresent() { - Map originalQueryContext = Collections.emptyMap(); + Map originalQueryContext = new HashMap<>(); MSQMode.populateDefaultQueryContext("strict", originalQueryContext); Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 0), originalQueryContext); } @@ -42,7 +41,8 @@ public void testPopulateQueryContextWhenNoSupercedingValuePresent() @Test public void testPopulateQueryContextWhenSupercedingValuePresent() { - Map originalQueryContext = Collections.emptyMap(); + Map originalQueryContext = new HashMap<>(); + originalQueryContext.put(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10); MSQMode.populateDefaultQueryContext("strict", originalQueryContext); Assert.assertEquals(ImmutableMap.of(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, 10), originalQueryContext); @@ -51,7 +51,7 @@ public void testPopulateQueryContextWhenSupercedingValuePresent() @Test public void testPopulateQueryContextWhenInvalidMode() { - Map originalQueryContext = Collections.emptyMap(); + Map originalQueryContext = new HashMap<>(); Assert.assertThrows(ISE.class, () -> { MSQMode.populateDefaultQueryContext("fake_mode", originalQueryContext); }); diff --git a/processing/src/main/java/org/apache/druid/query/QueryContext.java b/processing/src/main/java/org/apache/druid/query/QueryContext.java index 6239731a6942..0ed8e184664d 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContext.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContext.java @@ -60,7 +60,9 @@ public QueryContext(Map context) // There is no semantic difference between an empty and a null context. // Ensure that a context always exists to avoid the need to check for // a null context. Jackson serialization will omit empty contexts. - this.context = context == null ? Collections.emptyMap() : new TreeMap<>(context); + this.context = context == null + ? Collections.emptyMap() + : Collections.unmodifiableMap(new TreeMap<>(context)); } public static QueryContext empty() diff --git a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java index a0fb767eaad7..f9745c4bab62 100644 --- a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java +++ b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java @@ -161,7 +161,7 @@ public QueryRunner getQueryRunnerForSegments(final Query query, final } final QueryToolChest> toolChest = factory.getToolchest(); - final boolean skipIncrementalSegment = query.getContextBoolean(CONTEXT_SKIP_INCREMENTAL_SEGMENT, false); + final boolean skipIncrementalSegment = query.context().getBoolean(CONTEXT_SKIP_INCREMENTAL_SEGMENT, false); final AtomicLong cpuTimeAccumulator = new AtomicLong(0L); // Make sure this query type can handle the subquery, if present. diff --git a/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java b/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java index be84970bb533..caa7b207f286 100644 --- a/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java @@ -31,9 +31,9 @@ import org.apache.druid.sql.calcite.planner.PlannerContext; import java.io.Closeable; +import java.util.HashMap; import java.util.Map; import java.util.Set; -import java.util.TreeMap; import java.util.UUID; import java.util.function.Function; @@ -59,23 +59,23 @@ public abstract class AbstractStatement implements Closeable protected final SqlExecutionReporter reporter; /** - * Copy of the query context modified during planning. Modifications are - * valid in tasks that run in the request thread. Once the query forks - * child threads, then concurrent modifications to the query context will - * result in an undefined muddle of race conditions. + * Copy of the query context provided by the user. This copy is modified during + * planning. Modifications are possible up to the point where the context is passed + * to a native query. At that point, the context becomes immutable and can be changed + * only by copying the entire native query. */ protected final Map queryContext; protected PlannerContext plannerContext; /** * Resource actions used with authorizing a cancellation request. These actions - * include only the data-level actions (i.e. the datasource.) + * include only the data-level actions (e.g. the datasource.) */ - protected Set cancellationResourceActions; + protected Set cancelationResourceActions; /** * Full resource actions authorized as part of this request. Used when logging - * resource actions. Includes the query context, if query context authorization + * resource actions. Includes query context keys, if query context authorization * is enabled. */ protected Set fullResourceActions; @@ -89,9 +89,7 @@ public AbstractStatement( this.sqlToolbox = sqlToolbox; this.reporter = new SqlExecutionReporter(this, remoteAddress); this.queryPlus = queryPlus; - - // TreeMap is required to get consistent ordering of keys, as needed by tests. - this.queryContext = new TreeMap<>(queryPlus.context()); + this.queryContext = new HashMap<>(queryPlus.context()); // "bySegment" results are never valid to use with SQL because the result format is incompatible // so, overwrite any user specified context to avoid exceptions down the line @@ -164,7 +162,7 @@ protected void authorize( // to cancel the query, and includes only the query-level resources. The second // is used to report the resources actually authorized and includes the // query context variables, if we are authorizing them. - cancellationResourceActions = planner.resourceActions(false); + cancelationResourceActions = planner.resourceActions(false); fullResourceActions = planner.resourceActions(authorizeContextParams); } @@ -188,7 +186,7 @@ protected Function, Access> authorizer() */ public Set resources() { - return cancellationResourceActions; + return cancelationResourceActions; } public Set allResources() diff --git a/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java b/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java index b428777b306c..2aadecda7673 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java @@ -40,11 +40,12 @@ * the {@link Builder} class to create an instance from the information * available at each point in the code. *

- * The query context has a complex lifecycle. The copy here should remain - * unchanged: this is the set of values which the user requested. Planning will + * The query context has a complex lifecycle. The copy here is immutable: + * it is the set of values which the user requested. Planning will * add (and sometimes remove) values: that work should be done on a copy of the * context so that we have a clean record of the user's original requested - * values. + * values. This original record is required to perform security on the set + * of user-provided context keys. */ public class SqlQueryPlus { @@ -63,7 +64,7 @@ public SqlQueryPlus( this.sql = Preconditions.checkNotNull(sql); this.queryContext = queryContext == null ? Collections.emptyMap() - : new HashMap<>(queryContext); + : Collections.unmodifiableMap(new HashMap<>(queryContext)); this.parameters = parameters == null ? Collections.emptyList() : parameters; diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/AbstractDruidJdbcStatement.java b/sql/src/main/java/org/apache/druid/sql/avatica/AbstractDruidJdbcStatement.java index c23298463bbc..697ad1ca1725 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/AbstractDruidJdbcStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/AbstractDruidJdbcStatement.java @@ -68,7 +68,10 @@ public AbstractDruidJdbcStatement( this.statementId = statementId; } - protected static Meta.Signature createSignature(PrepareResult prepareResult, String sql) + protected static Meta.Signature createSignature( + final PrepareResult prepareResult, + final String sql + ) { List params = new ArrayList<>(); final RelDataType parameterRowType = prepareResult.getParameterRowType(); @@ -85,7 +88,10 @@ protected static Meta.Signature createSignature(PrepareResult prepareResult, Str ); } - private static AvaticaParameter createParameter(RelDataTypeField field, RelDataType type) + private static AvaticaParameter createParameter( + final RelDataTypeField field, + final RelDataType type + ) { // signed is always false because no way to extract from RelDataType, and the only usage of this AvaticaParameter // constructor I can find, in CalcitePrepareImpl, does it this way with hard coded false diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java index 7cbeecd344f3..23f1a222dd04 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java @@ -24,11 +24,10 @@ import com.google.errorprone.annotations.concurrent.GuardedBy; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.sql.PreparedStatement; import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlStatementFactory; -import java.util.HashMap; +import java.util.Collections; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -46,7 +45,14 @@ public class DruidConnection private final String connectionId; private final int maxStatements; private final Map userSecret; - private final Map context; + + /** + * The set of context values for each query within this connection. In JDBC, + * Druid query context values are set at the connection level, not on the + * individual query. This session context is shared by all queries (statements) + * within the connection. + */ + private final Map sessionContext; private final AtomicInteger statementCounter = new AtomicInteger(); private final AtomicReference> timeoutFuture = new AtomicReference<>(); @@ -63,13 +69,13 @@ public DruidConnection( final String connectionId, final int maxStatements, final Map userSecret, - final Map context + final Map sessionContext ) { this.connectionId = Preconditions.checkNotNull(connectionId); this.maxStatements = maxStatements; - this.userSecret = userSecret; - this.context = context; + this.userSecret = Collections.unmodifiableMap(userSecret); + this.sessionContext = Collections.unmodifiableMap(sessionContext); } public String getConnectionId() @@ -77,7 +83,19 @@ public String getConnectionId() return connectionId; } - public DruidJdbcStatement createStatement(SqlStatementFactory sqlStatementFactory) + public Map sessionContext() + { + return sessionContext; + } + + public Map userSecret() + { + return userSecret; + } + + public DruidJdbcStatement createStatement( + final SqlStatementFactory sqlStatementFactory + ) { final int statementId = statementCounter.incrementAndGet(); @@ -89,14 +107,13 @@ public DruidJdbcStatement createStatement(SqlStatementFactory sqlStatementFactor } if (statements.size() >= maxStatements) { - throw DruidMeta.logFailure(new ISE("Too many open statements, limit is [%,d]", maxStatements)); + throw DruidMeta.logFailure(new ISE("Too many open statements, limit is %,d", maxStatements)); } @SuppressWarnings("GuardedBy") final DruidJdbcStatement statement = new DruidJdbcStatement( connectionId, statementId, - new HashMap(context), sqlStatementFactory ); @@ -107,9 +124,10 @@ public DruidJdbcStatement createStatement(SqlStatementFactory sqlStatementFactor } public DruidJdbcPreparedStatement createPreparedStatement( - SqlStatementFactory sqlStatementFactory, - SqlQueryPlus sqlQueryPlus, - final long maxRowCount) + final SqlStatementFactory sqlStatementFactory, + final SqlQueryPlus sqlQueryPlus, + final long maxRowCount + ) { final int statementId = statementCounter.incrementAndGet(); @@ -121,17 +139,14 @@ public DruidJdbcPreparedStatement createPreparedStatement( } if (statements.size() >= maxStatements) { - throw DruidMeta.logFailure(new ISE("Too many open statements, limit is [%,d]", maxStatements)); + throw DruidMeta.logFailure(new ISE("Too many open statements, limit is %,d", maxStatements)); } @SuppressWarnings("GuardedBy") - final PreparedStatement statement = sqlStatementFactory.preparedStatement( - sqlQueryPlus.withContext(new HashMap(context)) - ); final DruidJdbcPreparedStatement jdbcStmt = new DruidJdbcPreparedStatement( connectionId, statementId, - statement, + sqlStatementFactory.preparedStatement(sqlQueryPlus), maxRowCount ); @@ -203,9 +218,4 @@ public DruidConnection sync(final Future newTimeoutFuture) } return this; } - - public Map userSecret() - { - return userSecret; - } } diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcResultSet.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcResultSet.java index 95005b7bf595..2b4940155251 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcResultSet.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcResultSet.java @@ -52,6 +52,10 @@ * the application can only read results sequentially, the application * can't jump around or read backwards. As a result, the enclosing * statement closes the result set at EOF to release resources early. + *

+ * Thread safe, but only when accessed sequentially by different request + * threads: not designed for concurrent access as JDBC does not support + * concurrent actions on the same result set. */ public class DruidJdbcResultSet implements Closeable { @@ -83,7 +87,7 @@ public class DruidJdbcResultSet implements Closeable public DruidJdbcResultSet( final AbstractDruidJdbcStatement jdbcStatement, - DirectStatement stmt, + final DirectStatement stmt, final long maxRowCount ) { diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java index 3eda8393154f..3b84b7e483b1 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java @@ -25,34 +25,31 @@ import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlStatementFactory; -import java.util.Map; - /** * Represents Druid's version of the JDBC {@code Statement} class: * can be executed multiple times, one after another, producing a - * {@link DruidJdbcResultSet} for each execution. + * {@link DruidJdbcResultSet} for each execution. Thread safe, but + * only when accessed sequentially by different request threads: + * not designed for concurrent access as JDBC does not support + * concurrent actions on the same statement. */ public class DruidJdbcStatement extends AbstractDruidJdbcStatement { private final SqlStatementFactory lifecycleFactory; - protected final Map queryContext; public DruidJdbcStatement( final String connectionId, final int statementId, - final Map queryContext, final SqlStatementFactory lifecycleFactory ) { super(connectionId, statementId); - this.queryContext = queryContext; this.lifecycleFactory = Preconditions.checkNotNull(lifecycleFactory, "lifecycleFactory"); } public synchronized void execute(SqlQueryPlus queryPlus, long maxRowCount) { closeResultSet(); - queryPlus = queryPlus.withContext(queryContext); DirectStatement stmt = lifecycleFactory.directStatement(queryPlus); resultSet = new DruidJdbcResultSet(this, stmt, Long.MAX_VALUE); try { diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java index f26a1f5837e5..b42665001470 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java @@ -100,6 +100,12 @@ public static T logFailure(T error) } private static final Logger LOG = new Logger(DruidMeta.class); + + /** + * Items passed in via the connection context which are not query + * context values. Instead, these are used at connection time to validate + * the user. + */ private static final Set SENSITIVE_CONTEXT_FIELDS = ImmutableSet.of( "user", "password" ); @@ -162,29 +168,32 @@ public DruidMeta( @Override public void openConnection(final ConnectionHandle ch, final Map info) { - try { - // Build connection context. - final Map secret = new HashMap<>(); - final ImmutableMap.Builder context = ImmutableMap.builder(); - if (info != null) { - for (Map.Entry entry : info.entrySet()) { - if (SENSITIVE_CONTEXT_FIELDS.contains(entry.getKey())) { - secret.put(entry.getKey(), entry.getValue()); - } else { - context.put(entry.getKey(), entry.getValue()); - } + // Build connection context. The session query context is built + // mutable here. It becomes immutable when attached to the connection. + final Map secret = new HashMap<>(); + final Map contextMap = new HashMap<>(); + if (info != null) { + for (Map.Entry entry : info.entrySet()) { + if (SENSITIVE_CONTEXT_FIELDS.contains(entry.getKey())) { + secret.put(entry.getKey(), entry.getValue()); + } else { + contextMap.put(entry.getKey(), entry.getValue()); } } - // we don't want to stringify arrays for JDBC ever because Avatica needs to handle this - context.put(PlannerContext.CTX_SQL_STRINGIFY_ARRAYS, false); - openDruidConnection(ch.id, secret, context.build()); + } + // Don't stringify arrays for JDBC because Avatica needs to handle arrays. + // When using query context security, all JDBC users must have permission on + // this context key. + contextMap.put(PlannerContext.CTX_SQL_STRINGIFY_ARRAYS, false); + try { + openDruidConnection(ch.id, secret, contextMap); } catch (NoSuchConnectionException e) { + // Avoid sanitizing Avatica specific exceptions so that the Avatica code + // can rely on them to handle issues in a JDBC-specific way. throw e; } catch (Throwable t) { - // we want to avoid sanitizing Avatica specific exceptions as the Avatica code can rely on them to handle issues - // differently throw mapException(t); } } @@ -257,16 +266,17 @@ public StatementHandle prepare( { try { final DruidConnection druidConnection = getDruidConnection(ch.id); - SqlQueryPlus sqlReq = new SqlQueryPlus( + final SqlQueryPlus sqlReq = new SqlQueryPlus( sql, - null, // Context provided by connection + druidConnection.sessionContext(), null, // No parameters in this path doAuthenticate(druidConnection) ); - DruidJdbcPreparedStatement stmt = getDruidConnection(ch.id).createPreparedStatement( + final DruidJdbcPreparedStatement stmt = getDruidConnection(ch.id).createPreparedStatement( sqlStatementFactory, sqlReq, - maxRowCount); + maxRowCount + ); stmt.prepare(); LOG.debug("Successfully prepared statement [%s] for execution", stmt.getStatementId()); return new StatementHandle(ch.id, stmt.getStatementId(), stmt.getSignature()); @@ -281,7 +291,7 @@ public StatementHandle prepare( private AuthenticationResult doAuthenticate(final DruidConnection druidConnection) { - AuthenticationResult authenticationResult = authenticateConnection(druidConnection); + final AuthenticationResult authenticationResult = authenticateConnection(druidConnection); if (authenticationResult == null) { throw logFailure( new ForbiddenException("Authentication failed."), @@ -324,6 +334,7 @@ public ExecuteResult prepareAndExecute( AuthenticationResult authenticationResult = doAuthenticate(druidConnection); SqlQueryPlus sqlRequest = SqlQueryPlus.builder(sql) .auth(authenticationResult) + .context(druidConnection.sessionContext()) .build(); druidStatement.execute(sqlRequest, maxRowCount); ExecuteResult result = doFetch(druidStatement, maxRowsInFirstFrame); @@ -411,7 +422,7 @@ public Frame fetch( { try { final int maxRows = getEffectiveMaxRowsPerFrame(fetchMaxRowCount); - LOG.debug("Fetching next frame from offset[%s] with [%s] rows for statement[%s]", offset, maxRows, statement.id); + LOG.debug("Fetching next frame from offset %,d with %,d rows for statement [%s]", offset, maxRows, statement.id); return getDruidStatement(statement, AbstractDruidJdbcStatement.class).nextFrame(offset, maxRows); } catch (NoSuchConnectionException e) { @@ -447,7 +458,7 @@ public ExecuteResult execute( druidStatement.execute(parameterValues); ExecuteResult result = doFetch(druidStatement, maxRowsInFirstFrame); LOG.debug( - "Successfully started execution of statement[%s]", + "Successfully started execution of statement [%s]", druidStatement.getStatementId()); return result; } @@ -503,7 +514,7 @@ public boolean syncResults( final long currentOffset = druidStatement.getCurrentOffset(); if (currentOffset != offset) { throw logFailure(new ISE( - "Requested offset[%,d] does not match currentOffset[%,d]", + "Requested offset %,d does not match currentOffset %,d", offset, currentOffset )); @@ -757,11 +768,11 @@ private AuthenticationResult authenticateConnection(final DruidConnection connec { Map context = connection.userSecret(); for (Authenticator authenticator : authenticators) { - LOG.debug("Attempting authentication with authenticator[%s]", authenticator.getClass()); + LOG.debug("Attempting authentication with authenticator [%s]", authenticator.getClass()); AuthenticationResult authenticationResult = authenticator.authenticateJDBCContext(context); if (authenticationResult != null) { LOG.debug( - "Authenticated identity[%s] for connection[%s]", + "Authenticated identity [%s] for connection [%s]", authenticationResult.getIdentity(), connection.getConnectionId() ); @@ -798,7 +809,7 @@ private DruidConnection openDruidConnection( connectionCount.decrementAndGet(); throw logFailure( new ISE("Too many connections"), - "Too many connections, limit is[%,d] per broker", + "Too many connections, limit is %,d per broker", config.getMaxConnections() ); } @@ -812,10 +823,10 @@ private DruidConnection openDruidConnection( if (putResult != null) { // Didn't actually insert the connection. connectionCount.decrementAndGet(); - throw logFailure(new ISE("Connection[%s] already open.", connectionId)); + throw logFailure(new ISE("Connection [%s] already open.", connectionId)); } - LOG.debug("Connection[%s] opened.", connectionId); + LOG.debug("Connection [%s] opened.", connectionId); // Call getDruidConnection to start the timeout timer. return getDruidConnection(connectionId); @@ -901,7 +912,7 @@ private MetaResultSet sqlResultSet(final ConnectionHandle ch, final String sql) * checked against if any additional frames are required (which means one of the input or maximum was set to a value * other than -1). */ - private int getEffectiveMaxRowsPerFrame(int clientMaxRowsPerFrame) + private int getEffectiveMaxRowsPerFrame(final int clientMaxRowsPerFrame) { // no configured row limit, use the client provided limit if (config.getMaxRowsPerFrame() < 0) { @@ -917,13 +928,12 @@ private int getEffectiveMaxRowsPerFrame(int clientMaxRowsPerFrame) /** * coerce fetch size to be, at minimum, {@link AvaticaServerConfig#minRowsPerFrame} */ - private int adjustForMinumumRowsPerFrame(int rowsPerFrame) + private int adjustForMinumumRowsPerFrame(final int rowsPerFrame) { - final int adjustedRowsPerFrame = Math.max(config.getMinRowsPerFrame(), rowsPerFrame); - return adjustedRowsPerFrame; + return Math.max(config.getMinRowsPerFrame(), rowsPerFrame); } - private static String withEscapeClause(String toEscape) + private static String withEscapeClause(final String toEscape) { return Calcites.escapeStringLiteral(toEscape) + " ESCAPE '\\'"; } diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java index 7d989cae2300..119d10ae4efe 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java @@ -777,7 +777,7 @@ public void testTooManyStatements() throws SQLException } expectedException.expect(AvaticaClientRuntimeException.class); - expectedException.expectMessage("Too many open statements, limit is [4]"); + expectedException.expectMessage("Too many open statements, limit is 4"); client.createStatement(); } diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java index eba660d3c0c1..e5dc4a662acf 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java @@ -138,7 +138,6 @@ private DruidJdbcStatement jdbcStatement() return new DruidJdbcStatement( "", 0, - Collections.emptyMap(), sqlStatementFactory ); } From a0fe52e583ae6650cb77bd9d08ec60c201e38654 Mon Sep 17 00:00:00 2001 From: Paul Rogers Date: Wed, 12 Oct 2022 11:16:14 -0700 Subject: [PATCH 8/9] Added enhanced context security from issue #13120 Required to get a security IT to pass Revises context security in native, SQL query path Moves two SQL context keys to QueryContexts for visibility to AuthConfig --- .../apache/druid/utils/CollectionUtils.java | 36 +++++ .../druid/utils/CollectionUtilsTest.java | 64 +++++++++ ...etsHistogramQuantileSqlAggregatorTest.java | 4 +- .../apache/druid/msq/test/MSQTestBase.java | 3 +- .../AbstractAuthConfigurationTest.java | 9 +- .../org/apache/druid/query/QueryContexts.java | 4 + .../apache/druid/server/QueryLifecycle.java | 8 +- .../druid/server/security/AuthConfig.java | 129 +++++++++++++++--- .../druid/server/security/AuthConfigTest.java | 58 ++++++++ .../apache/druid/sql/AbstractStatement.java | 45 +++--- .../druid/sql/SqlExecutionReporter.java | 7 +- .../apache/druid/sql/avatica/DruidMeta.java | 4 +- .../sql/calcite/planner/DruidPlanner.java | 71 ++++++---- .../sql/calcite/planner/PlannerContext.java | 28 ++-- .../sql/calcite/planner/PlannerFactory.java | 3 +- .../apache/druid/sql/SqlStatementTest.java | 3 +- .../sql/calcite/BaseCalciteQueryTest.java | 16 +-- .../sql/calcite/CalciteIngestionDmlTest.java | 4 +- .../druid/sql/http/SqlResourceTest.java | 4 +- 19 files changed, 369 insertions(+), 131 deletions(-) create mode 100644 core/src/test/java/org/apache/druid/utils/CollectionUtilsTest.java diff --git a/core/src/main/java/org/apache/druid/utils/CollectionUtils.java b/core/src/main/java/org/apache/druid/utils/CollectionUtils.java index 39f6fc259b7f..d4e1adf830ca 100644 --- a/core/src/main/java/org/apache/druid/utils/CollectionUtils.java +++ b/core/src/main/java/org/apache/druid/utils/CollectionUtils.java @@ -29,9 +29,11 @@ import java.util.Collection; import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; import java.util.Spliterator; import java.util.TreeSet; import java.util.function.Function; @@ -148,6 +150,40 @@ public static boolean isNullOrEmpty(@Nullable Collection list) return list == null || list.isEmpty(); } + /** + * Subtract one set from another: {@code C = A - B}. + */ + public static Set subtract(Set left, Set right) + { + Set result = new HashSet<>(left); + result.removeAll(right); + return result; + } + + /** + * Intersection of two sets: {@code C = A ∩ B}. + */ + public static Set intersect(Set left, Set right) + { + Set result = new HashSet<>(); + for (T key : left) { + if (right.contains(key)) { + result.add(key); + } + } + return result; + } + + /** + * Intersection of two sets: {@code C = A ∪ B}. + */ + public static Set union(Set left, Set right) + { + Set result = new HashSet<>(left); + result.addAll(right); + return result; + } + private CollectionUtils() { } diff --git a/core/src/test/java/org/apache/druid/utils/CollectionUtilsTest.java b/core/src/test/java/org/apache/druid/utils/CollectionUtilsTest.java new file mode 100644 index 000000000000..522b9e0ade70 --- /dev/null +++ b/core/src/test/java/org/apache/druid/utils/CollectionUtilsTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.utils; + +import com.google.common.collect.ImmutableSet; +import org.junit.Test; + +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +public class CollectionUtilsTest +{ + // When Java 9 is allowed, use Set.of(). + Set empty = ImmutableSet.of(); + Set abc = ImmutableSet.of("a", "b", "c"); + Set bcd = ImmutableSet.of("b", "c", "d"); + Set efg = ImmutableSet.of("e", "f", "g"); + + @Test + public void testSubtract() + { + assertEquals(empty, CollectionUtils.subtract(empty, empty)); + assertEquals(abc, CollectionUtils.subtract(abc, empty)); + assertEquals(empty, CollectionUtils.subtract(abc, abc)); + assertEquals(abc, CollectionUtils.subtract(abc, efg)); + assertEquals(ImmutableSet.of("a"), CollectionUtils.subtract(abc, bcd)); + } + + @Test + public void testIntersect() + { + assertEquals(empty, CollectionUtils.intersect(empty, empty)); + assertEquals(abc, CollectionUtils.intersect(abc, abc)); + assertEquals(empty, CollectionUtils.intersect(abc, efg)); + assertEquals(ImmutableSet.of("b", "c"), CollectionUtils.intersect(abc, bcd)); + } + + @Test + public void testUnion() + { + assertEquals(empty, CollectionUtils.union(empty, empty)); + assertEquals(abc, CollectionUtils.union(abc, abc)); + assertEquals(ImmutableSet.of("a", "b", "c", "e", "f", "g"), CollectionUtils.union(abc, efg)); + assertEquals(ImmutableSet.of("a", "b", "c", "d"), CollectionUtils.union(abc, bcd)); + } +} diff --git a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java index f5c850ea7150..22dbc8ff6979 100644 --- a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java +++ b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java @@ -27,6 +27,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.query.Druids; +import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; @@ -53,7 +54,6 @@ import org.apache.druid.sql.calcite.BaseCalciteQueryTest; import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.planner.DruidOperatorTable; -import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.util.CalciteTests; import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; import org.apache.druid.timeline.DataSegment; @@ -324,7 +324,7 @@ public void testQuantileOnCastedString() new QuantilePostAggregator("a6", "a6:agg", 0.999f), new QuantilePostAggregator("a7", "a5:agg", 0.999f) ) - .context(ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, "dummy")) + .context(ImmutableMap.of(QueryContexts.CTX_SQL_QUERY_ID, "dummy")) .build() ), ImmutableList.of( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index a08cf0ca3ece..45968dbd4733 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -131,7 +131,6 @@ import org.apache.druid.sql.calcite.external.ExternalDataSource; import org.apache.druid.sql.calcite.planner.CalciteRulesManager; import org.apache.druid.sql.calcite.planner.PlannerConfig; -import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerFactory; import org.apache.druid.sql.calcite.rel.DruidQuery; import org.apache.druid.sql.calcite.run.SqlEngine; @@ -207,7 +206,7 @@ public class MSQTestBase extends BaseCalciteQueryTest public static final Map DEFAULT_MSQ_CONTEXT = ImmutableMap.builder() .put(MultiStageQueryContext.CTX_ENABLE_DURABLE_SHUFFLE_STORAGE, true) - .put(PlannerContext.CTX_SQL_QUERY_ID, "test-query") + .put(QueryContexts.CTX_SQL_QUERY_ID, "test-query") .put(QueryContexts.FINALIZE_KEY, true) .build(); diff --git a/integration-tests/src/test/java/org/apache/druid/tests/security/AbstractAuthConfigurationTest.java b/integration-tests/src/test/java/org/apache/druid/tests/security/AbstractAuthConfigurationTest.java index aa729774491f..0be9d7beb7a6 100644 --- a/integration-tests/src/test/java/org/apache/druid/tests/security/AbstractAuthConfigurationTest.java +++ b/integration-tests/src/test/java/org/apache/druid/tests/security/AbstractAuthConfigurationTest.java @@ -547,7 +547,7 @@ public void test_avaticaQueryWithContext_datasourceAndContextParamsUser_succeed( public void test_sqlQueryWithContext_datasourceOnlyUser_fail() throws Exception { final String query = "select count(*) from auth_test"; - StatusResponseHolder responseHolder = makeSQLQueryRequest( + makeSQLQueryRequest( getHttpClient(User.DATASOURCE_ONLY_USER), query, ImmutableMap.of("auth_test_ctx", "should-be-denied"), @@ -559,7 +559,7 @@ public void test_sqlQueryWithContext_datasourceOnlyUser_fail() throws Exception public void test_sqlQueryWithContext_datasourceAndContextParamsUser_succeed() throws Exception { final String query = "select count(*) from auth_test"; - StatusResponseHolder responseHolder = makeSQLQueryRequest( + makeSQLQueryRequest( getHttpClient(User.DATASOURCE_AND_CONTEXT_PARAMS_USER), query, ImmutableMap.of("auth_test_ctx", "should-be-allowed"), @@ -844,11 +844,6 @@ protected void verifyGroupMappingsInvalidAuthNameFails() protected void verifyInvalidAuthNameFails(String endpoint) { - HttpClient adminClient = new CredentialedHttpClient( - new BasicCredentials("admin", "priest"), - httpClient - ); - HttpUtil.makeRequestWithExpectedStatus( getHttpClient(User.ADMIN), HttpMethod.POST, diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java index 976ab43bc074..c06c03624496 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java @@ -82,7 +82,11 @@ public class QueryContexts public static final String UNCOVERED_INTERVALS_LIMIT_KEY = "uncoveredIntervalsLimit"; public static final String MIN_TOP_N_THRESHOLD = "minTopNThreshold"; + // SQL query context keys + public static final String CTX_SQL_QUERY_ID = BaseQuery.SQL_QUERY_ID; + public static final String CTX_SQL_STRINGIFY_ARRAYS = "sqlStringifyArrays"; + // Defaults public static final boolean DEFAULT_BY_SEGMENT = false; public static final boolean DEFAULT_POPULATE_CACHE = true; public static final boolean DEFAULT_USE_CACHE = true; diff --git a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java index 0dc891b50ece..f56efbdf9f27 100644 --- a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java +++ b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java @@ -62,7 +62,7 @@ import javax.annotation.Nullable; import javax.servlet.http.HttpServletRequest; -import java.util.Collections; + import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; @@ -227,12 +227,10 @@ public Access authorize(HttpServletRequest req) baseQuery.getDataSource().getTableNames(), AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR ), - authConfig.authorizeQueryContextParams() - ? Iterables.transform( - userContextKeys, + Iterables.transform( + authConfig.filterContextKeys(userContextKeys), contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE) ) - : Collections.emptyList() ); return doAuthorize( AuthorizationUtils.authenticationResultFromRequest(req), diff --git a/server/src/main/java/org/apache/druid/server/security/AuthConfig.java b/server/src/main/java/org/apache/druid/server/security/AuthConfig.java index 6d990c7a744d..695260c3c85a 100644 --- a/server/src/main/java/org/apache/druid/server/security/AuthConfig.java +++ b/server/src/main/java/org/apache/druid/server/security/AuthConfig.java @@ -21,10 +21,14 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableSet; +import org.apache.druid.query.QueryContexts; +import org.apache.druid.utils.CollectionUtils; import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.Set; public class AuthConfig { @@ -46,25 +50,20 @@ public class AuthConfig public static final String TRUSTED_DOMAIN_NAME = "trustedDomain"; - public AuthConfig() - { - this(null, null, null, false, false); - } + /** + * Set of context keys which are always permissible because something in the Druid + * code itself sets the key before the security check. + */ + public static final Set ALLOWED_CONTEXT_KEYS = ImmutableSet.of( + // Set in the Avatica server path + QueryContexts.CTX_SQL_STRINGIFY_ARRAYS, + // Set by the Router + QueryContexts.CTX_SQL_QUERY_ID + ); - @JsonCreator - public AuthConfig( - @JsonProperty("authenticatorChain") List authenticatorChain, - @JsonProperty("authorizers") List authorizers, - @JsonProperty("unsecuredPaths") List unsecuredPaths, - @JsonProperty("allowUnauthenticatedHttpOptions") boolean allowUnauthenticatedHttpOptions, - @JsonProperty("authorizeQueryContextParams") boolean authorizeQueryContextParams - ) + public AuthConfig() { - this.authenticatorChain = authenticatorChain; - this.authorizers = authorizers; - this.unsecuredPaths = unsecuredPaths == null ? Collections.emptyList() : unsecuredPaths; - this.allowUnauthenticatedHttpOptions = allowUnauthenticatedHttpOptions; - this.authorizeQueryContextParams = authorizeQueryContextParams; + this(null, null, null, false, false, null, null); } @JsonProperty @@ -82,6 +81,44 @@ public AuthConfig( @JsonProperty private final boolean authorizeQueryContextParams; + /** + * The set of query context keys that are allowed, even when security is + * enabled. A null value is the same as an empty set. + */ + @JsonProperty + private final Set unsecuredContextKeys; + + /** + * The set of query context keys to secure, when context security is + * enabled. Null has a special meaning: it means to ignore this set. + * Else, only the keys in this set are subject to security. If set, + * the unsecured list is ignored. + */ + @JsonProperty + private final Set securedContextKeys; + + @JsonCreator + public AuthConfig( + @JsonProperty("authenticatorChain") List authenticatorChain, + @JsonProperty("authorizers") List authorizers, + @JsonProperty("unsecuredPaths") List unsecuredPaths, + @JsonProperty("allowUnauthenticatedHttpOptions") boolean allowUnauthenticatedHttpOptions, + @JsonProperty("authorizeQueryContextParams") boolean authorizeQueryContextParams, + @JsonProperty("unsecuredContextKeys") Set unsecuredContextKeys, + @JsonProperty("securedContextKeys") Set securedContextKeys + ) + { + this.authenticatorChain = authenticatorChain; + this.authorizers = authorizers; + this.unsecuredPaths = unsecuredPaths == null ? Collections.emptyList() : unsecuredPaths; + this.allowUnauthenticatedHttpOptions = allowUnauthenticatedHttpOptions; + this.authorizeQueryContextParams = authorizeQueryContextParams; + this.unsecuredContextKeys = unsecuredContextKeys == null + ? Collections.emptySet() + : unsecuredContextKeys; + this.securedContextKeys = securedContextKeys; + } + public List getAuthenticatorChain() { return authenticatorChain; @@ -107,6 +144,36 @@ public boolean authorizeQueryContextParams() return authorizeQueryContextParams; } + /** + * Filter the user-supplied context keys based on the context key security + * rules. If context key security is disabled, then allow all keys. Else, + * apply the three key lists defined here. + *

    + *
  • Allow Druid-defined keys.
  • + *
  • Allow anything not in the secured context key list.
  • + *
  • Allow anything in the config-defined unsecured key list.
  • + *
+ * In the typical case, a site defines either the secured key list + * (to handle a few keys that are are not allowed) or the unsecured key + * list (to enumerate a few that are allowed.) If both lists + * are given, think of the secured list as exceptions to the unsecured + * key list. + * + * @return the list of secured keys to check via authentication + */ + public Set filterContextKeys(final Set userKeys) + { + if (!authorizeQueryContextParams) { + return ImmutableSet.of(); + } + Set keysToCheck = CollectionUtils.subtract(userKeys, ALLOWED_CONTEXT_KEYS); + keysToCheck = CollectionUtils.subtract(keysToCheck, unsecuredContextKeys); + if (securedContextKeys != null) { + keysToCheck = CollectionUtils.intersect(keysToCheck, securedContextKeys); + } + return keysToCheck; + } + @Override public boolean equals(Object o) { @@ -121,7 +188,9 @@ public boolean equals(Object o) && authorizeQueryContextParams == that.authorizeQueryContextParams && Objects.equals(authenticatorChain, that.authenticatorChain) && Objects.equals(authorizers, that.authorizers) - && Objects.equals(unsecuredPaths, that.unsecuredPaths); + && Objects.equals(unsecuredPaths, that.unsecuredPaths) + && Objects.equals(unsecuredContextKeys, that.unsecuredContextKeys) + && Objects.equals(securedContextKeys, that.securedContextKeys); } @Override @@ -132,7 +201,9 @@ public int hashCode() authorizers, unsecuredPaths, allowUnauthenticatedHttpOptions, - authorizeQueryContextParams + authorizeQueryContextParams, + unsecuredContextKeys, + securedContextKeys ); } @@ -145,6 +216,8 @@ public String toString() ", unsecuredPaths=" + unsecuredPaths + ", allowUnauthenticatedHttpOptions=" + allowUnauthenticatedHttpOptions + ", enableQueryContextAuthorization=" + authorizeQueryContextParams + + ", unsecuredContextKeys=" + unsecuredContextKeys + + ", securedContextKeys=" + securedContextKeys + '}'; } @@ -163,6 +236,8 @@ public static class Builder private List unsecuredPaths; private boolean allowUnauthenticatedHttpOptions; private boolean authorizeQueryContextParams; + private Set unsecuredContextKeys; + private Set securedContextKeys; public Builder setAuthenticatorChain(List authenticatorChain) { @@ -194,6 +269,18 @@ public Builder setAuthorizeQueryContextParams(boolean authorizeQueryContextParam return this; } + public Builder setUnsecuredContextKeys(Set unsecuredContextKeys) + { + this.unsecuredContextKeys = unsecuredContextKeys; + return this; + } + + public Builder setSecuredContextKeys(Set securedContextKeys) + { + this.securedContextKeys = securedContextKeys; + return this; + } + public AuthConfig build() { return new AuthConfig( @@ -201,7 +288,9 @@ public AuthConfig build() authorizers, unsecuredPaths, allowUnauthenticatedHttpOptions, - authorizeQueryContextParams + authorizeQueryContextParams, + unsecuredContextKeys, + securedContextKeys ); } } diff --git a/server/src/test/java/org/apache/druid/server/security/AuthConfigTest.java b/server/src/test/java/org/apache/druid/server/security/AuthConfigTest.java index eb2b8da822de..5e0c81df23ca 100644 --- a/server/src/test/java/org/apache/druid/server/security/AuthConfigTest.java +++ b/server/src/test/java/org/apache/druid/server/security/AuthConfigTest.java @@ -19,9 +19,16 @@ package org.apache.druid.server.security; +import com.google.common.collect.ImmutableSet; import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.query.QueryContexts; import org.junit.Test; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + public class AuthConfigTest { @Test @@ -29,4 +36,55 @@ public void testEquals() { EqualsVerifier.configure().usingGetClass().forClass(AuthConfig.class).verify(); } + + @Test + public void testContextSecurity() + { + // No security + { + AuthConfig config = new AuthConfig(); + Set keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID); + assertTrue(config.filterContextKeys(keys).isEmpty()); + } + + // Default security + { + AuthConfig config = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .build(); + Set keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID); + assertEquals(ImmutableSet.of("a", "b"), config.filterContextKeys(keys)); + } + + // Specify unsecured keys (white-list) + { + AuthConfig config = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .setUnsecuredContextKeys(ImmutableSet.of("a")) + .build(); + Set keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID); + assertEquals(ImmutableSet.of("b"), config.filterContextKeys(keys)); + } + + // Specify secured keys (black-list) + { + AuthConfig config = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .setSecuredContextKeys(ImmutableSet.of("a")) + .build(); + Set keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID); + assertEquals(ImmutableSet.of("a"), config.filterContextKeys(keys)); + } + + // Specify both + { + AuthConfig config = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .setUnsecuredContextKeys(ImmutableSet.of("a", "b")) + .setSecuredContextKeys(ImmutableSet.of("b", "c")) + .build(); + Set keys = ImmutableSet.of("a", "b", "c", "d", QueryContexts.CTX_SQL_QUERY_ID); + assertEquals(ImmutableSet.of("c"), config.filterContextKeys(keys)); + } + } } diff --git a/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java b/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java index caa7b207f286..7388303d0312 100644 --- a/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java @@ -24,14 +24,18 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.QueryContexts; import org.apache.druid.server.security.Access; +import org.apache.druid.server.security.Action; import org.apache.druid.server.security.AuthorizationUtils; import org.apache.druid.server.security.ForbiddenException; +import org.apache.druid.server.security.Resource; import org.apache.druid.server.security.ResourceAction; +import org.apache.druid.server.security.ResourceType; import org.apache.druid.sql.calcite.planner.DruidPlanner; import org.apache.druid.sql.calcite.planner.PlannerContext; import java.io.Closeable; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.UUID; @@ -66,19 +70,7 @@ public abstract class AbstractStatement implements Closeable */ protected final Map queryContext; protected PlannerContext plannerContext; - - /** - * Resource actions used with authorizing a cancellation request. These actions - * include only the data-level actions (e.g. the datasource.) - */ - protected Set cancelationResourceActions; - - /** - * Full resource actions authorized as part of this request. Used when logging - * resource actions. Includes query context keys, if query context authorization - * is enabled. - */ - protected Set fullResourceActions; + protected DruidPlanner.AuthResult authResult; public AbstractStatement( final SqlToolbox sqlToolbox, @@ -97,7 +89,7 @@ public AbstractStatement( if (this.queryContext.remove(QueryContexts.BY_SEGMENT_KEY) != null) { log.warn("'bySegment' results are not supported for SQL queries, ignoring query context parameter"); } - this.queryContext.putIfAbsent(PlannerContext.CTX_SQL_QUERY_ID, UUID.randomUUID().toString()); + this.queryContext.putIfAbsent(QueryContexts.CTX_SQL_QUERY_ID, UUID.randomUUID().toString()); for (Map.Entry entry : sqlToolbox.defaultQueryConfig.getContext().entrySet()) { this.queryContext.putIfAbsent(entry.getKey(), entry.getValue()); } @@ -105,7 +97,7 @@ public AbstractStatement( public String sqlQueryId() { - return QueryContexts.parseString(queryContext, PlannerContext.CTX_SQL_QUERY_ID); + return QueryContexts.parseString(queryContext, QueryContexts.CTX_SQL_QUERY_ID); } /** @@ -149,21 +141,18 @@ protected void authorize( final Function, Access> authorizer ) { - boolean authorizeContextParams = sqlToolbox.authConfig.authorizeQueryContextParams(); + Set securedKeys = this.sqlToolbox.authConfig.filterContextKeys(queryPlus.context().keySet()); + Set contextResources = new HashSet<>(); + securedKeys.forEach(key -> contextResources.add( + new ResourceAction(new Resource(key, ResourceType.QUERY_CONTEXT), Action.WRITE) + )); // Authentication is done by the planner using the function provided // here. The planner ensures that this step is done before planning. - Access authorizationResult = planner.authorize(authorizer, authorizeContextParams); - if (!authorizationResult.isAllowed()) { - throw new ForbiddenException(authorizationResult.toMessage()); + authResult = planner.authorize(authorizer, contextResources); + if (!authResult.authorizationResult.isAllowed()) { + throw new ForbiddenException(authResult.authorizationResult.toMessage()); } - - // Capture the query resources twice. The first is used to validate the request - // to cancel the query, and includes only the query-level resources. The second - // is used to report the resources actually authorized and includes the - // query context variables, if we are authorizing them. - cancelationResourceActions = planner.resourceActions(false); - fullResourceActions = planner.resourceActions(authorizeContextParams); } /** @@ -186,12 +175,12 @@ protected Function, Access> authorizer() */ public Set resources() { - return cancelationResourceActions; + return authResult.sqlResourceActions; } public Set allResources() { - return fullResourceActions; + return authResult.allResourceActions; } public SqlQueryPlus query() diff --git a/sql/src/main/java/org/apache/druid/sql/SqlExecutionReporter.java b/sql/src/main/java/org/apache/druid/sql/SqlExecutionReporter.java index 0d7646d5f076..85396d4642b6 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlExecutionReporter.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlExecutionReporter.java @@ -94,10 +94,13 @@ public void emit() metricBuilder.setDimension("id", plannerContext.getSqlQueryId()); metricBuilder.setDimension("nativeQueryIds", plannerContext.getNativeQueryIds().toString()); } - if (stmt.fullResourceActions != null) { + if (stmt.authResult != null) { + // Note: the dimension is "dataSource" (sic), so we log only the SQL resource + // actions. Even here, for external tables, those actions are not always + // datasources. metricBuilder.setDimension( "dataSource", - stmt.fullResourceActions + stmt.authResult.sqlResourceActions .stream() .map(action -> action.getResource().getName()) .collect(Collectors.toList()) diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java index b42665001470..6ee3de811cdf 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java @@ -42,6 +42,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.query.QueryContexts; import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.server.security.Authenticator; import org.apache.druid.server.security.AuthenticatorMapper; @@ -49,7 +50,6 @@ import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlStatementFactory; import org.apache.druid.sql.calcite.planner.Calcites; -import org.apache.druid.sql.calcite.planner.PlannerContext; import org.joda.time.Interval; import javax.annotation.Nonnull; @@ -184,7 +184,7 @@ public void openConnection(final ConnectionHandle ch, final Map // Don't stringify arrays for JDBC because Avatica needs to handle arrays. // When using query context security, all JDBC users must have permission on // this context key. - contextMap.put(PlannerContext.CTX_SQL_STRINGIFY_ARRAYS, false); + contextMap.put(QueryContexts.CTX_SQL_STRINGIFY_ARRAYS, false); try { openDruidConnection(ch.id, secret, contextMap); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java index a31375bad8ae..1d34713c9e86 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java @@ -31,10 +31,8 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.QueryContext; import org.apache.druid.server.security.Access; -import org.apache.druid.server.security.Action; import org.apache.druid.server.security.Resource; import org.apache.druid.server.security.ResourceAction; -import org.apache.druid.server.security.ResourceType; import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.parser.DruidSqlReplace; import org.apache.druid.sql.calcite.run.SqlEngine; @@ -63,6 +61,35 @@ public enum State START, VALIDATED, PREPARED, PLANNED } + public static class AuthResult + { + public final Access authorizationResult; + + /** + * Resource actions used with authorizing a cancellation request. These actions + * include only the data-level actions (e.g. the datasource.) + */ + public final Set sqlResourceActions; + + /** + * Full resource actions authorized as part of this request. Used when logging + * resource actions. Includes query context keys, if query context authorization + * is enabled. + */ + public final Set allResourceActions; + + public AuthResult( + final Access authorizationResult, + final Set sqlResourceActions, + final Set allResourceActions + ) + { + this.authorizationResult = authorizationResult; + this.sqlResourceActions = sqlResourceActions; + this.allResourceActions = allResourceActions; + } + } + private final FrameworkConfig frameworkConfig; private final CalcitePlanner planner; private final PlannerContext plannerContext; @@ -162,41 +189,29 @@ public PrepareResult prepare() * step within the planner's state machine. * * @param authorizer a function from resource actions to a {@link Access} result. + * @param extraActions set of additional resource actions beyond those inferred + * from the query itself. Specifically, the set of context keys to + * authorize. * * @return the return value from the authorizer */ - public Access authorize(Function, Access> authorizer, boolean authorizeContextParams) + public AuthResult authorize( + final Function, Access> authorizer, + final Set extraActions + ) { Preconditions.checkState(state == State.VALIDATED); - Access access = authorizer.apply(resourceActions(authorizeContextParams)); + Set sqlResourceActions = plannerContext.getResourceActions(); + Set allResourceActions = new HashSet<>(sqlResourceActions); + allResourceActions.addAll(extraActions); + Access access = authorizer.apply(allResourceActions); plannerContext.setAuthorizationResult(access); // Authorization is done as a flag, not a state, alas. - // Views do prepare without authorize, Avatica does authorize, then prepare, - // so the only constraint is that authorize be done after validation, before plan. + // Views prepare without authorization, Avatica does authorize, then prepare, + // so the only constraint is that authorization be done before planning. authorized = true; - return access; - } - - /** - * Return the resource actions corresponding to the datasources and views which - * an authenticated request must be authorized for to process the query. The - * actions will be {@code null} if the planner has not yet advanced to the - * validation step. This may occur if validation fails and the caller accesses - * the resource actions as part of clean-up. - */ - public Set resourceActions(boolean includeContext) - { - Set resourceActions = plannerContext.getResourceActions(); - if (includeContext) { - Set actions = new HashSet<>(resourceActions); - plannerContext.queryContextKeys().forEach(contextParam -> actions.add( - new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE) - )); - return actions; - } else { - return resourceActions; - } + return new AuthResult(access, sqlResourceActions, allResourceActions); } /** diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java index 4c6bb50424fc..797bec1fe962 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java @@ -33,8 +33,8 @@ import org.apache.druid.java.util.common.Numbers; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.query.BaseQuery; import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryContexts; import org.apache.druid.server.security.Access; import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.server.security.ResourceAction; @@ -62,14 +62,14 @@ */ public class PlannerContext { - // query context keys - public static final String CTX_SQL_QUERY_ID = BaseQuery.SQL_QUERY_ID; + // Query context keys public static final String CTX_SQL_CURRENT_TIMESTAMP = "sqlCurrentTimestamp"; public static final String CTX_SQL_TIME_ZONE = "sqlTimeZone"; - public static final String CTX_SQL_STRINGIFY_ARRAYS = "sqlStringifyArrays"; - // This context parameter is an undocumented parameter, used internally, to allow the web console to - // apply a limit without having to rewrite the SQL query. + /** + * Undocumented context key, used internally, to allow the web console to + * apply a limit without having to rewrite the SQL query. + */ public static final String CTX_SQL_OUTER_LIMIT = "sqlOuterLimit"; // DataContext keys @@ -84,7 +84,6 @@ public class PlannerContext private final DruidSchemaCatalog rootSchema; private final SqlEngine engine; private final Map queryContext; - private final Set contextKeys; private final String sqlQueryId; private final boolean stringifyArrays; private final CopyOnWriteArrayList nativeQueryIds = new CopyOnWriteArrayList<>(); @@ -124,11 +123,10 @@ private PlannerContext( this.rootSchema = rootSchema; this.engine = engine; this.queryContext = queryContext; - this.contextKeys = contextKeys; this.localNow = Preconditions.checkNotNull(localNow, "localNow"); this.stringifyArrays = stringifyArrays; - String sqlQueryId = (String) this.queryContext.get(CTX_SQL_QUERY_ID); + String sqlQueryId = (String) this.queryContext.get(QueryContexts.CTX_SQL_QUERY_ID); // special handling for DruidViewMacro, normal client will allocate sqlid in SqlLifecyle if (Strings.isNullOrEmpty(sqlQueryId)) { sqlQueryId = UUID.randomUUID().toString(); @@ -152,7 +150,7 @@ public static PlannerContext create( final DateTimeZone timeZone; final boolean stringifyArrays; - final Object stringifyParam = queryContext.get(CTX_SQL_STRINGIFY_ARRAYS); + final Object stringifyParam = queryContext.get(QueryContexts.CTX_SQL_STRINGIFY_ARRAYS); final Object tsParam = queryContext.get(CTX_SQL_CURRENT_TIMESTAMP); final Object tzParam = queryContext.get(CTX_SQL_TIME_ZONE); @@ -243,16 +241,6 @@ public QueryContext queryContext() return QueryContext.of(queryContext); } - /** - * Returns the query context keys set by the user. (Actually, set by - * the request made on behalf of the user, which may include options set by - * intermediary services outside of Druid.) - */ - public Set queryContextKeys() - { - return contextKeys; - } - public boolean isStringifyArrays() { return stringifyArrays; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java index d0606b2f34e4..742d98ac56ea 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.avatica.util.Quoting; @@ -135,7 +136,7 @@ public DruidPlanner createPlannerForTesting(final SqlEngine engine, final String catch (SqlParseException | ValidationException e) { throw new RuntimeException(e); } - thePlanner.authorize(ra -> Access.OK, false); + thePlanner.authorize(ra -> Access.OK, ImmutableSet.of()); return thePlanner; } diff --git a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java index f95e9e9326cb..2311bbba4333 100644 --- a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java @@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.planner.CalciteRulesManager; import org.apache.druid.sql.calcite.planner.DruidOperatorTable; import org.apache.druid.sql.calcite.planner.PlannerConfig; -import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerFactory; import org.apache.druid.sql.calcite.planner.PrepareResult; import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; @@ -496,7 +495,7 @@ public void testIgnoredQueryContextParametersAreIgnored() Map context = stmt.context(); Assert.assertEquals(2, context.size()); // should contain only query id, not bySegment since it is not valid for SQL - Assert.assertTrue(context.containsKey(PlannerContext.CTX_SQL_QUERY_ID)); + Assert.assertTrue(context.containsKey(QueryContexts.CTX_SQL_QUERY_ID)); } @Test diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java index 917190596097..617e5acd58fc 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java @@ -188,18 +188,18 @@ public static void setupNullValues() private static final ImmutableMap.Builder DEFAULT_QUERY_CONTEXT_BUILDER = ImmutableMap.builder() - .put(PlannerContext.CTX_SQL_QUERY_ID, DUMMY_SQL_ID) + .put(QueryContexts.CTX_SQL_QUERY_ID, DUMMY_SQL_ID) .put(PlannerContext.CTX_SQL_CURRENT_TIMESTAMP, "2000-01-01T00:00:00Z") .put(QueryContexts.DEFAULT_TIMEOUT_KEY, QueryContexts.DEFAULT_TIMEOUT_MILLIS) .put(QueryContexts.MAX_SCATTER_GATHER_BYTES_KEY, Long.MAX_VALUE); public static final Map QUERY_CONTEXT_DEFAULT = DEFAULT_QUERY_CONTEXT_BUILDER.build(); public static final Map QUERY_CONTEXT_NO_STRINGIFY_ARRAY = - DEFAULT_QUERY_CONTEXT_BUILDER.put(PlannerContext.CTX_SQL_STRINGIFY_ARRAYS, false) + DEFAULT_QUERY_CONTEXT_BUILDER.put(QueryContexts.CTX_SQL_STRINGIFY_ARRAYS, false) .build(); public static final Map QUERY_CONTEXT_DONT_SKIP_EMPTY_BUCKETS = ImmutableMap.of( - PlannerContext.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, + QueryContexts.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, PlannerContext.CTX_SQL_CURRENT_TIMESTAMP, "2000-01-01T00:00:00Z", TimeseriesQuery.SKIP_EMPTY_BUCKETS, false, QueryContexts.DEFAULT_TIMEOUT_KEY, QueryContexts.DEFAULT_TIMEOUT_MILLIS, @@ -207,7 +207,7 @@ public static void setupNullValues() ); public static final Map QUERY_CONTEXT_DO_SKIP_EMPTY_BUCKETS = ImmutableMap.of( - PlannerContext.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, + QueryContexts.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, PlannerContext.CTX_SQL_CURRENT_TIMESTAMP, "2000-01-01T00:00:00Z", TimeseriesQuery.SKIP_EMPTY_BUCKETS, true, QueryContexts.DEFAULT_TIMEOUT_KEY, QueryContexts.DEFAULT_TIMEOUT_MILLIS, @@ -215,7 +215,7 @@ public static void setupNullValues() ); public static final Map QUERY_CONTEXT_NO_TOPN = ImmutableMap.of( - PlannerContext.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, + QueryContexts.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, PlannerContext.CTX_SQL_CURRENT_TIMESTAMP, "2000-01-01T00:00:00Z", PlannerConfig.CTX_KEY_USE_APPROXIMATE_TOPN, "false", QueryContexts.DEFAULT_TIMEOUT_KEY, QueryContexts.DEFAULT_TIMEOUT_MILLIS, @@ -223,7 +223,7 @@ public static void setupNullValues() ); public static final Map QUERY_CONTEXT_LOS_ANGELES = ImmutableMap.of( - PlannerContext.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, + QueryContexts.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, PlannerContext.CTX_SQL_CURRENT_TIMESTAMP, "2000-01-01T00:00:00Z", PlannerContext.CTX_SQL_TIME_ZONE, LOS_ANGELES, QueryContexts.DEFAULT_TIMEOUT_KEY, QueryContexts.DEFAULT_TIMEOUT_MILLIS, @@ -232,7 +232,7 @@ public static void setupNullValues() // Matches QUERY_CONTEXT_DEFAULT public static final Map TIMESERIES_CONTEXT_BY_GRAN = ImmutableMap.of( - PlannerContext.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, + QueryContexts.CTX_SQL_QUERY_ID, DUMMY_SQL_ID, PlannerContext.CTX_SQL_CURRENT_TIMESTAMP, "2000-01-01T00:00:00Z", TimeseriesQuery.SKIP_EMPTY_BUCKETS, true, QueryContexts.DEFAULT_TIMEOUT_KEY, QueryContexts.DEFAULT_TIMEOUT_MILLIS, @@ -292,7 +292,7 @@ public BaseCalciteQueryTest(@Nullable final SqlEngine engine) } static { - TIMESERIES_CONTEXT_LOS_ANGELES.put(PlannerContext.CTX_SQL_QUERY_ID, DUMMY_SQL_ID); + TIMESERIES_CONTEXT_LOS_ANGELES.put(QueryContexts.CTX_SQL_QUERY_ID, DUMMY_SQL_ID); TIMESERIES_CONTEXT_LOS_ANGELES.put(PlannerContext.CTX_SQL_CURRENT_TIMESTAMP, "2000-01-01T00:00:00Z"); TIMESERIES_CONTEXT_LOS_ANGELES.put(PlannerContext.CTX_SQL_TIME_ZONE, LOS_ANGELES); TIMESERIES_CONTEXT_LOS_ANGELES.put(TimeseriesQuery.SKIP_EMPTY_BUCKETS, true); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteIngestionDmlTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteIngestionDmlTest.java index 72dd87dc9cb1..3d2610af99c4 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteIngestionDmlTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteIngestionDmlTest.java @@ -31,6 +31,7 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContexts; import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; @@ -47,7 +48,6 @@ import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerConfig; -import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.util.CalciteTests; import org.hamcrest.CoreMatchers; import org.hamcrest.Matcher; @@ -65,7 +65,7 @@ public class CalciteIngestionDmlTest extends BaseCalciteQueryTest { protected static final Map DEFAULT_CONTEXT = ImmutableMap.builder() - .put(PlannerContext.CTX_SQL_QUERY_ID, DUMMY_SQL_ID) + .put(QueryContexts.CTX_SQL_QUERY_ID, DUMMY_SQL_ID) .build(); protected static final RowSignature FOO_TABLE_SIGNATURE = diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java index 88d754e0fa2b..10e9002c8b9d 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java @@ -1243,7 +1243,7 @@ public void testCsvResultFormatWithHeaders_nullColumnType() throws Exception public void testExplainCountStar() throws Exception { Map queryContext = ImmutableMap.of( - PlannerContext.CTX_SQL_QUERY_ID, + QueryContexts.CTX_SQL_QUERY_ID, DUMMY_SQL_QUERY_ID, PlannerConfig.CTX_KEY_USE_NATIVE_QUERY_EXPLAIN, "false" @@ -1829,7 +1829,7 @@ private void checkSqlRequestLog(boolean success) Assert.assertEquals(CalciteTests.REGULAR_USER_AUTH_RESULT.getIdentity(), stats.get("identity")); Assert.assertTrue(stats.containsKey("sqlQuery/time")); Assert.assertTrue(stats.containsKey("sqlQuery/planningTimeMs")); - Assert.assertTrue(queryContext.containsKey(PlannerContext.CTX_SQL_QUERY_ID)); + Assert.assertTrue(queryContext.containsKey(QueryContexts.CTX_SQL_QUERY_ID)); if (success) { Assert.assertTrue(stats.containsKey("sqlQuery/bytes")); } else { From 347da8297003c3c786ad1c11dcfb81bc3572f79a Mon Sep 17 00:00:00 2001 From: Paul Rogers Date: Fri, 14 Oct 2022 14:49:36 -0700 Subject: [PATCH 9/9] Test fix and revision from review comment --- .../apache/druid/server/QueryLifecycle.java | 2 +- .../apache/druid/server/security/Access.java | 1 + .../druid/server/security/AuthConfig.java | 2 +- .../druid/server/QueryLifecycleTest.java | 159 ++++++++++++++++-- .../druid/server/security/AuthConfigTest.java | 10 +- .../apache/druid/sql/AbstractStatement.java | 2 +- 6 files changed, 152 insertions(+), 24 deletions(-) diff --git a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java index f56efbdf9f27..68e9496bd14f 100644 --- a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java +++ b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java @@ -228,7 +228,7 @@ public Access authorize(HttpServletRequest req) AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR ), Iterables.transform( - authConfig.filterContextKeys(userContextKeys), + authConfig.contextKeysToAuthorize(userContextKeys), contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE) ) ); diff --git a/server/src/main/java/org/apache/druid/server/security/Access.java b/server/src/main/java/org/apache/druid/server/security/Access.java index 543ce1b0d734..1f5f5f5e0269 100644 --- a/server/src/main/java/org/apache/druid/server/security/Access.java +++ b/server/src/main/java/org/apache/druid/server/security/Access.java @@ -27,6 +27,7 @@ public class Access static final String DEFAULT_ERROR_MESSAGE = "Unauthorized"; public static final Access OK = new Access(true); + public static final Access DENIED = new Access(false); private final boolean allowed; private final String message; diff --git a/server/src/main/java/org/apache/druid/server/security/AuthConfig.java b/server/src/main/java/org/apache/druid/server/security/AuthConfig.java index 695260c3c85a..8bbdac70036e 100644 --- a/server/src/main/java/org/apache/druid/server/security/AuthConfig.java +++ b/server/src/main/java/org/apache/druid/server/security/AuthConfig.java @@ -161,7 +161,7 @@ public boolean authorizeQueryContextParams() * * @return the list of secured keys to check via authentication */ - public Set filterContextKeys(final Set userKeys) + public Set contextKeysToAuthorize(final Set userKeys) { if (!authorizeQueryContextParams) { return ImmutableSet.of(); diff --git a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java index 05efe30e0c28..4d44d5122e9e 100644 --- a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java +++ b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.guava.Sequences; @@ -76,9 +77,6 @@ public class QueryLifecycleTest RequestLogger requestLogger; AuthorizerMapper authzMapper; DefaultQueryConfig queryConfig; - AuthConfig authConfig; - - QueryLifecycle lifecycle; QueryToolChest toolChest; QueryRunner runner; @@ -100,11 +98,18 @@ public void setup() authorizer = EasyMock.createMock(Authorizer.class); authzMapper = new AuthorizerMapper(ImmutableMap.of(AUTHORIZER, authorizer)); queryConfig = EasyMock.createMock(DefaultQueryConfig.class); - authConfig = EasyMock.createMock(AuthConfig.class); + toolChest = EasyMock.createMock(QueryToolChest.class); + runner = EasyMock.createMock(QueryRunner.class); + metrics = EasyMock.createNiceMock(QueryMetrics.class); + authenticationResult = EasyMock.createMock(AuthenticationResult.class); + } + + private QueryLifecycle createLifecycle(AuthConfig authConfig) + { long nanos = System.nanoTime(); long millis = System.currentTimeMillis(); - lifecycle = new QueryLifecycle( + return new QueryLifecycle( toolChestWarehouse, texasRanger, metricsFactory, @@ -116,11 +121,6 @@ public void setup() millis, nanos ); - - toolChest = EasyMock.createMock(QueryToolChest.class); - runner = EasyMock.createMock(QueryRunner.class); - metrics = EasyMock.createNiceMock(QueryMetrics.class); - authenticationResult = EasyMock.createMock(AuthenticationResult.class); } @After @@ -154,9 +154,9 @@ public void testRunSimplePreauthorized() .once(); EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).once(); - replayAll(); + QueryLifecycle lifecycle = createLifecycle(new AuthConfig()); lifecycle.runSimple(query, authenticationResult, Access.OK); } @@ -177,6 +177,7 @@ public void testRunSimpleUnauthorized() replayAll(); + QueryLifecycle lifecycle = createLifecycle(new AuthConfig()); lifecycle.runSimple(query, authenticationResult, new Access(false)); } @@ -184,7 +185,6 @@ public void testRunSimpleUnauthorized() public void testAuthorizeQueryContext_authorized() { EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); - EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes(); EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ)) @@ -208,6 +208,10 @@ public void testAuthorizeQueryContext_authorized() .context(userContext) .build(); + AuthConfig authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .build(); + QueryLifecycle lifecycle = createLifecycle(authConfig); lifecycle.initialize(query); final Map revisedContext = new HashMap<>(lifecycle.getQuery().getContext()); @@ -225,13 +229,12 @@ public void testAuthorizeQueryContext_authorized() public void testAuthorizeQueryContext_notAuthorized() { EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); - EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes(); EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ)) .andReturn(Access.OK); EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE)) - .andReturn(new Access(false)); + .andReturn(Access.DENIED); EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject())) .andReturn(toolChest) @@ -246,6 +249,128 @@ public void testAuthorizeQueryContext_notAuthorized() .context(ImmutableMap.of("foo", "bar")) .build(); + AuthConfig authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .build(); + QueryLifecycle lifecycle = createLifecycle(authConfig); + lifecycle.initialize(query); + Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed()); + } + + @Test + public void testAuthorizeQueryContext_unsecuredKeys() + { + EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); + EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); + EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ)) + .andReturn(Access.OK); + + EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject())) + .andReturn(toolChest) + .once(); + + replayAll(); + + final Map userContext = ImmutableMap.of("foo", "bar", "baz", "qux"); + final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() + .dataSource(DATASOURCE) + .intervals(ImmutableList.of(Intervals.ETERNITY)) + .aggregators(new CountAggregatorFactory("chocula")) + .context(userContext) + .build(); + + AuthConfig authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .setUnsecuredContextKeys(ImmutableSet.of("foo", "baz")) + .build(); + QueryLifecycle lifecycle = createLifecycle(authConfig); + lifecycle.initialize(query); + + final Map revisedContext = new HashMap<>(lifecycle.getQuery().getContext()); + Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId")); + revisedContext.remove("queryId"); + Assert.assertEquals( + userContext, + revisedContext + ); + + Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed()); + } + + @Test + public void testAuthorizeQueryContext_securedKeys() + { + EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); + EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); + EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ)) + .andReturn(Access.OK); + + EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject())) + .andReturn(toolChest) + .once(); + + replayAll(); + + final Map userContext = ImmutableMap.of("foo", "bar", "baz", "qux"); + final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() + .dataSource(DATASOURCE) + .intervals(ImmutableList.of(Intervals.ETERNITY)) + .aggregators(new CountAggregatorFactory("chocula")) + .context(userContext) + .build(); + + AuthConfig authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + // We have secured keys, just not what the user gave. + .setSecuredContextKeys(ImmutableSet.of("foo2", "baz2")) + .build(); + QueryLifecycle lifecycle = createLifecycle(authConfig); + lifecycle.initialize(query); + + final Map revisedContext = new HashMap<>(lifecycle.getQuery().getContext()); + Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId")); + revisedContext.remove("queryId"); + Assert.assertEquals( + userContext, + revisedContext + ); + + Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed()); + } + + @Test + public void testAuthorizeQueryContext_securedKeysNotAuthorized() + { + EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); + EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); + EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ)) + .andReturn(Access.OK); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE)) + .andReturn(Access.DENIED); + + EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject())) + .andReturn(toolChest) + .once(); + + replayAll(); + + final Map userContext = ImmutableMap.of("foo", "bar", "baz", "qux"); + final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() + .dataSource(DATASOURCE) + .intervals(ImmutableList.of(Intervals.ETERNITY)) + .aggregators(new CountAggregatorFactory("chocula")) + .context(userContext) + .build(); + + AuthConfig authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + // We have secured keys. User used one of them. + .setSecuredContextKeys(ImmutableSet.of("foo", "baz2")) + .build(); + QueryLifecycle lifecycle = createLifecycle(authConfig); lifecycle.initialize(query); Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed()); } @@ -254,7 +379,6 @@ public void testAuthorizeQueryContext_notAuthorized() public void testAuthorizeLegacyQueryContext_authorized() { EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); - EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes(); EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("fake", ResourceType.DATASOURCE), Action.READ)) @@ -271,6 +395,10 @@ public void testAuthorizeLegacyQueryContext_authorized() final QueryContextTest.LegacyContextQuery query = new QueryContextTest.LegacyContextQuery(ImmutableMap.of("foo", "bar", "baz", "qux")); + AuthConfig authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .build(); + QueryLifecycle lifecycle = createLifecycle(authConfig); lifecycle.initialize(query); final Map revisedContext = lifecycle.getQuery().getContext(); @@ -304,7 +432,6 @@ private void replayAll() emitter, requestLogger, queryConfig, - authConfig, toolChest, runner, metrics, diff --git a/server/src/test/java/org/apache/druid/server/security/AuthConfigTest.java b/server/src/test/java/org/apache/druid/server/security/AuthConfigTest.java index 5e0c81df23ca..3987c4d9cda0 100644 --- a/server/src/test/java/org/apache/druid/server/security/AuthConfigTest.java +++ b/server/src/test/java/org/apache/druid/server/security/AuthConfigTest.java @@ -44,7 +44,7 @@ public void testContextSecurity() { AuthConfig config = new AuthConfig(); Set keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID); - assertTrue(config.filterContextKeys(keys).isEmpty()); + assertTrue(config.contextKeysToAuthorize(keys).isEmpty()); } // Default security @@ -53,7 +53,7 @@ public void testContextSecurity() .setAuthorizeQueryContextParams(true) .build(); Set keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID); - assertEquals(ImmutableSet.of("a", "b"), config.filterContextKeys(keys)); + assertEquals(ImmutableSet.of("a", "b"), config.contextKeysToAuthorize(keys)); } // Specify unsecured keys (white-list) @@ -63,7 +63,7 @@ public void testContextSecurity() .setUnsecuredContextKeys(ImmutableSet.of("a")) .build(); Set keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID); - assertEquals(ImmutableSet.of("b"), config.filterContextKeys(keys)); + assertEquals(ImmutableSet.of("b"), config.contextKeysToAuthorize(keys)); } // Specify secured keys (black-list) @@ -73,7 +73,7 @@ public void testContextSecurity() .setSecuredContextKeys(ImmutableSet.of("a")) .build(); Set keys = ImmutableSet.of("a", "b", QueryContexts.CTX_SQL_QUERY_ID); - assertEquals(ImmutableSet.of("a"), config.filterContextKeys(keys)); + assertEquals(ImmutableSet.of("a"), config.contextKeysToAuthorize(keys)); } // Specify both @@ -84,7 +84,7 @@ public void testContextSecurity() .setSecuredContextKeys(ImmutableSet.of("b", "c")) .build(); Set keys = ImmutableSet.of("a", "b", "c", "d", QueryContexts.CTX_SQL_QUERY_ID); - assertEquals(ImmutableSet.of("c"), config.filterContextKeys(keys)); + assertEquals(ImmutableSet.of("c"), config.contextKeysToAuthorize(keys)); } } } diff --git a/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java b/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java index 7388303d0312..1956b353b770 100644 --- a/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/AbstractStatement.java @@ -141,7 +141,7 @@ protected void authorize( final Function, Access> authorizer ) { - Set securedKeys = this.sqlToolbox.authConfig.filterContextKeys(queryPlus.context().keySet()); + Set securedKeys = this.sqlToolbox.authConfig.contextKeysToAuthorize(queryPlus.context().keySet()); Set contextResources = new HashSet<>(); securedKeys.forEach(key -> contextResources.add( new ResourceAction(new Resource(key, ResourceType.QUERY_CONTEXT), Action.WRITE)