diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java index 88236aa63e25..4005f08ef2e2 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/MSQTaskQueryMakerTest.java @@ -58,7 +58,10 @@ import org.apache.druid.query.DruidProcessingConfig; import org.apache.druid.query.Druids; import org.apache.druid.query.ForwardingQueryProcessingPool; +import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.JoinDataSource; +import org.apache.druid.query.LookupDataSource; +import org.apache.druid.query.OrderBy; import org.apache.druid.query.Query; import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryDataSource; @@ -389,6 +392,76 @@ public void testUnnestOnRestrictedPassedPolicyValidation() throws Exception ); } + + @Test + public void testInlineDataSourcePassedPolicyValidation() throws Exception + { + // Arrange + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + RowSignature resultSignature = RowSignature.builder() + .add("EXPR$0", ColumnType.LONG) + .build(); + fieldMapping = buildFieldMapping(resultSignature); + InlineDataSource inlineDataSource = InlineDataSource.fromIterable( + ImmutableList.of(new Object[]{2L}), + resultSignature + ); + Query query = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .dataSource(inlineDataSource) + .eternityInterval() + .columns(resultSignature.getColumnNames()) + .columnTypes(resultSignature.getColumnTypes()) + .build(); + DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature); + // Act + msqTaskQueryMaker = getMSQTaskQueryMaker(); + QueryResponse response = msqTaskQueryMaker.runQuery(druidQueryMock); + // Assert + String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0]; + MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId) + .get() + .get(MSQTaskReport.REPORT_KEY) + .getPayload(); + Assert.assertTrue(payload.getStatus().getStatus().isSuccess()); + ImmutableList expectedResults = ImmutableList.of(new Object[]{2L}); + assertResultsEquals("select 1 + 1", expectedResults, payload.getResults().getResults()); + } + + @Test + public void testLookupDataSourcePassedPolicyValidation() throws Exception + { + // Arrange + policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + final RowSignature resultSignature = RowSignature.builder().add("v", ColumnType.STRING).build(); + fieldMapping = buildFieldMapping(resultSignature); + Query query = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .eternityInterval() + .dataSource(new LookupDataSource("lookyloo")) + .columns(resultSignature.getColumnNames()) + .columnTypes(resultSignature.getColumnTypes()) + .orderBy(ImmutableList.of(OrderBy.ascending("v"))) + .build(); + DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature); + // Act + msqTaskQueryMaker = getMSQTaskQueryMaker(); + QueryResponse response = msqTaskQueryMaker.runQuery(druidQueryMock); + // Assert + String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0]; + MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId) + .get() + .get(MSQTaskReport.REPORT_KEY) + .getPayload(); + // Assert + Assert.assertTrue(payload.getStatus().getStatus().isSuccess()); + ImmutableList expectedResults = ImmutableList.of( + new Object[]{"mysteryvalue"}, + new Object[]{"x6"}, + new Object[]{"xa"}, + new Object[]{"xabc"} + ); + assertResultsEquals("select v from lookyloo", expectedResults, payload.getResults().getResults()); + } + @Test public void testJoinFailWithPolicyValidationOnLeftChild() throws Exception { diff --git a/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java b/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java index 4b34b96c23b1..f57bbc367d66 100644 --- a/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java +++ b/processing/src/main/java/org/apache/druid/query/policy/PolicyEnforcer.java @@ -27,6 +27,7 @@ import org.apache.druid.query.TableDataSource; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.SegmentReference; +import org.apache.druid.timeline.SegmentId; /** * Interface for enforcing policies on data sources and segments in Druid queries. @@ -77,14 +78,18 @@ default void validateOrElseThrow(TableDataSource ds, Policy policy) throws Druid */ default void validateOrElseThrow(ReferenceCountingSegment segment, Policy policy) throws DruidException { - // Validation will always fail on lookups, external, and inline segments, because they will not have policies applied (except for NoopPolicyEnforcer). - // This is a temporary solution since we don't have a perfect way to identify segments that are backed by a regular table yet. + SegmentId segmentId = segment.getId(); + // SegmentId is null if the segment is not table based, or is already closed + if (segmentId == null) { + return; + } + if (validate(policy)) { return; } throw DruidException.forPersona(DruidException.Persona.OPERATOR) .ofCategory(DruidException.Category.FORBIDDEN) - .build("Failed security validation with segment [%s]", segment.getId()); + .build("Failed security validation with segment [%s]", segmentId); } /** diff --git a/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java b/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java index 86c597d6e83c..e5a579524c54 100644 --- a/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java +++ b/processing/src/test/java/org/apache/druid/query/policy/RestrictAllTablesPolicyEnforcerTest.java @@ -23,13 +23,17 @@ import com.google.common.collect.ImmutableList; import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.RestrictedDataSource; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.filter.NullFilter; import org.apache.druid.segment.ReferenceCountingSegment; +import org.apache.druid.segment.RowBasedSegment; import org.apache.druid.segment.Segment; import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.TestSegmentUtils.SegmentForTesting; +import org.apache.druid.segment.column.RowSignature; import org.junit.Assert; import org.junit.Test; @@ -97,6 +101,35 @@ public void test_validate() throws Exception policyEnforcer.validateOrElseThrow(segment, policy); } + @Test + public void test_validate_allowNonTableSegments() throws Exception + { + final RestrictAllTablesPolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + + // Test validate segment, success for inline segment + final InlineDataSource inlineDataSource = InlineDataSource.fromIterable(ImmutableList.of(), RowSignature.empty()); + + final Segment inlineSegment = new RowBasedSegment<>( + Sequences.simple(inlineDataSource.getRows()), + inlineDataSource.rowAdapter(), + inlineDataSource.getRowSignature() + ); + ReferenceCountingSegment segment = ReferenceCountingSegment.wrapRootGenerationSegment(inlineSegment); + + policyEnforcer.validateOrElseThrow(segment, null); + } + + @Test + public void test_validate_closedSegment() throws Exception + { + final RestrictAllTablesPolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null); + Segment baseSegment = new SegmentForTesting("table", Intervals.ETERNITY, "1"); + ReferenceCountingSegment segment = ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment); + segment.close(); + + policyEnforcer.validateOrElseThrow(segment, null); + } + @Test public void test_validate_withAllowedPolicies() throws Exception {