diff --git a/server/src/main/java/io/druid/client/DirectDruidClient.java b/server/src/main/java/io/druid/client/DirectDruidClient.java index 8db791b1c9ba..225f91dc42a8 100644 --- a/server/src/main/java/io/druid/client/DirectDruidClient.java +++ b/server/src/main/java/io/druid/client/DirectDruidClient.java @@ -116,21 +116,16 @@ public class DirectDruidClient implements QueryRunner private final boolean isSmile; /** - * Removes the magical fields added by {@link #makeResponseContextForQuery(Query, long)}. + * Removes the magical fields added by {@link #makeResponseContextForQuery()}. */ public static void removeMagicResponseContextFields(Map responseContext) { - responseContext.remove(DirectDruidClient.QUERY_FAIL_TIME); responseContext.remove(DirectDruidClient.QUERY_TOTAL_BYTES_GATHERED); } - public static Map makeResponseContextForQuery(Query query, long startTimeMillis) + public static Map makeResponseContextForQuery() { final Map responseContext = new ConcurrentHashMap<>(); - responseContext.put( - DirectDruidClient.QUERY_FAIL_TIME, - startTimeMillis + QueryContexts.getTimeout(query) - ); responseContext.put( DirectDruidClient.QUERY_TOTAL_BYTES_GATHERED, new AtomicLong() @@ -199,7 +194,7 @@ public Sequence run(final QueryPlus queryPlus, final Map c final long requestStartTimeNs = System.nanoTime(); - long timeoutAt = ((Long) context.get(QUERY_FAIL_TIME)).longValue(); + long timeoutAt = query.getContextValue(QUERY_FAIL_TIME); long maxScatterGatherBytes = QueryContexts.getMaxScatterGatherBytes(query); AtomicLong totalBytesGathered = (AtomicLong) context.get(QUERY_TOTAL_BYTES_GATHERED); diff --git a/server/src/main/java/io/druid/server/QueryLifecycle.java b/server/src/main/java/io/druid/server/QueryLifecycle.java index b3049a2fa339..9639aa860bc1 100644 --- a/server/src/main/java/io/druid/server/QueryLifecycle.java +++ b/server/src/main/java/io/druid/server/QueryLifecycle.java @@ -247,10 +247,7 @@ public QueryResponse execute() { transition(State.AUTHORIZED, State.EXECUTING); - final Map responseContext = DirectDruidClient.makeResponseContextForQuery( - baseQuery, - System.currentTimeMillis() - ); + final Map responseContext = DirectDruidClient.makeResponseContextForQuery(); final Sequence res = QueryPlus.wrap(baseQuery) .withIdentity(authenticationResult.getIdentity()) diff --git a/server/src/main/java/io/druid/server/SetAndVerifyContextQueryRunner.java b/server/src/main/java/io/druid/server/SetAndVerifyContextQueryRunner.java index 9d5a355aaa61..8363e1bc6f4d 100644 --- a/server/src/main/java/io/druid/server/SetAndVerifyContextQueryRunner.java +++ b/server/src/main/java/io/druid/server/SetAndVerifyContextQueryRunner.java @@ -25,6 +25,8 @@ import io.druid.query.QueryPlus; import io.druid.query.QueryRunner; import io.druid.server.initialization.ServerConfig; +import io.druid.client.DirectDruidClient; +import com.google.common.collect.ImmutableMap; import java.util.Map; @@ -35,11 +37,13 @@ public class SetAndVerifyContextQueryRunner implements QueryRunner { private final ServerConfig serverConfig; private final QueryRunner baseRunner; + private final long startTimeMillis; public SetAndVerifyContextQueryRunner(ServerConfig serverConfig, QueryRunner baseRunner) { this.serverConfig = serverConfig; this.baseRunner = baseRunner; + this.startTimeMillis = System.currentTimeMillis(); } @Override @@ -53,7 +57,7 @@ public Sequence run(QueryPlus queryPlus, Map responseConte public Query withTimeoutAndMaxScatterGatherBytes(Query query, ServerConfig serverConfig) { - return QueryContexts.verifyMaxQueryTimeout( + Query newQuery = QueryContexts.verifyMaxQueryTimeout( QueryContexts.withMaxScatterGatherBytes( QueryContexts.withDefaultTimeout( query, @@ -63,5 +67,6 @@ public Query withTimeoutAndMaxScatterGatherBytes(Query query, ServerConfig ), serverConfig.getMaxQueryTimeout() ); + return newQuery.withOverriddenContext(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, this.startTimeMillis + QueryContexts.getTimeout(newQuery))); } } diff --git a/server/src/test/java/io/druid/client/DirectDruidClientTest.java b/server/src/test/java/io/druid/client/DirectDruidClientTest.java index c98e05362bff..060dac6671c2 100644 --- a/server/src/test/java/io/druid/client/DirectDruidClientTest.java +++ b/server/src/test/java/io/druid/client/DirectDruidClientTest.java @@ -19,6 +19,7 @@ package io.druid.client; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.util.concurrent.Futures; @@ -163,7 +164,7 @@ public void testRun() throws Exception serverSelector.addServerAndUpdateSegment(queryableDruidServer2, serverSelector.getSegment()); TimeBoundaryQuery query = Druids.newTimeBoundaryQueryBuilder().dataSource("test").build(); - + query = query.withOverriddenContext(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, Long.MAX_VALUE)); Sequence s1 = client1.run(QueryPlus.wrap(query), defaultContext); Assert.assertTrue(capturedRequest.hasCaptured()); Assert.assertEquals(url, capturedRequest.getValue().getUrl()); @@ -267,6 +268,7 @@ public void testCancel() serverSelector.addServerAndUpdateSegment(queryableDruidServer1, serverSelector.getSegment()); TimeBoundaryQuery query = Druids.newTimeBoundaryQueryBuilder().dataSource("test").build(); + query = query.withOverriddenContext(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, Long.MAX_VALUE)); cancellationFuture.set(new StatusResponseHolder(HttpResponseStatus.OK, new StringBuilder("cancelled"))); Sequence results = client1.run(QueryPlus.wrap(query), defaultContext); Assert.assertEquals(HttpMethod.DELETE, capturedRequest.getValue().getMethod()); @@ -338,6 +340,7 @@ public void testQueryInterruptionExceptionLogMessage() serverSelector.addServerAndUpdateSegment(queryableDruidServer, dataSegment); TimeBoundaryQuery query = Druids.newTimeBoundaryQueryBuilder().dataSource("test").build(); + query = query.withOverriddenContext(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, Long.MAX_VALUE)); interruptionFuture.set( new ByteArrayInputStream( StringUtils.toUtf8("{\"error\":\"testing1\",\"errorMessage\":\"testing2\"}")