Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions processing/src/main/java/org/apache/druid/query/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,14 @@ public interface Query<T>
Map<String, Object> getContext();

/**
* Returns QueryContext for this query.
* Returns QueryContext for this query. This type distinguishes between user provided, system default, and system
* generated query context keys so that authorization may be employed directly against the user supplied context
* values.
*
* This method is marked @Nullable, but is only so for backwards compatibility with Druid versions older than 0.23.
* Callers should check if the result of this method is null, and if so, they are dealing with a legacy query
* implementation, and should fall back to using {@link #getContext()} and {@link #withOverriddenContext(Map)} to
* manipulate the query context.
*
* Note for query context serialization and deserialization.
* Currently, once a query is serialized, its queryContext can be different from the original queryContext
Expand All @@ -110,7 +117,11 @@ public interface Query<T>
* after it is deserialized. This is because {@link BaseQuery#getContext()} uses
* {@link QueryContext#getMergedParams()} for serialization, and queries accept a map for deserialization.
*/
QueryContext getQueryContext();
@Nullable
default QueryContext getQueryContext()
{
return null;
}

<ContextType> ContextType getContextValue(String key);

Expand Down
187 changes: 187 additions & 0 deletions processing/src/test/java/org/apache/druid/query/QueryContextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,26 @@
package org.apache.druid.query;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Ordering;
import nl.jqno.equalsverifier.EqualsVerifier;
import nl.jqno.equalsverifier.Warning;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.spec.QuerySegmentSpec;
import org.joda.time.DateTimeZone;
import org.joda.time.Duration;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test;

import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class QueryContextTest
{
@Test
Expand Down Expand Up @@ -232,4 +247,176 @@ public void testGetMergedParams()

Assert.assertSame(context.getMergedParams(), context.getMergedParams());
}

@Test
public void testLegacyReturnsLegacy()
{
Query legacy = new LegacyContextQuery(ImmutableMap.of("foo", "bar"));
Assert.assertNull(legacy.getQueryContext());
}

@Test
public void testNonLegacyIsNotLegacyContext()
{
Query timeseries = Druids.newTimeseriesQueryBuilder()
.dataSource("test")
.intervals("2015-01-02/2015-01-03")
.granularity(Granularities.DAY)
.aggregators(Collections.singletonList(new CountAggregatorFactory("theCount")))
.context(ImmutableMap.of("foo", "bar"))
.build();
Assert.assertNotNull(timeseries.getQueryContext());
}

public static class LegacyContextQuery implements Query
{
private final Map<String, Object> context;

public LegacyContextQuery(Map<String, Object> context)
{
this.context = context;
}

@Override
public DataSource getDataSource()
{
return new TableDataSource("fake");
}

@Override
public boolean hasFilters()
{
return false;
}

@Override
public DimFilter getFilter()
{
return null;
}

@Override
public String getType()
{
return "legacy-context-query";
}

@Override
public QueryRunner getRunner(QuerySegmentWalker walker)
{
return new NoopQueryRunner();
}

@Override
public List<Interval> getIntervals()
{
return Collections.singletonList(Intervals.ETERNITY);
}

@Override
public Duration getDuration()
{
return getIntervals().get(0).toDuration();
}

@Override
public Granularity getGranularity()
{
return Granularities.ALL;
}

@Override
public DateTimeZone getTimezone()
{
return DateTimeZone.UTC;
}

@Override
public Map<String, Object> getContext()
{
return context;
}

@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
if (context == null || !context.containsKey(key)) {
return defaultValue;
}
return (boolean) context.get(key);
}

@Override
public boolean isDescending()
{
return false;
}

@Override
public Ordering getResultOrdering()
{
return Ordering.natural();
}

@Override
public Query withQuerySegmentSpec(QuerySegmentSpec spec)
{
return new LegacyContextQuery(context);
}

@Override
public Query withId(String id)
{
context.put(BaseQuery.QUERY_ID, id);
return this;
}

@Nullable
@Override
public String getId()
{
return (String) context.get(BaseQuery.QUERY_ID);
}

