diff --git a/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQueryRunner.java b/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQueryRunner.java index 645a3b100d7f..3d704dea3a0a 100644 --- a/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQueryRunner.java +++ b/extensions-contrib/moving-average-query/src/main/java/org/apache/druid/query/movingaverage/MovingAverageQueryRunner.java @@ -68,10 +68,6 @@ */ public class MovingAverageQueryRunner implements QueryRunner { - - public static final String QUERY_FAIL_TIME = "queryFailTime"; - public static final String QUERY_TOTAL_BYTES_GATHERED = "queryTotalBytesGathered"; - private final QuerySegmentWalker walker; private final RequestLogger requestLogger; @@ -127,8 +123,11 @@ public Sequence run(QueryPlus query, ResponseContext responseContext) GroupByQuery gbq = builder.build(); ResponseContext gbqResponseContext = ResponseContext.createEmpty(); - gbqResponseContext.put(QUERY_FAIL_TIME, System.currentTimeMillis() + QueryContexts.getTimeout(gbq)); - gbqResponseContext.put(QUERY_TOTAL_BYTES_GATHERED, new AtomicLong()); + gbqResponseContext.put( + ResponseContext.Key.QUERY_FAIL_DEADLINE_MILLIS, + System.currentTimeMillis() + QueryContexts.getTimeout(gbq) + ); + gbqResponseContext.put(ResponseContext.Key.QUERY_TOTAL_BYTES_GATHERED, new AtomicLong()); Sequence results = gbq.getRunner(walker).run(QueryPlus.wrap(gbq), gbqResponseContext); try { @@ -165,8 +164,11 @@ public Sequence run(QueryPlus query, ResponseContext responseContext) maq.getContext() ); ResponseContext tsqResponseContext = ResponseContext.createEmpty(); - tsqResponseContext.put(QUERY_FAIL_TIME, System.currentTimeMillis() + QueryContexts.getTimeout(tsq)); - tsqResponseContext.put(QUERY_TOTAL_BYTES_GATHERED, new AtomicLong()); + tsqResponseContext.put( + ResponseContext.Key.QUERY_FAIL_DEADLINE_MILLIS, + System.currentTimeMillis() + QueryContexts.getTimeout(tsq) + ); + tsqResponseContext.put(ResponseContext.Key.QUERY_TOTAL_BYTES_GATHERED, new AtomicLong()); Sequence> results = tsq.getRunner(walker).run(QueryPlus.wrap(tsq), tsqResponseContext); try { diff --git a/processing/src/main/java/org/apache/druid/query/CPUTimeMetricQueryRunner.java b/processing/src/main/java/org/apache/druid/query/CPUTimeMetricQueryRunner.java index 594a3273e8ed..7953d563f58f 100644 --- a/processing/src/main/java/org/apache/druid/query/CPUTimeMetricQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/CPUTimeMetricQueryRunner.java @@ -84,6 +84,7 @@ public void after(boolean isDone, Throwable thrown) if (report) { final long cpuTimeNs = cpuTimeAccumulator.get(); if (cpuTimeNs > 0) { + responseContext.add(ResponseContext.Key.CPU_CONSUMED_NANOS, cpuTimeNs); queryWithMetrics.getQueryMetrics().reportCpuTime(cpuTimeNs).emit(emitter); } } diff --git a/processing/src/main/java/org/apache/druid/query/Druids.java b/processing/src/main/java/org/apache/druid/query/Druids.java index 47e3ede9a339..2e35891fafc4 100644 --- a/processing/src/main/java/org/apache/druid/query/Druids.java +++ b/processing/src/main/java/org/apache/druid/query/Druids.java @@ -966,7 +966,7 @@ public static ScanQueryBuilder copy(ScanQuery query) .virtualColumns(query.getVirtualColumns()) .resultFormat(query.getResultFormat()) .batchSize(query.getBatchSize()) - .limit(query.getLimit()) + .limit(query.getScanRowsLimit()) .filters(query.getFilter()) .columns(query.getColumns()) .legacy(query.isLegacy()) diff --git a/processing/src/main/java/org/apache/druid/query/ReportTimelineMissingSegmentQueryRunner.java b/processing/src/main/java/org/apache/druid/query/ReportTimelineMissingSegmentQueryRunner.java index 97b6aa27fd08..b360d228f1a7 100644 --- a/processing/src/main/java/org/apache/druid/query/ReportTimelineMissingSegmentQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/ReportTimelineMissingSegmentQueryRunner.java @@ -23,8 +23,7 @@ import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.query.context.ResponseContext; -import java.util.ArrayList; -import java.util.List; +import java.util.Collections; /** */ @@ -40,13 +39,7 @@ public ReportTimelineMissingSegmentQueryRunner(SegmentDescriptor descriptor) @Override public Sequence run(QueryPlus queryPlus, ResponseContext responseContext) { - List missingSegments = - (List) responseContext.get(ResponseContext.CTX_MISSING_SEGMENTS); - if (missingSegments == null) { - missingSegments = new ArrayList<>(); - responseContext.put(ResponseContext.CTX_MISSING_SEGMENTS, missingSegments); - } - missingSegments.add(descriptor); + responseContext.add(ResponseContext.Key.MISSING_SEGMENTS, Collections.singletonList(descriptor)); return Sequences.empty(); } } diff --git a/processing/src/main/java/org/apache/druid/query/RetryQueryRunner.java b/processing/src/main/java/org/apache/druid/query/RetryQueryRunner.java index 28bcf0b69933..6b991b870575 100644 --- a/processing/src/main/java/org/apache/druid/query/RetryQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/RetryQueryRunner.java @@ -72,7 +72,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat for (int i = 0; i < config.getNumTries(); i++) { log.info("[%,d] missing segments found. Retry attempt [%,d]", missingSegments.size(), i); - context.put(ResponseContext.CTX_MISSING_SEGMENTS, new ArrayList<>()); + context.put(ResponseContext.Key.MISSING_SEGMENTS, new ArrayList<>()); final QueryPlus retryQueryPlus = queryPlus.withQuerySegmentSpec( new MultipleSpecificSegmentSpec( missingSegments @@ -102,7 +102,7 @@ public Yielder toYielder(OutType initValue, YieldingAccumulat private List getMissingSegments(final ResponseContext context) { - final Object maybeMissingSegments = context.get(ResponseContext.CTX_MISSING_SEGMENTS); + final Object maybeMissingSegments = context.get(ResponseContext.Key.MISSING_SEGMENTS); if (maybeMissingSegments == null) { return new ArrayList<>(); } diff --git a/processing/src/main/java/org/apache/druid/query/context/ConcurrentResponseContext.java b/processing/src/main/java/org/apache/druid/query/context/ConcurrentResponseContext.java index 48838f171917..b1e648467a78 100644 --- a/processing/src/main/java/org/apache/druid/query/context/ConcurrentResponseContext.java +++ b/processing/src/main/java/org/apache/druid/query/context/ConcurrentResponseContext.java @@ -35,10 +35,10 @@ public static ConcurrentResponseContext createEmpty() return new ConcurrentResponseContext(); } - private final ConcurrentHashMap delegate = new ConcurrentHashMap<>(); + private final ConcurrentHashMap delegate = new ConcurrentHashMap<>(); @Override - protected Map getDelegate() + protected Map getDelegate() { return delegate; } diff --git a/processing/src/main/java/org/apache/druid/query/context/DefaultResponseContext.java b/processing/src/main/java/org/apache/druid/query/context/DefaultResponseContext.java index adff1ff6b3fe..33724c1bf044 100644 --- a/processing/src/main/java/org/apache/druid/query/context/DefaultResponseContext.java +++ b/processing/src/main/java/org/apache/druid/query/context/DefaultResponseContext.java @@ -35,10 +35,10 @@ public static DefaultResponseContext createEmpty() return new DefaultResponseContext(); } - private final HashMap delegate = new HashMap<>(); + private final HashMap delegate = new HashMap<>(); @Override - protected Map getDelegate() + protected Map getDelegate() { return delegate; } diff --git a/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java b/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java index 93841f482fd8..269a1e564776 100644 --- a/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java +++ b/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java @@ -19,53 +19,236 @@ package org.apache.druid.query.context; +import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; import org.apache.druid.guice.annotations.PublicApi; import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.query.SegmentDescriptor; +import org.joda.time.Interval; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; import java.util.Map; +import java.util.TreeMap; +import java.util.function.BiFunction; /** * The context for storing and passing data between chains of {@link org.apache.druid.query.QueryRunner}s. * The context is also transferred between Druid nodes with all the data it contains. - * All the keys associated with data inside the context should be stored here. - * CTX_* keys might be aggregated into an enum. Consider refactoring that. */ @PublicApi public abstract class ResponseContext { /** - * Lists intervals for which NO segment is present. + * The base interface of a response context key. + * Should be implemented by every context key. */ - public static final String CTX_UNCOVERED_INTERVALS = "uncoveredIntervals"; - /** - * Indicates if the number of uncovered intervals exceeded the limit (true/false). - */ - public static final String CTX_UNCOVERED_INTERVALS_OVERFLOWED = "uncoveredIntervalsOverflowed"; - /** - * Lists missing segments. - */ - public static final String CTX_MISSING_SEGMENTS = "missingSegments"; - /** - * Entity tag. A part of HTTP cache validation mechanism. - * Is being removed from the context before sending and used as a separate HTTP header. - */ - public static final String CTX_ETAG = "ETag"; - /** - * Query total bytes gathered. - */ - public static final String CTX_QUERY_TOTAL_BYTES_GATHERED = "queryTotalBytesGathered"; - /** - * This variable indicates when a running query should be expired, - * and is effective only when 'timeout' of queryContext has a positive value. - */ - public static final String CTX_TIMEOUT_AT = "timeoutAt"; + public interface BaseKey + { + @JsonValue + String getName(); + /** + * Merge function associated with a key: Object (Object oldValue, Object newValue) + */ + BiFunction getMergeFunction(); + } + /** - * The number of scanned rows. + * Keys associated with objects in the context. + *

+ * If it's necessary to have some new keys in the context then they might be listed in a separate enum: + *

{@code
+   * public enum ExtensionResponseContextKey implements BaseKey
+   * {
+   *   EXTENSION_KEY_1("extension_key_1"), EXTENSION_KEY_2("extension_key_2");
+   *
+   *   static {
+   *     for (BaseKey key : values()) ResponseContext.Key.registerKey(key);
+   *   }
+   *
+   *   private final String name;
+   *   private final BiFunction mergeFunction;
+   *
+   *   ExtensionResponseContextKey(String name)
+   *   {
+   *     this.name = name;
+   *     this.mergeFunction = (oldValue, newValue) -> newValue;
+   *   }
+   *
+   *   @Override public String getName() { return name; }
+   *
+   *   @Override public BiFunction getMergeFunction() { return mergeFunction; }
+   * }
+   * }
+ * Make sure all extension enum values added with {@link Key#registerKey} method. */ - public static final String CTX_COUNT = "count"; + public enum Key implements BaseKey + { + /** + * Lists intervals for which NO segment is present. + */ + UNCOVERED_INTERVALS( + "uncoveredIntervals", + (oldValue, newValue) -> { + final ArrayList result = new ArrayList((List) oldValue); + result.addAll((List) newValue); + return result; + } + ), + /** + * Indicates if the number of uncovered intervals exceeded the limit (true/false). + */ + UNCOVERED_INTERVALS_OVERFLOWED( + "uncoveredIntervalsOverflowed", + (oldValue, newValue) -> (boolean) oldValue || (boolean) newValue + ), + /** + * Lists missing segments. + */ + MISSING_SEGMENTS( + "missingSegments", + (oldValue, newValue) -> { + final ArrayList result = new ArrayList((List) oldValue); + result.addAll((List) newValue); + return result; + } + ), + /** + * Entity tag. A part of HTTP cache validation mechanism. + * Is being removed from the context before sending and used as a separate HTTP header. + */ + ETAG("ETag"), + /** + * Query fail time (current time + timeout). + * It is not updated continuously as {@link Key#TIMEOUT_AT}. + */ + QUERY_FAIL_DEADLINE_MILLIS("queryFailTime"), + /** + * Query total bytes gathered. + */ + QUERY_TOTAL_BYTES_GATHERED("queryTotalBytesGathered"), + /** + * This variable indicates when a running query should be expired, + * and is effective only when 'timeout' of queryContext has a positive value. + * Continuously updated by {@link org.apache.druid.query.scan.ScanQueryEngine} + * by reducing its value on the time of every scan iteration. + */ + TIMEOUT_AT("timeoutAt"), + /** + * The number of scanned rows. + * For backward compatibility the context key name still equals to "count". + */ + NUM_SCANNED_ROWS( + "count", + (oldValue, newValue) -> (long) oldValue + (long) newValue + ), + /** + * The total CPU time for threads related to Sequence processing of the query. + * Resulting value on a Broker is a sum of downstream values from historicals / realtime nodes. + * For additional information see {@link org.apache.druid.query.CPUTimeMetricQueryRunner} + */ + CPU_CONSUMED_NANOS( + "cpuConsumed", + (oldValue, newValue) -> (long) oldValue + (long) newValue + ), + /** + * Indicates if a {@link ResponseContext} was truncated during serialization. + */ + TRUNCATED( + "truncated", + (oldValue, newValue) -> (boolean) oldValue || (boolean) newValue + ); + + /** + * TreeMap is used to have the natural ordering of its keys + */ + private static final Map registeredKeys = new TreeMap<>(); + + static { + for (BaseKey key : values()) { + registerKey(key); + } + } + + /** + * Primary way of registering context keys. + * @throws IllegalArgumentException if the key has already been registered. + */ + public static synchronized void registerKey(BaseKey key) + { + Preconditions.checkArgument( + !registeredKeys.containsKey(key.getName()), + "Key [%s] has already been registered as a context key", + key.getName() + ); + registeredKeys.put(key.getName(), key); + } + + /** + * Returns a registered key associated with the name {@param name}. + * @throws IllegalStateException if a corresponding key has not been registered. + */ + public static BaseKey keyOf(String name) + { + Preconditions.checkState( + registeredKeys.containsKey(name), + "Key [%s] has not yet been registered as a context key", + name + ); + return registeredKeys.get(name); + } + + /** + * Returns all keys registered via {@link Key#registerKey}. + */ + public static Collection getAllRegisteredKeys() + { + return Collections.unmodifiableCollection(registeredKeys.values()); + } + + private final String name; + + private final BiFunction mergeFunction; + + Key(String name) + { + this.name = name; + this.mergeFunction = (oldValue, newValue) -> newValue; + } + + Key(String name, BiFunction mergeFunction) + { + this.name = name; + this.mergeFunction = mergeFunction; + } + + @Override + public String getName() + { + return name; + } + + @Override + public BiFunction getMergeFunction() + { + return mergeFunction; + } + } + + protected abstract Map getDelegate(); + + private static final Comparator> valueLengthReversedComparator = + Comparator.comparing((Map.Entry e) -> e.getValue().toString().length()).reversed(); /** * Create an empty DefaultResponseContext instance @@ -76,56 +259,180 @@ public static ResponseContext createEmpty() return DefaultResponseContext.createEmpty(); } - protected abstract Map getDelegate(); + /** + * Deserializes a string into {@link ResponseContext} using given {@link ObjectMapper}. + * @throws IllegalStateException if one of the deserialized map keys has not been registered. + */ + public static ResponseContext deserialize(String responseContext, ObjectMapper objectMapper) throws IOException + { + final Map keyNameToObjects = objectMapper.readValue( + responseContext, + JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT + ); + final ResponseContext context = ResponseContext.createEmpty(); + keyNameToObjects.forEach((keyName, value) -> { + final BaseKey key = Key.keyOf(keyName); + context.add(key, value); + }); + return context; + } - public Object put(String key, Object value) + /** + * Associates the specified object with the specified key. + * @throws IllegalStateException if the key has not been registered. + */ + public Object put(BaseKey key, Object value) { - return getDelegate().put(key, value); + final BaseKey registeredKey = Key.keyOf(key.getName()); + return getDelegate().put(registeredKey, value); } - public Object get(String key) + public Object get(BaseKey key) { return getDelegate().get(key); } - public Object remove(String key) + public Object remove(BaseKey key) { return getDelegate().remove(key); } - public void putAll(Map m) + /** + * Adds (merges) a new value associated with a key to an old value. + * See merge function of a context key for a specific implementation. + * @throws IllegalStateException if the key has not been registered. + */ + public Object add(BaseKey key, Object value) { - getDelegate().putAll(m); + final BaseKey registeredKey = Key.keyOf(key.getName()); + return getDelegate().merge(registeredKey, value, key.getMergeFunction()); } - public void putAll(ResponseContext responseContext) + /** + * Merges a response context into the current. + * @throws IllegalStateException If a key of the {@code responseContext} has not been registered. + */ + public void merge(ResponseContext responseContext) { - getDelegate().putAll(responseContext.getDelegate()); + responseContext.getDelegate().forEach((key, newValue) -> { + if (newValue != null) { + add(key, newValue); + } + }); } - public int size() + /** + * Serializes the context given that the resulting string length is less than the provided limit. + * This method removes some elements from context collections if it's needed to satisfy the limit. + * There is no explicit priorities of keys which values are being truncated because for now there are only + * two potential limit breaking keys ({@link Key#UNCOVERED_INTERVALS} + * and {@link Key#MISSING_SEGMENTS}) and their values are arrays. + * Thus current implementation considers these arrays as equal prioritized and starts removing elements from + * the array which serialized value length is the biggest. + * The resulting string might be correctly deserialized to {@link ResponseContext}. + */ + public SerializationResult serializeWith(ObjectMapper objectMapper, int maxCharsNumber) throws JsonProcessingException { - return getDelegate().size(); + final String fullSerializedString = objectMapper.writeValueAsString(getDelegate()); + if (fullSerializedString.length() <= maxCharsNumber) { + return new SerializationResult(fullSerializedString, fullSerializedString); + } else { + // Indicates that the context is truncated during serialization. + add(Key.TRUNCATED, true); + final ObjectNode contextJsonNode = objectMapper.valueToTree(getDelegate()); + final ArrayList> sortedNodesByLength = Lists.newArrayList(contextJsonNode.fields()); + sortedNodesByLength.sort(valueLengthReversedComparator); + int needToRemoveCharsNumber = fullSerializedString.length() - maxCharsNumber; + // The complexity of this block is O(n*m*log(m)) where n - context size, m - context's array size + for (Map.Entry e : sortedNodesByLength) { + final String fieldName = e.getKey(); + final JsonNode node = e.getValue(); + if (node.isArray()) { + if (needToRemoveCharsNumber >= node.toString().length()) { + // We need to remove more chars than the field's length so removing it completely + contextJsonNode.remove(fieldName); + // Since the field is completely removed (name + value) we need to do a recalculation + needToRemoveCharsNumber = contextJsonNode.toString().length() - maxCharsNumber; + } else { + final ArrayNode arrayNode = (ArrayNode) node; + needToRemoveCharsNumber -= removeNodeElementsToSatisfyCharsLimit(arrayNode, needToRemoveCharsNumber); + if (arrayNode.size() == 0) { + // The field is empty, removing it because an empty array field may be misleading + // for the recipients of the truncated response context. + contextJsonNode.remove(fieldName); + // Since the field is completely removed (name + value) we need to do a recalculation + needToRemoveCharsNumber = contextJsonNode.toString().length() - maxCharsNumber; + } + } // node is not an array + } else { + // A context should not contain nulls so we completely remove the field. + contextJsonNode.remove(fieldName); + // Since the field is completely removed (name + value) we need to do a recalculation + needToRemoveCharsNumber = contextJsonNode.toString().length() - maxCharsNumber; + } + if (needToRemoveCharsNumber <= 0) { + break; + } + } + return new SerializationResult(contextJsonNode.toString(), fullSerializedString); + } } - public String serializeWith(ObjectMapper objectMapper) throws JsonProcessingException + /** + * Removes {@code node}'s elements which total length of serialized values is greater or equal to the passed limit. + * If it is impossible to satisfy the limit the method removes all {@code node}'s elements. + * On every iteration it removes exactly half of the remained elements to reduce the overall complexity. + * @param node {@link ArrayNode} which elements are being removed. + * @param needToRemoveCharsNumber the number of chars need to be removed. + * @return the number of removed chars. + */ + private static int removeNodeElementsToSatisfyCharsLimit(ArrayNode node, int needToRemoveCharsNumber) { - return objectMapper.writeValueAsString(getDelegate()); + int removedCharsNumber = 0; + while (node.size() > 0 && needToRemoveCharsNumber > removedCharsNumber) { + final int lengthBeforeRemove = node.toString().length(); + // Reducing complexity by removing half of array's elements + final int removeUntil = node.size() / 2; + for (int removeAt = node.size() - 1; removeAt >= removeUntil; removeAt--) { + node.remove(removeAt); + } + final int lengthAfterRemove = node.toString().length(); + removedCharsNumber += lengthBeforeRemove - lengthAfterRemove; + } + return removedCharsNumber; } - public static ResponseContext deserialize(String responseContext, ObjectMapper objectMapper) throws IOException + /** + * Serialization result of {@link ResponseContext}. + * Response context might be serialized using max legth limit, in this case the context might be reduced + * by removing max-length fields one by one unless serialization result length is less than the limit. + * This structure has a reduced serialization result along with full result and boolean property + * indicating if some fields were removed from the context. + */ + public static class SerializationResult { - final Map delegate = objectMapper.readValue( - responseContext, - JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT - ); - return new ResponseContext() + private final String truncatedResult; + private final String fullResult; + + SerializationResult(String truncatedResult, String fullResult) { - @Override - protected Map getDelegate() - { - return delegate; - } - }; + this.truncatedResult = truncatedResult; + this.fullResult = fullResult; + } + + public String getTruncatedResult() + { + return truncatedResult; + } + + public String getFullResult() + { + return fullResult; + } + + public Boolean isReduced() + { + return !truncatedResult.equals(fullResult); + } } } diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java index 7b314ce45702..719f5f27e6f7 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java @@ -110,7 +110,8 @@ public static Order fromString(String name) private final VirtualColumns virtualColumns; private final ResultFormat resultFormat; private final int batchSize; - private final long limit; + @JsonProperty("limit") + private final long scanRowsLimit; private final DimFilter dimFilter; private final List columns; private final Boolean legacy; @@ -125,7 +126,7 @@ public ScanQuery( @JsonProperty("virtualColumns") VirtualColumns virtualColumns, @JsonProperty("resultFormat") ResultFormat resultFormat, @JsonProperty("batchSize") int batchSize, - @JsonProperty("limit") long limit, + @JsonProperty("limit") long scanRowsLimit, @JsonProperty("order") Order order, @JsonProperty("filter") DimFilter dimFilter, @JsonProperty("columns") List columns, @@ -141,9 +142,9 @@ public ScanQuery( this.batchSize > 0, "batchSize must be greater than 0" ); - this.limit = (limit == 0) ? Long.MAX_VALUE : limit; + this.scanRowsLimit = (scanRowsLimit == 0) ? Long.MAX_VALUE : scanRowsLimit; Preconditions.checkArgument( - this.limit > 0, + this.scanRowsLimit > 0, "limit must be greater than 0" ); this.dimFilter = dimFilter; @@ -201,9 +202,9 @@ public int getBatchSize() } @JsonProperty - public long getLimit() + public long getScanRowsLimit() { - return limit; + return scanRowsLimit; } @JsonProperty @@ -311,7 +312,7 @@ public boolean equals(final Object o) } final ScanQuery scanQuery = (ScanQuery) o; return batchSize == scanQuery.batchSize && - limit == scanQuery.limit && + scanRowsLimit == scanQuery.scanRowsLimit && Objects.equals(legacy, scanQuery.legacy) && Objects.equals(virtualColumns, scanQuery.virtualColumns) && Objects.equals(resultFormat, scanQuery.resultFormat) && @@ -322,7 +323,8 @@ public boolean equals(final Object o) @Override public int hashCode() { - return Objects.hash(super.hashCode(), virtualColumns, resultFormat, batchSize, limit, dimFilter, columns, legacy); + return Objects.hash(super.hashCode(), virtualColumns, resultFormat, batchSize, + scanRowsLimit, dimFilter, columns, legacy); } @Override @@ -334,7 +336,7 @@ public String toString() ", virtualColumns=" + getVirtualColumns() + ", resultFormat='" + resultFormat + '\'' + ", batchSize=" + batchSize + - ", limit=" + limit + + ", scanRowsLimit=" + scanRowsLimit + ", dimFilter=" + dimFilter + ", columns=" + columns + ", legacy=" + legacy + diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java index 8d0bf512961a..d4155fac26a3 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryEngine.java @@ -67,14 +67,15 @@ public Sequence process( // "legacy" should be non-null due to toolChest.mergeResults final boolean legacy = Preconditions.checkNotNull(query.isLegacy(), "WTF?! Expected non-null legacy"); - if (responseContext.get(ResponseContext.CTX_COUNT) != null) { - long count = (long) responseContext.get(ResponseContext.CTX_COUNT); - if (count >= query.getLimit() && query.getOrder().equals(ScanQuery.Order.NONE)) { + final Object numScannedRows = responseContext.get(ResponseContext.Key.NUM_SCANNED_ROWS); + if (numScannedRows != null) { + long count = (long) numScannedRows; + if (count >= query.getScanRowsLimit() && query.getOrder().equals(ScanQuery.Order.NONE)) { return Sequences.empty(); } } final boolean hasTimeout = QueryContexts.hasTimeout(query); - final long timeoutAt = (long) responseContext.get(ResponseContext.CTX_TIMEOUT_AT); + final long timeoutAt = (long) responseContext.get(ResponseContext.Key.TIMEOUT_AT); final long start = System.currentTimeMillis(); final StorageAdapter adapter = segment.asStorageAdapter(); @@ -121,10 +122,8 @@ public Sequence process( final Filter filter = Filters.convertToCNFFromQueryContext(query, Filters.toFilter(query.getFilter())); - if (responseContext.get(ResponseContext.CTX_COUNT) == null) { - responseContext.put(ResponseContext.CTX_COUNT, 0L); - } - final long limit = calculateLimit(query, responseContext); + responseContext.add(ResponseContext.Key.NUM_SCANNED_ROWS, 0L); + final long limit = calculateRemainingScanRowsLimit(query, responseContext); return Sequences.concat( adapter .makeCursors( @@ -187,13 +186,10 @@ public ScanResultValue next() } else { throw new UOE("resultFormat[%s] is not supported", resultFormat.toString()); } - responseContext.put( - ResponseContext.CTX_COUNT, - (long) responseContext.get(ResponseContext.CTX_COUNT) + (offset - lastOffset) - ); + responseContext.add(ResponseContext.Key.NUM_SCANNED_ROWS, offset - lastOffset); if (hasTimeout) { responseContext.put( - ResponseContext.CTX_TIMEOUT_AT, + ResponseContext.Key.TIMEOUT_AT, timeoutAt - (System.currentTimeMillis() - start) ); } @@ -263,11 +259,11 @@ public void cleanup(Iterator iterFromMake) * If we're performing time-ordering, we want to scan through the first `limit` rows in each segment ignoring the number * of rows already counted on other segments. */ - private long calculateLimit(ScanQuery query, ResponseContext responseContext) + private long calculateRemainingScanRowsLimit(ScanQuery query, ResponseContext responseContext) { if (query.getOrder().equals(ScanQuery.Order.NONE)) { - return query.getLimit() - (long) responseContext.get(ResponseContext.CTX_COUNT); + return query.getScanRowsLimit() - (long) responseContext.get(ResponseContext.Key.NUM_SCANNED_ROWS); } - return query.getLimit(); + return query.getScanRowsLimit(); } } diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryLimitRowIterator.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryLimitRowIterator.java index 4e30e869aa18..b603dd54d7d9 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryLimitRowIterator.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryLimitRowIterator.java @@ -65,7 +65,7 @@ public ScanQueryLimitRowIterator( { this.query = (ScanQuery) queryPlus.getQuery(); this.resultFormat = query.getResultFormat(); - this.limit = query.getLimit(); + this.limit = query.getScanRowsLimit(); Query historicalQuery = queryPlus.getQuery().withOverriddenContext(ImmutableMap.of(ScanQuery.CTX_KEY_OUTERMOST, false)); Sequence baseSequence = baseRunner.run(QueryPlus.wrap(historicalQuery), responseContext); diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryQueryToolChest.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryQueryToolChest.java index 6d6758b19260..95006cee5766 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryQueryToolChest.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryQueryToolChest.java @@ -61,7 +61,7 @@ public QueryRunner mergeResults(final QueryRunner queryPlusWithNonNullLegacy = queryPlus.withQuery(scanQuery); - if (scanQuery.getLimit() == Long.MAX_VALUE) { + if (scanQuery.getScanRowsLimit() == Long.MAX_VALUE) { return runner.run(queryPlusWithNonNullLegacy, responseContext); } return new BaseSequence<>( diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryRunnerFactory.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryRunnerFactory.java index 570819a3bc52..645f545ef280 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQueryRunnerFactory.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQueryRunnerFactory.java @@ -92,9 +92,9 @@ public QueryRunner mergeRunners( ScanQuery query = (ScanQuery) queryPlus.getQuery(); // Note: this variable is effective only when queryContext has a timeout. - // See the comment of CTX_TIMEOUT_AT. + // See the comment of ResponseContext.Key.TIMEOUT_AT. final long timeoutAt = System.currentTimeMillis() + QueryContexts.getTimeout(queryPlus.getQuery()); - responseContext.put(ResponseContext.CTX_TIMEOUT_AT, timeoutAt); + responseContext.put(ResponseContext.Key.TIMEOUT_AT, timeoutAt); if (query.getOrder().equals(ScanQuery.Order.NONE)) { // Use normal strategy @@ -104,8 +104,8 @@ public QueryRunner mergeRunners( input -> input.run(queryPlus, responseContext) ) ); - if (query.getLimit() <= Integer.MAX_VALUE) { - return returnedRows.limit(Math.toIntExact(query.getLimit())); + if (query.getScanRowsLimit() <= Integer.MAX_VALUE) { + return returnedRows.limit(Math.toIntExact(query.getScanRowsLimit())); } else { return returnedRows; } @@ -120,7 +120,7 @@ public QueryRunner mergeRunners( int maxRowsQueuedForOrdering = (query.getMaxRowsQueuedForOrdering() == null ? scanQueryConfig.getMaxRowsQueuedForOrdering() : query.getMaxRowsQueuedForOrdering()); - if (query.getLimit() <= maxRowsQueuedForOrdering) { + if (query.getScanRowsLimit() <= maxRowsQueuedForOrdering) { // Use priority queue strategy return priorityQueueSortAndLimit( Sequences.concat(Sequences.map( @@ -189,7 +189,7 @@ public QueryRunner mergeRunners( + " Try reducing the scope of the query to scan fewer partitions than the configurable limit of" + " %,d partitions or lower the row limit below %,d.", maxNumPartitionsInSegment, - query.getLimit(), + query.getScanRowsLimit(), scanQueryConfig.getMaxSegmentPartitionsOrderedInMemory(), scanQueryConfig.getMaxRowsQueuedForOrdering() ); @@ -207,16 +207,16 @@ Sequence priorityQueueSortAndLimit( { Comparator priorityQComparator = new ScanResultValueTimestampComparator(scanQuery); - if (scanQuery.getLimit() > Integer.MAX_VALUE) { + if (scanQuery.getScanRowsLimit() > Integer.MAX_VALUE) { throw new UOE( "Limit of %,d rows not supported for priority queue strategy of time-ordering scan results", - scanQuery.getLimit() + scanQuery.getScanRowsLimit() ); } // Converting the limit from long to int could theoretically throw an ArithmeticException but this branch // only runs if limit < MAX_LIMIT_FOR_IN_MEMORY_TIME_ORDERING (which should be < Integer.MAX_VALUE) - int limit = Math.toIntExact(scanQuery.getLimit()); + int limit = Math.toIntExact(scanQuery.getScanRowsLimit()); PriorityQueue q = new PriorityQueue<>(limit, priorityQComparator); @@ -337,7 +337,7 @@ Sequence nWayMergeAndLimit( ) ) ); - long limit = ((ScanQuery) (queryPlus.getQuery())).getLimit(); + long limit = ((ScanQuery) (queryPlus.getQuery())).getScanRowsLimit(); if (limit == Long.MAX_VALUE) { return resultSequence; } @@ -370,9 +370,9 @@ public Sequence run(QueryPlus queryPlus, Respo } // it happens in unit tests - final Number timeoutAt = (Number) responseContext.get(ResponseContext.CTX_TIMEOUT_AT); + final Number timeoutAt = (Number) responseContext.get(ResponseContext.Key.TIMEOUT_AT); if (timeoutAt == null || timeoutAt.longValue() == 0L) { - responseContext.put(ResponseContext.CTX_TIMEOUT_AT, JodaUtils.MAX_INSTANT); + responseContext.put(ResponseContext.Key.TIMEOUT_AT, JodaUtils.MAX_INSTANT); } return engine.process((ScanQuery) query, segment, responseContext); } diff --git a/processing/src/main/java/org/apache/druid/query/spec/SpecificSegmentQueryRunner.java b/processing/src/main/java/org/apache/druid/query/spec/SpecificSegmentQueryRunner.java index 94c5f8fc8e67..625f0325229e 100644 --- a/processing/src/main/java/org/apache/druid/query/spec/SpecificSegmentQueryRunner.java +++ b/processing/src/main/java/org/apache/druid/query/spec/SpecificSegmentQueryRunner.java @@ -31,13 +31,11 @@ import org.apache.druid.query.Query; import org.apache.druid.query.QueryPlus; import org.apache.druid.query.QueryRunner; -import org.apache.druid.query.SegmentDescriptor; import org.apache.druid.query.context.ResponseContext; import org.apache.druid.segment.SegmentMissingException; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.Collections; /** */ @@ -152,13 +150,10 @@ public RetType wrap(Supplier sequenceProcessing) private void appendMissingSegment(ResponseContext responseContext) { - List missingSegments = - (List) responseContext.get(ResponseContext.CTX_MISSING_SEGMENTS); - if (missingSegments == null) { - missingSegments = new ArrayList<>(); - responseContext.put(ResponseContext.CTX_MISSING_SEGMENTS, missingSegments); - } - missingSegments.add(specificSpec.getDescriptor()); + responseContext.add( + ResponseContext.Key.MISSING_SEGMENTS, + Collections.singletonList(specificSpec.getDescriptor()) + ); } private RetType doNamed(Thread currThread, String currName, String newName, Supplier toRun) diff --git a/processing/src/main/java/org/apache/druid/segment/StringDimensionHandler.java b/processing/src/main/java/org/apache/druid/segment/StringDimensionHandler.java index ff7809351dbe..c14bd319bb13 100644 --- a/processing/src/main/java/org/apache/druid/segment/StringDimensionHandler.java +++ b/processing/src/main/java/org/apache/druid/segment/StringDimensionHandler.java @@ -58,7 +58,7 @@ public class StringDimensionHandler implements DimensionHandler()); + context.put(ResponseContext.Key.MISSING_SEGMENTS, new ArrayList<>()); RetryQueryRunner> runner = new RetryQueryRunner<>( new QueryRunner>() { @Override public Sequence> run(QueryPlus queryPlus, ResponseContext context) { - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).add( - new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1) + context.add( + ResponseContext.Key.MISSING_SEGMENTS, + Collections.singletonList(new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1)) ); return Sequences.empty(); } @@ -124,7 +125,7 @@ public boolean isReturnPartialResults() Assert.assertTrue( "Should have one entry in the list of missing segments", - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).size() == 1 + ((List) context.get(ResponseContext.Key.MISSING_SEGMENTS)).size() == 1 ); Assert.assertTrue("Should return an empty sequence as a result", ((List) actualResults).size() == 0); } @@ -134,8 +135,8 @@ public boolean isReturnPartialResults() public void testRetry() { ResponseContext context = ConcurrentResponseContext.createEmpty(); - context.put("count", 0); - context.put(ResponseContext.CTX_MISSING_SEGMENTS, new ArrayList<>()); + context.put(ResponseContext.Key.NUM_SCANNED_ROWS, 0); + context.put(ResponseContext.Key.MISSING_SEGMENTS, new ArrayList<>()); RetryQueryRunner> runner = new RetryQueryRunner<>( new QueryRunner>() { @@ -145,11 +146,12 @@ public Sequence> run( ResponseContext context ) { - if ((int) context.get("count") == 0) { - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).add( - new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1) + if ((int) context.get(ResponseContext.Key.NUM_SCANNED_ROWS) == 0) { + context.add( + ResponseContext.Key.MISSING_SEGMENTS, + Collections.singletonList(new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1)) ); - context.put("count", 1); + context.put(ResponseContext.Key.NUM_SCANNED_ROWS, 1); return Sequences.empty(); } else { return Sequences.simple( @@ -174,7 +176,7 @@ public Sequence> run( Assert.assertTrue("Should return a list with one element", ((List) actualResults).size() == 1); Assert.assertTrue( "Should have nothing in missingSegment list", - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).size() == 0 + ((List) context.get(ResponseContext.Key.MISSING_SEGMENTS)).size() == 0 ); } @@ -182,8 +184,8 @@ public Sequence> run( public void testRetryMultiple() { ResponseContext context = ConcurrentResponseContext.createEmpty(); - context.put("count", 0); - context.put(ResponseContext.CTX_MISSING_SEGMENTS, new ArrayList<>()); + context.put(ResponseContext.Key.NUM_SCANNED_ROWS, 0); + context.put(ResponseContext.Key.MISSING_SEGMENTS, new ArrayList<>()); RetryQueryRunner> runner = new RetryQueryRunner<>( new QueryRunner>() { @@ -193,11 +195,12 @@ public Sequence> run( ResponseContext context ) { - if ((int) context.get("count") < 3) { - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).add( - new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1) + if ((int) context.get(ResponseContext.Key.NUM_SCANNED_ROWS) < 3) { + context.add( + ResponseContext.Key.MISSING_SEGMENTS, + Collections.singletonList(new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1)) ); - context.put("count", (int) context.get("count") + 1); + context.put(ResponseContext.Key.NUM_SCANNED_ROWS, (int) context.get(ResponseContext.Key.NUM_SCANNED_ROWS) + 1); return Sequences.empty(); } else { return Sequences.simple( @@ -222,7 +225,7 @@ public Sequence> run( Assert.assertTrue("Should return a list with one element", ((List) actualResults).size() == 1); Assert.assertTrue( "Should have nothing in missingSegment list", - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).size() == 0 + ((List) context.get(ResponseContext.Key.MISSING_SEGMENTS)).size() == 0 ); } @@ -230,7 +233,7 @@ public Sequence> run( public void testException() { ResponseContext context = ConcurrentResponseContext.createEmpty(); - context.put(ResponseContext.CTX_MISSING_SEGMENTS, new ArrayList<>()); + context.put(ResponseContext.Key.MISSING_SEGMENTS, new ArrayList<>()); RetryQueryRunner> runner = new RetryQueryRunner<>( new QueryRunner>() { @@ -240,8 +243,9 @@ public Sequence> run( ResponseContext context ) { - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).add( - new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1) + context.add( + ResponseContext.Key.MISSING_SEGMENTS, + Collections.singletonList(new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1)) ); return Sequences.empty(); } @@ -254,7 +258,7 @@ public Sequence> run( Assert.assertTrue( "Should have one entry in the list of missing segments", - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).size() == 1 + ((List) context.get(ResponseContext.Key.MISSING_SEGMENTS)).size() == 1 ); } @@ -262,8 +266,8 @@ public Sequence> run( public void testNoDuplicateRetry() { ResponseContext context = ConcurrentResponseContext.createEmpty(); - context.put("count", 0); - context.put(ResponseContext.CTX_MISSING_SEGMENTS, new ArrayList<>()); + context.put(ResponseContext.Key.NUM_SCANNED_ROWS, 0); + context.put(ResponseContext.Key.MISSING_SEGMENTS, new ArrayList<>()); RetryQueryRunner> runner = new RetryQueryRunner<>( new QueryRunner>() { @@ -274,15 +278,16 @@ public Sequence> run( ) { final Query> query = queryPlus.getQuery(); - if ((int) context.get("count") == 0) { + if ((int) context.get(ResponseContext.Key.NUM_SCANNED_ROWS) == 0) { // assume 2 missing segments at first run - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).add( - new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1) - ); - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).add( - new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 2) + context.add( + ResponseContext.Key.MISSING_SEGMENTS, + Arrays.asList( + new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 1), + new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 2) + ) ); - context.put("count", 1); + context.put(ResponseContext.Key.NUM_SCANNED_ROWS, 1); return Sequences.simple( Collections.singletonList( new Result<>( @@ -293,14 +298,15 @@ public Sequence> run( ) ) ); - } else if ((int) context.get("count") == 1) { + } else if ((int) context.get(ResponseContext.Key.NUM_SCANNED_ROWS) == 1) { // this is first retry Assert.assertTrue("Should retry with 2 missing segments", ((MultipleSpecificSegmentSpec) ((BaseQuery) query).getQuerySegmentSpec()).getDescriptors().size() == 2); // assume only left 1 missing at first retry - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).add( - new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 2) + context.add( + ResponseContext.Key.MISSING_SEGMENTS, + Collections.singletonList(new SegmentDescriptor(Intervals.utc(178888, 1999999), "test", 2)) ); - context.put("count", 2); + context.put(ResponseContext.Key.NUM_SCANNED_ROWS, 2); return Sequences.simple( Collections.singletonList( new Result<>( @@ -315,7 +321,7 @@ public Sequence> run( // this is second retry Assert.assertTrue("Should retry with 1 missing segments", ((MultipleSpecificSegmentSpec) ((BaseQuery) query).getQuerySegmentSpec()).getDescriptors().size() == 1); // assume no more missing at second retry - context.put("count", 3); + context.put(ResponseContext.Key.NUM_SCANNED_ROWS, 3); return Sequences.simple( Collections.singletonList( new Result<>( @@ -338,7 +344,7 @@ public Sequence> run( Assert.assertTrue("Should return a list with 3 elements", ((List) actualResults).size() == 3); Assert.assertTrue( "Should have nothing in missingSegment list", - ((List) context.get(ResponseContext.CTX_MISSING_SEGMENTS)).size() == 0 + ((List) context.get(ResponseContext.Key.MISSING_SEGMENTS)).size() == 0 ); } } diff --git a/processing/src/test/java/org/apache/druid/query/UnionQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/UnionQueryRunnerTest.java index 78b7712ec4cc..a64c31301f83 100644 --- a/processing/src/test/java/org/apache/druid/query/UnionQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/UnionQueryRunnerTest.java @@ -28,12 +28,15 @@ import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; public class UnionQueryRunnerTest { @Test public void testUnionQueryRunner() { + AtomicBoolean ds1 = new AtomicBoolean(false); + AtomicBoolean ds2 = new AtomicBoolean(false); QueryRunner baseRunner = new QueryRunner() { @Override @@ -43,10 +46,10 @@ public Sequence run(QueryPlus queryPlus, ResponseContext responseContext) Assert.assertTrue(queryPlus.getQuery().getDataSource() instanceof TableDataSource); String dsName = Iterables.getOnlyElement(queryPlus.getQuery().getDataSource().getNames()); if ("ds1".equals(dsName)) { - responseContext.put("ds1", "ds1"); + ds1.compareAndSet(false, true); return Sequences.simple(Arrays.asList(1, 2, 3)); } else if ("ds2".equals(dsName)) { - responseContext.put("ds2", "ds2"); + ds2.compareAndSet(false, true); return Sequences.simple(Arrays.asList(4, 5, 6)); } else { throw new AssertionError("Unexpected DataSource"); @@ -71,11 +74,8 @@ public Sequence run(QueryPlus queryPlus, ResponseContext responseContext) Sequence result = runner.run(QueryPlus.wrap(q), responseContext); List res = result.toList(); Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5, 6), res); - - // verify response context - Assert.assertEquals(2, responseContext.size()); - Assert.assertEquals("ds1", responseContext.get("ds1")); - Assert.assertEquals("ds2", responseContext.get("ds2")); + Assert.assertEquals(true, ds1.get()); + Assert.assertEquals(true, ds2.get()); } } diff --git a/processing/src/test/java/org/apache/druid/query/context/ResponseContextTest.java b/processing/src/test/java/org/apache/druid/query/context/ResponseContextTest.java new file mode 100644 index 000000000000..f1354c3ea6b3 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/context/ResponseContextTest.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.context; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.common.collect.ImmutableMap; +import org.apache.druid.jackson.DefaultObjectMapper; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.query.SegmentDescriptor; +import org.joda.time.Interval; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +public class ResponseContextTest +{ + + enum ExtensionResponseContextKey implements ResponseContext.BaseKey + { + EXTENSION_KEY_1("extension_key_1"), + EXTENSION_KEY_2("extension_key_2", (oldValue, newValue) -> (long) oldValue + (long) newValue); + + static { + for (ResponseContext.BaseKey key : values()) { + ResponseContext.Key.registerKey(key); + } + } + + private final String name; + private final BiFunction mergeFunction; + + ExtensionResponseContextKey(String name) + { + this.name = name; + this.mergeFunction = (oldValue, newValue) -> newValue; + } + + ExtensionResponseContextKey(String name, BiFunction mergeFunction) + { + this.name = name; + this.mergeFunction = mergeFunction; + } + + @Override + public String getName() + { + return name; + } + + @Override + public BiFunction getMergeFunction() + { + return mergeFunction; + } + } + + private final ResponseContext.BaseKey nonregisteredKey = new ResponseContext.BaseKey() + { + @Override + public String getName() + { + return "non-registered-key"; + } + + @Override + public BiFunction getMergeFunction() + { + return (Object a, Object b) -> a; + } + }; + + @Test(expected = IllegalStateException.class) + public void putISETest() + { + ResponseContext.createEmpty().put(nonregisteredKey, new Object()); + } + + @Test(expected = IllegalStateException.class) + public void addISETest() + { + ResponseContext.createEmpty().add(nonregisteredKey, new Object()); + } + + @Test(expected = IllegalArgumentException.class) + public void registerKeyIAETest() + { + ResponseContext.Key.registerKey(ResponseContext.Key.NUM_SCANNED_ROWS); + } + + @Test + public void mergeValueTest() + { + final ResponseContext ctx = ResponseContext.createEmpty(); + ctx.add(ResponseContext.Key.ETAG, "dummy-etag"); + Assert.assertEquals("dummy-etag", ctx.get(ResponseContext.Key.ETAG)); + ctx.add(ResponseContext.Key.ETAG, "new-dummy-etag"); + Assert.assertEquals("new-dummy-etag", ctx.get(ResponseContext.Key.ETAG)); + + final Interval interval01 = Intervals.of("2019-01-01/P1D"); + ctx.add(ResponseContext.Key.UNCOVERED_INTERVALS, Collections.singletonList(interval01)); + Assert.assertArrayEquals( + Collections.singletonList(interval01).toArray(), + ((List) ctx.get(ResponseContext.Key.UNCOVERED_INTERVALS)).toArray() + ); + final Interval interval12 = Intervals.of("2019-01-02/P1D"); + final Interval interval23 = Intervals.of("2019-01-03/P1D"); + ctx.add(ResponseContext.Key.UNCOVERED_INTERVALS, Arrays.asList(interval12, interval23)); + Assert.assertArrayEquals( + Arrays.asList(interval01, interval12, interval23).toArray(), + ((List) ctx.get(ResponseContext.Key.UNCOVERED_INTERVALS)).toArray() + ); + + final SegmentDescriptor sd01 = new SegmentDescriptor(interval01, "01", 0); + ctx.add(ResponseContext.Key.MISSING_SEGMENTS, Collections.singletonList(sd01)); + Assert.assertArrayEquals( + Collections.singletonList(sd01).toArray(), + ((List) ctx.get(ResponseContext.Key.MISSING_SEGMENTS)).toArray() + ); + final SegmentDescriptor sd12 = new SegmentDescriptor(interval12, "12", 1); + final SegmentDescriptor sd23 = new SegmentDescriptor(interval23, "23", 2); + ctx.add(ResponseContext.Key.MISSING_SEGMENTS, Arrays.asList(sd12, sd23)); + Assert.assertArrayEquals( + Arrays.asList(sd01, sd12, sd23).toArray(), + ((List) ctx.get(ResponseContext.Key.MISSING_SEGMENTS)).toArray() + ); + + ctx.add(ResponseContext.Key.NUM_SCANNED_ROWS, 0L); + Assert.assertEquals(0L, ctx.get(ResponseContext.Key.NUM_SCANNED_ROWS)); + ctx.add(ResponseContext.Key.NUM_SCANNED_ROWS, 1L); + Assert.assertEquals(1L, ctx.get(ResponseContext.Key.NUM_SCANNED_ROWS)); + ctx.add(ResponseContext.Key.NUM_SCANNED_ROWS, 3L); + Assert.assertEquals(4L, ctx.get(ResponseContext.Key.NUM_SCANNED_ROWS)); + + ctx.add(ResponseContext.Key.UNCOVERED_INTERVALS_OVERFLOWED, false); + Assert.assertEquals(false, ctx.get(ResponseContext.Key.UNCOVERED_INTERVALS_OVERFLOWED)); + ctx.add(ResponseContext.Key.UNCOVERED_INTERVALS_OVERFLOWED, true); + Assert.assertEquals(true, ctx.get(ResponseContext.Key.UNCOVERED_INTERVALS_OVERFLOWED)); + ctx.add(ResponseContext.Key.UNCOVERED_INTERVALS_OVERFLOWED, false); + Assert.assertEquals(true, ctx.get(ResponseContext.Key.UNCOVERED_INTERVALS_OVERFLOWED)); + } + + @Test + public void mergeResponseContextTest() + { + final ResponseContext ctx1 = ResponseContext.createEmpty(); + ctx1.put(ResponseContext.Key.ETAG, "dummy-etag-1"); + final Interval interval01 = Intervals.of("2019-01-01/P1D"); + ctx1.put(ResponseContext.Key.UNCOVERED_INTERVALS, Collections.singletonList(interval01)); + ctx1.put(ResponseContext.Key.NUM_SCANNED_ROWS, 1L); + + final ResponseContext ctx2 = ResponseContext.createEmpty(); + ctx2.put(ResponseContext.Key.ETAG, "dummy-etag-2"); + final Interval interval12 = Intervals.of("2019-01-02/P1D"); + ctx2.put(ResponseContext.Key.UNCOVERED_INTERVALS, Collections.singletonList(interval12)); + final SegmentDescriptor sd01 = new SegmentDescriptor(interval01, "01", 0); + ctx2.put(ResponseContext.Key.MISSING_SEGMENTS, Collections.singletonList(sd01)); + ctx2.put(ResponseContext.Key.NUM_SCANNED_ROWS, 2L); + + ctx1.merge(ctx2); + Assert.assertEquals("dummy-etag-2", ctx1.get(ResponseContext.Key.ETAG)); + Assert.assertEquals(3L, ctx1.get(ResponseContext.Key.NUM_SCANNED_ROWS)); + Assert.assertArrayEquals( + Arrays.asList(interval01, interval12).toArray(), + ((List) ctx1.get(ResponseContext.Key.UNCOVERED_INTERVALS)).toArray() + ); + Assert.assertArrayEquals( + Collections.singletonList(sd01).toArray(), + ((List) ctx1.get(ResponseContext.Key.MISSING_SEGMENTS)).toArray() + ); + } + + @Test(expected = IllegalStateException.class) + public void mergeISETest() + { + final ResponseContext ctx = new ResponseContext() + { + @Override + protected Map getDelegate() + { + return ImmutableMap.of(nonregisteredKey, "non-registered-key"); + } + }; + ResponseContext.createEmpty().merge(ctx); + } + + @Test + public void serializeWithCorrectnessTest() throws JsonProcessingException + { + final ResponseContext ctx1 = ResponseContext.createEmpty(); + ctx1.add(ResponseContext.Key.ETAG, "string-value"); + final DefaultObjectMapper mapper = new DefaultObjectMapper(); + Assert.assertEquals( + mapper.writeValueAsString(ImmutableMap.of("ETag", "string-value")), + ctx1.serializeWith(mapper, Integer.MAX_VALUE).getTruncatedResult() + ); + + final ResponseContext ctx2 = ResponseContext.createEmpty(); + ctx2.add(ResponseContext.Key.NUM_SCANNED_ROWS, 100); + Assert.assertEquals( + mapper.writeValueAsString(ImmutableMap.of("count", 100)), + ctx2.serializeWith(mapper, Integer.MAX_VALUE).getTruncatedResult() + ); + } + + @Test + public void serializeWithTruncateValueTest() throws IOException + { + final ResponseContext ctx = ResponseContext.createEmpty(); + ctx.put(ResponseContext.Key.NUM_SCANNED_ROWS, 100); + ctx.put(ResponseContext.Key.ETAG, "long-string-that-is-supposed-to-be-removed-from-result"); + final DefaultObjectMapper objectMapper = new DefaultObjectMapper(); + final String fullString = objectMapper.writeValueAsString(ctx.getDelegate()); + final ResponseContext.SerializationResult res1 = ctx.serializeWith(objectMapper, Integer.MAX_VALUE); + Assert.assertEquals(fullString, res1.getTruncatedResult()); + final ResponseContext ctxCopy = ResponseContext.createEmpty(); + ctxCopy.merge(ctx); + final ResponseContext.SerializationResult res2 = ctx.serializeWith(objectMapper, 30); + ctxCopy.remove(ResponseContext.Key.ETAG); + ctxCopy.put(ResponseContext.Key.TRUNCATED, true); + Assert.assertEquals( + ctxCopy.getDelegate(), + ResponseContext.deserialize(res2.getTruncatedResult(), objectMapper).getDelegate() + ); + } + + @Test + public void serializeWithTruncateArrayTest() throws IOException + { + final ResponseContext ctx = ResponseContext.createEmpty(); + ctx.put(ResponseContext.Key.NUM_SCANNED_ROWS, 100); + ctx.put( + ResponseContext.Key.UNCOVERED_INTERVALS, + Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + ); + ctx.put( + ResponseContext.Key.MISSING_SEGMENTS, + Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + ); + final DefaultObjectMapper objectMapper = new DefaultObjectMapper(); + final String fullString = objectMapper.writeValueAsString(ctx.getDelegate()); + final ResponseContext.SerializationResult res1 = ctx.serializeWith(objectMapper, Integer.MAX_VALUE); + Assert.assertEquals(fullString, res1.getTruncatedResult()); + final ResponseContext ctxCopy = ResponseContext.createEmpty(); + ctxCopy.merge(ctx); + final ResponseContext.SerializationResult res2 = ctx.serializeWith(objectMapper, 70); + ctxCopy.put(ResponseContext.Key.UNCOVERED_INTERVALS, Arrays.asList(0, 1, 2, 3, 4)); + ctxCopy.remove(ResponseContext.Key.MISSING_SEGMENTS); + ctxCopy.put(ResponseContext.Key.TRUNCATED, true); + Assert.assertEquals( + ctxCopy.getDelegate(), + ResponseContext.deserialize(res2.getTruncatedResult(), objectMapper).getDelegate() + ); + } + + @Test + public void deserializeTest() throws IOException + { + final DefaultObjectMapper mapper = new DefaultObjectMapper(); + final ResponseContext ctx = ResponseContext.deserialize( + mapper.writeValueAsString(ImmutableMap.of("ETag", "string-value", "count", 100)), + mapper + ); + Assert.assertEquals("string-value", ctx.get(ResponseContext.Key.ETAG)); + Assert.assertEquals(100, ctx.get(ResponseContext.Key.NUM_SCANNED_ROWS)); + } + + @Test(expected = IllegalStateException.class) + public void deserializeISETest() throws IOException + { + final DefaultObjectMapper mapper = new DefaultObjectMapper(); + ResponseContext.deserialize( + mapper.writeValueAsString(ImmutableMap.of("ETag_unexpected", "string-value")), + mapper + ); + } + + @Test + public void extensionEnumIntegrityTest() + { + Assert.assertEquals( + ExtensionResponseContextKey.EXTENSION_KEY_1, + ResponseContext.Key.keyOf(ExtensionResponseContextKey.EXTENSION_KEY_1.getName()) + ); + Assert.assertEquals( + ExtensionResponseContextKey.EXTENSION_KEY_2, + ResponseContext.Key.keyOf(ExtensionResponseContextKey.EXTENSION_KEY_2.getName()) + ); + for (ResponseContext.BaseKey key : ExtensionResponseContextKey.values()) { + Assert.assertTrue(ResponseContext.Key.getAllRegisteredKeys().contains(key)); + } + } + + @Test + public void extensionEnumMergeTest() + { + final ResponseContext ctx = ResponseContext.createEmpty(); + ctx.add(ResponseContext.Key.ETAG, "etag"); + ctx.add(ExtensionResponseContextKey.EXTENSION_KEY_1, "string-value"); + ctx.add(ExtensionResponseContextKey.EXTENSION_KEY_2, 2L); + final ResponseContext ctxFinal = ResponseContext.createEmpty(); + ctxFinal.add(ResponseContext.Key.ETAG, "old-etag"); + ctxFinal.add(ExtensionResponseContextKey.EXTENSION_KEY_1, "old-string-value"); + ctxFinal.add(ExtensionResponseContextKey.EXTENSION_KEY_2, 1L); + ctxFinal.merge(ctx); + Assert.assertEquals("etag", ctxFinal.get(ResponseContext.Key.ETAG)); + Assert.assertEquals("string-value", ctxFinal.get(ExtensionResponseContextKey.EXTENSION_KEY_1)); + Assert.assertEquals(1L + 2L, ctxFinal.get(ExtensionResponseContextKey.EXTENSION_KEY_2)); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/datasourcemetadata/DataSourceMetadataQueryTest.java b/processing/src/test/java/org/apache/druid/query/datasourcemetadata/DataSourceMetadataQueryTest.java index dd5bedf7d275..bb71f8e023e7 100644 --- a/processing/src/test/java/org/apache/druid/query/datasourcemetadata/DataSourceMetadataQueryTest.java +++ b/processing/src/test/java/org/apache/druid/query/datasourcemetadata/DataSourceMetadataQueryTest.java @@ -139,7 +139,7 @@ public void testMaxIngestedEventTime() throws Exception .dataSource("testing") .build(); ResponseContext context = ConcurrentResponseContext.createEmpty(); - context.put(ResponseContext.CTX_MISSING_SEGMENTS, new ArrayList<>()); + context.put(ResponseContext.Key.MISSING_SEGMENTS, new ArrayList<>()); Iterable> results = runner.run(QueryPlus.wrap(dataSourceMetadataQuery), context).toList(); DataSourceMetadataResultValue val = results.iterator().next().getValue(); diff --git a/processing/src/test/java/org/apache/druid/query/scan/ScanQueryRunnerFactoryTest.java b/processing/src/test/java/org/apache/druid/query/scan/ScanQueryRunnerFactoryTest.java index cf76f3750482..287733d441b4 100644 --- a/processing/src/test/java/org/apache/druid/query/scan/ScanQueryRunnerFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/scan/ScanQueryRunnerFactoryTest.java @@ -145,13 +145,13 @@ public void testSortAndLimitScanResultValues() DateTimes.of("2019-01-01").plusHours(1) )) ).toList(); - if (query.getLimit() > Integer.MAX_VALUE) { + if (query.getScanRowsLimit() > Integer.MAX_VALUE) { Assert.fail("Unsupported exception should have been thrown due to high limit"); } validateSortedOutput(output, expectedEventTimestamps); } catch (UOE e) { - if (query.getLimit() <= Integer.MAX_VALUE) { + if (query.getScanRowsLimit() <= Integer.MAX_VALUE) { Assert.fail("Unsupported operation exception should not have been thrown here"); } } @@ -247,7 +247,7 @@ private void validateSortedOutput(List output, List expec } // check total # of rows <= limit - Assert.assertTrue(output.size() <= query.getLimit()); + Assert.assertTrue(output.size() <= query.getScanRowsLimit()); // check ordering is correct for (int i = 1; i < output.size(); i++) { @@ -261,7 +261,7 @@ private void validateSortedOutput(List output, List expec } // check the values are correct - for (int i = 0; i < query.getLimit() && i < output.size(); i++) { + for (int i = 0; i < query.getScanRowsLimit() && i < output.size(); i++) { Assert.assertEquals((long) expectedEventTimestamps.get(i), output.get(i).getFirstEventTimestamp(resultFormat)); } } diff --git a/processing/src/test/java/org/apache/druid/query/spec/SpecificSegmentQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/spec/SpecificSegmentQueryRunnerTest.java index f1b185aa77d4..f4e1c31a2186 100644 --- a/processing/src/test/java/org/apache/druid/query/spec/SpecificSegmentQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/spec/SpecificSegmentQueryRunnerTest.java @@ -197,7 +197,7 @@ public void run() private void validate(ObjectMapper mapper, SegmentDescriptor descriptor, ResponseContext responseContext) throws IOException { - Object missingSegments = responseContext.get(ResponseContext.CTX_MISSING_SEGMENTS); + Object missingSegments = responseContext.get(ResponseContext.Key.MISSING_SEGMENTS); Assert.assertTrue(missingSegments != null); Assert.assertTrue(missingSegments instanceof List); diff --git a/processing/src/test/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryRunnerTest.java index 5e3a63495079..325a477e8e41 100644 --- a/processing/src/test/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/timeboundary/TimeBoundaryQueryRunnerTest.java @@ -216,7 +216,7 @@ public void testTimeBoundaryMax() .bound(TimeBoundaryQuery.MAX_TIME) .build(); ResponseContext context = ConcurrentResponseContext.createEmpty(); - context.put(ResponseContext.CTX_MISSING_SEGMENTS, new ArrayList<>()); + context.put(ResponseContext.Key.MISSING_SEGMENTS, new ArrayList<>()); Iterable> results = runner.run(QueryPlus.wrap(timeBoundaryQuery), context).toList(); TimeBoundaryResultValue val = results.iterator().next().getValue(); DateTime minTime = val.getMinTime(); @@ -235,7 +235,7 @@ public void testTimeBoundaryMin() .bound(TimeBoundaryQuery.MIN_TIME) .build(); ResponseContext context = ConcurrentResponseContext.createEmpty(); - context.put(ResponseContext.CTX_MISSING_SEGMENTS, new ArrayList<>()); + context.put(ResponseContext.Key.MISSING_SEGMENTS, new ArrayList<>()); Iterable> results = runner.run(QueryPlus.wrap(timeBoundaryQuery), context).toList(); TimeBoundaryResultValue val = results.iterator().next().getValue(); DateTime minTime = val.getMinTime(); diff --git a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java index 6a06d7354d68..6aa98412a8f0 100644 --- a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java @@ -1301,9 +1301,7 @@ public void testTopNBySegment() ) ); - final ResponseContext responseContext = ResponseContext.createEmpty(); - responseContext.putAll(specialContext); - Sequence> results = runWithMerge(query, responseContext); + Sequence> results = runWithMerge(query); List> resultList = results .map((Result input) -> { // Stupid type erasure diff --git a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java index 06c45309c6ad..18a4a028b367 100644 --- a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java +++ b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java @@ -354,12 +354,12 @@ private void computeUncoveredIntervals(TimelineLookup ti } if (!uncoveredIntervals.isEmpty()) { - // This returns intervals for which NO segment is present. + // Record in the response context the interval for which NO segment is present. // Which is not necessarily an indication that the data doesn't exist or is // incomplete. The data could exist and just not be loaded yet. In either // case, though, this query will not include any data from the identified intervals. - responseContext.put(ResponseContext.CTX_UNCOVERED_INTERVALS, uncoveredIntervals); - responseContext.put(ResponseContext.CTX_UNCOVERED_INTERVALS_OVERFLOWED, uncoveredIntervalsOverflowed); + responseContext.add(ResponseContext.Key.UNCOVERED_INTERVALS, uncoveredIntervals); + responseContext.add(ResponseContext.Key.UNCOVERED_INTERVALS_OVERFLOWED, uncoveredIntervalsOverflowed); } } @@ -396,7 +396,7 @@ private String computeCurrentEtag(final Set segments, @Nullable hasher.putBytes(queryCacheKey == null ? strategy.computeCacheKey(query) : queryCacheKey); String currEtag = StringUtils.encodeBase64String(hasher.hash().asBytes()); - responseContext.put(ResponseContext.CTX_ETAG, currEtag); + responseContext.put(ResponseContext.Key.ETAG, currEtag); return currEtag; } else { return null; diff --git a/server/src/main/java/org/apache/druid/client/DirectDruidClient.java b/server/src/main/java/org/apache/druid/client/DirectDruidClient.java index ab06f54126a8..4c5017435e69 100644 --- a/server/src/main/java/org/apache/druid/client/DirectDruidClient.java +++ b/server/src/main/java/org/apache/druid/client/DirectDruidClient.java @@ -101,13 +101,13 @@ public class DirectDruidClient implements QueryRunner */ public static void removeMagicResponseContextFields(ResponseContext responseContext) { - responseContext.remove(ResponseContext.CTX_QUERY_TOTAL_BYTES_GATHERED); + responseContext.remove(ResponseContext.Key.QUERY_TOTAL_BYTES_GATHERED); } public static ResponseContext makeResponseContextForQuery() { final ResponseContext responseContext = ConcurrentResponseContext.createEmpty(); - responseContext.put(ResponseContext.CTX_QUERY_TOTAL_BYTES_GATHERED, new AtomicLong()); + responseContext.put(ResponseContext.Key.QUERY_TOTAL_BYTES_GATHERED, new AtomicLong()); return responseContext; } @@ -156,7 +156,7 @@ public Sequence run(final QueryPlus queryPlus, final ResponseContext conte final long requestStartTimeNs = System.nanoTime(); final long timeoutAt = query.getContextValue(QUERY_FAIL_TIME); final long maxScatterGatherBytes = QueryContexts.getMaxScatterGatherBytes(query); - final AtomicLong totalBytesGathered = (AtomicLong) context.get(ResponseContext.CTX_QUERY_TOTAL_BYTES_GATHERED); + final AtomicLong totalBytesGathered = (AtomicLong) context.get(ResponseContext.Key.QUERY_TOTAL_BYTES_GATHERED); final long maxQueuedBytes = QueryContexts.getMaxQueuedBytes(query, 0); final boolean usingBackpressure = maxQueuedBytes > 0; @@ -230,7 +230,7 @@ public ClientResponse handleResponse(HttpResponse response, Traffic final String responseContext = response.headers().get(QueryResource.HEADER_RESPONSE_CONTEXT); // context may be null in case of error or query timeout if (responseContext != null) { - context.putAll(ResponseContext.deserialize(responseContext, objectMapper)); + context.merge(ResponseContext.deserialize(responseContext, objectMapper)); } continueReading = enqueue(response.getContent(), 0L); } diff --git a/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java b/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java index addad8defcb4..ac1636c1e644 100644 --- a/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java +++ b/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java @@ -92,7 +92,7 @@ public Sequence run(QueryPlus queryPlus, ResponseContext responseContext) QueryPlus.wrap(query), responseContext ); - String newResultSetId = (String) responseContext.get(ResponseContext.CTX_ETAG); + String newResultSetId = (String) responseContext.get(ResponseContext.Key.ETAG); if (useResultCache && newResultSetId != null && newResultSetId.equals(existingResultSetId)) { log.debug("Return cached result set as there is no change in identifiers for query %s ", query.getId()); diff --git a/server/src/main/java/org/apache/druid/server/QueryResource.java b/server/src/main/java/org/apache/druid/server/QueryResource.java index f6c426d7d8bc..eb4af459f7d0 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResource.java +++ b/server/src/main/java/org/apache/druid/server/QueryResource.java @@ -210,7 +210,7 @@ public Response doPost( final ResponseContext responseContext = queryResponse.getResponseContext(); final String prevEtag = getPreviousEtag(req); - if (prevEtag != null && prevEtag.equals(responseContext.get(ResponseContext.CTX_ETAG))) { + if (prevEtag != null && prevEtag.equals(responseContext.get(ResponseContext.Key.ETAG))) { queryLifecycle.emitLogsAndMetrics(null, req.getRemoteAddr(), -1); successfulQueryCount.incrementAndGet(); return Response.notModified().build(); @@ -230,7 +230,7 @@ public Response doPost( serializeDateTimeAsLong ); - Response.ResponseBuilder builder = Response + Response.ResponseBuilder responseBuilder = Response .ok( new StreamingOutput() { @@ -269,9 +269,9 @@ public void write(OutputStream outputStream) throws WebApplicationException ) .header("X-Druid-Query-Id", queryId); - if (responseContext.get(ResponseContext.CTX_ETAG) != null) { - builder.header(HEADER_ETAG, responseContext.get(ResponseContext.CTX_ETAG)); - responseContext.remove(ResponseContext.CTX_ETAG); + Object entityTag = responseContext.remove(ResponseContext.Key.ETAG); + if (entityTag != null) { + responseBuilder.header(HEADER_ETAG, entityTag); } DirectDruidClient.removeMagicResponseContextFields(responseContext); @@ -279,14 +279,20 @@ public void write(OutputStream outputStream) throws WebApplicationException //Limit the response-context header, see https://github.com/apache/incubator-druid/issues/2331 //Note that Response.ResponseBuilder.header(String key,Object value).build() calls value.toString() //and encodes the string using ASCII, so 1 char is = 1 byte - String responseCtxString = responseContext.serializeWith(jsonMapper); - if (responseCtxString.length() > RESPONSE_CTX_HEADER_LEN_LIMIT) { - log.warn("Response Context truncated for id [%s] . Full context is [%s].", queryId, responseCtxString); - responseCtxString = responseCtxString.substring(0, RESPONSE_CTX_HEADER_LEN_LIMIT); + final ResponseContext.SerializationResult serializationResult = responseContext.serializeWith( + jsonMapper, + RESPONSE_CTX_HEADER_LEN_LIMIT + ); + if (serializationResult.isReduced()) { + log.info( + "Response Context truncated for id [%s] . Full context is [%s].", + queryId, + serializationResult.getFullResult() + ); } - return builder - .header(HEADER_RESPONSE_CONTEXT, responseCtxString) + return responseBuilder + .header(HEADER_RESPONSE_CONTEXT, serializationResult.getTruncatedResult()) .build(); } catch (Exception e) { diff --git a/server/src/test/java/org/apache/druid/client/CachingClusteredClientFunctionalityTest.java b/server/src/test/java/org/apache/druid/client/CachingClusteredClientFunctionalityTest.java index b261442647dd..9a7bb900a5fa 100644 --- a/server/src/test/java/org/apache/druid/client/CachingClusteredClientFunctionalityTest.java +++ b/server/src/test/java/org/apache/druid/client/CachingClusteredClientFunctionalityTest.java @@ -125,7 +125,7 @@ public void testUncoveredInterval() ResponseContext responseContext = ResponseContext.createEmpty(); runQuery(client, builder.build(), responseContext); - Assert.assertNull(responseContext.get("uncoveredIntervals")); + Assert.assertNull(responseContext.get(ResponseContext.Key.UNCOVERED_INTERVALS)); builder.intervals("2015-01-01/2015-01-03"); responseContext = ResponseContext.createEmpty(); @@ -174,8 +174,8 @@ private void assertUncovered(ResponseContext context, boolean uncoveredIntervals for (String interval : intervals) { expectedList.add(Intervals.of(interval)); } - Assert.assertEquals((Object) expectedList, context.get("uncoveredIntervals")); - Assert.assertEquals(uncoveredIntervalsOverflowed, context.get("uncoveredIntervalsOverflowed")); + Assert.assertEquals((Object) expectedList, context.get(ResponseContext.Key.UNCOVERED_INTERVALS)); + Assert.assertEquals(uncoveredIntervalsOverflowed, context.get(ResponseContext.Key.UNCOVERED_INTERVALS_OVERFLOWED)); } private void addToTimeline(Interval interval, String version) diff --git a/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java b/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java index 60944fd96720..2b8e2f52b811 100644 --- a/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java +++ b/server/src/test/java/org/apache/druid/client/CachingClusteredClientTest.java @@ -3194,7 +3194,7 @@ public void testIfNoneMatch() ResponseContext responseContext = ResponseContext.createEmpty(); getDefaultQueryRunner().run(QueryPlus.wrap(query), responseContext); - Assert.assertEquals("MDs2yIUvYLVzaG6zmwTH1plqaYE=", responseContext.get(ResponseContext.CTX_ETAG)); + Assert.assertEquals("MDs2yIUvYLVzaG6zmwTH1plqaYE=", responseContext.get(ResponseContext.Key.ETAG)); } @Test @@ -3240,9 +3240,9 @@ public void testEtagforDifferentQueryInterval() final ResponseContext responseContext = ResponseContext.createEmpty(); getDefaultQueryRunner().run(QueryPlus.wrap(query), responseContext); - final Object etag1 = responseContext.get("ETag"); + final Object etag1 = responseContext.get(ResponseContext.Key.ETAG); getDefaultQueryRunner().run(QueryPlus.wrap(query2), responseContext); - final Object etag2 = responseContext.get("ETag"); + final Object etag2 = responseContext.get(ResponseContext.Key.ETAG); Assert.assertNotEquals(etag1, etag2); }