Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.Druids;
import org.apache.druid.query.ForwardingQueryProcessingPool;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.LookupDataSource;
import org.apache.druid.query.OrderBy;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryDataSource;
Expand Down Expand Up @@ -389,6 +392,76 @@ public void testUnnestOnRestrictedPassedPolicyValidation() throws Exception
);
}


@Test
public void testInlineDataSourcePassedPolicyValidation() throws Exception
{
// Arrange
policyEnforcer = new RestrictAllTablesPolicyEnforcer(null);
RowSignature resultSignature = RowSignature.builder()
.add("EXPR$0", ColumnType.LONG)
.build();
fieldMapping = buildFieldMapping(resultSignature);
InlineDataSource inlineDataSource = InlineDataSource.fromIterable(
ImmutableList.of(new Object[]{2L}),
resultSignature
);
Query query = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.dataSource(inlineDataSource)
.eternityInterval()
.columns(resultSignature.getColumnNames())
.columnTypes(resultSignature.getColumnTypes())
.build();
DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature);
// Act
msqTaskQueryMaker = getMSQTaskQueryMaker();
QueryResponse<Object[]> response = msqTaskQueryMaker.runQuery(druidQueryMock);
// Assert
String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0];
MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId)
.get()
.get(MSQTaskReport.REPORT_KEY)
.getPayload();
Assert.assertTrue(payload.getStatus().getStatus().isSuccess());
ImmutableList<Object[]> expectedResults = ImmutableList.of(new Object[]{2L});
assertResultsEquals("select 1 + 1", expectedResults, payload.getResults().getResults());
}

@Test
public void testLookupDataSourcePassedPolicyValidation() throws Exception
{
// Arrange
policyEnforcer = new RestrictAllTablesPolicyEnforcer(null);
final RowSignature resultSignature = RowSignature.builder().add("v", ColumnType.STRING).build();
fieldMapping = buildFieldMapping(resultSignature);
Query query = new Druids.ScanQueryBuilder().resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.eternityInterval()
.dataSource(new LookupDataSource("lookyloo"))
.columns(resultSignature.getColumnNames())
.columnTypes(resultSignature.getColumnTypes())
.orderBy(ImmutableList.of(OrderBy.ascending("v")))
.build();
DruidQuery druidQueryMock = buildDruidQueryMock(query, resultSignature);
// Act
msqTaskQueryMaker = getMSQTaskQueryMaker();
QueryResponse<Object[]> response = msqTaskQueryMaker.runQuery(druidQueryMock);
// Assert
String taskId = (String) Iterables.getOnlyElement(response.getResults().toList())[0];
MSQTaskReportPayload payload = (MSQTaskReportPayload) fakeOverlordClient.taskReportAsMap(taskId)
.get()
.get(MSQTaskReport.REPORT_KEY)
.getPayload();
// Assert
Assert.assertTrue(payload.getStatus().getStatus().isSuccess());
ImmutableList<Object[]> expectedResults = ImmutableList.of(
new Object[]{"mysteryvalue"},
new Object[]{"x6"},
new Object[]{"xa"},
new Object[]{"xabc"}
);
assertResultsEquals("select v from lookyloo", expectedResults, payload.getResults().getResults());
}

@Test
public void testJoinFailWithPolicyValidationOnLeftChild() throws Exception
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.druid.query.TableDataSource;
import org.apache.druid.segment.ReferenceCountingSegment;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.timeline.SegmentId;

/**
* Interface for enforcing policies on data sources and segments in Druid queries.
Expand Down Expand Up @@ -77,14 +78,18 @@ default void validateOrElseThrow(TableDataSource ds, Policy policy) throws Druid
*/
default void validateOrElseThrow(ReferenceCountingSegment segment, Policy policy) throws DruidException
{
// Validation will always fail on lookups, external, and inline segments, because they will not have policies applied (except for NoopPolicyEnforcer).
// This is a temporary solution since we don't have a perfect way to identify segments that are backed by a regular table yet.
SegmentId segmentId = segment.getId();
// SegmentId is null if the segment is not table based, or is already closed
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to understand what does segment close mean.

if (segmentId == null) {
return;
}

if (validate(policy)) {
return;
}
throw DruidException.forPersona(DruidException.Persona.OPERATOR)
.ofCategory(DruidException.Category.FORBIDDEN)
.build("Failed security validation with segment [%s]", segment.getId());
.build("Failed security validation with segment [%s]", segmentId);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@
import com.google.common.collect.ImmutableList;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.RestrictedDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.filter.NullFilter;
import org.apache.druid.segment.ReferenceCountingSegment;
import org.apache.druid.segment.RowBasedSegment;
import org.apache.druid.segment.Segment;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.TestSegmentUtils.SegmentForTesting;
import org.apache.druid.segment.column.RowSignature;
import org.junit.Assert;
import org.junit.Test;

Expand Down Expand Up @@ -97,6 +101,35 @@ public void test_validate() throws Exception
policyEnforcer.validateOrElseThrow(segment, policy);
}

@Test
public void test_validate_allowNonTableSegments() throws Exception
{
final RestrictAllTablesPolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null);

// Test validate segment, success for inline segment
final InlineDataSource inlineDataSource = InlineDataSource.fromIterable(ImmutableList.of(), RowSignature.empty());

final Segment inlineSegment = new RowBasedSegment<>(
Sequences.simple(inlineDataSource.getRows()),
inlineDataSource.rowAdapter(),
inlineDataSource.getRowSignature()
);
ReferenceCountingSegment segment = ReferenceCountingSegment.wrapRootGenerationSegment(inlineSegment);

policyEnforcer.validateOrElseThrow(segment, null);
}

@Test
public void test_validate_closedSegment() throws Exception
{
final RestrictAllTablesPolicyEnforcer policyEnforcer = new RestrictAllTablesPolicyEnforcer(null);
Segment baseSegment = new SegmentForTesting("table", Intervals.ETERNITY, "1");
ReferenceCountingSegment segment = ReferenceCountingSegment.wrapRootGenerationSegment(baseSegment);
segment.close();

policyEnforcer.validateOrElseThrow(segment, null);
}

@Test
public void test_validate_withAllowedPolicies() throws Exception
{
Expand Down
Loading