diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java index f691d2cebecb..6b01d6beb180 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/InPlanningBenchmark.java @@ -33,6 +33,7 @@ import org.apache.druid.math.expr.ExpressionProcessing; import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.AutoTypeColumnSchema; import org.apache.druid.segment.IndexSpec; import org.apache.druid.segment.QueryableIndex; @@ -202,6 +203,7 @@ public void setup() throws JsonProcessingException CalciteTests.createJoinableFactoryWrapper(), CatalogResolver.NULL_RESOLVER, new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBaseBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBaseBenchmark.java index 05a31dda89d3..7ab942ff6f71 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBaseBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBaseBenchmark.java @@ -60,6 +60,7 @@ import org.apache.druid.query.aggregation.datasketches.theta.sql.ThetaSketchEstimateOperatorConversion; import org.apache.druid.query.aggregation.datasketches.tuple.ArrayOfDoublesSketchModule; import org.apache.druid.query.lookup.LookupExtractor; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.IncrementalIndexSegment; import org.apache.druid.segment.IndexSpec; import org.apache.druid.segment.PhysicalSegmentInspector; @@ -452,6 +453,7 @@ public static Pair createSqlSystem( new JoinableFactoryWrapper(QueryFrameworkUtils.createDefaultJoinableFactory(injector)), CatalogResolver.NULL_RESOLVER, new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java index d9017cb4f28c..1900447fbfaf 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java @@ -32,6 +32,7 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.ResultRow; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.generator.GeneratorBasicSchemas; import org.apache.druid.segment.generator.GeneratorSchemaInfo; @@ -129,6 +130,7 @@ public void setup() CalciteTests.createJoinableFactoryWrapper(), CatalogResolver.NULL_RESOLVER, new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); groupByQuery = GroupByQuery diff --git a/extensions-core/multi-stage-query/pom.xml b/extensions-core/multi-stage-query/pom.xml index edc5d90d3b7b..bc26993c072b 100644 --- a/extensions-core/multi-stage-query/pom.xml +++ b/extensions-core/multi-stage-query/pom.xml @@ -334,6 +334,10 @@ ${project.parent.version} test + + com.google.inject.extensions + guice-testlib + diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java index ff7d9fdc4e9f..130a639396e5 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java @@ -32,6 +32,7 @@ import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMergerV9; import org.apache.druid.segment.SegmentWrangler; @@ -78,6 +79,12 @@ public DartFrameContext( this.storageParameters = storageParameters; } + @Override + public PolicyEnforcer policyEnforcer() + { + return workerContext.policyEnforcer(); + } + @Override public SegmentWrangler segmentWrangler() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java index 05cbf210f897..530a4ec3e2e6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java @@ -44,6 +44,7 @@ import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.QueryContext; import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.SegmentWrangler; import org.apache.druid.server.DruidNode; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; @@ -61,6 +62,7 @@ public class DartWorkerContext implements WorkerContext private final WorkerId workerId; private final DruidNode selfNode; private final ObjectMapper jsonMapper; + private final PolicyEnforcer policyEnforcer; private final Injector injector; private final DartWorkerClient workerClient; private final DruidProcessingConfig processingConfig; @@ -84,6 +86,7 @@ public class DartWorkerContext implements WorkerContext final String controllerHost, final DruidNode selfNode, final ObjectMapper jsonMapper, + final PolicyEnforcer policyEnforcer, final Injector injector, final DartWorkerClient workerClient, final DruidProcessingConfig processingConfig, @@ -102,6 +105,7 @@ public class DartWorkerContext implements WorkerContext this.workerId = WorkerId.fromDruidNode(selfNode, queryId); this.selfNode = selfNode; this.jsonMapper = jsonMapper; + this.policyEnforcer = policyEnforcer; this.injector = injector; this.workerClient = workerClient; this.processingConfig = processingConfig; @@ -133,6 +137,12 @@ public ObjectMapper jsonMapper() return jsonMapper; } + @Override + public PolicyEnforcer policyEnforcer() + { + return policyEnforcer; + } + @Override public Injector injector() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java index 1960924b4b67..06b9226bc37e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java @@ -38,6 +38,7 @@ import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.QueryContext; import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.segment.SegmentWrangler; import org.apache.druid.server.DruidNode; @@ -52,6 +53,7 @@ public class DartWorkerFactoryImpl implements DartWorkerFactory private final DruidNode selfNode; private final ObjectMapper jsonMapper; private final ObjectMapper smileMapper; + private final PolicyEnforcer policyEnforcer; private final Injector injector; private final ServiceClientFactory serviceClientFactory; private final DruidProcessingConfig processingConfig; @@ -67,6 +69,7 @@ public DartWorkerFactoryImpl( @Self DruidNode selfNode, @Json ObjectMapper jsonMapper, @Smile ObjectMapper smileMapper, + PolicyEnforcer policyEnforcer, Injector injector, @EscalatedGlobal ServiceClientFactory serviceClientFactory, DruidProcessingConfig processingConfig, @@ -81,6 +84,7 @@ public DartWorkerFactoryImpl( this.selfNode = selfNode; this.jsonMapper = jsonMapper; this.smileMapper = smileMapper; + this.policyEnforcer = policyEnforcer; this.injector = injector; this.serviceClientFactory = serviceClientFactory; this.processingConfig = processingConfig; @@ -100,6 +104,7 @@ public Worker build(String queryId, String controllerHost, File tempDir, QueryCo controllerHost, selfNode, jsonMapper, + policyEnforcer, injector, new DartWorkerClientImpl(queryId, serviceClientFactory, smileMapper, null), processingConfig, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java index 90082fcf0dd0..50cbe781dc7b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java @@ -27,6 +27,7 @@ import org.apache.druid.msq.kernel.FrameProcessorFactory; import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.server.DruidNode; import java.io.File; @@ -51,6 +52,8 @@ public interface WorkerContext ObjectMapper jsonMapper(); + PolicyEnforcer policyEnforcer(); + // Using an Injector directly because tasks do not have a way to provide their own Guice modules. Injector injector(); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java index e8f3739facb4..f7e019c9aca7 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java @@ -31,6 +31,7 @@ import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMergerV9; import org.apache.druid.segment.SegmentWrangler; @@ -73,6 +74,12 @@ public IndexerFrameContext( this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; } + @Override + public PolicyEnforcer policyEnforcer() + { + return context.policyEnforcer(); + } + @Override public SegmentWrangler segmentWrangler() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java index a26eded43221..1175f6ed2071 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java @@ -52,6 +52,7 @@ import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryToolChestWarehouse; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.rpc.ServiceLocations; import org.apache.druid.rpc.ServiceLocator; @@ -191,6 +192,12 @@ public ObjectMapper jsonMapper() return toolbox.getJsonMapper(); } + @Override + public PolicyEnforcer policyEnforcer() + { + return toolbox.getPolicyEnforcer(); + } + @Override public Injector injector() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java index 1b80f72f86f5..e762d1f78407 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java @@ -27,6 +27,7 @@ import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMergerV9; import org.apache.druid.segment.SegmentWrangler; @@ -44,6 +45,8 @@ */ public interface FrameContext extends Closeable { + PolicyEnforcer policyEnforcer(); + SegmentWrangler segmentWrangler(); GroupingEngine groupingEngine(); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java index d35b3cf67221..ef53988c476b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafFrameProcessorFactory.java @@ -162,8 +162,8 @@ public ProcessorsAndChannels makeProcessors( final ProcessorManager processorManager; if (segmentMapFnProcessor == null) { - final Function segmentMapFn = - query.getDataSource().createSegmentMapFunction(query); + final Function segmentMapFn = ExecutionVertex.of(query) + .createSegmentMapFunction(frameContext.policyEnforcer()); processorManager = processorManagerFn.apply(ImmutableList.of(segmentMapFn)); } else { processorManager = new ChainedProcessorManager<>(ProcessorManagers.of(() -> segmentMapFnProcessor), processorManagerFn); @@ -342,7 +342,7 @@ private FrameProcessor> makeSegment if (broadcastInputs.isEmpty()) { if (ExecutionVertex.of(query).isSegmentMapFunctionExpensive()) { // Joins may require significant computation to compute the segmentMapFn. Offload it to a processor. - return new SimpleSegmentMapFnProcessor(query); + return new SimpleSegmentMapFnProcessor(query, frameContext.policyEnforcer()); } else { // Non-joins are expected to have cheap-to-compute segmentMapFn. Do the computation in the factory thread, // without offloading to a processor. @@ -351,6 +351,7 @@ private FrameProcessor> makeSegment } else { return BroadcastJoinSegmentMapFnProcessor.create( query, + frameContext.policyEnforcer(), broadcastInputs, frameContext.memoryParameters().getBroadcastBufferMemory() ); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinSegmentMapFnProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinSegmentMapFnProcessor.java index d3ce49d5f8e0..7c1a6083c6ab 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinSegmentMapFnProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinSegmentMapFnProcessor.java @@ -40,6 +40,8 @@ import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.JoinAlgorithm; import org.apache.druid.query.Query; +import org.apache.druid.query.planning.ExecutionVertex; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.Cursor; import org.apache.druid.segment.SegmentReference; @@ -64,6 +66,7 @@ public class BroadcastJoinSegmentMapFnProcessor implements FrameProcessor> { private final Query query; + private final PolicyEnforcer policyEnforcer; private final Int2IntMap inputNumberToProcessorChannelMap; private final List channels; private final List channelReaders; @@ -87,6 +90,7 @@ public class BroadcastJoinSegmentMapFnProcessor implements FrameProcessor query, + final PolicyEnforcer policyEnforcer, final Int2IntMap inputNumberToProcessorChannelMap, final List channels, final List channelReaders, @@ -94,6 +98,7 @@ public BroadcastJoinSegmentMapFnProcessor( ) { this.query = query; + this.policyEnforcer = policyEnforcer; this.inputNumberToProcessorChannelMap = inputNumberToProcessorChannelMap; this.channels = channels; this.channelReaders = channelReaders; @@ -117,6 +122,7 @@ public BroadcastJoinSegmentMapFnProcessor( */ public static BroadcastJoinSegmentMapFnProcessor create( final Query query, + final PolicyEnforcer policyEnforcer, final Int2ObjectMap sideChannels, final long memoryReservedForBroadcastJoin ) @@ -134,6 +140,7 @@ public static BroadcastJoinSegmentMapFnProcessor create( return new BroadcastJoinSegmentMapFnProcessor( query, + policyEnforcer, inputNumberToProcessorChannelMap, inputChannels, channelReaders, @@ -193,7 +200,8 @@ private void addFrame(final int channelNumber, final Frame frame) private Function createSegmentMapFunction() { - return inlineChannelData(query.getDataSource()).createSegmentMapFunction(query); + DataSource transformed = inlineChannelData(query.getDataSource()); + return ExecutionVertex.of(query.withDataSource(transformed)).createSegmentMapFunction(policyEnforcer); } DataSource inlineChannelData(final DataSource originalDataSource) @@ -230,7 +238,6 @@ DataSource inlineChannelData(final DataSource originalDataSource) * broadcast tables. * * @param readableInputs all readable input channel numbers, including non-side-channels - * * @return whether side channels have been fully read */ boolean buildBroadcastTablesIncrementally(final IntSet readableInputs) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/SimpleSegmentMapFnProcessor.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/SimpleSegmentMapFnProcessor.java index 54264d379c93..54a0065d1994 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/SimpleSegmentMapFnProcessor.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/SimpleSegmentMapFnProcessor.java @@ -26,6 +26,7 @@ import org.apache.druid.frame.processor.ReturnOrAwait; import org.apache.druid.query.Query; import org.apache.druid.query.planning.ExecutionVertex; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.SegmentReference; import java.util.Collections; @@ -44,10 +45,13 @@ public class SimpleSegmentMapFnProcessor implements FrameProcessor> { private final Query query; + private final PolicyEnforcer policyEnforcer; - public SimpleSegmentMapFnProcessor(final Query query) + public SimpleSegmentMapFnProcessor(final Query query, + final PolicyEnforcer policyEnforcer) { this.query = query; + this.policyEnforcer = policyEnforcer; } @Override @@ -66,7 +70,7 @@ public List outputChannels() public ReturnOrAwait> runIncrementally(final IntSet readableInputs) { ExecutionVertex ev = ExecutionVertex.of(query); - return ReturnOrAwait.returnObject(ev.createSegmentMapFunction()); + return ReturnOrAwait.returnObject(ev.createSegmentMapFunction(policyEnforcer)); } @Override diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java index 10b5f20e4187..c0608339bed0 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java @@ -51,6 +51,7 @@ import org.apache.druid.query.DefaultQueryConfig; import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.server.DruidNode; import org.apache.druid.server.QueryStackTests; import org.apache.druid.server.ResponseContextConfig; @@ -215,6 +216,7 @@ public void register(ControllerHolder holder) CalciteTests.createJoinableFactoryWrapper(), CatalogResolver.NULL_RESOLVER, new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java index 72f2c8176223..084bf8e92fff 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java @@ -29,6 +29,7 @@ import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QuerySegmentWalker; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.sql.calcite.planner.CalciteRulesManager; import org.apache.druid.sql.calcite.planner.CatalogResolver; @@ -77,6 +78,7 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) new CalciteRulesManager(ImmutableSet.of()), CalciteTests.TEST_AUTHORIZER_MAPPER, AuthConfig.newBuilder().build(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); final NativeSqlEngine engine = CalciteTests.createMockSqlEngine( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/TestMSQSqlModule.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/TestMSQSqlModule.java index 0b48d2904dd3..b3e7494c11c7 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/TestMSQSqlModule.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/TestMSQSqlModule.java @@ -50,11 +50,13 @@ public SqlStatementFactory makeMSQSqlStatementFactory( @Provides @LazySingleton - public MSQTaskSqlEngine createEngine( - ObjectMapper queryJsonMapper, - MSQTestOverlordServiceClient indexingServiceClient) + public MSQTaskSqlEngine createEngine(ObjectMapper queryJsonMapper, MSQTestOverlordServiceClient indexingServiceClient) { - return new MSQTaskSqlEngine(indexingServiceClient, queryJsonMapper, new SegmentGenerationTerminalStageSpecFactory()); + return new MSQTaskSqlEngine( + indexingServiceClient, + queryJsonMapper, + new SegmentGenerationTerminalStageSpecFactory() + ); } @Provides diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/BroadcastJoinSegmentMapFnProcessorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/BroadcastJoinSegmentMapFnProcessorTest.java index 2f3efa679cd8..41f12b8d8816 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/BroadcastJoinSegmentMapFnProcessorTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/BroadcastJoinSegmentMapFnProcessorTest.java @@ -46,6 +46,7 @@ import org.apache.druid.query.JoinDataSource; import org.apache.druid.query.Query; import org.apache.druid.query.QueryContext; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.CursorFactory; import org.apache.druid.segment.QueryableIndexCursorFactory; import org.apache.druid.segment.TestIndex; @@ -129,6 +130,7 @@ public void testBuildTableAndInlineData() throws IOException final BroadcastJoinSegmentMapFnProcessor broadcastJoinReader = new BroadcastJoinSegmentMapFnProcessor( null /* Query; not used for the methods we're testing */, + NoopPolicyEnforcer.instance(), sideStageChannelNumberMap, channels, channelReaders, @@ -220,6 +222,7 @@ public void testBuildTableMemoryLimit() throws IOException final BroadcastJoinSegmentMapFnProcessor broadcastJoinHelper = new BroadcastJoinSegmentMapFnProcessor( null /* Query; not used for the methods we're testing */, + NoopPolicyEnforcer.instance(), sideStageChannelNumberMap, channels, channelReaders, @@ -271,6 +274,7 @@ public void testBuildTableMemoryLimitWithSortMergeConfigured() throws IOExceptio EasyMock.replay(mockQuery); final BroadcastJoinSegmentMapFnProcessor broadcastJoinHelper = new BroadcastJoinSegmentMapFnProcessor( mockQuery, + NoopPolicyEnforcer.instance(), sideStageChannelNumberMap, channels, channelReaders, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java new file mode 100644 index 000000000000..ddb266889e8f --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java @@ -0,0 +1,607 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.sql; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.inject.Guice; +import com.google.inject.Injector; +import com.google.inject.Module; +import com.google.inject.testing.fieldbinder.Bind; +import com.google.inject.testing.fieldbinder.BoundFieldModule; +import com.google.inject.util.Modules; +import org.apache.calcite.jdbc.JavaTypeFactoryImpl; +import org.apache.druid.collections.ReferenceCountingResourceHolder; +import org.apache.druid.guice.ConfigModule; +import org.apache.druid.guice.DruidGuiceExtensions; +import org.apache.druid.guice.DruidSecondaryModule; +import org.apache.druid.guice.ExpressionModule; +import org.apache.druid.guice.LifecycleModule; +import org.apache.druid.guice.SegmentWranglerModule; +import org.apache.druid.guice.annotations.Json; +import org.apache.druid.jackson.DefaultObjectMapper; +import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.indexing.destination.MSQTerminalStageSpecFactory; +import org.apache.druid.msq.indexing.destination.SegmentGenerationTerminalStageSpecFactory; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.report.MSQTaskReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.msq.test.MSQTestBase; +import org.apache.druid.msq.test.MSQTestOverlordServiceClient; +import org.apache.druid.msq.test.MSQTestTaskActionClient; +import org.apache.druid.query.Druids; +import org.apache.druid.query.ForwardingQueryProcessingPool; +import org.apache.druid.query.JoinDataSource; +import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryDataSource; +import org.apache.druid.query.QueryProcessingPool; +import org.apache.druid.query.RestrictedDataSource; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.UnnestDataSource; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.filter.EqualityFilter; +import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.query.groupby.GroupByQueryConfig; +import org.apache.druid.query.groupby.GroupByQueryRunnerTest; +import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.query.groupby.TestGroupByBuffers; +import org.apache.druid.query.policy.NoopPolicyEnforcer; +import org.apache.druid.query.policy.PolicyEnforcer; +import org.apache.druid.query.policy.RestrictAllTablesPolicyEnforcer; +import org.apache.druid.query.policy.RowFilterPolicy; +import org.apache.druid.query.scan.ScanQuery; +import org.apache.druid.rpc.indexing.OverlordClient; +import org.apache.druid.segment.IndexIO; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ColumnConfig; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.join.JoinConditionAnalysis; +import org.apache.druid.segment.join.JoinType; +import org.apache.druid.segment.join.JoinableFactoryWrapper; +import org.apache.druid.server.QueryResponse; +import org.apache.druid.server.QueryStackTests; +import org.apache.druid.server.SpecificSegmentsQuerySegmentWalker; +import org.apache.druid.server.lookup.cache.LookupLoadingSpec; +import org.apache.druid.server.security.AuthenticationResult; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.DruidQuery; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.apache.druid.sql.calcite.util.LookylooModule; +import org.apache.druid.sql.calcite.util.TestDataBuilder; +import org.apache.druid.sql.destination.IngestDestination; +import org.apache.druid.timeline.SegmentId; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static java.util.stream.Collectors.toMap; +import static org.apache.druid.sql.calcite.BaseCalciteQueryTest.assertResultsEquals; +import static org.apache.druid.sql.calcite.BaseCalciteQueryTest.expressionVirtualColumn; +import static org.apache.druid.sql.calcite.table.RowSignatures.toRelDataType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.when; + +public class MSQTaskQueryMakerTest +{ + private static final Closer CLOSER = Closer.create(); + private static final JavaTypeFactoryImpl JAVA_TYPE_FACTORY = new JavaTypeFactoryImpl(); + + @Rule + public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Bind + private SpecificSegmentsQuerySegmentWalker walker; + @Bind + @Mock + private DataSegmentProvider dataSegmentProviderMock; + @Bind + private ObjectMapper objectMapper; + @Bind + @Json + private ObjectMapper jsonMapper; + @Bind + private IndexIO indexIO; + @Bind + @Nullable + private IngestDestination ingestDestination; + @Bind + private QueryProcessingPool queryProcessingPool; + @Bind + private GroupingEngine groupingEngine; + @Bind + private JoinableFactoryWrapper joinableFactoryWrapper; + @Bind + @Nullable + private DataServerQueryHandlerFactory dataServerQueryHandlerFactory; + @Bind(lazy = true) + private PolicyEnforcer policyEnforcer; // lazy so we can set it in the test + @Bind + private MSQTerminalStageSpecFactory terminalStageSpecFactory; + @Bind + @Mock + private PlannerContext plannerContextMock; + @Bind(lazy = true) + private List> fieldMapping; // lazy so we can set it in the test + @Bind(lazy = true) + private OverlordClient fakeOverlordClient; // lazy since we need to use the injector to create it + + private MSQTaskQueryMaker msqTaskQueryMaker; + + @Before + public void setUp() throws Exception + { + walker = TestDataBuilder.addDataSetsToWalker( + FileUtils.getTempDir().toFile(), + SpecificSegmentsQuerySegmentWalker.createWalker(QueryStackTests.createQueryRunnerFactoryConglomerate(CLOSER)) + ); + when(dataSegmentProviderMock.fetchSegment( + any(), + any(), + anyBoolean() + )).thenAnswer(invocation -> (Supplier) () -> { + SegmentId segmentId = (SegmentId) invocation.getArguments()[0]; + return new ReferenceCountingResourceHolder(walker.getSegment(segmentId), () -> { + // no-op closer, we don't want to close the segment + }); + }); + + objectMapper = TestHelper.makeJsonMapper(); + jsonMapper = new DefaultObjectMapper(); + indexIO = new IndexIO(objectMapper, ColumnConfig.DEFAULT); + queryProcessingPool = new ForwardingQueryProcessingPool(Execs.singleThreaded("Test-runner-processing-pool")); + groupingEngine = GroupByQueryRunnerTest.makeQueryRunnerFactory( + new GroupByQueryConfig(), + TestGroupByBuffers.createDefault() + ).getGroupingEngine(); + joinableFactoryWrapper = CalciteTests.createJoinableFactoryWrapper(); + policyEnforcer = NoopPolicyEnforcer.instance(); + terminalStageSpecFactory = new SegmentGenerationTerminalStageSpecFactory(); + when(plannerContextMock.getLookupLoadingSpec()).thenReturn(LookupLoadingSpec.NONE); + when(plannerContextMock.queryContext()).thenReturn(new QueryContext(ImmutableMap.of())); + when(plannerContextMock.getSql()).thenReturn("stub a sql statement, ignore this value"); + when(plannerContextMock.getJsonMapper()).thenReturn(jsonMapper); + when(plannerContextMock.getAuthenticationResult()).thenReturn(new AuthenticationResult( + "someone", + "ignore", + "ignore", + ImmutableMap.of() + )); + + Module defaultModule = Modules.combine( + new ExpressionModule(), + new DruidGuiceExtensions(), + new LifecycleModule(), + new ConfigModule(), + new SegmentWranglerModule(), + new LookylooModule() + ); + Injector injector = Guice.createInjector(defaultModule, BoundFieldModule.of(this)); + DruidSecondaryModule.setupJackson(injector, objectMapper); + fakeOverlordClient = new MSQTestOverlordServiceClient( + objectMapper, + injector, + new MSQTestTaskActionClient(objectMapper, injector), + MSQTestBase.makeTestWorkerMemoryParameters(), + new ArrayList<>() + ); + } + + @Test + public void testSimpleScanQuery() throws Exception + { + // Arrange + RowSignature resultSignature = RowSignature.builder() + .add("cnt", ColumnType.LONG) + .add("dim1", ColumnType.STRING) + .build(); + fieldMapping = buildFieldMapping(resultSignature); + Query query = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .eternityInterval() + .dataSource(CalciteTests.DATASOURCE1) + .columns(resultSignature.getColumnNames()) + .columnTypes(resultSignature.getColumnTypes()) + .build(); + DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature); + // Act + msqTaskQueryMaker = getMSQTaskQueryMaker(); + QueryResponse response = msqTaskQueryMaker.runQuery(druidQueryMock); + // Assert + String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0]; + MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId) + .get() + .get(MSQTaskReport.REPORT_KEY) + .getPayload(); + Assert.assertTrue(payload.getStatus().getStatus().isSuccess()); + ImmutableList expectedResults = ImmutableList.of( + new Object[]{1L, ""}, + new Object[]{1L, "10.1"}, + new Object[]{1L, "2"}, + new Object[]{1L, "1"}, + new Object[]{1L, "def"}, + new Object[]{1L, "abc"} + ); + assertResultsEquals("select cnt, dim1 from foo", expectedResults, payload.getResults().getResults()); + } + + @Test + public void testScanQueryFailWithPolicyValidation() throws Exception + { + // Arrange + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + RowSignature resultSignature = RowSignature.builder() + .add("cnt", ColumnType.LONG) + .add("dim1", ColumnType.STRING) + .build(); + fieldMapping = buildFieldMapping(resultSignature); + Query query = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .eternityInterval() + .dataSource(CalciteTests.DATASOURCE1) + .columns(resultSignature.getColumnNames()) + .columnTypes(resultSignature.getColumnTypes()) + .build(); + DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature); + // Act + msqTaskQueryMaker = getMSQTaskQueryMaker(); + QueryResponse response = msqTaskQueryMaker.runQuery(druidQueryMock); + // Assert + String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0]; + MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId) + .get() + .get(MSQTaskReport.REPORT_KEY) + .getPayload(); + Assert.assertTrue(payload.getStatus().getStatus().isFailure()); + MSQErrorReport errorReport = payload.getStatus().getErrorReport(); + Assert.assertTrue(errorReport.getFault().getErrorMessage().contains("Failed security validation with segment")); + } + + @Test + public void testScanQueryPassedPolicyValidation() throws Exception + { + // Arrange + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + RowSignature resultSignature = RowSignature.builder() + .add("cnt", ColumnType.LONG) + .add("dim1", ColumnType.STRING) + .build(); + fieldMapping = buildFieldMapping(resultSignature); + RestrictedDataSource restrictedDataSource = RestrictedDataSource.create( + TableDataSource.create(CalciteTests.DATASOURCE1), + RowFilterPolicy.from(new EqualityFilter( + "dim1", + ColumnType.STRING, + "abc", + null + )) + ); + Query query = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .eternityInterval() + .dataSource(restrictedDataSource) + .columns(resultSignature.getColumnNames()) + .columnTypes(resultSignature.getColumnTypes()) + .build(); + DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature); + // Act + msqTaskQueryMaker = getMSQTaskQueryMaker(); + QueryResponse response = msqTaskQueryMaker.runQuery(druidQueryMock); + // Assert + String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0]; + MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId) + .get() + .get(MSQTaskReport.REPORT_KEY) + .getPayload(); + Assert.assertTrue(payload.getStatus().getStatus().isSuccess()); + ImmutableList expectedResults = ImmutableList.of(new Object[]{1L, "abc"}); + assertResultsEquals( + "select cnt, dim1 from foo (with restriction)", + expectedResults, + payload.getResults().getResults() + ); + } + + @Test + public void testUnnestOnRestrictedPassedPolicyValidation() throws Exception + { + // Arrange + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + RowSignature resultSignature = RowSignature.builder() + .add("dim1", ColumnType.STRING) + .add("j0.unnest", ColumnType.STRING) + .build(); + fieldMapping = buildFieldMapping(resultSignature); + UnnestDataSource unnestDataSource = UnnestDataSource.create( + RestrictedDataSource.create( + TableDataSource.create(CalciteTests.DATASOURCE1), + RowFilterPolicy.from(new EqualityFilter( + "dim1", + ColumnType.STRING, + "10.1", + null + )) + ), + expressionVirtualColumn("j0.unnest", "\"dim3\"", ColumnType.STRING), + null + ); + Query query = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .dataSource(unnestDataSource) + .eternityInterval() + .columns(resultSignature.getColumnNames()) + .columnTypes(resultSignature.getColumnTypes()) + .build(); + DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature); + // Act + msqTaskQueryMaker = getMSQTaskQueryMaker(); + QueryResponse response = msqTaskQueryMaker.runQuery(druidQueryMock); + // Assert + String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0]; + MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId) + .get() + .get(MSQTaskReport.REPORT_KEY) + .getPayload(); + Assert.assertTrue(payload.getStatus().getStatus().isSuccess()); + ImmutableList expectedResults = ImmutableList.of(new Object[]{"10.1", "b"}, new Object[]{"10.1", "c"}); + assertResultsEquals( + "SELECT dim1 FROM foo, UNNEST(MV_TO_ARRAY(dim3)) as unnested (d3) (with restriction)", + expectedResults, + payload.getResults().getResults() + ); + } + + @Test + public void testJoinFailWithPolicyValidationOnLeftChild() throws Exception + { + // Arrange + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + RowSignature resultSignature = RowSignature.builder() + .add("dim1", ColumnType.STRING) + .add("j0.a0", ColumnType.LONG) + .build(); + fieldMapping = buildFieldMapping(resultSignature); + RestrictedDataSource restrictedDataSource = RestrictedDataSource.create( + TableDataSource.create(CalciteTests.DATASOURCE1), + RowFilterPolicy.from(new EqualityFilter( + "dim1", + ColumnType.STRING, + "abc", + null + )) + ); + QueryDataSource rightChild = new QueryDataSource(new GroupByQuery.Builder().setInterval(Intervals.ETERNITY) + .setDataSource(restrictedDataSource) + .addAggregator(new CountAggregatorFactory( + "a0")) + .setGranularity(Granularities.ALL) + .build()); + JoinDataSource joinDataSourceLeftChildNoRestriction = JoinDataSource.create( + TableDataSource.create(CalciteTests.DATASOURCE1), + rightChild, + "j0.", + JoinConditionAnalysis.forExpression( + "1", + "j0.", + ExprMacroTable.nil() + ), + JoinType.INNER, + null, + null, + null + ); + Query queryLeftChildNoRestriction = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .eternityInterval() + .dataSource(joinDataSourceLeftChildNoRestriction) + .columns(resultSignature.getColumnNames()) + .columnTypes(resultSignature.getColumnTypes()) + .build(); + DruidQuery druidQueryMock = buildDruidQueryMock(queryLeftChildNoRestriction, resultSignature); + // Act + msqTaskQueryMaker = getMSQTaskQueryMaker(); + QueryResponse response = msqTaskQueryMaker.runQuery(druidQueryMock); + // Assert + String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0]; + MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId) + .get() + .get(MSQTaskReport.REPORT_KEY) + .getPayload(); + Assert.assertTrue(payload.getStatus().getStatus().isFailure()); + MSQErrorReport errorReport = payload.getStatus().getErrorReport(); + Assert.assertTrue(errorReport.getFault().getErrorMessage().contains("Failed security validation with segment")); + } + + @Test + public void testJoinFailWithPolicyValidationOnRightChild() throws Exception + { + // Arrange + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + RowSignature resultSignature = RowSignature.builder() + .add("dim1", ColumnType.STRING) + .add("j0.a0", ColumnType.LONG) + .build(); + fieldMapping = buildFieldMapping(resultSignature); + RestrictedDataSource restrictedDataSource = RestrictedDataSource.create( + TableDataSource.create(CalciteTests.DATASOURCE1), + RowFilterPolicy.from(new EqualityFilter( + "dim1", + ColumnType.STRING, + "abc", + null + )) + ); + QueryDataSource rightChildNoRestriction = new QueryDataSource(new GroupByQuery.Builder().setInterval(Intervals.ETERNITY) + .setDataSource(CalciteTests.DATASOURCE1) + .addAggregator(new CountAggregatorFactory( + "a0")) + .setGranularity( + Granularities.ALL) + .build()); + JoinDataSource joinDataSourceRightChildNoRestriction = JoinDataSource.create( + restrictedDataSource, + rightChildNoRestriction, + "j0.", + JoinConditionAnalysis.forExpression( + "1", + "j0.", + ExprMacroTable.nil() + ), + JoinType.INNER, + null, + null, + null + ); + Query queryRightChildNoRestriction = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .eternityInterval() + .dataSource(joinDataSourceRightChildNoRestriction) + .columns(resultSignature.getColumnNames()) + .columnTypes(resultSignature.getColumnTypes()) + .build(); + DruidQuery druidQueryMock = buildDruidQueryMock(queryRightChildNoRestriction, resultSignature); + // Act + msqTaskQueryMaker = getMSQTaskQueryMaker(); + QueryResponse response = msqTaskQueryMaker.runQuery(druidQueryMock); + // Assert + String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0]; + MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId) + .get() + .get(MSQTaskReport.REPORT_KEY) + .getPayload(); + Assert.assertTrue(payload.getStatus().getStatus().isFailure()); + MSQErrorReport errorReport = payload.getStatus().getErrorReport(); + Assert.assertTrue(errorReport.getFault().getErrorMessage().contains("Failed security validation with segment")); + } + + @Test + public void testJoinPassedPolicyValidation() throws Exception + { + // Arrange + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + RowSignature resultSignature = RowSignature.builder() + .add("dim1", ColumnType.STRING) + .add("j0.a0", ColumnType.LONG) + .build(); + fieldMapping = buildFieldMapping(resultSignature); + RestrictedDataSource restrictedDataSource = RestrictedDataSource.create( + TableDataSource.create(CalciteTests.DATASOURCE1), + RowFilterPolicy.from(new EqualityFilter( + "dim1", + ColumnType.STRING, + "abc", + null + )) + ); + QueryDataSource rightChild = new QueryDataSource(new GroupByQuery.Builder().setInterval(Intervals.ETERNITY) + .setDataSource(restrictedDataSource) + .addAggregator(new CountAggregatorFactory( + "a0")) + .setGranularity(Granularities.ALL) + .build()); + JoinDataSource joinDataSource = JoinDataSource.create( + restrictedDataSource, + rightChild, + "j0.", + JoinConditionAnalysis.forExpression( + "1", + "j0.", + ExprMacroTable.nil() + ), + JoinType.INNER, + null, + null, + null + ); + Query query = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .eternityInterval() + .dataSource(joinDataSource) + .columns(resultSignature.getColumnNames()) + .columnTypes(resultSignature.getColumnTypes()) + .build(); + DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature); + // Act + msqTaskQueryMaker = getMSQTaskQueryMaker(); + QueryResponse response = msqTaskQueryMaker.runQuery(druidQueryMock); + // Assert + String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0]; + MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId) + .get() + .get(MSQTaskReport.REPORT_KEY) + .getPayload(); + Assert.assertTrue(payload.getStatus().getStatus().isSuccess()); + ImmutableList expectedResults = ImmutableList.of(new Object[]{"abc", 1L}); + assertResultsEquals( + "select dim1, q.c from foo, (select count(*) c from foo) q", + expectedResults, + payload.getResults().getResults() + ); + } + + private static DruidQuery buildDruidQueryMock(Query query, RowSignature resultSignature) + { + DruidQuery druidQueryMock = Mockito.mock(DruidQuery.class); + when(druidQueryMock.getQuery()).thenReturn(query); + when(druidQueryMock.getDataSource()).thenReturn(query.getDataSource()); + when(druidQueryMock.getOutputRowSignature()).thenReturn(resultSignature); + when(druidQueryMock.getOutputRowType()).thenReturn(toRelDataType(resultSignature, JAVA_TYPE_FACTORY, false)); + return druidQueryMock; + } + + private static List> buildFieldMapping(RowSignature resultSignature) + { + List columns = resultSignature.getColumnNames(); + return ImmutableList.copyOf(IntStream.range(0, columns.size()) + .boxed() + .collect(toMap(Function.identity(), columns::get)) + .entrySet()); + } + + private MSQTaskQueryMaker getMSQTaskQueryMaker() + { + // This can't be in setUp() because the fieldMapping are set in the test + return new MSQTaskQueryMaker( + ingestDestination, + fakeOverlordClient, + plannerContextMock, + objectMapper, + fieldMapping, + terminalStageSpecFactory + ); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 517b5407cbae..3d73e3f7f7e8 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -131,6 +131,8 @@ import org.apache.druid.query.groupby.GroupByQueryRunnerTest; import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.query.groupby.TestGroupByBuffers; +import org.apache.druid.query.policy.NoopPolicyEnforcer; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.segment.AggregateProjectionMetadata; import org.apache.druid.segment.CompleteSegment; @@ -535,7 +537,8 @@ public String getFormatString() new LookylooModule(), new SegmentWranglerModule(), new HllSketchModule(), - binder -> binder.bind(Bouncer.class).toInstance(new Bouncer(1)) + binder -> binder.bind(Bouncer.class).toInstance(new Bouncer(1)), + binder -> binder.bind(PolicyEnforcer.class).toInstance(NoopPolicyEnforcer.instance()) ); // adding node role injection to the modules, since CliPeon would also do that through run method injector = new CoreInjectorBuilder(new StartupInjectorBuilder().build(), ImmutableSet.of(NodeRole.PEON)) @@ -587,6 +590,7 @@ public String getFormatString() CalciteTests.createJoinableFactoryWrapper(), catalogResolver, new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java index f4b7171a6126..865457d4062b 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java @@ -38,6 +38,7 @@ import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMergerV9; import org.apache.druid.segment.SegmentWrangler; @@ -94,6 +95,12 @@ public ObjectMapper jsonMapper() return mapper; } + @Override + public PolicyEnforcer policyEnforcer() + { + return injector.getInstance(PolicyEnforcer.class); + } + @Override public Injector injector() { @@ -175,6 +182,12 @@ public FrameContextImpl(File tempDir) this.tempDir = tempDir; } + @Override + public PolicyEnforcer policyEnforcer() + { + return MSQTestWorkerContext.this.policyEnforcer(); + } + @Override public SegmentWrangler segmentWrangler() { diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/TaskToolbox.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/TaskToolbox.java index 041ab12e7905..a94a388bb2ed 100644 --- a/indexing-service/src/main/java/org/apache/druid/indexing/common/TaskToolbox.java +++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/TaskToolbox.java @@ -46,6 +46,7 @@ import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMergerV9; @@ -116,6 +117,7 @@ public class TaskToolbox private final Cache cache; private final CacheConfig cacheConfig; private final CachePopulatorStats cachePopulatorStats; + private final PolicyEnforcer policyEnforcer; private final IndexMergerV9 indexMergerV9; private final TaskReportFileWriter taskReportFileWriter; @@ -165,6 +167,7 @@ public TaskToolbox( Cache cache, CacheConfig cacheConfig, CachePopulatorStats cachePopulatorStats, + PolicyEnforcer policyEnforcer, IndexMergerV9 indexMergerV9, DruidNodeAnnouncer druidNodeAnnouncer, DruidNode druidNode, @@ -209,6 +212,7 @@ public TaskToolbox( this.cache = cache; this.cacheConfig = cacheConfig; this.cachePopulatorStats = cachePopulatorStats; + this.policyEnforcer = policyEnforcer; this.indexMergerV9 = Preconditions.checkNotNull(indexMergerV9, "Null IndexMergerV9"); this.druidNodeAnnouncer = druidNodeAnnouncer; this.druidNode = druidNode; @@ -255,6 +259,11 @@ public ServiceEmitter getEmitter() return emitter; } + public PolicyEnforcer getPolicyEnforcer() + { + return policyEnforcer; + } + public DataSegmentPusher getSegmentPusher() { return segmentPusher; @@ -560,6 +569,7 @@ public static class Builder private Cache cache; private CacheConfig cacheConfig; private CachePopulatorStats cachePopulatorStats; + private PolicyEnforcer policyEnforcer; private IndexMergerV9 indexMergerV9; private DruidNodeAnnouncer druidNodeAnnouncer; private DruidNode druidNode; @@ -609,6 +619,7 @@ public Builder(TaskToolbox other) this.cache = other.cache; this.cacheConfig = other.cacheConfig; this.cachePopulatorStats = other.cachePopulatorStats; + this.policyEnforcer = other.policyEnforcer; this.indexMergerV9 = other.indexMergerV9; this.druidNodeAnnouncer = other.druidNodeAnnouncer; this.druidNode = other.druidNode; @@ -657,6 +668,12 @@ public Builder emitter(final ServiceEmitter emitter) return this; } + public Builder policyEnforcer(final PolicyEnforcer policyEnforcer) + { + this.policyEnforcer = policyEnforcer; + return this; + } + public Builder segmentPusher(final DataSegmentPusher segmentPusher) { this.segmentPusher = segmentPusher; @@ -906,6 +923,7 @@ public TaskToolbox build() cache, cacheConfig, cachePopulatorStats, + policyEnforcer, indexMergerV9, druidNodeAnnouncer, druidNode, diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/common/TaskToolboxFactory.java b/indexing-service/src/main/java/org/apache/druid/indexing/common/TaskToolboxFactory.java index 9084c330655c..188b6b0962aa 100644 --- a/indexing-service/src/main/java/org/apache/druid/indexing/common/TaskToolboxFactory.java +++ b/indexing-service/src/main/java/org/apache/druid/indexing/common/TaskToolboxFactory.java @@ -47,6 +47,7 @@ import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.rpc.StandardRetryPolicy; import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.segment.IndexIO; @@ -81,6 +82,7 @@ public class TaskToolboxFactory private final DruidNode taskExecutorNode; private final TaskActionClientFactory taskActionClientFactory; private final ServiceEmitter emitter; + private final PolicyEnforcer policyEnforcer; private final DataSegmentPusher segmentPusher; private final DataSegmentKiller dataSegmentKiller; private final DataSegmentMover dataSegmentMover; @@ -127,6 +129,7 @@ public TaskToolboxFactory( @Parent DruidNode taskExecutorNode, TaskActionClientFactory taskActionClientFactory, ServiceEmitter emitter, + PolicyEnforcer policyEnforcer, DataSegmentPusher segmentPusher, DataSegmentKiller dataSegmentKiller, DataSegmentMover dataSegmentMover, @@ -170,6 +173,7 @@ public TaskToolboxFactory( this.taskExecutorNode = taskExecutorNode; this.taskActionClientFactory = taskActionClientFactory; this.emitter = emitter; + this.policyEnforcer = policyEnforcer; this.segmentPusher = segmentPusher; this.dataSegmentKiller = dataSegmentKiller; this.dataSegmentMover = dataSegmentMover; @@ -227,6 +231,7 @@ public TaskToolbox build(TaskConfig config, Task task) .taskExecutorNode(taskExecutorNode) .taskActionClient(taskActionClientFactory.create(task)) .emitter(emitter) + .policyEnforcer(policyEnforcer) .segmentPusher(segmentPusher) .dataSegmentKiller(dataSegmentKiller) .dataSegmentMover(dataSegmentMover) diff --git a/indexing-service/src/main/java/org/apache/druid/indexing/seekablestream/SeekableStreamIndexTask.java b/indexing-service/src/main/java/org/apache/druid/indexing/seekablestream/SeekableStreamIndexTask.java index 2550408ddc0b..0ea455a5a230 100644 --- a/indexing-service/src/main/java/org/apache/druid/indexing/seekablestream/SeekableStreamIndexTask.java +++ b/indexing-service/src/main/java/org/apache/druid/indexing/seekablestream/SeekableStreamIndexTask.java @@ -210,6 +210,7 @@ public Appenderator newAppenderator( toolbox.getCache(), toolbox.getCacheConfig(), toolbox.getCachePopulatorStats(), + toolbox.getPolicyEnforcer(), rowIngestionMeters, parseExceptionHandler, isUseMaxMemoryEstimates(), diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/TaskToolboxTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/TaskToolboxTest.java index 8f00f7961ff7..9b64ad318e31 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/common/TaskToolboxTest.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/TaskToolboxTest.java @@ -39,6 +39,7 @@ import org.apache.druid.query.DruidProcessingConfigTest; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.rpc.indexing.NoopOverlordClient; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMergerV9; @@ -123,6 +124,7 @@ public void setUp() throws IOException new DruidNode("druid/middlemanager", "localhost", false, 8091, null, true, false), mockTaskActionClientFactory, mockEmitter, + NoopPolicyEnforcer.instance(), mockSegmentPusher, mockDataSegmentKiller, mockDataSegmentMover, diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/TestAppenderatorsManager.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/TestAppenderatorsManager.java index 417195326fb4..56b8675eaabe 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/TestAppenderatorsManager.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/TestAppenderatorsManager.java @@ -29,6 +29,7 @@ import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMerger; import org.apache.druid.segment.incremental.ParseExceptionHandler; @@ -69,6 +70,7 @@ public Appenderator createRealtimeAppenderatorForTask( Cache cache, CacheConfig cacheConfig, CachePopulatorStats cachePopulatorStats, + PolicyEnforcer policyEnforcer, RowIngestionMeters rowIngestionMeters, ParseExceptionHandler parseExceptionHandler, boolean useMaxMemoryEstimates, @@ -92,6 +94,7 @@ public Appenderator createRealtimeAppenderatorForTask( cache, cacheConfig, cachePopulatorStats, + policyEnforcer, rowIngestionMeters, parseExceptionHandler, useMaxMemoryEstimates, diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/AbstractParallelIndexSupervisorTaskTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/AbstractParallelIndexSupervisorTaskTest.java index c599854055f2..6c251ea584ca 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/AbstractParallelIndexSupervisorTaskTest.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/common/task/batch/parallel/AbstractParallelIndexSupervisorTaskTest.java @@ -82,6 +82,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.expression.LookupEnabledTestExprMacroTable; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.rpc.indexing.NoopOverlordClient; import org.apache.druid.segment.DataSegmentsWithSchemas; import org.apache.druid.segment.IndexIO; @@ -688,6 +689,7 @@ public File getStorageDirectory() .indexMergerV9(getIndexMergerV9Factory().create(task.getContextValue(Tasks.STORE_EMPTY_COLUMNS_KEY, true))) .intermediaryDataManager(intermediaryDataManager) .taskReportFileWriter(new SingleFileTaskReportFileWriter(reportsFile)) + .policyEnforcer(NoopPolicyEnforcer.instance()) .authorizerMapper(AuthTestUtils.TEST_AUTHORIZER_MAPPER) .chatHandlerProvider(new NoopChatHandlerProvider()) .rowIngestionMetersFactory(new TestUtils().getRowIngestionMetersFactory()) diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/SingleTaskBackgroundRunnerTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/SingleTaskBackgroundRunnerTest.java index e5be345fb9ef..1b24f3b84c54 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/SingleTaskBackgroundRunnerTest.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/SingleTaskBackgroundRunnerTest.java @@ -43,6 +43,7 @@ import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.Druids; import org.apache.druid.query.QueryRunner; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.query.scan.ScanResultValue; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.rpc.indexing.NoopOverlordClient; @@ -104,6 +105,7 @@ public void setup() throws IOException null, EasyMock.createMock(TaskActionClientFactory.class), emitter, + NoopPolicyEnforcer.instance(), new NoopDataSegmentPusher(), new NoopDataSegmentKiller(), new NoopDataSegmentMover(), diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TaskLifecycleTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TaskLifecycleTest.java index 75bdd55d3951..fe91e1f2edc8 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TaskLifecycleTest.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TaskLifecycleTest.java @@ -111,6 +111,7 @@ import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.rpc.indexing.NoopOverlordClient; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMergerV9Factory; @@ -571,6 +572,7 @@ private TaskToolboxFactory setUpTaskToolboxFactory( new DruidNode("druid/middlemanager", "localhost", false, 8091, null, true, false), tac, emitter, + NoopPolicyEnforcer.instance(), dataSegmentPusher, new LocalDataSegmentKiller(new LocalDataSegmentPusherConfig()), (dataSegment, targetLoadSpec) -> dataSegment, @@ -1245,6 +1247,7 @@ public void testUnifiedAppenderatorsManagerCleanup() throws Exception MapCache.create(2048), new CacheConfig(), new CachePopulatorStats(), + NoopPolicyEnforcer.instance(), MAPPER, new NoopServiceEmitter(), () -> queryRunnerFactoryConglomerate diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TestTaskToolboxFactory.java b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TestTaskToolboxFactory.java index b56da16de0c3..8dd8064b4d56 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TestTaskToolboxFactory.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/TestTaskToolboxFactory.java @@ -44,6 +44,7 @@ import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.rpc.indexing.NoopOverlordClient; import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.segment.IndexIO; @@ -85,6 +86,7 @@ public TestTaskToolboxFactory( bob.taskExecutorNode, bob.taskActionClientFactory, bob.emitter, + bob.policyEnforcer, bob.segmentPusher, bob.dataSegmentKiller, bob.dataSegmentMover, @@ -130,6 +132,7 @@ public static class Builder private DruidNode taskExecutorNode; private TaskActionClientFactory taskActionClientFactory = task -> null; private ServiceEmitter emitter; + private PolicyEnforcer policyEnforcer; private DataSegmentPusher segmentPusher; private DataSegmentKiller dataSegmentKiller; private DataSegmentMover dataSegmentMover; @@ -191,6 +194,12 @@ public Builder setEmitter(ServiceEmitter emitter) return this; } + public Builder setPolicyEnforcer(PolicyEnforcer policyEnforcer) + { + this.policyEnforcer = policyEnforcer; + return this; + } + public Builder setSegmentPusher(DataSegmentPusher segmentPusher) { this.segmentPusher = segmentPusher; diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/seekablestream/SeekableStreamIndexTaskTestBase.java b/indexing-service/src/test/java/org/apache/druid/indexing/seekablestream/SeekableStreamIndexTaskTestBase.java index 8c0bdf56ed59..e0332708c3fe 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/seekablestream/SeekableStreamIndexTaskTestBase.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/seekablestream/SeekableStreamIndexTaskTestBase.java @@ -100,6 +100,7 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.query.timeseries.TimeseriesQuery; import org.apache.druid.query.timeseries.TimeseriesResultValue; import org.apache.druid.rpc.indexing.NoopOverlordClient; @@ -676,6 +677,7 @@ public void close() null, // taskExecutorNode taskActionClientFactory, emitter, + NoopPolicyEnforcer.instance(), dataSegmentPusher, new TestDataSegmentKiller(), null, // DataSegmentMover diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/worker/WorkerTaskManagerTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/worker/WorkerTaskManagerTest.java index 930e2fc7d7c2..445e9469c9d3 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/worker/WorkerTaskManagerTest.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/worker/WorkerTaskManagerTest.java @@ -44,6 +44,7 @@ import org.apache.druid.indexing.overlord.TestTaskRunner; import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.http.client.response.StringFullResponseHolder; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.rpc.HttpResponseException; import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.segment.IndexIO; @@ -134,6 +135,7 @@ private WorkerTaskManager createWorkerTaskManager() null, taskActionClientFactory, null, + NoopPolicyEnforcer.instance(), null, null, null, diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/worker/WorkerTaskMonitorTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/worker/WorkerTaskMonitorTest.java index 8e6577a86e85..6c0462dcf500 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/worker/WorkerTaskMonitorTest.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/worker/WorkerTaskMonitorTest.java @@ -47,6 +47,7 @@ import org.apache.druid.indexing.worker.config.WorkerConfig; import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.rpc.indexing.NoopOverlordClient; import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.segment.IndexIO; @@ -176,6 +177,7 @@ private WorkerTaskMonitor createTaskMonitor() null, taskActionClientFactory, null, + NoopPolicyEnforcer.instance(), null, null, null, diff --git a/integration-tests-ex/tools/src/main/java/org/apache/druid/testing/tools/ServerManagerForQueryErrorTest.java b/integration-tests-ex/tools/src/main/java/org/apache/druid/testing/tools/ServerManagerForQueryErrorTest.java index 91434232b92c..2fc2391cd541 100644 --- a/integration-tests-ex/tools/src/main/java/org/apache/druid/testing/tools/ServerManagerForQueryErrorTest.java +++ b/integration-tests-ex/tools/src/main/java/org/apache/druid/testing/tools/ServerManagerForQueryErrorTest.java @@ -45,6 +45,7 @@ import org.apache.druid.query.ReportTimelineMissingSegmentQueryRunner; import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.SegmentReference; import org.apache.druid.server.SegmentManager; @@ -109,7 +110,8 @@ public ServerManagerForQueryErrorTest( cache, cacheConfig, segmentManager, - serverConfig + serverConfig, + NoopPolicyEnforcer.instance() ); } diff --git a/integration-tests/src/main/java/org/apache/druid/server/coordination/ServerManagerForQueryErrorTest.java b/integration-tests/src/main/java/org/apache/druid/server/coordination/ServerManagerForQueryErrorTest.java index 6e225304a4a3..1282a7d78337 100644 --- a/integration-tests/src/main/java/org/apache/druid/server/coordination/ServerManagerForQueryErrorTest.java +++ b/integration-tests/src/main/java/org/apache/druid/server/coordination/ServerManagerForQueryErrorTest.java @@ -45,6 +45,7 @@ import org.apache.druid.query.ReportTimelineMissingSegmentQueryRunner; import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.SegmentReference; import org.apache.druid.server.SegmentManager; @@ -107,7 +108,8 @@ public ServerManagerForQueryErrorTest( cache, cacheConfig, segmentManager, - serverConfig + serverConfig, + NoopPolicyEnforcer.instance() ); } diff --git a/pom.xml b/pom.xml index 621f31f446ed..68ee72ef2e30 100644 --- a/pom.xml +++ b/pom.xml @@ -1325,6 +1325,12 @@ 1.1.1 test + + com.google.inject.extensions + guice-testlib + ${guice.version} + test + com.google.guava guava-testlib diff --git a/processing/src/main/java/org/apache/druid/guice/JsonConfigurator.java b/processing/src/main/java/org/apache/druid/guice/JsonConfigurator.java index 2d7db35e2acb..940bee2565ee 100644 --- a/processing/src/main/java/org/apache/druid/guice/JsonConfigurator.java +++ b/processing/src/main/java/org/apache/druid/guice/JsonConfigurator.java @@ -20,6 +20,7 @@ package org.apache.druid.guice; import com.fasterxml.jackson.annotation.JacksonInject; +import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.introspect.AnnotatedField; @@ -45,15 +46,19 @@ import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Properties; import java.util.Set; /** + * */ public class JsonConfigurator { @@ -125,8 +130,14 @@ public T configurate( try { if (defaultClass != null && jsonMap.isEmpty()) { // No configs were provided. Don't use the jsonMapper; instead create a default instance of the default class - // using the no-arg constructor. We know it exists because verifyClazzIsConfigurable checks for it. - config = defaultClass.getConstructor().newInstance(); + // using the JsonCreator annotated factory method or no-arg constructor. + // We know it exists because verifyClazzIsConfigurable checks for it. + Optional factoryMethod = findJsonCreatorFactoryMethod(defaultClass); + if (factoryMethod.isPresent()) { + config = (T) factoryMethod.get().invoke(null); + } else { + config = defaultClass.getConstructor().newInstance(); + } } else { config = jsonMapper.convertValue(jsonMap, clazz); } @@ -256,7 +267,9 @@ public static void verifyClazzIsConfigurable( { if (defaultClass != null) { try { - defaultClass.getConstructor(); + if (findJsonCreatorFactoryMethod(defaultClass).isEmpty()) { + defaultClass.getConstructor(); + } } catch (NoSuchMethodException e) { throw new ProvisionException( @@ -283,4 +296,15 @@ public static void verifyClazzIsConfigurable( } } } + + private static Optional findJsonCreatorFactoryMethod(Class clazz) + { + return Arrays.stream(clazz.getMethods()) + .filter(m -> m.getAnnotation(JsonCreator.class) != null + && m.getParameterCount() == 0 + && m.getReturnType() + .equals(clazz)) + .findFirst(); + + } } diff --git a/processing/src/main/java/org/apache/druid/query/DataSource.java b/processing/src/main/java/org/apache/druid/query/DataSource.java index ada01e095468..3e36cf0c1f0a 100644 --- a/processing/src/main/java/org/apache/druid/query/DataSource.java +++ b/processing/src/main/java/org/apache/druid/query/DataSource.java @@ -24,10 +24,10 @@ import org.apache.druid.java.util.common.Cacheable; import org.apache.druid.query.planning.PreJoinableClause; import org.apache.druid.query.policy.Policy; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.SegmentReference; import javax.annotation.Nullable; - import java.util.List; import java.util.Map; import java.util.Optional; @@ -83,9 +83,9 @@ public interface DataSource extends Cacheable /** * Decides if this datasource can be accessed globally. - * + *

* This means that all servers have a full copy of this datasource. - * + *

* Examples: inline table, lookup. */ boolean isGlobal(); @@ -94,7 +94,7 @@ public interface DataSource extends Cacheable * Communicates that this {@link DataSource} can be directly used to run a {@link Query}. *

* A Processable datasource must pack the necessary logic into the {@link DataSource#createSegmentMapFunction(Query)}. - * + *

* Processable examples are: {@link TableDataSource}, {@link InlineDataSource}, {@link FilteredDataSource}. * Non-processable ones are those which need further pre-processing before running them. * examples are: {@link QueryDataSource} and join which are not supported directly. @@ -111,20 +111,22 @@ public interface DataSource extends Cacheable *

* If this datasource contains no table, no changes should occur. * - * @param policyMap a mapping of table names to policy restrictions. A missing key is different from an empty value: - *

    - *
  • a missing key means the table has never been permission checked. - *
  • an empty value indicates the table doesn't have any policy restrictions, it has been permission checked. + * @param policyMap a mapping of table names to policy restrictions. A missing key is different from an empty value: + *
      + *
    • a missing key means the table has never been permission checked. + *
    • an empty value indicates the table doesn't have any policy restrictions, it has been permission checked. + *
    + * @param policyEnforcer the policy enforcer to enforce the result datasource complies * @return the updated datasource, with restrictions applied in the datasource tree * @throws IllegalStateException when mapping a RestrictedDataSource, unless the table has a NoRestrictionPolicy in * the policyMap (used by druid-internal). Missing policy or adding a * non-NoRestrictionPolicy to RestrictedDataSource would throw. */ - default DataSource withPolicies(Map> policyMap) + default DataSource withPolicies(Map> policyMap, PolicyEnforcer policyEnforcer) { List children = this.getChildren() .stream() - .map(child -> child.withPolicies(policyMap)) + .map(child -> child.withPolicies(policyMap, policyEnforcer)) .collect(Collectors.toList()); return this.withChildren(children); } diff --git a/processing/src/main/java/org/apache/druid/query/RestrictedDataSource.java b/processing/src/main/java/org/apache/druid/query/RestrictedDataSource.java index bdcb013f2de5..74f6b3ed87f7 100644 --- a/processing/src/main/java/org/apache/druid/query/RestrictedDataSource.java +++ b/processing/src/main/java/org/apache/druid/query/RestrictedDataSource.java @@ -26,6 +26,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.query.policy.NoRestrictionPolicy; import org.apache.druid.query.policy.Policy; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.RestrictedSegment; import org.apache.druid.segment.SegmentReference; @@ -129,14 +130,14 @@ public Function createSegmentMapFunction(Que } @Override - public DataSource withPolicies(Map> policyMap) + public DataSource withPolicies(Map> policyMap, PolicyEnforcer policyEnforcer) { if (!policyMap.containsKey(base.getName())) { throw new ISE("Missing policy check result for table [%s]", base.getName()); } Optional newPolicy = policyMap.getOrDefault(base.getName(), Optional.empty()); - if (!newPolicy.isPresent()) { + if (newPolicy.isEmpty()) { throw new ISE( "No restriction found on table [%s], but had policy [%s] before.", base.getName(), @@ -157,6 +158,7 @@ public DataSource withPolicies(Map> policyMap) } // The only happy path is, newPolicy is NoRestrictionPolicy, which means this comes from an anthenticated and // authorized druid-internal request. + policyEnforcer.validateOrElseThrow(base, policy); return this; } diff --git a/processing/src/main/java/org/apache/druid/query/TableDataSource.java b/processing/src/main/java/org/apache/druid/query/TableDataSource.java index de67917c43b7..e0aa1332e5f2 100644 --- a/processing/src/main/java/org/apache/druid/query/TableDataSource.java +++ b/processing/src/main/java/org/apache/druid/query/TableDataSource.java @@ -25,6 +25,7 @@ import com.google.common.base.Preconditions; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.policy.Policy; +import org.apache.druid.query.policy.PolicyEnforcer; import java.util.Collections; import java.util.Map; @@ -80,14 +81,11 @@ public boolean isProcessable() } @Override - public DataSource withPolicies(Map> policyMap) + public DataSource withPolicies(Map> policyMap, PolicyEnforcer policyEnforcer) { Optional policy = policyMap.getOrDefault(name, Optional.empty()); - if (!policy.isPresent()) { - // Skip adding restriction on table if there's no policy restriction found. - return this; - } - return RestrictedDataSource.create(this, policy.get()); + policyEnforcer.validateOrElseThrow(this, policy.orElse(null)); + return policy.isEmpty() ? this : RestrictedDataSource.create(this, policy.get()); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/planning/ExecutionVertex.java b/processing/src/main/java/org/apache/druid/query/planning/ExecutionVertex.java index 5a4df92bbd72..b8d750ef1666 100644 --- a/processing/src/main/java/org/apache/druid/query/planning/ExecutionVertex.java +++ b/processing/src/main/java/org/apache/druid/query/planning/ExecutionVertex.java @@ -31,6 +31,7 @@ import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnionDataSource; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.query.spec.QuerySegmentSpec; @@ -392,10 +393,13 @@ public int hashCode() /** * Assembles the segment mapping function which should be applied to the input segments. */ - public Function createSegmentMapFunction() + public Function createSegmentMapFunction(PolicyEnforcer policyEnforcer) { DataSource topDataSource = getTopDataSource(); - return topDataSource.createSegmentMapFunction(topQuery); + return topDataSource.createSegmentMapFunction(topQuery).andThen(segmentReference -> { + segmentReference.validateOrElseThrow(policyEnforcer); + return segmentReference; + }); } /** diff --git a/processing/src/main/java/org/apache/druid/query/policy/NoopPolicyEnforcer.java b/processing/src/main/java/org/apache/druid/query/policy/NoopPolicyEnforcer.java new file mode 100644 index 000000000000..b62b505a24d7 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/policy/NoopPolicyEnforcer.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.policy; + +import com.fasterxml.jackson.annotation.JsonCreator; + +/** + * Allows all data sources (no restrictions). + */ +public class NoopPolicyEnforcer implements PolicyEnforcer +{ + NoopPolicyEnforcer() + { + } + + @Override + public boolean validate(Policy unused) + { + return true; + } + + @JsonCreator + public static NoopPolicyEnforcer instance() + { + return new NoopPolicyEnforcer(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + return true; + } + + @Override + public int hashCode() + { + return 0; + } +} diff --git a/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java b/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java new file mode 100644 index 000000000000..4b34b96c23b1 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.policy; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.druid.error.DruidException; +import org.apache.druid.guice.annotations.UnstableApi; +import org.apache.druid.query.DataSource; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.segment.ReferenceCountingSegment; +import org.apache.druid.segment.SegmentReference; + +/** + * Interface for enforcing policies on data sources and segments in Druid queries. + *

    + * Note: The {@code PolicyEnforcer} is intended to serve as a sanity checker and not as a primary authorization mechanism. + * It should not be used to implement security rules. Instead, it acts as a last line of defense to verify that + * security policies have been implemented correctly and to prevent incorrect policy usage. + *

    + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = NoopPolicyEnforcer.class, name = "none"), + @JsonSubTypes.Type(value = RestrictAllTablesPolicyEnforcer.class, name = "restrictAllTables"), +}) +@UnstableApi +public interface PolicyEnforcer +{ + /** + * Validates a {@link DataSource} against the policy enforcer. Prior to query execution, the {@link org.apache.druid.query.Query#getDataSource()} + * tree is walked. This method is invoked once for each {@link org.apache.druid.query.RestrictedDataSource} and once + * for each {@link TableDataSource} that is not wrapped inside a {@link org.apache.druid.query.RestrictedDataSource}, + * no matter where they appear within the tree. + * + * @param ds the table to validate. + * @param policy the policy attached to the table; either {@link org.apache.druid.query.RestrictedDataSource#policy} or null. + * @throws DruidException if the data source does not comply with the policy + */ + default void validateOrElseThrow(TableDataSource ds, Policy policy) throws DruidException + { + if (validate(policy)) { + return; + } + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.FORBIDDEN) + .build("Failed security validation with dataSource [%s]", ds); + } + + /** + * Validates a {@link SegmentReference} against the policy enforcer. Prior to query execution, the {@link SegmentReference} tree is walked. + * This method is invoked once for each {@link org.apache.druid.segment.RestrictedSegment} and once for each {@link ReferenceCountingSegment} + * that is not wrapped inside a {@link org.apache.druid.segment.RestrictedSegment}. + *

    + * Direct invocation of this method is discouraged; use {@link SegmentReference#validateOrElseThrow(PolicyEnforcer)} instead. + * + * @param segment the segment to validate + * @param policy the policy on the segment, {@link org.apache.druid.segment.RestrictedSegment#policy} or null for other + * @throws DruidException if the segment does not comply with the policy + */ + default void validateOrElseThrow(ReferenceCountingSegment segment, Policy policy) throws DruidException + { + // Validation will always fail on lookups, external, and inline segments, because they will not have policies applied (except for NoopPolicyEnforcer). + // This is a temporary solution since we don't have a perfect way to identify segments that are backed by a regular table yet. + if (validate(policy)) { + return; + } + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.FORBIDDEN) + .build("Failed security validation with segment [%s]", segment.getId()); + } + + /** + * Returns true if the policy complies with the policy enforcer. + */ + boolean validate(Policy policy); +} diff --git a/processing/src/main/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcer.java b/processing/src/main/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcer.java new file mode 100644 index 000000000000..6e9cd628cb54 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcer.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.policy; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; + +/** + * All tables must be restricted by a policy. + */ +public class RestrictAllTablesPolicyEnforcer implements PolicyEnforcer +{ + private final ImmutableList allowedPolicies; + + @JsonCreator + public RestrictAllTablesPolicyEnforcer( + @Nullable @JsonProperty("allowedPolicies") List allowedPolicies + ) + { + if (allowedPolicies == null) { + this.allowedPolicies = ImmutableList.of(); + } else { + this.allowedPolicies = ImmutableList.copyOf(allowedPolicies); + } + } + + @Override + public boolean validate(Policy policy) + { + return policy != null && (allowedPolicies.isEmpty() || allowedPolicies.contains(policy.getClass().getName())); + } + + @JsonProperty + public ImmutableList getAllowedPolicies() + { + return allowedPolicies; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RestrictAllTablesPolicyEnforcer that = (RestrictAllTablesPolicyEnforcer) o; + return allowedPolicies.equals(that.allowedPolicies); + } + + @Override + public int hashCode() + { + return Objects.hash(allowedPolicies); + } + + @Override + public String toString() + { + return "RestrictAllTablesPolicyEnforcer{" + + "allowedPolicies=" + allowedPolicies + + '}'; + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/ReferenceCountingSegment.java b/processing/src/main/java/org/apache/druid/segment/ReferenceCountingSegment.java index e2969f10d00e..8957ec723d74 100644 --- a/processing/src/main/java/org/apache/druid/segment/ReferenceCountingSegment.java +++ b/processing/src/main/java/org/apache/druid/segment/ReferenceCountingSegment.java @@ -20,6 +20,7 @@ package org.apache.druid.segment; import com.google.common.base.Preconditions; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.timeline.Overshadowable; import org.apache.druid.timeline.SegmentId; import org.apache.druid.timeline.partition.ShardSpec; @@ -33,7 +34,7 @@ * {@link Segment} that is also a {@link ReferenceCountingSegment}, allowing query engines that operate directly on * segments to track references so that dropping a {@link Segment} can be done safely to ensure there are no in-flight * queries. - * + *

    * Extensions can extend this class for populating {@link org.apache.druid.timeline.VersionedIntervalTimeline} with * a custom implementation through SegmentLoader. */ @@ -176,6 +177,12 @@ public Optional acquireReferences() return incrementReferenceAndDecrementOnceCloseable(); } + @Override + public void validateOrElseThrow(PolicyEnforcer policyEnforcer) + { + policyEnforcer.validateOrElseThrow(this, null); + } + @Override public T as(Class clazz) { diff --git a/processing/src/main/java/org/apache/druid/segment/RestrictedSegment.java b/processing/src/main/java/org/apache/druid/segment/RestrictedSegment.java index c4a7cb4828e8..f3f5563d7c36 100644 --- a/processing/src/main/java/org/apache/druid/segment/RestrictedSegment.java +++ b/processing/src/main/java/org/apache/druid/segment/RestrictedSegment.java @@ -19,8 +19,10 @@ package org.apache.druid.segment; +import com.google.common.base.Preconditions; import org.apache.druid.query.policy.NoRestrictionPolicy; import org.apache.druid.query.policy.Policy; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.timeline.SegmentId; import org.joda.time.Interval; @@ -45,11 +47,13 @@ public class RestrictedSegment implements SegmentReference protected final SegmentReference delegate; protected final Policy policy; - public RestrictedSegment( - SegmentReference delegate, - Policy policy - ) + public RestrictedSegment(SegmentReference delegate, Policy policy) { + // This is a sanity check, a restricted data source should alway wrap a druid table directly. + Preconditions.checkArgument( + delegate instanceof ReferenceCountingSegment, + "delegate must be a ReferenceCountingSegment" + ); this.delegate = delegate; this.policy = policy; } @@ -109,6 +113,12 @@ public T as(@Nonnull Class clazz) return null; } + @Override + public void validateOrElseThrow(PolicyEnforcer policyEnforcer) + { + policyEnforcer.validateOrElseThrow((ReferenceCountingSegment) delegate, policy); + } + @Override public boolean isTombstone() { diff --git a/processing/src/main/java/org/apache/druid/segment/Segment.java b/processing/src/main/java/org/apache/druid/segment/Segment.java index 14cd1a4da443..1da4b47f89cd 100644 --- a/processing/src/main/java/org/apache/druid/segment/Segment.java +++ b/processing/src/main/java/org/apache/druid/segment/Segment.java @@ -102,7 +102,6 @@ default boolean isTombstone() return false; } - default String asString() { return getClass().toString(); diff --git a/processing/src/main/java/org/apache/druid/segment/SegmentReference.java b/processing/src/main/java/org/apache/druid/segment/SegmentReference.java index fae2a7b36f65..fba3e3158734 100644 --- a/processing/src/main/java/org/apache/druid/segment/SegmentReference.java +++ b/processing/src/main/java/org/apache/druid/segment/SegmentReference.java @@ -19,12 +19,31 @@ package org.apache.druid.segment; +import org.apache.druid.query.policy.NoopPolicyEnforcer; +import org.apache.druid.query.policy.PolicyEnforcer; + /** - * A {@link Segment} with a associated references, such as {@link ReferenceCountingSegment} where the reference is + * A {@link Segment} with an associated references, such as {@link ReferenceCountingSegment} where the reference is * the segment itself, and {@link org.apache.druid.segment.join.HashJoinSegment} which wraps a * {@link ReferenceCountingSegment} and also includes the associated list of * {@link org.apache.druid.segment.join.JoinableClause} */ public interface SegmentReference extends Segment, ReferenceCountedObject { + + /** + * Validates if the segment complies with the policy restrictions on tables. + *

    + * This should be called right before the segment is about to be processed by the query stack, and after + * {@link org.apache.druid.query.planning.ExecutionVertex#createSegmentMapFunction(PolicyEnforcer)}. + */ + default void validateOrElseThrow(PolicyEnforcer policyEnforcer) + { + // For testing purposes, we allow the NoopPolicyEnforcer to pass through. + if (policyEnforcer instanceof NoopPolicyEnforcer) { + return; + } + throw new UnsupportedOperationException("validateOrElseThrow is not supported"); + } + } diff --git a/processing/src/main/java/org/apache/druid/segment/WrappedSegmentReference.java b/processing/src/main/java/org/apache/druid/segment/WrappedSegmentReference.java index 97bb1afd4b77..b9c77b07cad9 100644 --- a/processing/src/main/java/org/apache/druid/segment/WrappedSegmentReference.java +++ b/processing/src/main/java/org/apache/druid/segment/WrappedSegmentReference.java @@ -19,6 +19,7 @@ package org.apache.druid.segment; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.timeline.SegmentId; import org.joda.time.Interval; @@ -79,6 +80,12 @@ public T as(@Nonnull Class clazz) } } + @Override + public void validateOrElseThrow(PolicyEnforcer policyEnforcer) + { + delegate.validateOrElseThrow(policyEnforcer); + } + @Override public boolean isTombstone() { diff --git a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java index b6a8da3fbaed..3f90e7dddf1b 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java +++ b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java @@ -23,6 +23,7 @@ import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.filter.Filter; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.query.rowsandcols.CursorFactoryRowsAndColumns; import org.apache.druid.segment.CloseableShapeshifter; import org.apache.druid.segment.CursorFactory; @@ -136,6 +137,12 @@ public T as(Class clazz) } } + @Override + public void validateOrElseThrow(PolicyEnforcer policyEnforcer) + { + baseSegment.validateOrElseThrow(policyEnforcer); + } + @Override public void close() throws IOException { diff --git a/processing/src/test/java/org/apache/druid/query/DataSourceTest.java b/processing/src/test/java/org/apache/druid/query/DataSourceTest.java index 25733b340f16..b87b12fc9bca 100644 --- a/processing/src/test/java/org/apache/druid/query/DataSourceTest.java +++ b/processing/src/test/java/org/apache/druid/query/DataSourceTest.java @@ -20,18 +20,28 @@ package org.apache.druid.query; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.policy.NoRestrictionPolicy; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.query.policy.Policy; +import org.apache.druid.query.policy.PolicyEnforcer; +import org.apache.druid.query.policy.RestrictAllTablesPolicyEnforcer; import org.apache.druid.query.policy.RowFilterPolicy; import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.join.JoinType; +import org.apache.druid.segment.join.JoinableFactoryWrapper; +import org.apache.druid.segment.join.NoopJoinableFactory; import org.junit.Assert; import org.junit.Test; @@ -122,26 +132,26 @@ public void testUnionDataSource() throws Exception } @Test - public void testMapWithRestriction() + public void testWithPolicies_onUnionDataSource() { TableDataSource table1 = TableDataSource.create("table1"); TableDataSource table2 = TableDataSource.create("table2"); - TableDataSource table3 = TableDataSource.create("table3"); - UnionDataSource unionDataSource = new UnionDataSource(Lists.newArrayList(table1, table2, table3)); + InlineDataSource inlineDataSource = InlineDataSource.fromIterable(ImmutableList.of(), RowSignature.empty()); + + UnionDataSource unionDataSource = new UnionDataSource(Lists.newArrayList(table1, table2, inlineDataSource)); ImmutableMap> restrictions = ImmutableMap.of( "table1", Optional.of(NoRestrictionPolicy.instance()), "table2", - Optional.of(NoRestrictionPolicy.instance()), - "table3", Optional.of(RowFilterPolicy.from(new NullFilter( "some-column", null ))) ); + PolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); Assert.assertEquals( - unionDataSource.withPolicies(restrictions), + unionDataSource.withPolicies(restrictions, policyEnforcer), new UnionDataSource(Lists.newArrayList( RestrictedDataSource.create( table1, @@ -149,21 +159,49 @@ public void testMapWithRestriction() ), RestrictedDataSource.create( table2, - NoRestrictionPolicy.instance() - ), - RestrictedDataSource.create( - table3, RowFilterPolicy.from(new NullFilter( "some-column", null )) - ) + ), + inlineDataSource )) ); } @Test - public void testMapWithRestriction_onRestrictedDataSource_fromDruidSystem() + public void testWithPolicies_onUnionDataSource_throwsOnValidation() + { + TableDataSource table1 = TableDataSource.create("table1"); + TableDataSource table2 = TableDataSource.create("table2"); + InlineDataSource inlineDataSource = InlineDataSource.fromIterable(ImmutableList.of(), RowSignature.empty()); + + UnionDataSource unionDataSource = new UnionDataSource(Lists.newArrayList(table1, table2, inlineDataSource)); + ImmutableMap> restrictions = ImmutableMap.of( + "table1", + Optional.of(NoRestrictionPolicy.instance()), + "table2", + Optional.of(RowFilterPolicy.from(new NullFilter( + "some-column", + null + ))) + ); + PolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(ImmutableList.of(NoRestrictionPolicy.class.getName())); + + DruidException e = Assert.assertThrows( + DruidException.class, + () -> unionDataSource.withPolicies(restrictions, policyEnforcer) + ); + Assert.assertEquals(DruidException.Category.FORBIDDEN, e.getCategory()); + Assert.assertEquals(DruidException.Persona.OPERATOR, e.getTargetPersona()); + Assert.assertEquals( + "Failed security validation with dataSource [table2]", + e.getMessage() + ); + } + + @Test + public void testWithPolicies_onRestrictedDataSource_fromDruidSystem() { RestrictedDataSource restrictedDataSource = RestrictedDataSource.create( TableDataSource.create("table1"), @@ -174,12 +212,13 @@ public void testMapWithRestriction_onRestrictedDataSource_fromDruidSystem() "table1", Optional.of(NoRestrictionPolicy.instance()) ); + PolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); - Assert.assertEquals(restrictedDataSource, restrictedDataSource.withPolicies(noRestrictionPolicy)); + Assert.assertEquals(restrictedDataSource, restrictedDataSource.withPolicies(noRestrictionPolicy, policyEnforcer)); } @Test - public void testMapWithRestriction_onRestrictedDataSource_samePolicy() + public void testWithPolicies_onRestrictedDataSource_samePolicy() { RestrictedDataSource restrictedDataSource = RestrictedDataSource.create( TableDataSource.create("table1"), @@ -189,12 +228,13 @@ public void testMapWithRestriction_onRestrictedDataSource_samePolicy() "table1", Optional.of(RowFilterPolicy.from(new NullFilter("some-column", null))) ); + PolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); - Assert.assertEquals(restrictedDataSource, restrictedDataSource.withPolicies(policyMap)); + Assert.assertEquals(restrictedDataSource, restrictedDataSource.withPolicies(policyMap, policyEnforcer)); } @Test - public void testMapWithRestriction_onRestrictedDataSource_alwaysThrows() + public void testWithPolicies_onRestrictedDataSource_alwaysThrows() { RestrictedDataSource restrictedDataSource = RestrictedDataSource.create( TableDataSource.create("table1"), @@ -206,19 +246,89 @@ public void testMapWithRestriction_onRestrictedDataSource_alwaysThrows() ); ImmutableMap> noPolicyFound = ImmutableMap.of("table1", Optional.empty()); ImmutableMap> policyWasNotChecked = ImmutableMap.of(); + PolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); - ISE e = Assert.assertThrows(ISE.class, () -> restrictedDataSource.withPolicies(anotherRestrictions)); + ISE e = Assert.assertThrows( + ISE.class, + () -> restrictedDataSource.withPolicies(anotherRestrictions, policyEnforcer) + ); Assert.assertEquals( "Different restrictions on table [table1]: previous policy [RowFilterPolicy{rowFilter=random-column IS NULL}] and new policy [RowFilterPolicy{rowFilter=some-column IS NULL}]", e.getMessage() ); - ISE e2 = Assert.assertThrows(ISE.class, () -> restrictedDataSource.withPolicies(noPolicyFound)); + ISE e2 = Assert.assertThrows(ISE.class, () -> restrictedDataSource.withPolicies(noPolicyFound, policyEnforcer)); Assert.assertEquals( "No restriction found on table [table1], but had policy [RowFilterPolicy{rowFilter=random-column IS NULL}] before.", e2.getMessage() ); - ISE e3 = Assert.assertThrows(ISE.class, () -> restrictedDataSource.withPolicies(policyWasNotChecked)); + ISE e3 = Assert.assertThrows( + ISE.class, + () -> restrictedDataSource.withPolicies(policyWasNotChecked, policyEnforcer) + ); Assert.assertEquals("Missing policy check result for table [table1]", e3.getMessage()); } + + @Test + public void testWithPolicies_onInlineDataSource() + { + InlineDataSource inlineDataSource = InlineDataSource.fromIterable(ImmutableList.of(), RowSignature.empty()); + DataSource withPolicies = inlineDataSource.withPolicies(ImmutableMap.of(), NoopPolicyEnforcer.instance()); + Assert.assertEquals(inlineDataSource, withPolicies); + } + + @Test + public void testWithPolicies_onJoinDataSource() + { + JoinableFactoryWrapper joinableFactoryWrapper = new JoinableFactoryWrapper(NoopJoinableFactory.INSTANCE); + JoinDataSource joinDataSource = JoinDataSource.create( + new TableDataSource("table1"), + new TableDataSource("table2"), + "j.", + "x == \"j.x\"", + JoinType.LEFT, + null, + ExprMacroTable.nil(), + joinableFactoryWrapper, + JoinAlgorithm.BROADCAST + ); + final PolicyEnforcer policyEnforcer = NoopPolicyEnforcer.instance(); + DataSource mapped = joinDataSource.withPolicies(ImmutableMap.of(), policyEnforcer); + Assert.assertEquals(joinDataSource, mapped); + + // Use stricter enforcer + final PolicyEnforcer restrictAllTablesPolicyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + // Fail, policy must exist on both tables + Assert.assertThrows( + DruidException.class, + () -> joinDataSource.withPolicies(ImmutableMap.of(), restrictAllTablesPolicyEnforcer) + ); + Assert.assertThrows( + DruidException.class, + () -> joinDataSource.withPolicies(ImmutableMap.of( + "table1", + Optional.of(NoRestrictionPolicy.instance()) + ), restrictAllTablesPolicyEnforcer) + ); + + DataSource mapped2 = joinDataSource.withPolicies(ImmutableMap.of( + "table1", + Optional.of(NoRestrictionPolicy.instance()), + "table2", + Optional.of(NoRestrictionPolicy.instance()) + ), restrictAllTablesPolicyEnforcer); + + JoinDataSource expected = JoinDataSource.create( + RestrictedDataSource.create(new TableDataSource("table1"), NoRestrictionPolicy.instance()), + RestrictedDataSource.create(new TableDataSource("table2"), NoRestrictionPolicy.instance()), + "j.", + "x == \"j.x\"", + JoinType.LEFT, + null, + ExprMacroTable.nil(), + joinableFactoryWrapper, + JoinAlgorithm.BROADCAST + ); + Assert.assertEquals(expected, mapped2); + } } diff --git a/processing/src/test/java/org/apache/druid/query/QueryRunnerTestHelper.java b/processing/src/test/java/org/apache/druid/query/QueryRunnerTestHelper.java index dde0928ea8ac..80cd8a55e463 100644 --- a/processing/src/test/java/org/apache/druid/query/QueryRunnerTestHelper.java +++ b/processing/src/test/java/org/apache/druid/query/QueryRunnerTestHelper.java @@ -54,6 +54,7 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.query.planning.ExecutionVertex; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.query.spec.SpecificSegmentSpec; @@ -558,7 +559,7 @@ public static > QueryRunner makeQueryRunnerWith ) { ExecutionVertex ev = ExecutionVertex.of(query); - final SegmentReference segmentReference = ev.createSegmentMapFunction() + final SegmentReference segmentReference = ev.createSegmentMapFunction(NoopPolicyEnforcer.instance()) .apply(ReferenceCountingSegment.wrapRootGenerationSegment(adapter)); return makeQueryRunner(factory, segmentReference, runnerName); } diff --git a/processing/src/test/java/org/apache/druid/query/policy/NoopPolicyEnforcerTest.java b/processing/src/test/java/org/apache/druid/query/policy/NoopPolicyEnforcerTest.java new file mode 100644 index 000000000000..8c19977440c0 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/policy/NoopPolicyEnforcerTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.policy; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.query.RestrictedDataSource; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.filter.NullFilter; +import org.apache.druid.segment.ReferenceCountingSegment; +import org.apache.druid.segment.Segment; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.TestSegmentUtils; +import org.junit.Assert; +import org.junit.Test; + +public class NoopPolicyEnforcerTest +{ + @Test + public void test_serialize() throws Exception + { + NoopPolicyEnforcer policyEnforcer = NoopPolicyEnforcer.instance(); + ObjectMapper jsonMapper = TestHelper.makeJsonMapper(); + + String expected = "{\"type\":\"none\"}"; + Assert.assertEquals(expected, jsonMapper.writeValueAsString(policyEnforcer)); + } + + @Test + public void test_serde_roundTrip() throws Exception + { + NoopPolicyEnforcer policyEnforcer = NoopPolicyEnforcer.instance(); + ObjectMapper jsonMapper = TestHelper.makeJsonMapper(); + PolicyEnforcer deserialized = jsonMapper.readValue( + jsonMapper.writeValueAsString(policyEnforcer), + PolicyEnforcer.class + ); + Assert.assertEquals(policyEnforcer, deserialized); + } + + @Test + public void test_validate() throws Exception + { + NoopPolicyEnforcer policyEnforcer = NoopPolicyEnforcer.instance(); + RowFilterPolicy policy = RowFilterPolicy.from(new NullFilter("some-col", null)); + TableDataSource table = TableDataSource.create("table"); + RestrictedDataSource restricted = RestrictedDataSource.create(table, policy); + + Assert.assertTrue(policyEnforcer.validate(null)); + Assert.assertTrue(policyEnforcer.validate(policy)); + policyEnforcer.validateOrElseThrow(table, null); + policyEnforcer.validateOrElseThrow(restricted.getBase(), restricted.getPolicy()); + + Segment baseSegment = new TestSegmentUtils.SegmentForTesting("table", Intervals.ETERNITY, "1"); + ReferenceCountingSegment segment = ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment); + Assert.assertTrue(policyEnforcer.validate(null)); + Assert.assertTrue(policyEnforcer.validate(policy)); + policyEnforcer.validateOrElseThrow(segment, null); + policyEnforcer.validateOrElseThrow(segment, policy); + + } + +} diff --git a/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java b/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java new file mode 100644 index 000000000000..86c597d6e83c --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.policy; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.query.RestrictedDataSource; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.filter.NullFilter; +import org.apache.druid.segment.ReferenceCountingSegment; +import org.apache.druid.segment.Segment; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.TestSegmentUtils.SegmentForTesting; +import org.junit.Assert; +import org.junit.Test; + +public class RestrictAllTablesPolicyEnforcerTest +{ + @Test + public void test_serialize() throws Exception + { + RestrictAllTablesPolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(ImmutableList.of( + NoRestrictionPolicy.class.getName())); + ObjectMapper jsonMapper = TestHelper.makeJsonMapper(); + + String expected = "{\"type\":\"restrictAllTables\",\"allowedPolicies\":[\"org.apache.druid.query.policy.NoRestrictionPolicy\"]}"; + Assert.assertEquals(expected, jsonMapper.writeValueAsString(policyEnforcer)); + } + + @Test + public void test_serde_roundTrip() throws Exception + { + RestrictAllTablesPolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + ObjectMapper jsonMapper = TestHelper.makeJsonMapper(); + PolicyEnforcer deserialized = jsonMapper.readValue( + jsonMapper.writeValueAsString(policyEnforcer), + PolicyEnforcer.class + ); + Assert.assertEquals(policyEnforcer, deserialized); + } + + @Test + public void test_validate() throws Exception + { + final RestrictAllTablesPolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + RowFilterPolicy policy = RowFilterPolicy.from(new NullFilter("some-col", null)); + TableDataSource table = TableDataSource.create("table"); + RestrictedDataSource restricted = RestrictedDataSource.create(table, policy); + // Test validate data source, fail for TableDataSource, success for RestrictedDataSource + Assert.assertFalse(policyEnforcer.validate(null)); + Assert.assertTrue(policyEnforcer.validate(policy)); + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> policyEnforcer.validateOrElseThrow(table, null) + ); + Assert.assertEquals(DruidException.Category.FORBIDDEN, e.getCategory()); + Assert.assertEquals(DruidException.Persona.OPERATOR, e.getTargetPersona()); + Assert.assertEquals( + "Failed security validation with dataSource [table]", + e.getMessage() + ); + policyEnforcer.validateOrElseThrow(restricted.getBase(), restricted.getPolicy()); + // Test validate segment, fail for ReferenceCountingSegment not wrapped with any policy, success when wrapped with a policy + Segment baseSegment = new SegmentForTesting("table", Intervals.ETERNITY, "1"); + ReferenceCountingSegment segment = ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment); + Assert.assertFalse(policyEnforcer.validate(null)); + Assert.assertTrue(policyEnforcer.validate(policy)); + final DruidException e2 = Assert.assertThrows( + DruidException.class, + () -> policyEnforcer.validateOrElseThrow(segment, null) + ); + Assert.assertEquals(DruidException.Category.FORBIDDEN, e2.getCategory()); + Assert.assertEquals(DruidException.Persona.OPERATOR, e2.getTargetPersona()); + Assert.assertEquals( + "Failed security validation with segment [table_-146136543-09-08T08:23:32.096Z_146140482-04-24T15:36:27.903Z_1]", + e2.getMessage() + ); + policyEnforcer.validateOrElseThrow(segment, policy); + } + + @Test + public void test_validate_withAllowedPolicies() throws Exception + { + RestrictAllTablesPolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(ImmutableList.of( + NoRestrictionPolicy.class.getName())); + RowFilterPolicy policy = RowFilterPolicy.from(new NullFilter("some-col", null)); + TableDataSource table = TableDataSource.create("table"); + RestrictedDataSource restricted = RestrictedDataSource.create(table, policy); + + Assert.assertThrows(DruidException.class, () -> policyEnforcer.validateOrElseThrow(table, null)); + Assert.assertThrows( + DruidException.class, + () -> policyEnforcer.validateOrElseThrow(restricted.getBase(), restricted.getPolicy()) + ); + policyEnforcer.validateOrElseThrow(table, NoRestrictionPolicy.instance()); + + Segment baseSegment = new SegmentForTesting("table", Intervals.ETERNITY, "1"); + ReferenceCountingSegment segment = ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment); + Assert.assertThrows(DruidException.class, () -> policyEnforcer.validateOrElseThrow(segment, null)); + Assert.assertThrows(DruidException.class, () -> policyEnforcer.validateOrElseThrow(segment, policy)); + policyEnforcer.validateOrElseThrow(segment, NoRestrictionPolicy.instance()); + } +} diff --git a/server/src/test/java/org/apache/druid/server/TestSegmentUtils.java b/processing/src/test/java/org/apache/druid/segment/TestSegmentUtils.java similarity index 84% rename from server/src/test/java/org/apache/druid/server/TestSegmentUtils.java rename to processing/src/test/java/org/apache/druid/segment/TestSegmentUtils.java index 7cc61d940945..2d20be1f7f57 100644 --- a/server/src/test/java/org/apache/druid/server/TestSegmentUtils.java +++ b/processing/src/test/java/org/apache/druid/segment/TestSegmentUtils.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.druid.server; +package org.apache.druid.segment; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -28,15 +28,6 @@ import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.query.OrderBy; -import org.apache.druid.segment.CursorFactory; -import org.apache.druid.segment.Cursors; -import org.apache.druid.segment.DimensionHandler; -import org.apache.druid.segment.IndexIO; -import org.apache.druid.segment.Metadata; -import org.apache.druid.segment.QueryableIndex; -import org.apache.druid.segment.QueryableIndexCursorFactory; -import org.apache.druid.segment.Segment; -import org.apache.druid.segment.SegmentLazyLoadFailCallback; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.data.Indexed; import org.apache.druid.segment.loading.LoadSpec; @@ -45,9 +36,11 @@ import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.SegmentId; import org.apache.druid.timeline.partition.NoneShardSpec; +import org.apache.druid.timeline.partition.TombstoneShardSpec; import org.joda.time.Interval; import org.junit.Assert; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.io.File; import java.io.IOException; @@ -57,7 +50,9 @@ import java.util.Map; import java.util.concurrent.ThreadLocalRandom; - +/** + * Test utility class for creating test segments and load specs. + */ public class TestSegmentUtils { @JsonTypeName("test") @@ -126,14 +121,14 @@ public Segment factorize( } } - public static class SegmentForTesting implements Segment + public static class SegmentForTesting extends QueryableIndexSegment implements Segment { private final String datasource; private final String version; private final Interval interval; private final Object lock = new Object(); private volatile boolean closed = false; - private final QueryableIndex index = new QueryableIndex() + private static final QueryableIndex INDEX = new QueryableIndex() { @Override public Interval getDataInterval() @@ -200,9 +195,10 @@ public ColumnHolder getColumnHolder(String columnName) public SegmentForTesting(String datasource, Interval interval, String version) { + super(INDEX, SegmentId.of(datasource, interval, version, 0)); this.datasource = datasource; - this.version = version; this.interval = interval; + this.version = version; } public String getVersion() @@ -235,13 +231,24 @@ public Interval getDataInterval() @Override public QueryableIndex asQueryableIndex() { - return index; + return INDEX; } @Override public CursorFactory asCursorFactory() { - return new QueryableIndexCursorFactory(index); + return new QueryableIndexCursorFactory(INDEX); + } + + @Override + public T as(@Nonnull Class clazz) + { + if (clazz.equals(QueryableIndex.class)) { + return (T) asQueryableIndex(); + } else if (clazz.equals(CursorFactory.class)) { + return (T) asCursorFactory(); + } + return null; } @Override @@ -253,6 +260,25 @@ public void close() } } + public static DataSegment makeTombstoneSegment(String dataSource, String version, Interval interval) + { + return new DataSegment( + dataSource, + interval, + version, + ImmutableMap.of("version", version, + "interval", interval, + "type", + DataSegment.TOMBSTONE_LOADSPEC_TYPE + ), + Arrays.asList("dim1", "dim2", "dim3"), + Arrays.asList("metric1", "metric2"), + TombstoneShardSpec.INSTANCE, + IndexIO.CURRENT_VERSION_ID, + 1L + ); + } + public static DataSegment makeSegment(String dataSource, String version, Interval interval) { return new DataSegment( diff --git a/server/pom.xml b/server/pom.xml index 76b851fd4ad6..5de233b37059 100644 --- a/server/pom.xml +++ b/server/pom.xml @@ -434,6 +434,11 @@ system-rules test + + com.google.inject.extensions + guice-testlib + test + com.sun.jersey jersey-grizzly2 diff --git a/server/src/main/java/org/apache/druid/guice/security/PolicyModule.java b/server/src/main/java/org/apache/druid/guice/security/PolicyModule.java new file mode 100644 index 000000000000..9e6919c253b8 --- /dev/null +++ b/server/src/main/java/org/apache/druid/guice/security/PolicyModule.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.guice.security; + +import com.google.inject.Binder; +import org.apache.druid.guice.JsonConfigProvider; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.query.policy.NoopPolicyEnforcer; +import org.apache.druid.query.policy.PolicyEnforcer; + +/** + * Module for configuring the policy enforcer. + */ +public class PolicyModule implements DruidModule +{ + @Override + public void configure(Binder binder) + { + JsonConfigProvider.bindWithDefault(binder, "druid.policy.enforcer", PolicyEnforcer.class, NoopPolicyEnforcer.class); + } +} diff --git a/server/src/main/java/org/apache/druid/initialization/CoreInjectorBuilder.java b/server/src/main/java/org/apache/druid/initialization/CoreInjectorBuilder.java index 6ca4261e0140..caef8be64b15 100644 --- a/server/src/main/java/org/apache/druid/initialization/CoreInjectorBuilder.java +++ b/server/src/main/java/org/apache/druid/initialization/CoreInjectorBuilder.java @@ -47,6 +47,7 @@ import org.apache.druid.guice.security.AuthorizerModule; import org.apache.druid.guice.security.DruidAuthModule; import org.apache.druid.guice.security.EscalatorModule; +import org.apache.druid.guice.security.PolicyModule; import org.apache.druid.metadata.storage.derby.DerbyMetadataStorageDruidModule; import org.apache.druid.rpc.guice.ServiceClientModule; import org.apache.druid.segment.writeout.SegmentWriteOutMediumModule; @@ -100,6 +101,7 @@ public CoreInjectorBuilder forServer() add( ExtensionsModule.SecondaryModule.class, new DruidAuthModule(), + new PolicyModule(), TLSCertificateCheckerModule.class, EmitterModule.class, HttpClientModule.global(), diff --git a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/Appenderators.java b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/Appenderators.java index 28b4379f7b98..09442e7226b6 100644 --- a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/Appenderators.java +++ b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/Appenderators.java @@ -27,6 +27,7 @@ import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMerger; import org.apache.druid.segment.incremental.ParseExceptionHandler; @@ -58,6 +59,7 @@ public static Appenderator createRealtime( Cache cache, CacheConfig cacheConfig, CachePopulatorStats cachePopulatorStats, + PolicyEnforcer policyEnforcer, RowIngestionMeters rowIngestionMeters, ParseExceptionHandler parseExceptionHandler, boolean useMaxMemoryEstimates, @@ -84,7 +86,8 @@ public static Appenderator createRealtime( queryProcessingPool, Preconditions.checkNotNull(cache, "cache"), cacheConfig, - cachePopulatorStats + cachePopulatorStats, + policyEnforcer ), indexIO, indexMerger, diff --git a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/AppenderatorsManager.java b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/AppenderatorsManager.java index ec328c3b3cd4..8b565e8193b5 100644 --- a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/AppenderatorsManager.java +++ b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/AppenderatorsManager.java @@ -29,6 +29,7 @@ import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMerger; import org.apache.druid.segment.incremental.ParseExceptionHandler; @@ -84,6 +85,7 @@ Appenderator createRealtimeAppenderatorForTask( Cache cache, CacheConfig cacheConfig, CachePopulatorStats cachePopulatorStats, + PolicyEnforcer policyEnforcer, RowIngestionMeters rowIngestionMeters, ParseExceptionHandler parseExceptionHandler, boolean useMaxMemoryEstimates, diff --git a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/DummyForInjectionAppenderatorsManager.java b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/DummyForInjectionAppenderatorsManager.java index d613f3ff59ce..ff716ce42dc1 100644 --- a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/DummyForInjectionAppenderatorsManager.java +++ b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/DummyForInjectionAppenderatorsManager.java @@ -30,6 +30,7 @@ import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMerger; import org.apache.druid.segment.incremental.ParseExceptionHandler; @@ -74,6 +75,7 @@ public Appenderator createRealtimeAppenderatorForTask( Cache cache, CacheConfig cacheConfig, CachePopulatorStats cachePopulatorStats, + PolicyEnforcer policyEnforcer, RowIngestionMeters rowIngestionMeters, ParseExceptionHandler parseExceptionHandler, boolean useMaxMemoryEstimates, diff --git a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/PeonAppenderatorsManager.java b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/PeonAppenderatorsManager.java index 998f674daf7c..d1dd208e2a9e 100644 --- a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/PeonAppenderatorsManager.java +++ b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/PeonAppenderatorsManager.java @@ -30,6 +30,7 @@ import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMerger; import org.apache.druid.segment.incremental.ParseExceptionHandler; @@ -80,6 +81,7 @@ public Appenderator createRealtimeAppenderatorForTask( Cache cache, CacheConfig cacheConfig, CachePopulatorStats cachePopulatorStats, + PolicyEnforcer policyEnforcer, RowIngestionMeters rowIngestionMeters, ParseExceptionHandler parseExceptionHandler, boolean useMaxMemoryEstimates, @@ -108,6 +110,7 @@ public Appenderator createRealtimeAppenderatorForTask( cache, cacheConfig, cachePopulatorStats, + policyEnforcer, rowIngestionMeters, parseExceptionHandler, useMaxMemoryEstimates, diff --git a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java index 69e9a6840d19..37355c4b33de 100644 --- a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java +++ b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/SinkQuerySegmentWalker.java @@ -61,6 +61,7 @@ import org.apache.druid.query.SinkQueryRunners; import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.planning.ExecutionVertex; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.query.spec.SpecificSegmentQueryRunner; import org.apache.druid.query.spec.SpecificSegmentSpec; import org.apache.druid.segment.SegmentReference; @@ -125,6 +126,7 @@ public class SinkQuerySegmentWalker implements QuerySegmentWalker private final Cache cache; private final CacheConfig cacheConfig; private final CachePopulatorStats cachePopulatorStats; + private final PolicyEnforcer policyEnforcer; public SinkQuerySegmentWalker( String dataSource, @@ -135,7 +137,8 @@ public SinkQuerySegmentWalker( QueryProcessingPool queryProcessingPool, Cache cache, CacheConfig cacheConfig, - CachePopulatorStats cachePopulatorStats + CachePopulatorStats cachePopulatorStats, + PolicyEnforcer policyEnforcer ) { this.dataSource = Preconditions.checkNotNull(dataSource, "dataSource"); @@ -147,6 +150,7 @@ public SinkQuerySegmentWalker( this.cache = Preconditions.checkNotNull(cache, "cache"); this.cacheConfig = Preconditions.checkNotNull(cacheConfig, "cacheConfig"); this.cachePopulatorStats = Preconditions.checkNotNull(cachePopulatorStats, "cachePopulatorStats"); + this.policyEnforcer = policyEnforcer; if (!cache.isLocal()) { log.warn("Configured cache[%s] is not local, caching will not be enabled.", cache.getClass().getName()); @@ -204,7 +208,7 @@ public QueryRunner getQueryRunnerForSegments(final Query query, final // segmentMapFn maps each base Segment into a joined Segment if necessary. final Function segmentMapFn = JvmUtils.safeAccumulateThreadCpuTime( cpuTimeAccumulator, - () -> ev.createSegmentMapFunction() + () -> ev.createSegmentMapFunction(policyEnforcer) ); // We compute the join cache key here itself so it doesn't need to be re-computed for every segment diff --git a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManager.java b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManager.java index 71157cea8d9e..3b8084ad8647 100644 --- a/server/src/main/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManager.java +++ b/server/src/main/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManager.java @@ -46,6 +46,7 @@ import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.planning.ExecutionVertex; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.BaseProgressIndicator; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMerger; @@ -113,6 +114,7 @@ public class UnifiedIndexerAppenderatorsManager implements AppenderatorsManager private final Cache cache; private final CacheConfig cacheConfig; private final CachePopulatorStats cachePopulatorStats; + private final PolicyEnforcer policyEnforcer; private final ObjectMapper objectMapper; private final ServiceEmitter serviceEmitter; private final Provider queryRunnerFactoryConglomerateProvider; @@ -127,6 +129,7 @@ public UnifiedIndexerAppenderatorsManager( Cache cache, CacheConfig cacheConfig, CachePopulatorStats cachePopulatorStats, + PolicyEnforcer policyEnforcer, ObjectMapper objectMapper, ServiceEmitter serviceEmitter, Provider queryRunnerFactoryConglomerateProvider @@ -138,6 +141,7 @@ public UnifiedIndexerAppenderatorsManager( this.cache = cache; this.cacheConfig = cacheConfig; this.cachePopulatorStats = cachePopulatorStats; + this.policyEnforcer = policyEnforcer; this.objectMapper = objectMapper; this.serviceEmitter = serviceEmitter; this.queryRunnerFactoryConglomerateProvider = queryRunnerFactoryConglomerateProvider; @@ -166,6 +170,7 @@ public Appenderator createRealtimeAppenderatorForTask( Cache cache, CacheConfig cacheConfig, CachePopulatorStats cachePopulatorStats, + PolicyEnforcer policyEnforcer, RowIngestionMeters rowIngestionMeters, ParseExceptionHandler parseExceptionHandler, boolean useMaxMemoryEstimates, @@ -356,7 +361,8 @@ public DatasourceBundle( queryProcessingPool, Preconditions.checkNotNull(cache, "cache"), cacheConfig, - cachePopulatorStats + cachePopulatorStats, + policyEnforcer ); } diff --git a/server/src/main/java/org/apache/druid/server/LocalQuerySegmentWalker.java b/server/src/main/java/org/apache/druid/server/LocalQuerySegmentWalker.java index 7e9266c35ffe..5f149c8b3d82 100644 --- a/server/src/main/java/org/apache/druid/server/LocalQuerySegmentWalker.java +++ b/server/src/main/java/org/apache/druid/server/LocalQuerySegmentWalker.java @@ -34,6 +34,7 @@ import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.query.planning.ExecutionVertex; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.SegmentReference; import org.apache.druid.segment.SegmentWrangler; @@ -59,6 +60,7 @@ public class LocalQuerySegmentWalker implements QuerySegmentWalker private final SegmentWrangler segmentWrangler; private final JoinableFactoryWrapper joinableFactoryWrapper; private final QueryScheduler scheduler; + private final PolicyEnforcer policyEnforcer; private final ServiceEmitter emitter; @Inject @@ -67,6 +69,7 @@ public LocalQuerySegmentWalker( SegmentWrangler segmentWrangler, JoinableFactoryWrapper joinableFactoryWrapper, QueryScheduler scheduler, + PolicyEnforcer policyEnforcer, ServiceEmitter emitter ) { @@ -74,6 +77,7 @@ public LocalQuerySegmentWalker( this.segmentWrangler = segmentWrangler; this.joinableFactoryWrapper = joinableFactoryWrapper; this.scheduler = scheduler; + this.policyEnforcer = policyEnforcer; this.emitter = emitter; } @@ -94,7 +98,7 @@ public QueryRunner getQueryRunnerForIntervals(final Query query, final final AtomicLong cpuAccumulator = new AtomicLong(0L); - final Function segmentMapFn = ev.createSegmentMapFunction(); + final Function segmentMapFn = ev.createSegmentMapFunction(policyEnforcer); final QueryRunnerFactory> queryRunnerFactory = conglomerate.findFactory(query); final QueryRunner baseRunner = queryRunnerFactory.mergeRunners( diff --git a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java index 5be2944b0211..5b52f3a03532 100644 --- a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java +++ b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java @@ -48,6 +48,7 @@ import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.context.ResponseContext; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.server.QueryResource.ResourceIOReaderWriter; import org.apache.druid.server.log.RequestLogger; import org.apache.druid.server.security.Action; @@ -98,6 +99,7 @@ public class QueryLifecycle private final AuthorizerMapper authorizerMapper; private final DefaultQueryConfig defaultQueryConfig; private final AuthConfig authConfig; + private final PolicyEnforcer policyEnforcer; private final long startMs; private final long startNs; @@ -119,6 +121,7 @@ public QueryLifecycle( final AuthorizerMapper authorizerMapper, final DefaultQueryConfig defaultQueryConfig, final AuthConfig authConfig, + final PolicyEnforcer policyEnforcer, final long startMs, final long startNs ) @@ -131,6 +134,7 @@ public QueryLifecycle( this.authorizerMapper = authorizerMapper; this.defaultQueryConfig = defaultQueryConfig; this.authConfig = authConfig; + this.policyEnforcer = policyEnforcer; this.startMs = startMs; this.startNs = startNs; } @@ -320,8 +324,11 @@ private AuthorizationResult doAuthorize( transition(State.AUTHORIZING, State.UNAUTHORIZED); } else { transition(State.AUTHORIZING, State.AUTHORIZED); - this.baseQuery = this.baseQuery.withDataSource(this.baseQuery.getDataSource() - .withPolicies(authorizationResult.getPolicyMap())); + this.baseQuery = this.baseQuery.withDataSource(baseQuery.getDataSource() + .withPolicies( + authorizationResult.getPolicyMap(), + policyEnforcer + )); } this.authenticationResult = authenticationResult; diff --git a/server/src/main/java/org/apache/druid/server/QueryLifecycleFactory.java b/server/src/main/java/org/apache/druid/server/QueryLifecycleFactory.java index 1dc32348bc26..73a49eca5930 100644 --- a/server/src/main/java/org/apache/druid/server/QueryLifecycleFactory.java +++ b/server/src/main/java/org/apache/druid/server/QueryLifecycleFactory.java @@ -27,6 +27,7 @@ import org.apache.druid.query.GenericQueryMetricsFactory; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QuerySegmentWalker; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.server.log.RequestLogger; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.AuthorizerMapper; @@ -42,6 +43,7 @@ public class QueryLifecycleFactory private final AuthorizerMapper authorizerMapper; private final DefaultQueryConfig defaultQueryConfig; private final AuthConfig authConfig; + private final PolicyEnforcer policyEnforcer; @Inject public QueryLifecycleFactory( @@ -51,6 +53,7 @@ public QueryLifecycleFactory( final ServiceEmitter emitter, final RequestLogger requestLogger, final AuthConfig authConfig, + final PolicyEnforcer policyEnforcer, final AuthorizerMapper authorizerMapper, final Supplier queryConfigSupplier ) @@ -63,6 +66,7 @@ public QueryLifecycleFactory( this.authorizerMapper = authorizerMapper; this.defaultQueryConfig = queryConfigSupplier.get(); this.authConfig = authConfig; + this.policyEnforcer = policyEnforcer; } public QueryLifecycle factorize() @@ -76,6 +80,7 @@ public QueryLifecycle factorize() authorizerMapper, defaultQueryConfig, authConfig, + policyEnforcer, System.currentTimeMillis(), System.nanoTime() ); diff --git a/server/src/main/java/org/apache/druid/server/coordination/ServerManager.java b/server/src/main/java/org/apache/druid/server/coordination/ServerManager.java index 2ce38d44984a..953bee305e89 100644 --- a/server/src/main/java/org/apache/druid/server/coordination/ServerManager.java +++ b/server/src/main/java/org/apache/druid/server/coordination/ServerManager.java @@ -54,6 +54,7 @@ import org.apache.druid.query.ReportTimelineMissingSegmentQueryRunner; import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.query.planning.ExecutionVertex; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.query.spec.SpecificSegmentQueryRunner; import org.apache.druid.query.spec.SpecificSegmentSpec; import org.apache.druid.segment.ReferenceCountingSegment; @@ -91,6 +92,7 @@ public class ServerManager implements QuerySegmentWalker private final CacheConfig cacheConfig; private final SegmentManager segmentManager; private final ServerConfig serverConfig; + private final PolicyEnforcer policyEnforcer; @Inject public ServerManager( @@ -102,7 +104,8 @@ public ServerManager( Cache cache, CacheConfig cacheConfig, SegmentManager segmentManager, - ServerConfig serverConfig + ServerConfig serverConfig, + PolicyEnforcer policyEnforcer ) { this.conglomerate = conglomerate; @@ -116,6 +119,7 @@ public ServerManager( this.cacheConfig = cacheConfig; this.segmentManager = segmentManager; this.serverConfig = serverConfig; + this.policyEnforcer = policyEnforcer; } @Override @@ -196,7 +200,7 @@ public QueryRunner getQueryRunnerForSegments(Query theQuery, Iterable< } final Function segmentMapFn = JvmUtils.safeAccumulateThreadCpuTime( cpuTimeAccumulator, - () -> ev.createSegmentMapFunction() + () -> ev.createSegmentMapFunction(policyEnforcer) ); // We compute the datasource's cache key here itself so it doesn't need to be re-computed for every segment diff --git a/server/src/test/java/org/apache/druid/guice/security/PolicyModuleTest.java b/server/src/test/java/org/apache/druid/guice/security/PolicyModuleTest.java new file mode 100644 index 000000000000..9c524235fb17 --- /dev/null +++ b/server/src/test/java/org/apache/druid/guice/security/PolicyModuleTest.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.guice.security; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Guice; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.ProvisionException; +import org.apache.druid.guice.ConfigModule; +import org.apache.druid.guice.DruidGuiceExtensions; +import org.apache.druid.jackson.JacksonModule; +import org.apache.druid.query.filter.NullFilter; +import org.apache.druid.query.policy.NoRestrictionPolicy; +import org.apache.druid.query.policy.NoopPolicyEnforcer; +import org.apache.druid.query.policy.PolicyEnforcer; +import org.apache.druid.query.policy.RestrictAllTablesPolicyEnforcer; +import org.apache.druid.query.policy.RowFilterPolicy; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Properties; + +public class PolicyModuleTest +{ + @Test + public void testDefaultConfigNoopPolicyEnforcer() + { + Properties properties = new Properties(); + PolicyEnforcer policyEnforcer = Guice.createInjector( + binder -> binder.bind(Properties.class).toInstance(properties), + new DruidGuiceExtensions(), + new ConfigModule(), + new PolicyModule() + ).getInstance(Key.get(PolicyEnforcer.class)); + Assert.assertNotNull(policyEnforcer); + Assert.assertTrue(policyEnforcer instanceof NoopPolicyEnforcer); + } + + @Test + public void testConfigThrowForUnrecognizedType() + { + Properties properties = new Properties(); + properties.setProperty("druid.policy.enforcer.type", "unrecognizedType"); + Injector injector = Guice.createInjector( + binder -> binder.bind(Properties.class).toInstance(properties), + new DruidGuiceExtensions(), + new ConfigModule(), + new PolicyModule() + ); + ProvisionException e = Assert.assertThrows( + ProvisionException.class, + () -> injector.getInstance(Key.get(PolicyEnforcer.class)) + ); + Assert.assertTrue(e.getCause() + .getMessage() + .contains( + "Could not resolve type id 'unrecognizedType' as a subtype of `org.apache.druid.query.policy.PolicyEnforcer`")); + } + + @Test + public void testConfigRestrictAllTablesPolicyEnforcer() + { + Properties properties = new Properties(); + properties.setProperty("druid.policy.enforcer.type", "restrictAllTables"); + PolicyEnforcer policyEnforcer = Guice.createInjector( + binder -> binder.bind(Properties.class).toInstance(properties), + new DruidGuiceExtensions(), + new ConfigModule(), + new JacksonModule(), + new PolicyModule() + ).getInstance(Key.get(PolicyEnforcer.class)); + + Assert.assertNotNull(policyEnforcer); + Assert.assertEquals(new RestrictAllTablesPolicyEnforcer(null), policyEnforcer); + } + + @Test + public void testConfigRestrictAllTablesPolicyEnforcerWithAllowedPolicies() + { + Properties properties = new Properties(); + properties.setProperty("druid.policy.enforcer.type", "restrictAllTables"); + properties.setProperty( + "druid.policy.enforcer.allowedPolicies", + "[\"some-policy-class\", \"org.apache.druid.query.policy.NoRestrictionPolicy\"]" + ); + PolicyEnforcer policyEnforcer = Guice.createInjector( + binder -> binder.bind(Properties.class).toInstance(properties), + new DruidGuiceExtensions(), + new ConfigModule(), + new JacksonModule(), + new PolicyModule() + ).getInstance(Key.get(PolicyEnforcer.class)); + + Assert.assertNotNull(policyEnforcer); + Assert.assertEquals(new RestrictAllTablesPolicyEnforcer(ImmutableList.of( + "some-policy-class", + "org.apache.druid.query.policy.NoRestrictionPolicy" + )), policyEnforcer); + Assert.assertTrue(policyEnforcer.validate(NoRestrictionPolicy.instance())); + Assert.assertFalse(policyEnforcer.validate(RowFilterPolicy.from(new NullFilter("some-col", null)))); + } +} diff --git a/server/src/test/java/org/apache/druid/segment/loading/SegmentLocalCacheManagerTest.java b/server/src/test/java/org/apache/druid/segment/loading/SegmentLocalCacheManagerTest.java index 1ca73d4b934e..67d72c6ced0f 100644 --- a/server/src/test/java/org/apache/druid/segment/loading/SegmentLocalCacheManagerTest.java +++ b/server/src/test/java/org/apache/druid/segment/loading/SegmentLocalCacheManagerTest.java @@ -35,7 +35,7 @@ import org.apache.druid.segment.SegmentLazyLoadFailCallback; import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.TestIndex; -import org.apache.druid.server.TestSegmentUtils; +import org.apache.druid.segment.TestSegmentUtils; import org.apache.druid.server.metrics.NoopServiceEmitter; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.partition.NoneShardSpec; diff --git a/server/src/test/java/org/apache/druid/segment/metadata/SegmentMetadataCacheTestBase.java b/server/src/test/java/org/apache/druid/segment/metadata/SegmentMetadataCacheTestBase.java index 60195a7a3efe..e8380b2f3031 100644 --- a/server/src/test/java/org/apache/druid/segment/metadata/SegmentMetadataCacheTestBase.java +++ b/server/src/test/java/org/apache/druid/segment/metadata/SegmentMetadataCacheTestBase.java @@ -37,6 +37,7 @@ import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.IndexBuilder; import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.incremental.IncrementalIndexSchema; @@ -293,6 +294,7 @@ public QueryLifecycleFactory getQueryLifecycleFactory(QuerySegmentWalker walker) new NoopServiceEmitter(), new TestRequestLogger(), new AuthConfig(), + NoopPolicyEnforcer.instance(), AuthTestUtils.TEST_AUTHORIZER_MAPPER, Suppliers.ofInstance(new DefaultQueryConfig(ImmutableMap.of())) ); diff --git a/server/src/test/java/org/apache/druid/segment/realtime/appenderator/StreamAppenderatorTest.java b/server/src/test/java/org/apache/druid/segment/realtime/appenderator/StreamAppenderatorTest.java index 72a221e9040a..cdd0e85c9c5c 100644 --- a/server/src/test/java/org/apache/druid/segment/realtime/appenderator/StreamAppenderatorTest.java +++ b/server/src/test/java/org/apache/druid/segment/realtime/appenderator/StreamAppenderatorTest.java @@ -29,6 +29,7 @@ import org.apache.druid.data.input.Committer; import org.apache.druid.data.input.InputRow; import org.apache.druid.data.input.MapBasedInputRow; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.Pair; @@ -38,10 +39,14 @@ import org.apache.druid.query.Druids; import org.apache.druid.query.Order; import org.apache.druid.query.QueryPlus; +import org.apache.druid.query.RestrictedDataSource; import org.apache.druid.query.Result; import org.apache.druid.query.SegmentDescriptor; +import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.context.ResponseContext; +import org.apache.druid.query.policy.NoRestrictionPolicy; +import org.apache.druid.query.policy.RestrictAllTablesPolicyEnforcer; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.scan.ScanResultValue; import org.apache.druid.query.spec.MultipleSpecificSegmentSpec; @@ -1859,6 +1864,53 @@ public void testQueryByIntervals_withSegmentVersionUpgrades() throws Exception } } + @Test + public void testQueryFailWithSecurityValidation() throws Exception + { + final StubServiceEmitter serviceEmitter = new StubServiceEmitter(); + final StreamAppenderatorTester tester = + new StreamAppenderatorTester.Builder().maxRowsInMemory(2) + .basePersistDirectory(temporaryFolder.newFolder()) + .withServiceEmitter(serviceEmitter) + .withPolicyEnforcer(new RestrictAllTablesPolicyEnforcer(null)) + .build(); + final Appenderator appenderator = tester.getAppenderator(); + + appenderator.startJob(); + appenderator.add(IDENTIFIERS.get(0), ir("2000", "foo", 1), Suppliers.ofInstance(Committers.nil())); + + // Query1: no policy restriction, fail + final TimeseriesQuery query1 = Druids.newTimeseriesQueryBuilder() + .dataSource(StreamAppenderatorTester.DATASOURCE) + .intervals(ImmutableList.of(Intervals.of("2000/2001"))) + .aggregators(ImmutableList.of(new LongSumAggregatorFactory("count", "count"))) + .granularity(Granularities.DAY) + .build(); + DruidException e = Assert.assertThrows( + DruidException.class, + () -> QueryPlus.wrap(query1) + .run(appenderator, ResponseContext.createEmpty()) + .toList() + ); + Assert.assertEquals(DruidException.Category.FORBIDDEN, e.getCategory()); + Assert.assertEquals(DruidException.Persona.OPERATOR, e.getTargetPersona()); + Assert.assertEquals( + "Failed security validation with segment [foo_2000-01-01T00:00:00.000Z_2001-01-01T00:00:00.000Z_A]", + e.getMessage() + ); + + // Query2: with policy restriction, success + RestrictedDataSource restrictedDataSource = RestrictedDataSource.create( + TableDataSource.create(StreamAppenderatorTester.DATASOURCE), NoRestrictionPolicy.instance()); + final TimeseriesQuery query2 = Druids.newTimeseriesQueryBuilder() + .dataSource(restrictedDataSource) + .intervals(ImmutableList.of(Intervals.of("2000/2001"))) + .aggregators(ImmutableList.of(new LongSumAggregatorFactory("count", "count"))) + .granularity(Granularities.DAY) + .build(); + QueryPlus.wrap(query2).run(appenderator, ResponseContext.createEmpty()).toList(); + } + @Test public void testQueryByIntervals() throws Exception { diff --git a/server/src/test/java/org/apache/druid/segment/realtime/appenderator/StreamAppenderatorTester.java b/server/src/test/java/org/apache/druid/segment/realtime/appenderator/StreamAppenderatorTester.java index 6f9bf400b0fd..8b331b322a6e 100644 --- a/server/src/test/java/org/apache/druid/segment/realtime/appenderator/StreamAppenderatorTester.java +++ b/server/src/test/java/org/apache/druid/segment/realtime/appenderator/StreamAppenderatorTester.java @@ -48,6 +48,8 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.expression.TestExprMacroTable; +import org.apache.druid.query.policy.NoopPolicyEnforcer; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.scan.ScanQueryConfig; import org.apache.druid.query.scan.ScanQueryEngine; @@ -110,7 +112,8 @@ public StreamAppenderatorTester( final boolean skipBytesInMemoryOverheadCheck, final DataSegmentAnnouncer announcer, final CentralizedDatasourceSchemaConfig centralizedDatasourceSchemaConfig, - final ServiceEmitter serviceEmitter + final ServiceEmitter serviceEmitter, + final PolicyEnforcer policyEnforcer ) { objectMapper = new DefaultObjectMapper(); @@ -244,6 +247,7 @@ ScanQuery.class, new ScanQueryRunnerFactory( MapCache.create(2048), new CacheConfig(), new CachePopulatorStats(), + policyEnforcer, rowIngestionMeters, new ParseExceptionHandler(rowIngestionMeters, false, Integer.MAX_VALUE, 0), true, @@ -286,6 +290,7 @@ ScanQuery.class, new ScanQueryRunnerFactory( MapCache.create(2048), new CacheConfig(), new CachePopulatorStats(), + NoopPolicyEnforcer.instance(), rowIngestionMeters, new ParseExceptionHandler(rowIngestionMeters, false, Integer.MAX_VALUE, 0), true, @@ -353,6 +358,7 @@ public static class Builder private boolean skipBytesInMemoryOverheadCheck; private int delayInMilli = 0; private ServiceEmitter serviceEmitter; + private PolicyEnforcer policyEnforcer = NoopPolicyEnforcer.instance(); public Builder maxRowsInMemory(final int maxRowsInMemory) { @@ -402,6 +408,12 @@ public Builder withServiceEmitter(ServiceEmitter serviceEmitter) return this; } + public Builder withPolicyEnforcer(PolicyEnforcer policyEnforcer) + { + this.policyEnforcer = policyEnforcer; + return this; + } + public StreamAppenderatorTester build() { return new StreamAppenderatorTester( @@ -414,7 +426,8 @@ public StreamAppenderatorTester build() skipBytesInMemoryOverheadCheck, new NoopDataSegmentAnnouncer(), CentralizedDatasourceSchemaConfig.create(), - serviceEmitter + serviceEmitter, + policyEnforcer ); } @@ -433,7 +446,8 @@ public StreamAppenderatorTester build( skipBytesInMemoryOverheadCheck, dataSegmentAnnouncer, config, - serviceEmitter + serviceEmitter, + policyEnforcer ); } } diff --git a/server/src/test/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManagerTest.java b/server/src/test/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManagerTest.java index e92df0f65183..3a684027e23b 100644 --- a/server/src/test/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManagerTest.java +++ b/server/src/test/java/org/apache/druid/segment/realtime/appenderator/UnifiedIndexerAppenderatorsManagerTest.java @@ -34,6 +34,7 @@ import org.apache.druid.query.DirectQueryProcessingPool; import org.apache.druid.query.Druids; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.segment.IndexMerger; @@ -81,6 +82,7 @@ public class UnifiedIndexerAppenderatorsManagerTest extends InitializedNullHandl MapCache.create(10), new CacheConfig(), new CachePopulatorStats(), + NoopPolicyEnforcer.instance(), TestHelper.makeJsonMapper(), new NoopServiceEmitter(), () -> DefaultQueryRunnerFactoryConglomerate.buildFromQueryRunnerFactories(ImmutableMap.of()) diff --git a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java index c317bc291702..cd39030bd3a2 100644 --- a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java +++ b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java @@ -19,10 +19,18 @@ package org.apache.druid.server; +import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.inject.Guice; +import com.google.inject.Inject; +import com.google.inject.Injector; +import com.google.inject.Scopes; +import com.google.inject.testing.fieldbinder.Bind; +import com.google.inject.testing.fieldbinder.BoundFieldModule; import org.apache.druid.error.DruidException; +import org.apache.druid.guice.LazySingleton; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.guava.Sequences; @@ -45,7 +53,10 @@ import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.metadata.metadata.SegmentMetadataQuery; import org.apache.druid.query.policy.NoRestrictionPolicy; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.query.policy.Policy; +import org.apache.druid.query.policy.PolicyEnforcer; +import org.apache.druid.query.policy.RestrictAllTablesPolicyEnforcer; import org.apache.druid.query.policy.RowFilterPolicy; import org.apache.druid.query.timeseries.TimeseriesQuery; import org.apache.druid.server.log.RequestLogger; @@ -72,6 +83,7 @@ import java.util.Map; import java.util.Optional; +@LazySingleton public class QueryLifecycleTest { private static final String DATASOURCE = "some_datasource"; @@ -85,19 +97,40 @@ public class QueryLifecycleTest .intervals(ImmutableList.of(Intervals.ETERNITY)) .aggregators(new CountAggregatorFactory("chocula")) .build(); + QueryToolChest toolChest; + @Bind QueryRunnerFactoryConglomerate conglomerate; + + QueryRunner runner; + @Bind QuerySegmentWalker texasRanger; + @Bind GenericQueryMetricsFactory metricsFactory; + @Bind ServiceEmitter emitter; + @Bind RequestLogger requestLogger; + + Authorizer authorizer; + @Bind AuthorizerMapper authzMapper; + DefaultQueryConfig queryConfig; + @Bind(lazy = true) + Supplier queryConfigSupplier; + + @Bind(lazy = true) + AuthConfig authConfig; + @Bind(lazy = true) + PolicyEnforcer policyEnforcer; - QueryToolChest toolChest; - QueryRunner runner; QueryMetrics metrics; AuthenticationResult authenticationResult; - Authorizer authorizer; + + @Inject + QueryLifecycleFactory queryLifecycleFactory; + + Injector injector; @Rule public ExpectedException expectedException = ExpectedException.none(); @@ -113,29 +146,25 @@ public void setup() authorizer = EasyMock.createMock(Authorizer.class); authzMapper = new AuthorizerMapper(ImmutableMap.of(AUTHORIZER, authorizer)); queryConfig = EasyMock.createMock(DefaultQueryConfig.class); + queryConfigSupplier = () -> queryConfig; toolChest = EasyMock.createMock(QueryToolChest.class); runner = EasyMock.createMock(QueryRunner.class); metrics = EasyMock.createNiceMock(QueryMetrics.class); authenticationResult = EasyMock.createMock(AuthenticationResult.class); + authConfig = new AuthConfig(); + policyEnforcer = NoopPolicyEnforcer.instance(); + + injector = Guice.createInjector( + BoundFieldModule.of(this), + binder -> binder.bindScope(LazySingleton.class, Scopes.SINGLETON) + ); } - private QueryLifecycle createLifecycle(AuthConfig authConfig) + private QueryLifecycle createLifecycle() { - long nanos = System.nanoTime(); - long millis = System.currentTimeMillis(); - return new QueryLifecycle( - conglomerate, - texasRanger, - metricsFactory, - emitter, - requestLogger, - authzMapper, - queryConfig, - authConfig, - millis, - nanos - ); + injector.injectMembers(this); + return queryLifecycleFactory.factorize(); } @After @@ -171,7 +200,7 @@ public void testRunSimple_preauthorizedAsSuperuser() replayAll(); - QueryLifecycle lifecycle = createLifecycle(new AuthConfig()); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.runSimple(query, authenticationResult, AuthorizationResult.ALLOW_NO_RESTRICTION); } @@ -189,7 +218,7 @@ public void testRunSimpleUnauthorized() EasyMock.expect(toolChest.makeMetrics(EasyMock.anyObject())).andReturn(metrics).anyTimes(); replayAll(); - QueryLifecycle lifecycle = createLifecycle(new AuthConfig()); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.runSimple(query, authenticationResult, AuthorizationResult.DENY); } @@ -223,10 +252,10 @@ public void testRunSimple_withPolicyRestriction() EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).once(); replayAll(); - AuthConfig authConfig = AuthConfig.newBuilder() - .setAuthorizeQueryContextParams(true) - .build(); - QueryLifecycle lifecycle = createLifecycle(authConfig); + authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .build(); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.runSimple(query, authenticationResult, authorizationResult); } @@ -255,7 +284,7 @@ public void testRunSimple_withPolicyRestriction_segmentMetadataQueryRunAsInterna EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).once(); replayAll(); - QueryLifecycle lifecycle = createLifecycle(new AuthConfig()); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.runSimple(query, authenticationResult, authorizationResult); } @@ -279,14 +308,17 @@ public void testRunSimple_withPolicyRestriction_segmentMetadataQueryRunAsExterna EasyMock.expect(toolChest.makeMetrics(EasyMock.anyObject())).andReturn(metrics).once(); replayAll(); - QueryLifecycle lifecycle = createLifecycle(new AuthConfig()); + QueryLifecycle lifecycle = createLifecycle(); DruidException e = Assert.assertThrows( DruidException.class, () -> lifecycle.runSimple(query, authenticationResult, authorizationResult) ); Assert.assertEquals(DruidException.Persona.USER, e.getTargetPersona()); Assert.assertEquals(DruidException.Category.FORBIDDEN, e.getCategory()); - Assert.assertEquals("You do not have permission to run a segmentMetadata query on table[some_datasource].", e.getMessage()); + Assert.assertEquals( + "You do not have permission to run a segmentMetadata query on table[some_datasource].", + e.getMessage() + ); } @Test @@ -305,7 +337,7 @@ public void testRunSimple_withoutPolicy() EasyMock.expect(toolChest.makeMetrics(EasyMock.anyObject())).andReturn(metrics).anyTimes(); replayAll(); - QueryLifecycle lifecycle = createLifecycle(AuthConfig.newBuilder().build()); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.runSimple(query, authenticationResult, authorizationResult); } @@ -345,7 +377,7 @@ public void testRunSimple_foundDifferentPolicyRestrictions() EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).anyTimes(); replayAll(); - QueryLifecycle lifecycle = createLifecycle(new AuthConfig()); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.runSimple(query, authenticationResult, authorizationResult); } @@ -387,7 +419,7 @@ public void testRunSimple_queryWithRestrictedDataSource_policyRestrictionMightHa EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).anyTimes(); replayAll(); - QueryLifecycle lifecycle = createLifecycle(new AuthConfig()); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.runSimple(query, authenticationResult, authorizationResult); } @@ -422,12 +454,48 @@ public void testAuthorized_withPolicyRestriction() EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).anyTimes(); replayAll(); - QueryLifecycle lifecycle = createLifecycle(new AuthConfig()); + policyEnforcer = new RestrictAllTablesPolicyEnforcer(ImmutableList.of( + NoRestrictionPolicy.class.getName(), + RowFilterPolicy.class.getName() + )); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertTrue(lifecycle.authorize(authenticationResult).allowBasicAccess()); + // Success, query has a RowFilterPolicy, and is allowed by PolicyEnforcer. lifecycle.execute(); } + @Test + public void testAuthorized_withPolicyRestriction_failedSecurityValidation() + { + // Test the path broker receives a native json query from external client, should add restriction on data source + Policy rowFilterPolicy = RowFilterPolicy.from(new NullFilter("some-column", null)); + Access access = Access.allowWithRestriction(rowFilterPolicy); + + final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() + .dataSource(DATASOURCE) + .intervals(ImmutableList.of(Intervals.ETERNITY)) + .aggregators(new CountAggregatorFactory("chocula")) + .build(); + EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); + EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); + EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); + EasyMock.expect(authorizer.authorize(authenticationResult, RESOURCE, Action.READ)) + .andReturn(access).anyTimes(); + EasyMock.expect(conglomerate.getToolChest(EasyMock.anyObject())) + .andReturn(toolChest).anyTimes(); + replayAll(); + + policyEnforcer = new RestrictAllTablesPolicyEnforcer(ImmutableList.of(NoRestrictionPolicy.class.getName())); + QueryLifecycle lifecycle = createLifecycle(); + lifecycle.initialize(query); + // Fail, only NoRestrictionPolicy is allowed. + DruidException e = Assert.assertThrows(DruidException.class, () -> lifecycle.authorize(authenticationResult)); + Assert.assertEquals(DruidException.Category.FORBIDDEN, e.getCategory()); + Assert.assertEquals(DruidException.Persona.OPERATOR, e.getTargetPersona()); + Assert.assertEquals("Failed security validation with dataSource [some_datasource]", e.getMessage()); + } + @Test public void testAuthorized_queryWithRestrictedDataSource_runWithSuperUserPermission() { @@ -459,10 +527,8 @@ public void testAuthorized_queryWithRestrictedDataSource_runWithSuperUserPermiss EasyMock.expect(runner.run(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(Sequences.empty()).anyTimes(); replayAll(); - AuthConfig authConfig = AuthConfig.newBuilder() - .setAuthorizeQueryContextParams(true) - .build(); - QueryLifecycle lifecycle = createLifecycle(authConfig); + authConfig = AuthConfig.newBuilder().setAuthorizeQueryContextParams(true).build(); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertTrue(lifecycle.authorize(authenticationResult).allowBasicAccess()); lifecycle.execute(); @@ -507,10 +573,8 @@ public void testAuthorizeQueryContext_authorized() .context(userContext) .build(); - AuthConfig authConfig = AuthConfig.newBuilder() - .setAuthorizeQueryContextParams(true) - .build(); - QueryLifecycle lifecycle = createLifecycle(authConfig); + authConfig = AuthConfig.newBuilder().setAuthorizeQueryContextParams(true).build(); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.initialize(query); final Map revisedContext = new HashMap<>(lifecycle.getQuery().getContext()); @@ -523,7 +587,7 @@ public void testAuthorizeQueryContext_authorized() Assert.assertTrue(lifecycle.authorize(mockRequest()).allowAccessWithNoRestriction()); - lifecycle = createLifecycle(authConfig); + lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertTrue(lifecycle.authorize(authenticationResult).allowAccessWithNoRestriction()); } @@ -562,14 +626,12 @@ public void testAuthorizeQueryContext_notAuthorized() .context(ImmutableMap.of("foo", "bar")) .build(); - AuthConfig authConfig = AuthConfig.newBuilder() - .setAuthorizeQueryContextParams(true) - .build(); - QueryLifecycle lifecycle = createLifecycle(authConfig); + authConfig = AuthConfig.newBuilder().setAuthorizeQueryContextParams(true).build(); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertFalse(lifecycle.authorize(mockRequest()).allowBasicAccess()); - lifecycle = createLifecycle(authConfig); + lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertFalse(lifecycle.authorize(authenticationResult).allowBasicAccess()); } @@ -598,11 +660,11 @@ public void testAuthorizeQueryContext_unsecuredKeys() .context(userContext) .build(); - AuthConfig authConfig = AuthConfig.newBuilder() - .setAuthorizeQueryContextParams(true) - .setUnsecuredContextKeys(ImmutableSet.of("foo", "baz")) - .build(); - QueryLifecycle lifecycle = createLifecycle(authConfig); + authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .setUnsecuredContextKeys(ImmutableSet.of("foo", "baz")) + .build(); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.initialize(query); final Map revisedContext = new HashMap<>(lifecycle.getQuery().getContext()); @@ -615,7 +677,7 @@ public void testAuthorizeQueryContext_unsecuredKeys() Assert.assertTrue(lifecycle.authorize(mockRequest()).allowAccessWithNoRestriction()); - lifecycle = createLifecycle(authConfig); + lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertTrue(lifecycle.authorize(authenticationResult).allowAccessWithNoRestriction()); } @@ -648,12 +710,12 @@ public void testAuthorizeQueryContext_securedKeys() .context(userContext) .build(); - AuthConfig authConfig = AuthConfig.newBuilder() - .setAuthorizeQueryContextParams(true) - // We have secured keys, just not what the user gave. - .setSecuredContextKeys(ImmutableSet.of("foo2", "baz2")) - .build(); - QueryLifecycle lifecycle = createLifecycle(authConfig); + authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + // We have secured keys, just not what the user gave. + .setSecuredContextKeys(ImmutableSet.of("foo2", "baz2")) + .build(); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.initialize(query); final Map revisedContext = new HashMap<>(lifecycle.getQuery().getContext()); @@ -666,7 +728,7 @@ public void testAuthorizeQueryContext_securedKeys() Assert.assertTrue(lifecycle.authorize(mockRequest()).allowBasicAccess()); - lifecycle = createLifecycle(authConfig); + lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertTrue(lifecycle.authorize(authenticationResult).allowBasicAccess()); } @@ -706,16 +768,16 @@ public void testAuthorizeQueryContext_securedKeysNotAuthorized() .context(userContext) .build(); - AuthConfig authConfig = AuthConfig.newBuilder() - .setAuthorizeQueryContextParams(true) - // We have secured keys. User used one of them. - .setSecuredContextKeys(ImmutableSet.of("foo", "baz2")) - .build(); - QueryLifecycle lifecycle = createLifecycle(authConfig); + authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + // We have secured keys. User used one of them. + .setSecuredContextKeys(ImmutableSet.of("foo", "baz2")) + .build(); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertFalse(lifecycle.authorize(mockRequest()).allowBasicAccess()); - lifecycle = createLifecycle(authConfig); + lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertFalse(lifecycle.authorize(authenticationResult).allowBasicAccess()); } @@ -761,10 +823,10 @@ public void testAuthorizeLegacyQueryContext_authorized() "qux" )); - AuthConfig authConfig = AuthConfig.newBuilder() - .setAuthorizeQueryContextParams(true) - .build(); - QueryLifecycle lifecycle = createLifecycle(authConfig); + authConfig = AuthConfig.newBuilder() + .setAuthorizeQueryContextParams(true) + .build(); + QueryLifecycle lifecycle = createLifecycle(); lifecycle.initialize(query); final Map revisedContext = lifecycle.getQuery().getContext(); @@ -775,7 +837,7 @@ public void testAuthorizeLegacyQueryContext_authorized() Assert.assertTrue(lifecycle.authorize(mockRequest()).allowBasicAccess()); - lifecycle = createLifecycle(authConfig); + lifecycle = createLifecycle(); lifecycle.initialize(query); Assert.assertTrue(lifecycle.authorize(mockRequest()).allowBasicAccess()); } diff --git a/server/src/test/java/org/apache/druid/server/QueryResourceTest.java b/server/src/test/java/org/apache/druid/server/QueryResourceTest.java index 6f87127b0968..684092904848 100644 --- a/server/src/test/java/org/apache/druid/server/QueryResourceTest.java +++ b/server/src/test/java/org/apache/druid/server/QueryResourceTest.java @@ -65,6 +65,7 @@ import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.query.TruncatedResponseContextException; import org.apache.druid.query.filter.NullFilter; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.query.policy.RowFilterPolicy; import org.apache.druid.query.timeboundary.TimeBoundaryResultValue; import org.apache.druid.server.initialization.ServerConfig; @@ -247,6 +248,7 @@ private QueryResource createQueryResource(ResponseContextConfig responseContextC new NoopServiceEmitter(), testRequestLogger, new AuthConfig(), + NoopPolicyEnforcer.instance(), AuthTestUtils.TEST_AUTHORIZER_MAPPER, Suppliers.ofInstance(new DefaultQueryConfig(ImmutableMap.of())) ), @@ -282,6 +284,7 @@ public void testGoodQueryWithQueryConfigOverrideDefault() throws IOException new NoopServiceEmitter(), testRequestLogger, new AuthConfig(), + NoopPolicyEnforcer.instance(), AuthTestUtils.TEST_AUTHORIZER_MAPPER, Suppliers.ofInstance(overrideConfig) ), @@ -356,6 +359,7 @@ public QueryRunner getQueryRunnerForSegments( new NoopServiceEmitter(), testRequestLogger, new AuthConfig(), + NoopPolicyEnforcer.instance(), AuthTestUtils.TEST_AUTHORIZER_MAPPER, Suppliers.ofInstance(overrideConfig) ), @@ -448,6 +452,7 @@ public QueryRunner getQueryRunnerForSegments( new NoopServiceEmitter(), testRequestLogger, new AuthConfig(), + NoopPolicyEnforcer.instance(), AuthTestUtils.TEST_AUTHORIZER_MAPPER, Suppliers.ofInstance(new DefaultQueryConfig(ImmutableMap.of())) ), @@ -493,6 +498,7 @@ public void testSuccessResponseWithTrailerHeader() throws IOException new NoopServiceEmitter(), testRequestLogger, new AuthConfig(), + NoopPolicyEnforcer.instance(), AuthTestUtils.TEST_AUTHORIZER_MAPPER, Suppliers.ofInstance(new DefaultQueryConfig(ImmutableMap.of())) ), @@ -553,7 +559,7 @@ public QueryRunner getQueryRunnerForSegments( queryResource = new QueryResource( - new QueryLifecycleFactory(null, null, null, null, null, null, null, Suppliers.ofInstance(overrideConfig)) + new QueryLifecycleFactory(null, null, null, null, null, null, NoopPolicyEnforcer.instance(), null, Suppliers.ofInstance(overrideConfig)) { @Override public QueryLifecycle factorize() @@ -567,6 +573,7 @@ public QueryLifecycle factorize() AuthTestUtils.TEST_AUTHORIZER_MAPPER, overrideConfig, new AuthConfig(), + NoopPolicyEnforcer.instance(), System.currentTimeMillis(), System.nanoTime() ) @@ -618,6 +625,7 @@ public void testGoodQueryWithQueryConfigDoesNotOverrideQueryContext() throws IOE new NoopServiceEmitter(), testRequestLogger, new AuthConfig(), + NoopPolicyEnforcer.instance(), AuthTestUtils.TEST_AUTHORIZER_MAPPER, Suppliers.ofInstance(overrideConfig) ), @@ -855,6 +863,7 @@ public Access authorize(AuthenticationResult authenticationResult, Resource reso new NoopServiceEmitter(), testRequestLogger, new AuthConfig(), + NoopPolicyEnforcer.instance(), authMapper, Suppliers.ofInstance(new DefaultQueryConfig(ImmutableMap.of())) ), @@ -930,6 +939,7 @@ public QueryRunner getQueryRunnerForSegments(Query query, Iterable QueryRunner getQueryRunnerForSegments(Query query, Iterable QueryRunner getQueryRunnerForSegments(final Query query, final throw new ISE("Cannot handle subquery: %s", dataSourceFromQuery); } - final Function segmentMapFn = ev.createSegmentMapFunction(); + final Function segmentMapFn = ev.createSegmentMapFunction(NoopPolicyEnforcer.instance()); final QueryRunner baseRunner = new FinalizeResultsQueryRunner<>( toolChest.postMergeQueryDecoration( diff --git a/server/src/test/java/org/apache/druid/server/coordination/SegmentBootstrapperCacheTest.java b/server/src/test/java/org/apache/druid/server/coordination/SegmentBootstrapperCacheTest.java index 17e862a50a9c..3e51cf15be4e 100644 --- a/server/src/test/java/org/apache/druid/server/coordination/SegmentBootstrapperCacheTest.java +++ b/server/src/test/java/org/apache/druid/server/coordination/SegmentBootstrapperCacheTest.java @@ -27,6 +27,7 @@ import org.apache.druid.java.util.metrics.StubServiceEmitter; import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.TestIndex; +import org.apache.druid.segment.TestSegmentUtils; import org.apache.druid.segment.loading.DataSegmentPusher; import org.apache.druid.segment.loading.LeastBytesUsedStorageLocationSelectorStrategy; import org.apache.druid.segment.loading.SegmentLoaderConfig; @@ -35,7 +36,6 @@ import org.apache.druid.segment.loading.StorageLocation; import org.apache.druid.segment.loading.StorageLocationConfig; import org.apache.druid.server.SegmentManager; -import org.apache.druid.server.TestSegmentUtils; import org.apache.druid.server.metrics.DataSourceTaskIdHolder; import org.apache.druid.timeline.DataSegment; import org.junit.Assert; diff --git a/server/src/test/java/org/apache/druid/server/coordination/SegmentBootstrapperTest.java b/server/src/test/java/org/apache/druid/server/coordination/SegmentBootstrapperTest.java index ed30e1aa3af1..176ae19d5d06 100644 --- a/server/src/test/java/org/apache/druid/server/coordination/SegmentBootstrapperTest.java +++ b/server/src/test/java/org/apache/druid/server/coordination/SegmentBootstrapperTest.java @@ -37,6 +37,7 @@ import org.apache.druid.segment.loading.StorageLocationConfig; import org.apache.druid.server.SegmentManager; import org.apache.druid.server.metrics.DataSourceTaskIdHolder; +import org.apache.druid.test.utils.TestSegmentCacheManager; import org.apache.druid.timeline.DataSegment; import org.junit.Assert; import org.junit.Before; @@ -51,7 +52,7 @@ import java.util.List; import java.util.Set; -import static org.apache.druid.server.TestSegmentUtils.makeSegment; +import static org.apache.druid.segment.TestSegmentUtils.makeSegment; public class SegmentBootstrapperTest { diff --git a/server/src/test/java/org/apache/druid/server/coordination/SegmentLoadDropHandlerTest.java b/server/src/test/java/org/apache/druid/server/coordination/SegmentLoadDropHandlerTest.java index 99a2a1f7ad46..bade415887be 100644 --- a/server/src/test/java/org/apache/druid/server/coordination/SegmentLoadDropHandlerTest.java +++ b/server/src/test/java/org/apache/druid/server/coordination/SegmentLoadDropHandlerTest.java @@ -31,6 +31,7 @@ import org.apache.druid.server.SegmentManager; import org.apache.druid.server.coordination.SegmentChangeStatus.State; import org.apache.druid.server.http.SegmentLoadingMode; +import org.apache.druid.test.utils.TestSegmentCacheManager; import org.apache.druid.timeline.DataSegment; import org.junit.Assert; import org.junit.Before; @@ -52,7 +53,7 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import static org.apache.druid.server.TestSegmentUtils.makeSegment; +import static org.apache.druid.segment.TestSegmentUtils.makeSegment; public class SegmentLoadDropHandlerTest { diff --git a/server/src/test/java/org/apache/druid/server/coordination/ServerManagerTest.java b/server/src/test/java/org/apache/druid/server/coordination/ServerManagerTest.java index 119272acae3c..8364ff40835c 100644 --- a/server/src/test/java/org/apache/druid/server/coordination/ServerManagerTest.java +++ b/server/src/test/java/org/apache/druid/server/coordination/ServerManagerTest.java @@ -20,19 +20,27 @@ package org.apache.druid.server.coordination; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Function; import com.google.common.base.Functions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.inject.Guice; +import com.google.inject.Inject; +import com.google.inject.testing.fieldbinder.Bind; +import com.google.inject.testing.fieldbinder.BoundFieldModule; +import org.apache.druid.client.cache.Cache; import org.apache.druid.client.cache.CacheConfig; +import org.apache.druid.client.cache.CachePopulator; import org.apache.druid.client.cache.CachePopulatorStats; import org.apache.druid.client.cache.ForegroundCachePopulator; import org.apache.druid.client.cache.LocalCacheProvider; import org.apache.druid.error.DruidException; +import org.apache.druid.guice.annotations.Smile; import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Intervals; -import org.apache.druid.java.util.common.MapUtils; import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.granularity.Granularities; @@ -43,7 +51,7 @@ import org.apache.druid.java.util.common.guava.YieldingAccumulator; import org.apache.druid.java.util.common.guava.YieldingSequenceBase; import org.apache.druid.java.util.emitter.EmittingLogger; -import org.apache.druid.query.BaseQuery; +import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.query.ConcatQueryRunner; import org.apache.druid.query.DataSource; import org.apache.druid.query.DefaultQueryMetrics; @@ -61,59 +69,45 @@ import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.QueryUnsupportedException; +import org.apache.druid.query.RestrictedDataSource; import org.apache.druid.query.Result; import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.aggregation.MetricManipulationFn; import org.apache.druid.query.context.DefaultResponseContext; import org.apache.druid.query.context.ResponseContext; -import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.planning.ExecutionVertex; +import org.apache.druid.query.policy.NoRestrictionPolicy; +import org.apache.druid.query.policy.NoopPolicyEnforcer; +import org.apache.druid.query.policy.PolicyEnforcer; +import org.apache.druid.query.policy.RestrictAllTablesPolicyEnforcer; import org.apache.druid.query.search.SearchQuery; import org.apache.druid.query.search.SearchResultValue; import org.apache.druid.query.spec.MultipleSpecificSegmentSpec; -import org.apache.druid.query.spec.QuerySegmentSpec; -import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.Segment; -import org.apache.druid.segment.TestHelper; -import org.apache.druid.segment.TestIndex; -import org.apache.druid.segment.loading.LeastBytesUsedStorageLocationSelectorStrategy; -import org.apache.druid.segment.loading.SegmentLoaderConfig; +import org.apache.druid.segment.TestSegmentUtils; +import org.apache.druid.segment.TestSegmentUtils.SegmentForTesting; import org.apache.druid.segment.loading.SegmentLoadingException; -import org.apache.druid.segment.loading.SegmentLocalCacheManager; -import org.apache.druid.segment.loading.StorageLocation; -import org.apache.druid.segment.loading.StorageLocationConfig; -import org.apache.druid.segment.loading.TombstoneSegmentizerFactory; import org.apache.druid.server.SegmentManager; -import org.apache.druid.server.TestSegmentUtils; import org.apache.druid.server.initialization.ServerConfig; import org.apache.druid.server.metrics.NoopServiceEmitter; +import org.apache.druid.test.utils.TestSegmentCacheManager; import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.TimelineObjectHolder; import org.apache.druid.timeline.VersionedIntervalTimeline; -import org.apache.druid.timeline.partition.NoneShardSpec; import org.apache.druid.timeline.partition.PartitionChunk; -import org.apache.druid.timeline.partition.TombstoneShardSpec; -import org.hamcrest.MatcherAssert; -import org.hamcrest.text.StringContainsInOrder; import org.joda.time.Interval; import org.junit.Assert; +import org.junit.Assume; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.rules.TemporaryFolder; -import java.io.File; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -124,103 +118,77 @@ public class ServerManagerTest { - @Rule - public ExpectedException expectedException = ExpectedException.none(); + private static ImmutableSet DATA_SEGMENTS = new ImmutableSet.Builder() + .add(TestSegmentUtils.makeSegment("test", "1", Intervals.of("P1d/2011-04-01"))) + .add(TestSegmentUtils.makeSegment("test", "1", Intervals.of("P1d/2011-04-02"))) + .add(TestSegmentUtils.makeSegment("test", "2", Intervals.of("P1d/2011-04-02"))) + .add(TestSegmentUtils.makeSegment("test", "1", Intervals.of("P1d/2011-04-03"))) + .add(TestSegmentUtils.makeSegment("test", "1", Intervals.of("P1d/2011-04-04"))) + .add(TestSegmentUtils.makeSegment("test", "1", Intervals.of("P1d/2011-04-05"))) + .add(TestSegmentUtils.makeSegment("test", "2", Intervals.of("PT1h/2011-04-04T01"))) + .add(TestSegmentUtils.makeSegment("test", "2", Intervals.of("PT1h/2011-04-04T02"))) + .add(TestSegmentUtils.makeSegment("test", "2", Intervals.of("PT1h/2011-04-04T03"))) + .add(TestSegmentUtils.makeSegment("test", "2", Intervals.of("PT1h/2011-04-04T05"))) + .add(TestSegmentUtils.makeSegment("test", "2", Intervals.of("PT1h/2011-04-04T06"))) + .add(TestSegmentUtils.makeSegment("test2", "1", Intervals.of("P1d/2011-04-01"))) + .add(TestSegmentUtils.makeSegment("test2", "1", Intervals.of("P1d/2011-04-02"))) + .build(); + + @Bind + private QueryRunnerFactoryConglomerate conglomerate; + @Bind + private SegmentManager segmentManager; + @Bind + private PolicyEnforcer policyEnforcer; + @Bind + private ServerConfig serverConfig; + @Bind + private ServiceEmitter serviceEmitter; + @Bind + private QueryProcessingPool queryProcessingPool; + @Bind + private CachePopulator cachePopulator; + @Bind + @Smile + private ObjectMapper objectMapper; + @Bind + private Cache cache; + @Bind + private CacheConfig cacheConfig; - private ServerManager serverManager; private MyQueryRunnerFactory factory; - private CountDownLatch queryWaitLatch; - private CountDownLatch queryWaitYieldLatch; - private CountDownLatch queryNotifyLatch; private ExecutorService serverManagerExec; - private SegmentManager segmentManager; - @Rule - public TemporaryFolder temporaryFolder = new TemporaryFolder(); + @Inject + private ServerManager serverManager; @Before public void setUp() { - final SegmentLoaderConfig loaderConfig = new SegmentLoaderConfig() - { - @Override - public File getInfoDir() - { - return temporaryFolder.getRoot(); - } - - @Override - public List getLocations() - { - return Collections.singletonList( - new StorageLocationConfig(temporaryFolder.getRoot(), null, null) - ); - } - }; - - final List storageLocations = loaderConfig.toStorageLocations(); - final SegmentLocalCacheManager localCacheManager = new SegmentLocalCacheManager( - storageLocations, - loaderConfig, - new LeastBytesUsedStorageLocationSelectorStrategy(storageLocations), - TestIndex.INDEX_IO, - TestHelper.makeJsonMapper() - ) - { - @Override - public ReferenceCountingSegment getSegment(final DataSegment dataSegment) - { - if (dataSegment.isTombstone()) { - return ReferenceCountingSegment - .wrapSegment(TombstoneSegmentizerFactory.segmentForTombstone(dataSegment), dataSegment.getShardSpec()); - } else { - return ReferenceCountingSegment.wrapSegment(new TestSegmentUtils.SegmentForTesting( - dataSegment.getDataSource(), - (Interval) dataSegment.getLoadSpec().get("interval"), - MapUtils.getString(dataSegment.getLoadSpec(), "version") - ), dataSegment.getShardSpec()); - } - } - }; + serviceEmitter = new NoopServiceEmitter(); + EmittingLogger.registerEmitter(new NoopServiceEmitter()); + segmentManager = new SegmentManager(new TestSegmentCacheManager(DATA_SEGMENTS)); + for (DataSegment segment : DATA_SEGMENTS) { + loadQueryable(segment.getDataSource(), segment.getVersion(), segment.getInterval()); + } - segmentManager = new SegmentManager(localCacheManager); + factory = new MyQueryRunnerFactory(new CountDownLatch(1), new CountDownLatch(1), new CountDownLatch(1)); + // Only SearchQuery is supported in this test. + conglomerate = DefaultQueryRunnerFactoryConglomerate.buildFromQueryRunnerFactories(ImmutableMap.of( + SearchQuery.class, + factory + )); - EmittingLogger.registerEmitter(new NoopServiceEmitter()); - queryWaitLatch = new CountDownLatch(1); - queryWaitYieldLatch = new CountDownLatch(1); - queryNotifyLatch = new CountDownLatch(1); - factory = new MyQueryRunnerFactory(queryWaitLatch, queryWaitYieldLatch, queryNotifyLatch); serverManagerExec = Execs.multiThreaded(2, "ServerManagerTest-%d"); - QueryRunnerFactoryConglomerate conglomerate = DefaultQueryRunnerFactoryConglomerate.buildFromQueryRunnerFactories(ImmutableMap - ., QueryRunnerFactory>builder() - .put(SearchQuery.class, factory) - .build()); - serverManager = new ServerManager( - conglomerate, - new NoopServiceEmitter(), - new ForwardingQueryProcessingPool(serverManagerExec), - new ForegroundCachePopulator(new DefaultObjectMapper(), new CachePopulatorStats(), -1), - new DefaultObjectMapper(), - new LocalCacheProvider().get(), - new CacheConfig(), - segmentManager, - new ServerConfig() - ); - - loadQueryable("test", "1", Intervals.of("P1d/2011-04-01")); - loadQueryable("test", "1", Intervals.of("P1d/2011-04-02")); - loadQueryable("test", "2", Intervals.of("P1d/2011-04-02")); - loadQueryable("test", "1", Intervals.of("P1d/2011-04-03")); - loadQueryable("test", "1", Intervals.of("P1d/2011-04-04")); - loadQueryable("test", "1", Intervals.of("P1d/2011-04-05")); - loadQueryable("test", "2", Intervals.of("PT1h/2011-04-04T01")); - loadQueryable("test", "2", Intervals.of("PT1h/2011-04-04T02")); - loadQueryable("test", "2", Intervals.of("PT1h/2011-04-04T03")); - loadQueryable("test", "2", Intervals.of("PT1h/2011-04-04T05")); - loadQueryable("test", "2", Intervals.of("PT1h/2011-04-04T06")); - loadQueryable("test2", "1", Intervals.of("P1d/2011-04-01")); - loadQueryable("test2", "1", Intervals.of("P1d/2011-04-02")); - loadQueryable("testTombstone", "1", Intervals.of("P1d/2011-04-02")); + queryProcessingPool = new ForwardingQueryProcessingPool(serverManagerExec); + cachePopulator = new ForegroundCachePopulator(new DefaultObjectMapper(), new CachePopulatorStats(), -1); + objectMapper = new DefaultObjectMapper(); + cache = new LocalCacheProvider().get(); + cacheConfig = new CacheConfig(); + serverConfig = new ServerConfig(); + policyEnforcer = NoopPolicyEnforcer.instance(); + + Guice.createInjector(BoundFieldModule.of(this)).injectMembers(this); } @Test @@ -230,9 +198,7 @@ public void testSimpleGet() Granularities.DAY, "test", Intervals.of("P1d/2011-04-01"), - ImmutableList.of( - new Pair<>("1", Intervals.of("P1d/2011-04-01")) - ) + ImmutableList.of(new Pair<>("1", Intervals.of("P1d/2011-04-01"))) ); waitForTestVerificationAndCleanup(future); @@ -257,7 +223,6 @@ public void testSimpleGetTombstone() Collections.emptyList() // tombstone returns no data ); waitForTestVerificationAndCleanup(future); - } @Test @@ -351,7 +316,7 @@ public void testReferenceCounting() throws Exception ) ); - queryNotifyLatch.await(1000, TimeUnit.MILLISECONDS); + factory.notifyLatch.await(1000, TimeUnit.MILLISECONDS); Assert.assertEquals(1, factory.getSegmentReferences().size()); @@ -359,7 +324,7 @@ public void testReferenceCounting() throws Exception Assert.assertEquals(1, referenceCountingSegment.getNumReferences()); } - queryWaitYieldLatch.countDown(); + factory.waitYieldLatch.countDown(); Assert.assertEquals(1, factory.getAdapters().size()); @@ -367,7 +332,7 @@ public void testReferenceCounting() throws Exception Assert.assertFalse(segment.isClosed()); } - queryWaitLatch.countDown(); + factory.waitLatch.countDown(); future.get(); dropQueryable("test", "3", Intervals.of("2011-04-04/2011-04-05")); @@ -378,7 +343,7 @@ public void testReferenceCounting() throws Exception } @Test - public void testReferenceCountingWhileQueryExecuting() throws Exception + public void testReferenceCounting_whileQueryExecuting() throws Exception { loadQueryable("test", "3", Intervals.of("2011-04-04/2011-04-05")); @@ -390,7 +355,7 @@ public void testReferenceCountingWhileQueryExecuting() throws Exception ) ); - queryNotifyLatch.await(1000, TimeUnit.MILLISECONDS); + factory.notifyLatch.await(1000, TimeUnit.MILLISECONDS); Assert.assertEquals(1, factory.getSegmentReferences().size()); @@ -398,7 +363,7 @@ public void testReferenceCountingWhileQueryExecuting() throws Exception Assert.assertEquals(1, referenceCountingSegment.getNumReferences()); } - queryWaitYieldLatch.countDown(); + factory.waitYieldLatch.countDown(); Assert.assertEquals(1, factory.getAdapters().size()); @@ -412,7 +377,7 @@ public void testReferenceCountingWhileQueryExecuting() throws Exception Assert.assertFalse(segment.isClosed()); } - queryWaitLatch.countDown(); + factory.waitLatch.countDown(); future.get(); for (TestSegmentUtils.SegmentForTesting segment : factory.getAdapters()) { @@ -421,7 +386,7 @@ public void testReferenceCountingWhileQueryExecuting() throws Exception } @Test - public void testMultipleDrops() throws Exception + public void testReferenceCounting_multipleDrops() throws Exception { loadQueryable("test", "3", Intervals.of("2011-04-04/2011-04-05")); @@ -433,7 +398,7 @@ public void testMultipleDrops() throws Exception ) ); - queryNotifyLatch.await(1000, TimeUnit.MILLISECONDS); + factory.notifyLatch.await(1000, TimeUnit.MILLISECONDS); Assert.assertEquals(1, factory.getSegmentReferences().size()); @@ -441,7 +406,7 @@ public void testMultipleDrops() throws Exception Assert.assertEquals(1, referenceCountingSegment.getNumReferences()); } - queryWaitYieldLatch.countDown(); + factory.waitYieldLatch.countDown(); Assert.assertEquals(1, factory.getAdapters().size()); @@ -456,7 +421,7 @@ public void testMultipleDrops() throws Exception Assert.assertFalse(segment.isClosed()); } - queryWaitLatch.countDown(); + factory.waitLatch.countDown(); future.get(); for (TestSegmentUtils.SegmentForTesting segment : factory.getAdapters()) { @@ -465,7 +430,49 @@ public void testMultipleDrops() throws Exception } @Test - public void testGetQueryRunnerForIntervalsWhenTimelineIsMissingReturningNoopQueryRunner() + public void testReferenceCounting_restrictedSegment() throws Exception + { + factory = new MyQueryRunnerFactory(new CountDownLatch(1), new CountDownLatch(1), new CountDownLatch(2)); + conglomerate = DefaultQueryRunnerFactoryConglomerate.buildFromQueryRunnerFactories(ImmutableMap.of( + SearchQuery.class, + factory + )); + serverManager = Guice.createInjector(BoundFieldModule.of(this)).getInstance(ServerManager.class); + + Interval interval = Intervals.of("P1d/2011-04-01"); + loadQueryable("test", "1", interval); + SearchQuery query = searchQuery("test", interval, Granularities.ALL); + SearchQuery queryOnRestricted = searchQuery(RestrictedDataSource.create( + TableDataSource.create("test"), + NoRestrictionPolicy.instance() + ), interval, Granularities.ALL); + + Future future = assertQuery(query, interval, ImmutableList.of(new Pair<>("1", interval))); + // sleep for 1s to make sure the first query hits/finishes factory.createRunner first, since we can't test adapter + // and segmentReference for RestrictedSegment unless there's already a ReferenceCountingSegment. + Thread.sleep(1000L); + Future futureOnRestricted = assertQuery( + queryOnRestricted, + interval, + ImmutableList.of(new Pair<>("1", interval)) + ); + + Assert.assertTrue(factory.notifyLatch.await(1000, TimeUnit.MILLISECONDS)); + Assert.assertEquals(1, factory.getSegmentReferences().size()); + // Expect 2 references here: 1 for query and 1 for queryOnRestricted + Assert.assertEquals(2, factory.getSegmentReferences().get(0).getNumReferences()); + + factory.waitYieldLatch.countDown(); + factory.waitLatch.countDown(); + future.get(); + futureOnRestricted.get(); + Assert.assertEquals(1, factory.getSegmentReferences().size()); + // no references since both query are finished + Assert.assertEquals(0, factory.getSegmentReferences().get(0).getNumReferences()); + } + + @Test + public void testGetQueryRunnerForIntervals_whenTimelineIsMissingReturningNoopQueryRunner() { final Interval interval = Intervals.of("0000-01-01/P1D"); final QueryRunner> queryRunner = serverManager.getQueryRunnerForIntervals( @@ -476,7 +483,7 @@ public void testGetQueryRunnerForIntervalsWhenTimelineIsMissingReturningNoopQuer } @Test - public void testGetQueryRunnerForSegmentsWhenTimelineIsMissingReportingMissingSegmentsOnQueryDataSource() + public void testGetQueryRunnerForSegments_whenTimelineIsMissingReportingMissingSegmentsOnQueryDataSource() { final Interval interval = Intervals.of("0000-01-01/P1D"); final SearchQuery query = searchQueryWithQueryDataSource("unknown_datasource", interval, Granularities.ALL); @@ -487,52 +494,49 @@ public void testGetQueryRunnerForSegmentsWhenTimelineIsMissingReportingMissingSe DruidException.class, () -> serverManager.getQueryRunnerForSegments(query, unknownSegments) ); - MatcherAssert.assertThat( - e.getMessage(), - StringContainsInOrder.stringContainsInOrder(Arrays.asList("Base dataSource", "is not a table!")) - ); + Assert.assertTrue(e.getMessage().startsWith("Base dataSource")); + Assert.assertTrue(e.getMessage().endsWith("is not a table!")); } @Test - public void testGetQueryRunnerForSegmentsWhenTimelineIsMissingReportingMissingSegments() + public void testGetQueryRunnerForSegments_whenTimelineIsMissingReportingMissingSegments() { final Interval interval = Intervals.of("0000-01-01/P1D"); final SearchQuery query = searchQuery("unknown_datasource", interval, Granularities.ALL); final List unknownSegments = Collections.singletonList( new SegmentDescriptor(interval, "unknown_version", 0) ); - final QueryRunner> queryRunner = serverManager.getQueryRunnerForSegments( - query, - unknownSegments - ); final ResponseContext responseContext = DefaultResponseContext.createEmpty(); - final List> results = queryRunner.run(QueryPlus.wrap(query), responseContext).toList(); + + final List> results = serverManager.getQueryRunnerForSegments(query, unknownSegments) + .run(QueryPlus.wrap(query), responseContext) + .toList(); + Assert.assertTrue(results.isEmpty()); Assert.assertNotNull(responseContext.getMissingSegments()); Assert.assertEquals(unknownSegments, responseContext.getMissingSegments()); } @Test - public void testGetQueryRunnerForSegmentsWhenTimelineEntryIsMissingReportingMissingSegments() + public void testGetQueryRunnerForSegments_whenTimelineEntryIsMissingReportingMissingSegments() { final Interval interval = Intervals.of("P1d/2011-04-01"); final SearchQuery query = searchQuery("test", interval, Granularities.ALL); final List unknownSegments = Collections.singletonList( new SegmentDescriptor(interval, "unknown_version", 0) ); - final QueryRunner> queryRunner = serverManager.getQueryRunnerForSegments( - query, - unknownSegments - ); final ResponseContext responseContext = DefaultResponseContext.createEmpty(); - final List> results = queryRunner.run(QueryPlus.wrap(query), responseContext).toList(); + + final List> results = serverManager.getQueryRunnerForSegments(query, unknownSegments) + .run(QueryPlus.wrap(query), responseContext) + .toList(); Assert.assertTrue(results.isEmpty()); Assert.assertNotNull(responseContext.getMissingSegments()); Assert.assertEquals(unknownSegments, responseContext.getMissingSegments()); } @Test - public void testGetQueryRunnerForSegmentsWhenTimelinePartitionChunkIsMissingReportingMissingSegments() + public void testGetQueryRunnerForSegments_whenTimelinePartitionChunkIsMissingReportingMissingSegments() { final Interval interval = Intervals.of("P1d/2011-04-01"); final int unknownPartitionId = 1000; @@ -540,25 +544,24 @@ public void testGetQueryRunnerForSegmentsWhenTimelinePartitionChunkIsMissingRepo final List unknownSegments = Collections.singletonList( new SegmentDescriptor(interval, "1", unknownPartitionId) ); - final QueryRunner> queryRunner = serverManager.getQueryRunnerForSegments( - query, - unknownSegments - ); final ResponseContext responseContext = DefaultResponseContext.createEmpty(); - final List> results = queryRunner.run(QueryPlus.wrap(query), responseContext).toList(); + final List> results = serverManager.getQueryRunnerForSegments(query, unknownSegments) + .run(QueryPlus.wrap(query), responseContext) + .toList(); Assert.assertTrue(results.isEmpty()); Assert.assertNotNull(responseContext.getMissingSegments()); Assert.assertEquals(unknownSegments, responseContext.getMissingSegments()); } @Test - public void testGetQueryRunnerForSegmentsWhenSegmentIsClosedReportingMissingSegments() + public void testGetQueryRunnerForSegments_whenSegmentIsClosedReportingMissingSegments() { final Interval interval = Intervals.of("P1d/2011-04-01"); final SearchQuery query = searchQuery("test", interval, Granularities.ALL); final Optional> maybeTimeline = segmentManager .getTimeline(ExecutionVertex.of(query).getBaseTableDataSource()); - Assert.assertTrue(maybeTimeline.isPresent()); + Assume.assumeTrue(maybeTimeline.isPresent()); + // close all segments in interval final List> holders = maybeTimeline.get().lookup(interval); final List closedSegments = new ArrayList<>(); for (TimelineObjectHolder holder : holders) { @@ -571,113 +574,110 @@ public void testGetQueryRunnerForSegmentsWhenSegmentIsClosedReportingMissingSegm segment.close(); } } - final QueryRunner> queryRunner = serverManager.getQueryRunnerForSegments( - query, - closedSegments - ); final ResponseContext responseContext = DefaultResponseContext.createEmpty(); - final List> results = queryRunner.run(QueryPlus.wrap(query), responseContext).toList(); + + final List> results = serverManager.getQueryRunnerForSegments(query, closedSegments) + .run(QueryPlus.wrap(query), responseContext) + .toList(); Assert.assertTrue(results.isEmpty()); Assert.assertNotNull(responseContext.getMissingSegments()); Assert.assertEquals(closedSegments, responseContext.getMissingSegments()); } @Test - public void testGetQueryRunnerForSegmentsForUnknownQueryThrowingException() + public void testGetQueryRunnerForSegments_forUnknownQueryThrowingException() { final Interval interval = Intervals.of("P1d/2011-04-01"); final List descriptors = Collections.singletonList(new SegmentDescriptor(interval, "1", 0)); - expectedException.expect(QueryUnsupportedException.class); - expectedException.expectMessage("Unknown query type"); - serverManager.getQueryRunnerForSegments( - new BaseQuery<>( - new TableDataSource("test"), - new MultipleSpecificSegmentSpec(descriptors), - new HashMap<>() - ) - { - @Override - public boolean hasFilters() - { - return false; - } - - @Override - public DimFilter getFilter() - { - return null; - } - - @Override - public String getType() - { - return null; - } - - @Override - public Query withOverriddenContext(Map contextOverride) - { - return this; - } - - @Override - public Query withQuerySegmentSpec(QuerySegmentSpec spec) - { - return null; - } + Query query = Druids.newTimeBoundaryQueryBuilder() + .dataSource("random-ds") + .intervals(interval.toString()) + .build(); + // We only have QueryRunnerFactory for SearchQuery in test. + QueryUnsupportedException e = Assert.assertThrows( + QueryUnsupportedException.class, + () -> serverManager.getQueryRunnerForSegments(query, descriptors) + ); + Assert.assertTrue(e.getMessage().startsWith("Unknown query type")); + } - @Override - public Query withDataSource(DataSource dataSource) - { - return null; - } - }, - descriptors + @Test + public void testGetQueryRunnerForSegments_restricted() throws Exception + { + conglomerate = DefaultQueryRunnerFactoryConglomerate.buildFromQueryRunnerFactories(ImmutableMap.of( + SearchQuery.class, + factory + )); + serverManager = Guice.createInjector(BoundFieldModule.of(this)).getInstance(ServerManager.class); + Interval interval = Intervals.of("P1d/2011-04-01"); + SearchQuery query = searchQuery("test", interval, Granularities.ALL); + SearchQuery queryOnRestricted = searchQuery(RestrictedDataSource.create( + TableDataSource.create("test"), + NoRestrictionPolicy.instance() + ), interval, Granularities.ALL); + + serverManager.getQueryRunnerForIntervals(query, ImmutableList.of(interval)).run(QueryPlus.wrap(query)).toList(); + // switch to a policy enforcer that restricts all tables + policyEnforcer = new RestrictAllTablesPolicyEnforcer(ImmutableList.of(NoRestrictionPolicy.class.getName())); + serverManager = Guice.createInjector(BoundFieldModule.of(this)).getInstance(ServerManager.class); + // fail on query + DruidException e = Assert.assertThrows( + DruidException.class, + () -> serverManager.getQueryRunnerForIntervals(query, ImmutableList.of(interval)) + .run(QueryPlus.wrap(query)) + .toList() ); + Assert.assertEquals(DruidException.Category.FORBIDDEN, e.getCategory()); + Assert.assertEquals(DruidException.Persona.OPERATOR, e.getTargetPersona()); + Assert.assertEquals( + "Failed security validation with segment [test_2011-03-31T00:00:00.000Z_2011-04-01T00:00:00.000Z_1]", + e.getMessage() + ); + // succeed on queryOnRestricted + serverManager.getQueryRunnerForIntervals(queryOnRestricted, ImmutableList.of(interval)) + .run(QueryPlus.wrap(queryOnRestricted)) + .toList(); } private void waitForTestVerificationAndCleanup(Future future) { try { - queryNotifyLatch.await(1000, TimeUnit.MILLISECONDS); - queryWaitYieldLatch.countDown(); - queryWaitLatch.countDown(); + factory.notifyLatch.await(1000, TimeUnit.MILLISECONDS); + factory.waitLatch.countDown(); future.get(); factory.clearAdapters(); } catch (Exception e) { - throw new RuntimeException(e); + throw new RuntimeException(e.getCause()); } } - private SearchQuery searchQuery(String datasource, Interval interval, Granularity granularity) - { - return Druids.newSearchQueryBuilder() - .dataSource(datasource) - .intervals(Collections.singletonList(interval)) - .granularity(granularity) - .limit(10000) - .query("wow") - .build(); - } - - - private SearchQuery searchQueryWithQueryDataSource(String datasource, Interval interval, Granularity granularity) + private static SearchQuery searchQueryWithQueryDataSource( + String datasource, + Interval interval, + Granularity granularity + ) { final ImmutableList descriptors = ImmutableList.of( new SegmentDescriptor(Intervals.of("2000/3000"), "0", 0), new SegmentDescriptor(Intervals.of("2000/3000"), "0", 1) ); + return searchQuery(new QueryDataSource(Druids.newTimeseriesQueryBuilder() + .dataSource(datasource) + .intervals(new MultipleSpecificSegmentSpec(descriptors)) + .granularity(Granularities.ALL) + .build()), interval, granularity); + } + + private static SearchQuery searchQuery(String datasource, Interval interval, Granularity granularity) + { + return searchQuery(TableDataSource.create(datasource), interval, granularity); + } + + private static SearchQuery searchQuery(DataSource datasource, Interval interval, Granularity granularity) + { return Druids.newSearchQueryBuilder() - .dataSource( - new QueryDataSource( - Druids.newTimeseriesQueryBuilder() - .dataSource(datasource) - .intervals(new MultipleSpecificSegmentSpec(descriptors)) - .granularity(Granularities.ALL) - .build() - ) - ) + .dataSource(datasource) .intervals(Collections.singletonList(interval)) .granularity(granularity) .limit(10000) @@ -685,29 +685,35 @@ private SearchQuery searchQueryWithQueryDataSource(String datasource, Interval i .build(); } - private Future assertQueryable( + private Future assertQueryable( Granularity granularity, String dataSource, Interval interval, List> expected ) + { + final SearchQuery query = searchQuery(dataSource, interval, granularity); + return assertQuery(query, interval, expected); + } + + private Future assertQuery( + SearchQuery query, + Interval interval, + List> expected + ) { final Iterator> expectedIter = expected.iterator(); final List intervals = Collections.singletonList(interval); - final SearchQuery query = searchQuery(dataSource, interval, granularity); - final QueryRunner> runner = serverManager.getQueryRunnerForIntervals( - query, - intervals - ); + final QueryRunner> runner = serverManager.getQueryRunnerForIntervals(query, intervals); return serverManagerExec.submit( () -> { Sequence> seq = runner.run(QueryPlus.wrap(query)); seq.toList(); - Iterator adaptersIter = factory.getAdapters().iterator(); + Iterator adaptersIter = factory.getAdapters().iterator(); while (expectedIter.hasNext() && adaptersIter.hasNext()) { Pair expectedVals = expectedIter.next(); - TestSegmentUtils.SegmentForTesting value = adaptersIter.next(); + SegmentForTesting value = adaptersIter.next(); Assert.assertEquals(expectedVals.lhs, value.getVersion()); Assert.assertEquals(expectedVals.rhs, value.getInterval()); @@ -722,39 +728,10 @@ private Future assertQueryable( private void loadQueryable(String dataSource, String version, Interval interval) { try { - if ("testTombstone".equals(dataSource)) { - segmentManager.loadSegment( - new DataSegment( - dataSource, - interval, - version, - ImmutableMap.of("version", version, - "interval", interval, - "type", - DataSegment.TOMBSTONE_LOADSPEC_TYPE - ), - Arrays.asList("dim1", "dim2", "dim3"), - Arrays.asList("metric1", "metric2"), - TombstoneShardSpec.INSTANCE, - IndexIO.CURRENT_VERSION_ID, - 1L - ) - ); - } else { - segmentManager.loadSegment( - new DataSegment( - dataSource, - interval, - version, - ImmutableMap.of("version", version, "interval", interval), - Arrays.asList("dim1", "dim2", "dim3"), - Arrays.asList("metric1", "metric2"), - NoneShardSpec.instance(), - IndexIO.CURRENT_VERSION_ID, - 1L - ) - ); - } + DataSegment segment = "testTombstone".equals(dataSource) ? + TestSegmentUtils.makeTombstoneSegment(dataSource, version, interval) : + TestSegmentUtils.makeSegment(dataSource, version, interval); + segmentManager.loadSegment(segment); } catch (SegmentLoadingException | IOException e) { throw new RuntimeException(e); @@ -763,19 +740,7 @@ private void loadQueryable(String dataSource, String version, Interval interval) private void dropQueryable(String dataSource, String version, Interval interval) { - segmentManager.dropSegment( - new DataSegment( - dataSource, - interval, - version, - ImmutableMap.of("version", version, "interval", interval), - Arrays.asList("dim1", "dim2", "dim3"), - Arrays.asList("metric1", "metric2"), - NoneShardSpec.instance(), - IndexIO.CURRENT_VERSION_ID, - 123L - ) - ); + segmentManager.dropSegment(TestSegmentUtils.makeSegment(dataSource, version, interval)); } private static class MyQueryRunnerFactory implements QueryRunnerFactory, SearchQuery> @@ -801,14 +766,22 @@ public MyQueryRunnerFactory( @Override public QueryRunner> createRunner(Segment adapter) { - if (!(adapter instanceof ReferenceCountingSegment)) { - throw new IAE("Expected instance of ReferenceCountingSegment, got %s", adapter.getClass()); + final ReferenceCountingSegment segment; + if (this.adapters.stream() + .map(SegmentForTesting::getId) + .anyMatch(segmentId -> adapter.getId().equals(segmentId))) { + // Already have adapter for this segment, skip. + // For RestrictedSegment, we don't have access to RestrictedSegment.delegate, but it'd be recorded in segmentReferences. + // This means we can't test adapter and segmentReference unless there's already a ReferenceCountingSegment. + } else if (adapter instanceof ReferenceCountingSegment) { + segment = (ReferenceCountingSegment) adapter; + Assert.assertTrue(segment.getNumReferences() > 0); + segmentReferences.add(segment); + adapters.add((SegmentForTesting) segment.getBaseSegment()); + } else { + throw new IAE("Unsupported segment instance: [%s]", adapter.getClass()); } - final ReferenceCountingSegment segment = (ReferenceCountingSegment) adapter; - Assert.assertTrue(segment.getNumReferences() > 0); - segmentReferences.add(segment); - adapters.add((TestSegmentUtils.SegmentForTesting) segment.getBaseSegment()); return new BlockingQueryRunner<>(new NoopQueryRunner<>(), waitLatch, waitYieldLatch, notifyLatch); } @@ -827,7 +800,7 @@ public QueryToolChest, SearchQuery> getToolchest() return new NoopQueryToolChest<>(); } - public List getAdapters() + public List getAdapters() { return adapters; } @@ -866,7 +839,9 @@ public Function makePreComputeManipulatorFn(QueryType query, MetricManipul @Override public TypeReference getResultTypeReference() { - return new TypeReference<>() {}; + return new TypeReference<>() + { + }; } } @@ -897,6 +872,10 @@ public Sequence run(QueryPlus queryPlus, ResponseContext responseContext) } } + /** + * A Sequence that count-down {@code notifyLatch} when {@link #toYielder} is called, and the returned Yielder waits + * for {@code waitYieldLatch} and {@code waitLatch} count-down. + */ private static class BlockingSequence extends YieldingSequenceBase { private final Sequence baseSequence; diff --git a/server/src/test/java/org/apache/druid/server/coordination/TestSegmentCacheManager.java b/server/src/test/java/org/apache/druid/test/utils/TestSegmentCacheManager.java similarity index 95% rename from server/src/test/java/org/apache/druid/server/coordination/TestSegmentCacheManager.java rename to server/src/test/java/org/apache/druid/test/utils/TestSegmentCacheManager.java index b7cce457f136..f212f8c15404 100644 --- a/server/src/test/java/org/apache/druid/server/coordination/TestSegmentCacheManager.java +++ b/server/src/test/java/org/apache/druid/test/utils/TestSegmentCacheManager.java @@ -17,16 +17,16 @@ * under the License. */ -package org.apache.druid.server.coordination; +package org.apache.druid.test.utils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.apache.druid.java.util.common.MapUtils; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.SegmentLazyLoadFailCallback; +import org.apache.druid.segment.TestSegmentUtils; import org.apache.druid.segment.loading.NoopSegmentCacheManager; import org.apache.druid.segment.loading.TombstoneSegmentizerFactory; -import org.apache.druid.server.TestSegmentUtils; import org.apache.druid.timeline.DataSegment; import org.joda.time.Interval; @@ -40,7 +40,7 @@ * methods to support these operations; any other method invoked will throw an exception from the base class, * {@link NoopSegmentCacheManager}. */ -class TestSegmentCacheManager extends NoopSegmentCacheManager +public class TestSegmentCacheManager extends NoopSegmentCacheManager { private final List cachedSegments; @@ -51,12 +51,12 @@ class TestSegmentCacheManager extends NoopSegmentCacheManager private final List observedSegmentsRemovedFromCache; private final AtomicInteger observedShutdownBootstrapCount; - TestSegmentCacheManager() + public TestSegmentCacheManager() { this(ImmutableSet.of()); } - TestSegmentCacheManager(final Set segmentsToCache) + public TestSegmentCacheManager(final Set segmentsToCache) { this.cachedSegments = ImmutableList.copyOf(segmentsToCache); diff --git a/services/src/main/java/org/apache/druid/cli/CliPeon.java b/services/src/main/java/org/apache/druid/cli/CliPeon.java index 3d06ad7dee3e..74ba257737fb 100644 --- a/services/src/main/java/org/apache/druid/cli/CliPeon.java +++ b/services/src/main/java/org/apache/druid/cli/CliPeon.java @@ -265,6 +265,7 @@ public void configure(Binder binder) .setTaskFile(Paths.get(taskDirPath, "task.json").toFile()) .setStatusFile(Paths.get(taskDirPath, "attempt", attemptId, "status.json").toFile()); + binder.bind(Properties.class).toInstance(properties); if (properties.getProperty("druid.indexer.runner.type", "").contains("k8s")) { log.info("Running peon in k8s mode"); executorLifecycleConfig.setParentStreamDefined(false); diff --git a/services/src/test/java/org/apache/druid/cli/CliPeonTest.java b/services/src/test/java/org/apache/druid/cli/CliPeonTest.java index 576298921628..0743b0d62a17 100644 --- a/services/src/test/java/org/apache/druid/cli/CliPeonTest.java +++ b/services/src/test/java/org/apache/druid/cli/CliPeonTest.java @@ -42,6 +42,8 @@ import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.granularity.AllGranularity; import org.apache.druid.query.DruidMetrics; +import org.apache.druid.query.policy.PolicyEnforcer; +import org.apache.druid.query.policy.RestrictAllTablesPolicyEnforcer; import org.apache.druid.segment.indexing.DataSchema; import org.apache.druid.storage.local.LocalTmpStorageConfig; import org.joda.time.Duration; @@ -102,6 +104,23 @@ public void testCliPeonK8sANdWorkerIsK8sMode() throws IOException Assert.assertNotNull(runnable.makeInjector()); } + @Test + public void testCliPeonPolicyEnforcerInToolbox() throws IOException + { + CliPeon runnable = new CliPeon(); + File file = temporaryFolder.newFile("task.json"); + FileUtils.write(file, "{\"type\":\"noop\"}", StandardCharsets.UTF_8); + runnable.taskAndStatusFile = ImmutableList.of(file.getParent(), "1"); + + Properties properties = new Properties(); + properties.setProperty("druid.policy.enforcer.type", "restrictAllTables"); + runnable.configure(properties); + runnable.configure(properties, GuiceInjectors.makeStartupInjector()); + + Injector secondaryInjector = runnable.makeInjector(); + Assert.assertEquals(new RestrictAllTablesPolicyEnforcer(null), secondaryInjector.getInstance(PolicyEnforcer.class)); + } + @Test public void testCliPeonHeartbeatDimensions() throws IOException { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java index 3e69d275471f..84587d639062 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java @@ -40,6 +40,7 @@ import org.apache.druid.guice.annotations.Json; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.QueryContexts; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.AuthorizationResult; @@ -81,6 +82,7 @@ public PlannerFactory( final JoinableFactoryWrapper joinableFactoryWrapper, final CatalogResolver catalog, final AuthConfig authConfig, + final PolicyEnforcer policyEnforcer, final DruidHookDispatcher hookDispatcher ) { @@ -96,6 +98,7 @@ public PlannerFactory( calciteRuleManager, authorizerMapper, authConfig, + policyEnforcer, hookDispatcher ); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerToolbox.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerToolbox.java index d8e3a25f4583..17887afd06a5 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerToolbox.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerToolbox.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Preconditions; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.query.policy.PolicyEnforcer; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.AuthorizerMapper; @@ -41,6 +42,7 @@ public class PlannerToolbox protected final CalciteRulesManager calciteRuleManager; protected final AuthorizerMapper authorizerMapper; protected final AuthConfig authConfig; + protected final PolicyEnforcer policyEnforcer; protected final DruidHookDispatcher hookDispatcher; public PlannerToolbox( @@ -55,6 +57,7 @@ public PlannerToolbox( final CalciteRulesManager calciteRuleManager, final AuthorizerMapper authorizerMapper, final AuthConfig authConfig, + final PolicyEnforcer policyEnforcer, final DruidHookDispatcher hookDispatcher ) { @@ -69,6 +72,7 @@ public PlannerToolbox( this.calciteRuleManager = calciteRuleManager; this.authorizerMapper = authorizerMapper; this.authConfig = authConfig; + this.policyEnforcer = policyEnforcer; this.hookDispatcher = hookDispatcher; } @@ -122,6 +126,11 @@ public AuthConfig getAuthConfig() return authConfig; } + public PolicyEnforcer getPolicyEnforcer() + { + return policyEnforcer; + } + public DruidHookDispatcher getHookDispatcher() { return hookDispatcher; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java index 067c6c2a0fe5..09790ae75b1d 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java @@ -303,7 +303,10 @@ public static DruidQuery fromPartialQuery( } return new DruidQuery( - applyPolicies ? dataSource.withPolicies(plannerContext.getAuthorizationResult().getPolicyMap()) : dataSource, + applyPolicies ? dataSource.withPolicies( + plannerContext.getAuthorizationResult().getPolicyMap(), + plannerContext.getPlannerToolbox().getPolicyEnforcer() + ) : dataSource, plannerContext, filter, selectProjection, diff --git a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java index e924102284f3..dfbad004f3ca 100644 --- a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java @@ -19,6 +19,7 @@ package org.apache.druid.sql; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListeningExecutorService; @@ -37,6 +38,9 @@ import org.apache.druid.query.Query; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.NoopPolicyEnforcer; +import org.apache.druid.query.policy.PolicyEnforcer; +import org.apache.druid.query.policy.RestrictAllTablesPolicyEnforcer; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.server.QueryScheduler; import org.apache.druid.server.QueryStackTests; @@ -78,6 +82,7 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.apache.druid.sql.calcite.BaseCalciteQueryTest.assertResultsEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertSame; @@ -91,12 +96,13 @@ public class SqlStatementTest private static Closer resourceCloser; @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder(); - private TestRequestLogger testRequestLogger; private ListeningExecutorService executorService; - private SqlStatementFactory sqlStatementFactory; private final DefaultQueryConfig defaultQueryConfig = new DefaultQueryConfig( ImmutableMap.of("DEFAULT_KEY", "DEFAULT_VALUE")); + private PolicyEnforcer policyEnforcer; + private SqlStatementFactory sqlStatementFactory; + @BeforeClass public static void setUpClass() throws Exception { @@ -135,45 +141,8 @@ public void setUp() { executorService = MoreExecutors.listeningDecorator(Execs.multiThreaded(8, "test_sql_resource_%s")); - final PlannerConfig plannerConfig = PlannerConfig.builder().build(); - final DruidSchemaCatalog rootSchema = CalciteTests.createMockRootSchema( - conglomerate, - walker, - plannerConfig, - CalciteTests.TEST_AUTHORIZER_MAPPER - ); - final DruidOperatorTable operatorTable = CalciteTests.createOperatorTable(); - final ExprMacroTable macroTable = CalciteTests.createExprMacroTable(); - - testRequestLogger = new TestRequestLogger(); - final JoinableFactoryWrapper joinableFactoryWrapper = CalciteTests.createJoinableFactoryWrapper(); - - final PlannerFactory plannerFactory = new PlannerFactory( - rootSchema, - operatorTable, - macroTable, - plannerConfig, - CalciteTests.TEST_AUTHORIZER_MAPPER, - CalciteTests.getJsonMapper(), - CalciteTests.DRUID_SCHEMA_NAME, - new CalciteRulesManager(ImmutableSet.of()), - joinableFactoryWrapper, - CatalogResolver.NULL_RESOLVER, - new AuthConfig(), - new DruidHookDispatcher() - ); - - this.sqlStatementFactory = new SqlStatementFactory( - new SqlToolbox( - CalciteTests.createMockSqlEngine(walker, conglomerate), - plannerFactory, - new NoopServiceEmitter(), - testRequestLogger, - QueryStackTests.DEFAULT_NOOP_SCHEDULER, - defaultQueryConfig, - new SqlLifecycleManager() - ) - ); + policyEnforcer = NoopPolicyEnforcer.instance(); + this.sqlStatementFactory = buildSqlStatementFactory(); } @After @@ -275,6 +244,42 @@ public void testDirectExecTwice() } } + @Test + public void testDirectPolicyEnforcerThrowsForNoPolicy() + { + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + sqlStatementFactory = buildSqlStatementFactory(); + SqlQueryPlus sqlReq = queryPlus( + "SELECT COUNT(*) AS cnt FROM druid.foo", + CalciteTests.REGULAR_USER_AUTH_RESULT + ); + DirectStatement stmt = sqlStatementFactory.directStatement(sqlReq); + ResultSet resultSet = stmt.plan(); + DruidException e = Assert.assertThrows(DruidException.class, () -> resultSet.run()); + + Assert.assertEquals(DruidException.Category.FORBIDDEN, e.getCategory()); + Assert.assertEquals(DruidException.Persona.OPERATOR, e.getTargetPersona()); + Assert.assertEquals("Failed security validation with dataSource [foo]", e.getMessage()); + } + + @Test + public void testDirectPolicyEnforcerValidatesWithPolicy() + { + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + sqlStatementFactory = buildSqlStatementFactory(); + SqlQueryPlus sqlReq = queryPlus( + "SELECT COUNT(*) AS cnt FROM druid.restrictedDatasource_m1_is_6", + CalciteTests.REGULAR_USER_AUTH_RESULT + ); + + DirectStatement stmt = sqlStatementFactory.directStatement(sqlReq); + ResultSet resultSet = stmt.plan(); + List results = resultSet.run().getResults().toList(); + + ImmutableList expectedResults = ImmutableList.of(new Object[]{1L}); + assertResultsEquals("SELECT COUNT(*) AS cnt FROM druid.restrictedDatasource_m1_is_6", expectedResults, results); + } + @Test public void testDirectSyntaxError() { @@ -423,6 +428,38 @@ public void testHttpPermissionError() } } + @Test + public void testHttpPolicyEnforcerThrowsForNoPolicy() throws Exception + { + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + sqlStatementFactory = buildSqlStatementFactory(); + HttpStatement stmt = sqlStatementFactory.httpStatement( + makeQuery("SELECT COUNT(*) AS cnt FROM druid.foo"), + request(true) + ); + ResultSet resultSet = stmt.plan(); + DruidException e = Assert.assertThrows(DruidException.class, () -> resultSet.run()); + + Assert.assertEquals(DruidException.Category.FORBIDDEN, e.getCategory()); + Assert.assertEquals(DruidException.Persona.OPERATOR, e.getTargetPersona()); + Assert.assertEquals("Failed security validation with dataSource [foo]", e.getMessage()); + } + + @Test + public void testHttpPolicyEnforcerValidatesWithPolicy() + { + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + sqlStatementFactory = buildSqlStatementFactory(); + HttpStatement stmt = sqlStatementFactory.httpStatement( + makeQuery("SELECT COUNT(*) AS cnt FROM druid.restrictedDatasource_m1_is_6"), + request(true) + ); + List results = stmt.plan().run().getResults().toList(); + + ImmutableList expectedResults = ImmutableList.of(new Object[]{1L}); + assertResultsEquals("SELECT COUNT(*) AS cnt FROM druid.restrictedDatasource_m1_is_6", expectedResults, results); + } + //----------------------------------------------------------------- // Prepared statements: using a prepare/execute model. @@ -518,6 +555,39 @@ public void testPreparePermissionError() } } + @Test + public void testPreparePolicyEnforcerThrowsForNoPolicy() throws Exception + { + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + sqlStatementFactory = buildSqlStatementFactory(); + SqlQueryPlus sqlReq = queryPlus( + "SELECT COUNT(*) AS cnt FROM druid.foo", + CalciteTests.REGULAR_USER_AUTH_RESULT + ); + PreparedStatement stmt = sqlStatementFactory.preparedStatement(sqlReq); + DruidException e = Assert.assertThrows(DruidException.class, () -> stmt.execute(Collections.emptyList()).execute()); + + Assert.assertEquals(DruidException.Category.FORBIDDEN, e.getCategory()); + Assert.assertEquals(DruidException.Persona.OPERATOR, e.getTargetPersona()); + Assert.assertEquals("Failed security validation with dataSource [foo]", e.getMessage()); + } + + @Test + public void testPreparePolicyEnforcerValidatesWithPolicy() + { + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + sqlStatementFactory = buildSqlStatementFactory(); + SqlQueryPlus sqlReq = queryPlus( + "SELECT COUNT(*) AS cnt FROM druid.restrictedDatasource_m1_is_6", + CalciteTests.REGULAR_USER_AUTH_RESULT + ); + PreparedStatement stmt = sqlStatementFactory.preparedStatement(sqlReq); + List results = stmt.execute(Collections.emptyList()).execute().getResults().toList(); + + ImmutableList expectedResults = ImmutableList.of(new Object[]{1L}); + assertResultsEquals("SELECT COUNT(*) AS cnt FROM druid.restrictedDatasource_m1_is_6", expectedResults, results); + } + //----------------------------------------------------------------- // Generic tests. @@ -552,4 +622,48 @@ public void testDefaultQueryContextIsApplied() Assert.assertTrue(context.containsKey(defaultContextKey)); } } + + private SqlStatementFactory buildSqlStatementFactory() + { + final PlannerConfig plannerConfig = PlannerConfig.builder().build(); + final DruidSchemaCatalog rootSchema = CalciteTests.createMockRootSchema( + conglomerate, + walker, + plannerConfig, + CalciteTests.TEST_AUTHORIZER_MAPPER + ); + final DruidOperatorTable operatorTable = CalciteTests.createOperatorTable(); + final ExprMacroTable macroTable = CalciteTests.createExprMacroTable(); + + TestRequestLogger testRequestLogger = new TestRequestLogger(); + final JoinableFactoryWrapper joinableFactoryWrapper = CalciteTests.createJoinableFactoryWrapper(); + + final PlannerFactory plannerFactory = new PlannerFactory( + rootSchema, + operatorTable, + macroTable, + plannerConfig, + CalciteTests.TEST_AUTHORIZER_MAPPER, + CalciteTests.getJsonMapper(), + CalciteTests.DRUID_SCHEMA_NAME, + new CalciteRulesManager(ImmutableSet.of()), + joinableFactoryWrapper, + CatalogResolver.NULL_RESOLVER, + new AuthConfig(), + policyEnforcer, + new DruidHookDispatcher() + ); + + return new SqlStatementFactory( + new SqlToolbox( + CalciteTests.createMockSqlEngine(walker, conglomerate), + plannerFactory, + new NoopServiceEmitter(), + testRequestLogger, + QueryStackTests.DEFAULT_NOOP_SCHEDULER, + defaultQueryConfig, + new SqlLifecycleManager() + ) + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java index 50f05fd98fe9..0aae1021bb10 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java @@ -42,6 +42,7 @@ import org.apache.calcite.avatica.server.AbstractAvaticaHandler; import org.apache.druid.guice.LazySingleton; import org.apache.druid.guice.StartupInjectorBuilder; +import org.apache.druid.guice.security.PolicyModule; import org.apache.druid.initialization.CoreInjectorBuilder; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Pair; @@ -55,6 +56,7 @@ import org.apache.druid.query.BaseQuery; import org.apache.druid.query.DefaultQueryConfig; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.server.DruidNode; import org.apache.druid.server.QueryLifecycleFactory; @@ -273,6 +275,7 @@ public void setUp() throws Exception binder.bind(AuthenticatorMapper.class).toInstance(CalciteTests.TEST_AUTHENTICATOR_MAPPER); binder.bind(AuthorizerMapper.class).toInstance(CalciteTests.TEST_AUTHORIZER_MAPPER); binder.bind(Escalator.class).toInstance(CalciteTests.TEST_AUTHENTICATOR_ESCALATOR); + binder.install(new PolicyModule()); binder.bind(RequestLogger.class).toInstance(testRequestLogger); binder.bind(DruidSchemaCatalog.class).toInstance(rootSchema); for (NamedSchema schema : rootSchema.getNamedSchemas().values()) { @@ -1059,6 +1062,7 @@ private SqlStatementFactory makeStatementFactory() CalciteTests.createJoinableFactoryWrapper(), CatalogResolver.NULL_RESOLVER, new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ) ); diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java index 94b20eb9d65f..b74362d7df29 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java @@ -29,6 +29,7 @@ import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.server.QueryStackTests; import org.apache.druid.server.SpecificSegmentsQuerySegmentWalker; @@ -113,6 +114,7 @@ public void setUp() joinableFactoryWrapper, CatalogResolver.NULL_RESOLVER, new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); this.sqlStatementFactory = CalciteTests.createSqlStatementFactory( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java index aa8542d92bef..53f292d80fc9 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java @@ -1052,7 +1052,7 @@ public void assertResultsValid(final ResultMatchMode matchMode, final List expectedResults, List results) + public static void assertResultsEquals(String sql, List expectedResults, List results) { int minSize = Math.min(results.size(), expectedResults.size()); for (int i = 0; i < minSize; i++) { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/SqlVectorizedExpressionSanityTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/SqlVectorizedExpressionSanityTest.java index 07f57a5deeb3..26b125bf6576 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/SqlVectorizedExpressionSanityTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/SqlVectorizedExpressionSanityTest.java @@ -31,6 +31,7 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.generator.GeneratorBasicSchemas; import org.apache.druid.segment.generator.GeneratorSchemaInfo; @@ -154,6 +155,7 @@ public static void setupClass() joinableFactoryWrapper, CatalogResolver.NULL_RESOLVER, new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java index f4ada1f1c179..20d7c0fa0786 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java @@ -37,6 +37,7 @@ import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.RowAdapters; import org.apache.druid.segment.RowBasedColumnSelectorFactory; import org.apache.druid.segment.VirtualColumn; @@ -99,6 +100,7 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) new CalciteRulesManager(ImmutableSet.of()), CalciteTests.TEST_AUTHORIZER_MAPPER, AuthConfig.newBuilder().build(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); public static final PlannerContext PLANNER_CONTEXT = PlannerContext.create( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java index fe26ba9db8c7..4f30cc1feb11 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java @@ -26,6 +26,7 @@ import org.apache.calcite.schema.SchemaPlus; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QuerySegmentWalker; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.sql.calcite.planner.CalciteRulesManager; import org.apache.druid.sql.calcite.planner.CatalogResolver; @@ -73,6 +74,7 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) new CalciteRulesManager(ImmutableSet.of()), CalciteTests.TEST_AUTHORIZER_MAPPER, AuthConfig.newBuilder().build(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); final PlannerContext plannerContext = PlannerContext.create( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java index 89df405b7f1f..52aadd54839c 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java @@ -33,12 +33,12 @@ import org.apache.calcite.rel.rules.ProjectMergeRule; import org.apache.calcite.schema.Schema; import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.security.PolicyModule; import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.jackson.JacksonModule; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.server.QueryLifecycleFactory; -import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.server.security.ResourceType; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; @@ -49,8 +49,6 @@ import org.apache.druid.sql.calcite.schema.DruidSchemaName; import org.apache.druid.sql.calcite.schema.NamedSchema; import org.apache.druid.sql.calcite.util.CalciteTestBase; -import org.apache.druid.sql.calcite.util.CalciteTests; -import org.apache.druid.sql.hook.DruidHookDispatcher; import org.easymock.EasyMock; import org.easymock.EasyMockExtension; import org.easymock.Mock; @@ -127,6 +125,7 @@ public void onMatch(RelOptRuleCall call) }; injector = Guice.createInjector( new JacksonModule(), + new PolicyModule(), binder -> { binder.bind(Validator.class).toInstance(Validation.buildDefaultValidatorFactory().getValidator()); binder.bindScope(LazySingleton.class, Scopes.SINGLETON); @@ -182,20 +181,7 @@ public void testPlannerConfigIsInjected() public void testExtensionCalciteRule() { ObjectMapper mapper = new DefaultObjectMapper(); - PlannerToolbox toolbox = new PlannerToolbox( - injector.getInstance(DruidOperatorTable.class), - macroTable, - mapper, - injector.getInstance(PlannerConfig.class), - rootSchema, - joinableFactoryWrapper, - CatalogResolver.NULL_RESOLVER, - "druid", - new CalciteRulesManager(ImmutableSet.of()), - CalciteTests.TEST_AUTHORIZER_MAPPER, - AuthConfig.newBuilder().build(), - new DruidHookDispatcher() - ); + PlannerToolbox toolbox = injector.getInstance(PlannerFactory.class); PlannerContext context = PlannerContext.create( toolbox, @@ -215,20 +201,7 @@ public void testExtensionCalciteRule() public void testConfigurableBloat() { ObjectMapper mapper = new DefaultObjectMapper(); - PlannerToolbox toolbox = new PlannerToolbox( - injector.getInstance(DruidOperatorTable.class), - macroTable, - mapper, - injector.getInstance(PlannerConfig.class), - rootSchema, - joinableFactoryWrapper, - CatalogResolver.NULL_RESOLVER, - "druid", - new CalciteRulesManager(ImmutableSet.of()), - CalciteTests.TEST_AUTHORIZER_MAPPER, - AuthConfig.newBuilder().build(), - new DruidHookDispatcher() - ); + PlannerToolbox toolbox = injector.getInstance(PlannerFactory.class); PlannerContext contextWithBloat = PlannerContext.create( toolbox, diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java index 7888ec8090ba..816ce3fd1a1d 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java @@ -39,6 +39,7 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.server.security.AuthConfig; @@ -105,6 +106,7 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) new CalciteRulesManager(ImmutableSet.of()), CalciteTests.TEST_AUTHORIZER_MAPPER, AuthConfig.newBuilder().build(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); private static final PlannerContext PLANNER_CONTEXT = PlannerContext.create( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidRelTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidRelTest.java index 549484b5f9a9..292ba32cff88 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidRelTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/rel/DruidRelTest.java @@ -53,6 +53,7 @@ import org.apache.druid.query.TableDataSource; import org.apache.druid.query.UnionDataSource; import org.apache.druid.query.UnnestDataSource; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.query.policy.Policy; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; @@ -60,6 +61,7 @@ import org.apache.druid.sql.calcite.planner.ExpressionParser; import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.planner.PlannerToolbox; import org.apache.druid.sql.calcite.rel.logical.DruidUnion; import org.apache.druid.sql.calcite.table.DatasourceTable; import org.apache.druid.sql.calcite.table.DatasourceTable.PhysicalDatasourceMetadata; @@ -121,6 +123,9 @@ public void setup() throws Exception when(mockRelOptCluster.getTypeFactory()).thenReturn(DEFAULT_TYPE_FACTORY); when(mockRelOptTable.getRowType()).thenReturn(REC_TYPE); + PlannerToolbox mockPlannerToolbox = mock(PlannerToolbox.class); + when(mockPlannerToolbox.getPolicyEnforcer()).thenReturn(NoopPolicyEnforcer.instance()); + when(mockPlannerContext.getPlannerToolbox()).thenReturn(mockPlannerToolbox); when(mockPlannerContext.getPlannerConfig()).thenReturn(PlannerConfig.builder().build()); when(mockPlannerContext.getJsonMapper()).thenReturn(JsonMapper.builder().build()); when(mockPlannerContext.getAuthorizationResult()).thenReturn(AUTHORIZATION_RESULT); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java index 67b5b9d11e22..39e1b4f41307 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java @@ -35,6 +35,7 @@ import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.segment.join.JoinableFactory; import org.apache.druid.segment.loading.SegmentCacheManager; import org.apache.druid.segment.metadata.CentralizedDatasourceSchemaConfig; @@ -96,6 +97,7 @@ public static QueryLifecycleFactory createMockQueryLifecycleFactory( new ServiceEmitter("dummy", "dummy", new NoopEmitter()), new NoopRequestLogger(), new AuthConfig(), + NoopPolicyEnforcer.instance(), authorizerMapper, Suppliers.ofInstance(new DefaultQueryConfig(ImmutableMap.of())) ); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/SqlTestFramework.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/SqlTestFramework.java index 7803b9082aae..fb5101b78dbd 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/SqlTestFramework.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/SqlTestFramework.java @@ -48,6 +48,7 @@ import org.apache.druid.guice.annotations.Global; import org.apache.druid.guice.annotations.Merging; import org.apache.druid.guice.annotations.Self; +import org.apache.druid.guice.security.PolicyModule; import org.apache.druid.initialization.CoreInjectorBuilder; import org.apache.druid.initialization.DruidModule; import org.apache.druid.initialization.ServiceInjectorBuilder; @@ -72,6 +73,7 @@ import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.query.groupby.TestGroupByBuffers; import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.query.topn.TopNQueryConfig; import org.apache.druid.quidem.ProjectPathUtils; import org.apache.druid.quidem.TestSqlModule; @@ -402,6 +404,7 @@ public void gatherProperties(Properties properties) public DruidModule getCoreModule() { return DruidModuleCollection.of( + new PolicyModule(), new LookylooModule(), new SegmentWranglerModule(), new ExpressionModule(), @@ -849,6 +852,7 @@ public PlannerFixture( framework.injector.getInstance(JoinableFactoryWrapper.class), framework.builder.catalogResolver, authConfig != null ? authConfig : new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() ); componentSupplier.finalizePlanner(this); diff --git a/sql/src/test/java/org/apache/druid/sql/guice/SqlModuleTest.java b/sql/src/test/java/org/apache/druid/sql/guice/SqlModuleTest.java index b4437045fe1a..84629a90810a 100644 --- a/sql/src/test/java/org/apache/druid/sql/guice/SqlModuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/guice/SqlModuleTest.java @@ -41,6 +41,7 @@ import org.apache.druid.guice.LifecycleModule; import org.apache.druid.guice.PolyBind; import org.apache.druid.guice.ServerModule; +import org.apache.druid.guice.security.PolicyModule; import org.apache.druid.initialization.DruidModule; import org.apache.druid.jackson.JacksonModule; import org.apache.druid.java.util.emitter.service.ServiceEmitter; @@ -185,6 +186,7 @@ private Injector makeInjectorWithProperties(final Properties props) new LifecycleModule(), new ServerModule(), new JacksonModule(), + new PolicyModule(), new AuthenticatorMapperModule(), binder -> { binder.bind(Validator.class).toInstance(Validation.buildDefaultValidatorFactory().getValidator()); diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java index 8168fb53683d..8b767de3b5a5 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java @@ -63,6 +63,7 @@ import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.groupby.GroupByQueryConfig; +import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.server.DruidNode; import org.apache.druid.server.QueryResource; import org.apache.druid.server.QueryResponse; @@ -259,6 +260,7 @@ public void setUp() throws Exception CalciteTests.createJoinableFactoryWrapper(), CatalogResolver.NULL_RESOLVER, new AuthConfig(), + NoopPolicyEnforcer.instance(), new DruidHookDispatcher() );