diff --git a/server/src/main/java/io/druid/client/DirectDruidClient.java b/server/src/main/java/io/druid/client/DirectDruidClient.java index d34efee089de..f18345cc8a64 100644 --- a/server/src/main/java/io/druid/client/DirectDruidClient.java +++ b/server/src/main/java/io/druid/client/DirectDruidClient.java @@ -116,21 +116,16 @@ public class DirectDruidClient implements QueryRunner private final boolean isSmile; /** - * Removes the magical fields added by {@link #makeResponseContextForQuery(Query, long)}. + * Removes the magical fields added by {@link #makeResponseContextForQuery()}. */ public static void removeMagicResponseContextFields(Map responseContext) { - responseContext.remove(DirectDruidClient.QUERY_FAIL_TIME); responseContext.remove(DirectDruidClient.QUERY_TOTAL_BYTES_GATHERED); } - public static Map makeResponseContextForQuery(Query query, long startTimeMillis) + public static Map makeResponseContextForQuery() { final Map responseContext = new ConcurrentHashMap<>(); - responseContext.put( - DirectDruidClient.QUERY_FAIL_TIME, - startTimeMillis + QueryContexts.getTimeout(query) - ); responseContext.put( DirectDruidClient.QUERY_TOTAL_BYTES_GATHERED, new AtomicLong() @@ -199,7 +194,7 @@ public Sequence run(final QueryPlus queryPlus, final Map c final long requestStartTimeNs = System.nanoTime(); - long timeoutAt = ((Long) context.get(QUERY_FAIL_TIME)).longValue(); + long timeoutAt = query.getContextValue(QUERY_FAIL_TIME); long maxScatterGatherBytes = QueryContexts.getMaxScatterGatherBytes(query); AtomicLong totalBytesGathered = (AtomicLong) context.get(QUERY_TOTAL_BYTES_GATHERED); diff --git a/server/src/main/java/io/druid/server/QueryLifecycle.java b/server/src/main/java/io/druid/server/QueryLifecycle.java index 70486b9b7bb9..3c03314cd595 100644 --- a/server/src/main/java/io/druid/server/QueryLifecycle.java +++ b/server/src/main/java/io/druid/server/QueryLifecycle.java @@ -21,7 +21,6 @@ import com.google.common.base.Strings; import com.google.common.collect.Iterables; -import io.druid.java.util.emitter.service.ServiceEmitter; import io.druid.client.DirectDruidClient; import io.druid.java.util.common.DateTimes; import io.druid.java.util.common.ISE; @@ -29,6 +28,7 @@ import io.druid.java.util.common.guava.SequenceWrapper; import io.druid.java.util.common.guava.Sequences; import io.druid.java.util.common.logger.Logger; +import io.druid.java.util.emitter.service.ServiceEmitter; import io.druid.query.DruidMetrics; import io.druid.query.GenericQueryMetricsFactory; import io.druid.query.Query; @@ -249,10 +249,7 @@ public QueryResponse execute() { transition(State.AUTHORIZED, State.EXECUTING); - final Map responseContext = DirectDruidClient.makeResponseContextForQuery( - queryPlus.getQuery(), - System.currentTimeMillis() - ); + final Map responseContext = DirectDruidClient.makeResponseContextForQuery(); final Sequence res = queryPlus.run(texasRanger, responseContext); diff --git a/server/src/main/java/io/druid/server/SetAndVerifyContextQueryRunner.java b/server/src/main/java/io/druid/server/SetAndVerifyContextQueryRunner.java index 637b9dd14fb0..7901e3674b09 100644 --- a/server/src/main/java/io/druid/server/SetAndVerifyContextQueryRunner.java +++ b/server/src/main/java/io/druid/server/SetAndVerifyContextQueryRunner.java @@ -19,6 +19,8 @@ package io.druid.server; +import com.google.common.collect.ImmutableMap; +import io.druid.client.DirectDruidClient; import io.druid.java.util.common.guava.Sequence; import io.druid.query.Query; import io.druid.query.QueryContexts; @@ -35,11 +37,13 @@ public class SetAndVerifyContextQueryRunner implements QueryRunner { private final ServerConfig serverConfig; private final QueryRunner baseRunner; + private final long startTimeMillis; public SetAndVerifyContextQueryRunner(ServerConfig serverConfig, QueryRunner baseRunner) { this.serverConfig = serverConfig; this.baseRunner = baseRunner; + this.startTimeMillis = System.currentTimeMillis(); } @Override @@ -54,12 +58,12 @@ public Sequence run(QueryPlus queryPlus, Map responseContext) ); } - public static > QueryType withTimeoutAndMaxScatterGatherBytes( + public > QueryType withTimeoutAndMaxScatterGatherBytes( final QueryType query, ServerConfig serverConfig ) { - return (QueryType) QueryContexts.verifyMaxQueryTimeout( + Query newQuery = QueryContexts.verifyMaxQueryTimeout( QueryContexts.withMaxScatterGatherBytes( QueryContexts.withDefaultTimeout( (Query) query, @@ -69,5 +73,6 @@ public static > QueryType withTimeoutAndMaxScatter ), serverConfig.getMaxQueryTimeout() ); + return (QueryType) newQuery.withOverriddenContext(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, this.startTimeMillis + QueryContexts.getTimeout(newQuery))); } } diff --git a/server/src/test/java/io/druid/client/DirectDruidClientTest.java b/server/src/test/java/io/druid/client/DirectDruidClientTest.java index 3347d7d29bc5..de9d1eda80ac 100644 --- a/server/src/test/java/io/druid/client/DirectDruidClientTest.java +++ b/server/src/test/java/io/druid/client/DirectDruidClientTest.java @@ -20,15 +20,12 @@ package io.druid.client; import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.druid.java.util.http.client.HttpClient; -import io.druid.java.util.http.client.Request; -import io.druid.java.util.http.client.response.HttpResponseHandler; -import io.druid.java.util.http.client.response.StatusResponseHolder; import io.druid.client.selector.ConnectionCountServerSelectorStrategy; import io.druid.client.selector.HighestPriorityTierSelectorStrategy; import io.druid.client.selector.QueryableDruidServer; @@ -39,6 +36,10 @@ import io.druid.java.util.common.StringUtils; import io.druid.java.util.common.guava.Sequence; import io.druid.java.util.common.guava.Sequences; +import io.druid.java.util.http.client.HttpClient; +import io.druid.java.util.http.client.Request; +import io.druid.java.util.http.client.response.HttpResponseHandler; +import io.druid.java.util.http.client.response.StatusResponseHolder; import io.druid.query.Druids; import io.druid.query.QueryInterruptedException; import io.druid.query.QueryPlus; @@ -165,7 +166,7 @@ public void testRun() throws Exception serverSelector.addServerAndUpdateSegment(queryableDruidServer2, serverSelector.getSegment()); TimeBoundaryQuery query = Druids.newTimeBoundaryQueryBuilder().dataSource("test").build(); - + query = query.withOverriddenContext(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, Long.MAX_VALUE)); Sequence s1 = client1.run(QueryPlus.wrap(query), defaultContext); Assert.assertTrue(capturedRequest.hasCaptured()); Assert.assertEquals(url, capturedRequest.getValue().getUrl()); @@ -269,6 +270,7 @@ public void testCancel() throws Exception serverSelector.addServerAndUpdateSegment(queryableDruidServer1, serverSelector.getSegment()); TimeBoundaryQuery query = Druids.newTimeBoundaryQueryBuilder().dataSource("test").build(); + query = query.withOverriddenContext(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, Long.MAX_VALUE)); cancellationFuture.set(new StatusResponseHolder(HttpResponseStatus.OK, new StringBuilder("cancelled"))); Sequence results = client1.run(QueryPlus.wrap(query), defaultContext); Assert.assertEquals(HttpMethod.DELETE, capturedRequest.getValue().getMethod()); @@ -340,6 +342,7 @@ public void testQueryInterruptionExceptionLogMessage() throws JsonProcessingExce serverSelector.addServerAndUpdateSegment(queryableDruidServer, dataSegment); TimeBoundaryQuery query = Druids.newTimeBoundaryQueryBuilder().dataSource("test").build(); + query = query.withOverriddenContext(ImmutableMap.of(DirectDruidClient.QUERY_FAIL_TIME, Long.MAX_VALUE)); interruptionFuture.set( new ByteArrayInputStream( StringUtils.toUtf8("{\"error\":\"testing1\",\"errorMessage\":\"testing2\"}")