From abf1ee61008eccabdf864b69ca328a2c7378fbd6 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Wed, 25 May 2022 02:54:41 -0700 Subject: [PATCH] make query context changes backwards compatible (#12564) Adds a default implementation of getQueryContext, which was added to the Query interface in #12396. Query is marked with @ExtensionPoint, and lately we have been trying to be less volatile on these interfaces by providing default implementations to be more chill for extension writers. The way this default implementation is done in this PR is a bit strange due to the way that getQueryContext is used (mutated with system default and system generated keys); the default implementation has a specific object that it returns, and I added another temporary default method isLegacyContext that checks if the getQueryContext returns that object or not. If not, callers fall back to using getContext and withOverriddenContext to set these default and system values. I am open to other ideas as well, but this way should work at least without exploding, and added some tests to ensure that it is wired up correctly for QueryLifecycle, including the context authorization stuff. The added test shows the strange behavior if query context authorization is enabled, mainly that the system default and system generated query context keys also need to be granted as permissions for things to function correctly. This is not great, so I mentioned it in the javadocs as well. Not sure if it needs to be called out anywhere else. --- .../java/org/apache/druid/query/Query.java | 15 +- .../apache/druid/query/QueryContextTest.java | 187 ++++++++++++++++++ .../apache/druid/server/QueryLifecycle.java | 25 ++- .../apache/druid/server/QueryResource.java | 10 +- .../druid/server/QueryLifecycleTest.java | 35 ++++ 5 files changed, 264 insertions(+), 8 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/Query.java b/processing/src/main/java/org/apache/druid/query/Query.java index f74327523a17..ced91a6383ca 100644 --- a/processing/src/main/java/org/apache/druid/query/Query.java +++ b/processing/src/main/java/org/apache/druid/query/Query.java @@ -101,7 +101,14 @@ public interface Query Map getContext(); /** - * Returns QueryContext for this query. + * Returns QueryContext for this query. This type distinguishes between user provided, system default, and system + * generated query context keys so that authorization may be employed directly against the user supplied context + * values. + * + * This method is marked @Nullable, but is only so for backwards compatibility with Druid versions older than 0.23. + * Callers should check if the result of this method is null, and if so, they are dealing with a legacy query + * implementation, and should fall back to using {@link #getContext()} and {@link #withOverriddenContext(Map)} to + * manipulate the query context. * * Note for query context serialization and deserialization. * Currently, once a query is serialized, its queryContext can be different from the original queryContext @@ -110,7 +117,11 @@ public interface Query * after it is deserialized. This is because {@link BaseQuery#getContext()} uses * {@link QueryContext#getMergedParams()} for serialization, and queries accept a map for deserialization. */ - QueryContext getQueryContext(); + @Nullable + default QueryContext getQueryContext() + { + return null; + } ContextType getContextValue(String key); diff --git a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java index 3ff961b3f0d0..3654f85af175 100644 --- a/processing/src/test/java/org/apache/druid/query/QueryContextTest.java +++ b/processing/src/test/java/org/apache/druid/query/QueryContextTest.java @@ -20,11 +20,26 @@ package org.apache.druid.query; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Ordering; import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.Warning; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.java.util.common.granularity.Granularity; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.filter.DimFilter; +import org.apache.druid.query.spec.QuerySegmentSpec; +import org.joda.time.DateTimeZone; +import org.joda.time.Duration; +import org.joda.time.Interval; import org.junit.Assert; import org.junit.Test; +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.List; +import java.util.Map; + public class QueryContextTest { @Test @@ -232,4 +247,176 @@ public void testGetMergedParams() Assert.assertSame(context.getMergedParams(), context.getMergedParams()); } + + @Test + public void testLegacyReturnsLegacy() + { + Query legacy = new LegacyContextQuery(ImmutableMap.of("foo", "bar")); + Assert.assertNull(legacy.getQueryContext()); + } + + @Test + public void testNonLegacyIsNotLegacyContext() + { + Query timeseries = Druids.newTimeseriesQueryBuilder() + .dataSource("test") + .intervals("2015-01-02/2015-01-03") + .granularity(Granularities.DAY) + .aggregators(Collections.singletonList(new CountAggregatorFactory("theCount"))) + .context(ImmutableMap.of("foo", "bar")) + .build(); + Assert.assertNotNull(timeseries.getQueryContext()); + } + + public static class LegacyContextQuery implements Query + { + private final Map context; + + public LegacyContextQuery(Map context) + { + this.context = context; + } + + @Override + public DataSource getDataSource() + { + return new TableDataSource("fake"); + } + + @Override + public boolean hasFilters() + { + return false; + } + + @Override + public DimFilter getFilter() + { + return null; + } + + @Override + public String getType() + { + return "legacy-context-query"; + } + + @Override + public QueryRunner getRunner(QuerySegmentWalker walker) + { + return new NoopQueryRunner(); + } + + @Override + public List getIntervals() + { + return Collections.singletonList(Intervals.ETERNITY); + } + + @Override + public Duration getDuration() + { + return getIntervals().get(0).toDuration(); + } + + @Override + public Granularity getGranularity() + { + return Granularities.ALL; + } + + @Override + public DateTimeZone getTimezone() + { + return DateTimeZone.UTC; + } + + @Override + public Map getContext() + { + return context; + } + + @Override + public boolean getContextBoolean(String key, boolean defaultValue) + { + if (context == null || !context.containsKey(key)) { + return defaultValue; + } + return (boolean) context.get(key); + } + + @Override + public boolean isDescending() + { + return false; + } + + @Override + public Ordering getResultOrdering() + { + return Ordering.natural(); + } + + @Override + public Query withQuerySegmentSpec(QuerySegmentSpec spec) + { + return new LegacyContextQuery(context); + } + + @Override + public Query withId(String id) + { + context.put(BaseQuery.QUERY_ID, id); + return this; + } + + @Nullable + @Override + public String getId() + { + return (String) context.get(BaseQuery.QUERY_ID); + } + + @Override + public Query withSubQueryId(String subQueryId) + { + context.put(BaseQuery.SUB_QUERY_ID, subQueryId); + return this; + } + + @Nullable + @Override + public String getSubQueryId() + { + return (String) context.get(BaseQuery.SUB_QUERY_ID); + } + + @Override + public Query withDataSource(DataSource dataSource) + { + return this; + } + + @Override + public Query withOverriddenContext(Map contextOverride) + { + return new LegacyContextQuery(contextOverride); + } + + @Override + public Object getContextValue(String key, Object defaultValue) + { + if (!context.containsKey(key)) { + return defaultValue; + } + return context.get(key); + } + + @Override + public Object getContextValue(String key) + { + return context.get(key); + } + } } diff --git a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java index 8a38f7238e10..ecada161cf62 100644 --- a/server/src/main/java/org/apache/druid/server/QueryLifecycle.java +++ b/server/src/main/java/org/apache/druid/server/QueryLifecycle.java @@ -36,6 +36,7 @@ import org.apache.druid.query.DruidMetrics; import org.apache.druid.query.GenericQueryMetricsFactory; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryMetrics; @@ -63,6 +64,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -186,11 +188,18 @@ public void initialize(final Query baseQuery) { transition(State.NEW, State.INITIALIZED); - baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString()); - baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext()); + if (baseQuery.getQueryContext() == null) { + QueryContext context = new QueryContext(baseQuery.getContext()); + context.addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString()); + context.addDefaultParams(defaultQueryConfig.getContext()); - this.baseQuery = baseQuery; - this.toolChest = warehouse.getToolChest(baseQuery); + this.baseQuery = baseQuery.withOverriddenContext(context.getMergedParams()); + } else { + baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString()); + baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext()); + this.baseQuery = baseQuery; + } + this.toolChest = warehouse.getToolChest(this.baseQuery); } /** @@ -204,6 +213,12 @@ public void initialize(final Query baseQuery) public Access authorize(HttpServletRequest req) { transition(State.INITIALIZED, State.AUTHORIZING); + final Set contextKeys; + if (baseQuery.getQueryContext() == null) { + contextKeys = baseQuery.getContext().keySet(); + } else { + contextKeys = baseQuery.getQueryContext().getUserParams().keySet(); + } final Iterable resourcesToAuthorize = Iterables.concat( Iterables.transform( baseQuery.getDataSource().getTableNames(), @@ -211,7 +226,7 @@ public Access authorize(HttpServletRequest req) ), authConfig.authorizeQueryContextParams() ? Iterables.transform( - baseQuery.getQueryContext().getUserParams().keySet(), + contextKeys, contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE) ) : Collections.emptyList() diff --git a/server/src/main/java/org/apache/druid/server/QueryResource.java b/server/src/main/java/org/apache/druid/server/QueryResource.java index a235fae199e2..d71fdab56ca9 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResource.java +++ b/server/src/main/java/org/apache/druid/server/QueryResource.java @@ -46,6 +46,7 @@ import org.apache.druid.query.BadQueryException; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryTimeoutException; @@ -373,7 +374,14 @@ private Query readQuery( String prevEtag = getPreviousEtag(req); if (prevEtag != null) { - baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag); + if (baseQuery.getQueryContext() == null) { + QueryContext context = new QueryContext(baseQuery.getContext()); + context.addSystemParam(HEADER_IF_NONE_MATCH, prevEtag); + + return baseQuery.withOverriddenContext(context.getMergedParams()); + } else { + baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag); + } } return baseQuery; diff --git a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java index e49d3a87aeb3..1d1840b6c72e 100644 --- a/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java +++ b/server/src/test/java/org/apache/druid/server/QueryLifecycleTest.java @@ -28,6 +28,7 @@ import org.apache.druid.query.DefaultQueryConfig; import org.apache.druid.query.Druids; import org.apache.druid.query.GenericQueryMetricsFactory; +import org.apache.druid.query.QueryContextTest; import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryRunner; import org.apache.druid.query.QuerySegmentWalker; @@ -244,6 +245,40 @@ public void testAuthorizeQueryContext_notAuthorized() Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed()); } + @Test + public void testAuthorizeLegacyQueryContext_authorized() + { + EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes(); + EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes(); + EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes(); + EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes(); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("fake", ResourceType.DATASOURCE), Action.READ)) + .andReturn(Access.OK); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE)) + .andReturn(Access.OK); + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE)).andReturn(Access.OK); + // to use legacy query context with context authorization, even system generated things like queryId need to be explicitly added + EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("queryId", ResourceType.QUERY_CONTEXT), Action.WRITE)) + .andReturn(Access.OK); + + EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject())) + .andReturn(toolChest) + .once(); + + replayAll(); + + final QueryContextTest.LegacyContextQuery query = new QueryContextTest.LegacyContextQuery(ImmutableMap.of("foo", "bar", "baz", "qux")); + + lifecycle.initialize(query); + + Assert.assertNull(lifecycle.getQuery().getQueryContext()); + Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("foo")); + Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("baz")); + Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId")); + + Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed()); + } + private HttpServletRequest mockRequest() { HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class);