@Override
public Query withSubQueryId(String subQueryId)
{
context.put(BaseQuery.SUB_QUERY_ID, subQueryId);
return this;
}

@Nullable
@Override
public String getSubQueryId()
{
return (String) context.get(BaseQuery.SUB_QUERY_ID);
}

@Override
public Query withDataSource(DataSource dataSource)
{
return this;
}

@Override
public Query withOverriddenContext(Map contextOverride)
{
return new LegacyContextQuery(contextOverride);
}

@Override
public Object getContextValue(String key, Object defaultValue)
{
if (!context.containsKey(key)) {
return defaultValue;
}
return context.get(key);
}

@Override
public Object getContextValue(String key)
{
return context.get(key);
}
}
}
25 changes: 20 additions & 5 deletions server/src/main/java/org/apache/druid/server/QueryLifecycle.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.druid.query.DruidMetrics;
import org.apache.druid.query.GenericQueryMetricsFactory;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryMetrics;
Expand Down Expand Up @@ -63,6 +64,7 @@
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -186,11 +188,18 @@ public void initialize(final Query baseQuery)
{
transition(State.NEW, State.INITIALIZED);

baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext());
if (baseQuery.getQueryContext() == null) {
QueryContext context = new QueryContext(baseQuery.getContext());
context.addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
context.addDefaultParams(defaultQueryConfig.getContext());

this.baseQuery = baseQuery;
this.toolChest = warehouse.getToolChest(baseQuery);
this.baseQuery = baseQuery.withOverriddenContext(context.getMergedParams());
} else {
baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext());
this.baseQuery = baseQuery;
}
this.toolChest = warehouse.getToolChest(this.baseQuery);
}

/**
Expand All @@ -204,14 +213,20 @@ public void initialize(final Query baseQuery)
public Access authorize(HttpServletRequest req)
{
transition(State.INITIALIZED, State.AUTHORIZING);
final Set<String> contextKeys;
if (baseQuery.getQueryContext() == null) {
contextKeys = baseQuery.getContext().keySet();
} else {
contextKeys = baseQuery.getQueryContext().getUserParams().keySet();
}
final Iterable<ResourceAction> resourcesToAuthorize = Iterables.concat(
Iterables.transform(
baseQuery.getDataSource().getTableNames(),
AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR
),
authConfig.authorizeQueryContextParams()
? Iterables.transform(
baseQuery.getQueryContext().getUserParams().keySet(),
contextKeys,
contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE)
)
: Collections.emptyList()
Expand Down
10 changes: 9 additions & 1 deletion server/src/main/java/org/apache/druid/server/QueryResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.druid.query.BadQueryException;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryException;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryTimeoutException;
Expand Down Expand Up @@ -373,7 +374,14 @@ private Query<?> readQuery(
String prevEtag = getPreviousEtag(req);

if (prevEtag != null) {
baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);
if (baseQuery.getQueryContext() == null) {
QueryContext context = new QueryContext(baseQuery.getContext());
context.addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);

return baseQuery.withOverriddenContext(context.getMergedParams());
} else {
baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);
}
}

return baseQuery;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.druid.query.DefaultQueryConfig;
import org.apache.druid.query.Druids;
import org.apache.druid.query.GenericQueryMetricsFactory;
import org.apache.druid.query.QueryContextTest;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QuerySegmentWalker;
Expand Down Expand Up @@ -244,6 +245,40 @@ public void testAuthorizeQueryContext_notAuthorized()
Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed());
}

@Test
public void testAuthorizeLegacyQueryContext_authorized()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("fake", ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE)).andReturn(Access.OK);
// to use legacy query context with context authorization, even system generated things like queryId need to be explicitly added
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("queryId", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK);

EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();

replayAll();

final QueryContextTest.LegacyContextQuery query = new QueryContextTest.LegacyContextQuery(ImmutableMap.of("foo", "bar", "baz", "qux"));

lifecycle.initialize(query);

Assert.assertNull(lifecycle.getQuery().getQueryContext());
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("foo"));
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("baz"));
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));

Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
}

private HttpServletRequest mockRequest()
{
HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class);
Expand Down