diff --git a/processing/src/main/java/org/apache/druid/query/Query.java b/processing/src/main/java/org/apache/druid/query/Query.java index f74327523a17..ced91a6383ca 100644 --- a/processing/src/main/java/org/apache/druid/query/Query.java +++ b/processing/src/main/java/org/apache/druid/query/Query.java @@ -101,7 +101,14 @@ public interface Query Map getContext(); /** - * Returns QueryContext for this query. + * Returns QueryContext for this query. This type distinguishes between user provided, system default, and system + * generated query context keys so that authorization may be employed directly against the user supplied context + * values. + * + * This method is marked @Nullable, but is only so for backwards compatibility with Druid versions older than 0.23. + * Callers should check if the result of this method is null, and if so, they are dealing with a legacy query + * implementation, and should fall back to using {@link #getContext()} and {@link #withOverriddenContext(Map)} to + * manipulate the query context. * * Note for query context serialization and deserialization. * Currently, once a query is serialized, its queryContext can be different from the original queryContext @@ -110,7 +117,11 @@ public interface Query * after it is deserialized. This is because {@link BaseQuery#getContext()} uses * {@link QueryContext#getMergedParams()} for serialization, and queries accept a map for deserialization. */ - QueryContext getQueryContext(); + @Nullable + default QueryContext getQueryContext() + { + return null; + } ContextType getContextValue(String key); diff --git a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java index 3ff961b3f0d0..3654f85af175 100644 --- a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java +++ b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java @@ -20,11 +20,26 @@ package org.apache.druid.query; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Ordering; import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.Warning; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.java.util.common.granularity.Granularity; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.filter.DimFilter; +import org.apache.druid.query.spec.QuerySegmentSpec; +import org.joda.time.DateTimeZone; +import org.joda.time.Duration; +import org.joda.time.Interval; import org.junit.Assert; import org.junit.Test; +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.List; +import java.util.Map; + public class QueryContextTest { @Test @@ -232,4 +247,176 @@ public void testGetMergedParams() Assert.assertSame(context.getMergedParams(), context.getMergedParams()); } + + @Test + public void testLegacyReturnsLegacy() + { + Query legacy = new LegacyContextQuery(ImmutableMap.of("foo", "bar")); + Assert.assertNull(legacy.getQueryContext()); + } + + @Test + public void testNonLegacyIsNotLegacyContext() + { + Query timeseries = Druids.newTimeseriesQueryBuilder() + .dataSource("test") + .intervals("2015-01-02/2015-01-03") + .granularity(Granularities.DAY) + .aggregators(Collections.singletonList(new CountAggregatorFactory("theCount"))) + .context(ImmutableMap.of("foo", "bar")) + .build(); + Assert.assertNotNull(timeseries.getQueryContext()); + } + + public static class LegacyContextQuery implements Query + { + private final Map context; + + public LegacyContextQuery(Map context) + { + this.context = context; + } + + @Override + public DataSource getDataSource() + { + return new TableDataSource("fake"); + } + + @Override + public boolean hasFilters() + { + return false; + } + + @Override + public DimFilter getFilter() + { + return null; + } + + @Override + public String getType() + { + return "legacy-context-query"; + } + + @Override + public QueryRunner getRunner(QuerySegmentWalker walker) + { + return new NoopQueryRunner(); + } + + @Override + public List getIntervals() + { + return Collections.singletonList(Intervals.ETERNITY); + } + + @Override + public Duration getDuration() + { + return getIntervals().get(0).toDuration(); + } + + @Override + public Granularity getGranularity() + { + return Granularities.ALL; + } + + @Override + public DateTimeZone getTimezone() + { + return DateTimeZone.UTC; + } + + @Override + public Map getContext() + { + return context; + } + + @Override + public boolean getContextBoolean(String key, boolean defaultValue) + { + if (context == null || !context.containsKey(key)) { + return defaultValue; + } + return (boolean) context.get(key); + } + + @Override + public boolean isDescending() + { + return false; + } + + @Override + public Ordering getResultOrdering() + { + return Ordering.natural(); + } + + @Override + public Query withQuerySegmentSpec(QuerySegmentSpec spec) + { + return new LegacyContextQuery(context); + } + + @Override + public Query withId(String id) + { + context.put(BaseQuery.QUERY_ID, id); + return this; + } + + @Nullable + @Override + public String getId() + { + return (String) context.get(BaseQuery.QUERY_ID); + } + + @Override + public Query withSubQueryId(String subQueryId) + { + context.put(BaseQuery.SUB_QUERY_ID, subQueryId); + return this; + } + + @Nullable + @Override + public String getSubQueryId() + { + return (String) context.get(BaseQuery.SUB_QUERY_ID); + } + + @Override + public Query withDataSource(DataSource dataSource) + { + return this; + } + + @Override + public Query withOverriddenContext(Map contextOverride) + { + return new LegacyContextQuery(contextOverride); + } + + @Override + public Object getContextValue(String key, Object defaultValue) + { + if (!context.containsKey(key)) { + return defaultValue; + } + return context.get(key); + } + + @Override + public Object getContextValue(String key) + { + return context.get(key); + } + } } diff --git a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java index 8a38f7238e10..ecada161cf62 100644 --- a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java +++ b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java @@ -36,6 +36,7 @@ import org.apache.druid.query.DruidMetrics; import org.apache.druid.query.GenericQueryMetricsFactory; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryMetrics; @@ -63,6 +64,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -186,11 +188,18 @@ public void initialize(final Query baseQuery) { transition(State.NEW, State.INITIALIZED); - baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString()); - baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext()); + if (baseQuery.getQueryContext() == null) { + QueryContext context = new QueryContext(baseQuery.getContext()); + context.addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString()); + context.addDefaultParams(defaultQueryConfig.getContext()); - this.baseQuery = baseQuery; - this.toolChest = warehouse.getToolChest(baseQuery); + this.baseQuery = baseQuery.withOverriddenContext(context.getMergedParams()); + } else { + baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString()); + baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext()); + this.baseQuery = baseQuery; + } + this.toolChest = warehouse.getToolChest(this.baseQuery); } /** @@ -204,6 +213,12 @@ public void initialize(final Query baseQuery) public Access authorize(HttpServletRequest req) { transition(State.INITIALIZED, State.AUTHORIZING); + final Set contextKeys; + if (baseQuery.getQueryContext() == null) { + contextKeys = baseQuery.getContext().keySet(); + } else { + contextKeys = baseQuery.getQueryContext().getUserParams().keySet(); + } final Iterable resourcesToAuthorize = Iterables.concat( Iterables.transform( baseQuery.getDataSource().getTableNames(), @@ -211,7 +226,7 @@ public Access authorize(HttpServletRequest req) ), authConfig.authorizeQueryContextParams() ? Iterables.transform( - baseQuery.getQueryContext().getUserParams().keySet(), + contextKeys, contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE) ) : Collections.emptyList() diff --git a/server/src/main/java/org/apache/druid/server/QueryResource.java b/server/src/main/java/org/apache/druid/server/QueryResource.java index a235fae199e2..d71fdab56ca9 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResource.java +++ b/server/src/main/java/org/apache/druid/server/QueryResource.java @@ -46,6 +46,7 @@ import org.apache.druid.query.BadQueryException; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryTimeoutException; @@ -373,7 +374,14 @@ private Query readQuery( String prevEtag = getPreviousEtag(req); if (prevEtag != null) { - baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag); + if (baseQuery.getQueryContext() == null) { + QueryContext context = new QueryContext(baseQuery.getContext()); + context.addSystemParam(HEADER_IF_NONE_MATCH, prevEtag); + + return baseQuery.withOverriddenContext(context.getMergedParams()); + } else { + baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag); + } } return baseQuery; diff --git a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java index e49d3a87aeb3..1d1840b6c72e 100644 --- a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java +++ b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java @@ -28,6 +28,7 @@ import org.apache.druid.query.DefaultQueryConfig; import org.apache.druid.query.Druids; import org.apache.druid.query.GenericQueryMetricsFactory; +import org.apache.druid.query.QueryContextTest; import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QuerySegmentWalker; @@ -244,6 +245,40 @@ public void testAuthorizeQueryContext_notAuthorized() Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed()); } + @Test + public void testAuthorizeLegacyQueryContext_authorized() + { + EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); + EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes(); + EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); + EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("fake", ResourceType.DATASOURCE), Action.READ)) + .andReturn(Access.OK); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE)) + .andReturn(Access.OK); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE)).andReturn(Access.OK); + // to use legacy query context with context authorization, even system generated things like queryId need to be explicitly added + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("queryId", ResourceType.QUERY_CONTEXT), Action.WRITE)) + .andReturn(Access.OK); + + EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject())) + .andReturn(toolChest) + .once(); + + replayAll(); + + final QueryContextTest.LegacyContextQuery query = new QueryContextTest.LegacyContextQuery(ImmutableMap.of("foo", "bar", "baz", "qux")); + + lifecycle.initialize(query); + + Assert.assertNull(lifecycle.getQuery().getQueryContext()); + Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("foo")); + Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("baz")); + Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId")); + + Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed()); + } + private HttpServletRequest mockRequest() { HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class);