diff --git a/common/src/main/java/io/druid/common/utils/IntArrayUtils.java b/common/src/main/java/io/druid/common/utils/IntArrayUtils.java new file mode 100644 index 000000000000..83cf5ecfd228 --- /dev/null +++ b/common/src/main/java/io/druid/common/utils/IntArrayUtils.java @@ -0,0 +1,64 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.common.utils; + +public class IntArrayUtils +{ + /** + * Inverses the values of the given array with their indexes. + * For example, the result for [2, 0, 1] is [1, 2, 0] because + * + * a[0]: 2 => a[2]: 0 + * a[1]: 0 => a[0]: 1 + * a[2]: 1 => a[1]: 2 + */ + public static void inverse(int[] a) + { + for (int i = 0; i < a.length; i++) { + if (a[i] >= 0) { + inverseLoop(a, i); + } + } + + for (int i = 0; i < a.length; i++) { + a[i] = ~a[i]; + } + } + + private static void inverseLoop(int[] a, int startValue) + { + final int startIndex = a[startValue]; + + int nextIndex = startIndex; + int nextValue = startValue; + + do { + final int curIndex = nextIndex; + final int curValue = nextValue; + + nextValue = curIndex; + nextIndex = a[curIndex]; + + a[curIndex] = ~curValue; + } while (nextIndex != startIndex); + } + + private IntArrayUtils() {} +} diff --git a/common/src/test/java/io/druid/common/utils/IntArrayUtilsTest.java b/common/src/test/java/io/druid/common/utils/IntArrayUtilsTest.java new file mode 100644 index 000000000000..dc7f22839f60 --- /dev/null +++ b/common/src/test/java/io/druid/common/utils/IntArrayUtilsTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.common.utils; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class IntArrayUtilsTest +{ + @Test + public void testInverse() + { + final int numVals = 10000; + final Random random = new Random(System.currentTimeMillis()); + final int[] inverted = new int[numVals]; + final int[] original = new int[numVals]; + + final List ints = IntStream.range(0, numVals).boxed().collect(Collectors.toList()); + Collections.shuffle(ints, random); + + for (int i = 0; i < numVals; i++) { + inverted[i] = ints.get(i); + original[i] = inverted[i]; + } + IntArrayUtils.inverse(inverted); + + for (int i = 0; i < numVals; i++) { + Assert.assertEquals(i, inverted[original[i]]); + } + } +} diff --git a/docs/content/operations/performance-faq.md b/docs/content/operations/performance-faq.md index 44456367be31..edeaf77d8aee 100644 --- a/docs/content/operations/performance-faq.md +++ b/docs/content/operations/performance-faq.md @@ -34,6 +34,7 @@ A useful formula for estimating direct memory usage follows: `druid.processing.buffer.sizeBytes * (druid.processing.numMergeBuffers + druid.processing.numThreads + 1)` The `+1` is a fuzzy parameter meant to account for the decompression and dictionary merging buffers and may need to be adjusted based on the characteristics of the data being ingested/queried. +Operators can ensure at least this amount of direct memory is available by providing `-XX:MaxDirectMemorySize=` at the command line. ## What is the intermediate computation buffer? The intermediate computation buffer specifies a buffer size for the storage of intermediate results. The computation engine in both the Historical and Realtime nodes will use a scratch buffer of this size to do all of their intermediate computations off-heap. Larger values allow for more aggregations in a single pass over the data while smaller values can require more passes depending on the query that is being executed. The default size is 1073741824 bytes (1GB). diff --git a/docs/content/querying/groupbyquery.md b/docs/content/querying/groupbyquery.md index fc323217ec06..d67671fdd188 100644 --- a/docs/content/querying/groupbyquery.md +++ b/docs/content/querying/groupbyquery.md @@ -49,9 +49,9 @@ An example groupBy query object is shown below: ], "intervals": [ "2012-01-01T00:00:00.000/2012-01-03T00:00:00.000" ], "having": { - "type": "greaterThan", - "aggregation": "total_usage", - "value": 100 + "type": "greaterThan", + "aggregation": "total_usage", + "value": 100 } } ``` @@ -180,7 +180,7 @@ disk space. With groupBy v2, cluster operators should make sure that the off-heap hash tables and on-heap merging dictionaries will not exceed available memory for the maximum possible concurrent query load (given by -druid.processing.numMergeBuffers). +druid.processing.numMergeBuffers). See [How much direct memory does Druid use?](../operations/performance-faq.html) for more details. When using groupBy v1, all aggregation is done on-heap, and resource limits are done through the parameter druid.query.groupBy.maxResults. This is a cap on the maximum number of results in a result set. Queries that exceed @@ -188,6 +188,31 @@ this limit will fail with a "Resource limit exceeded" error indicating they exce operators should make sure that the on-heap aggregations will not exceed available JVM heap space for the expected concurrent query load. +#### Performance tuning for groupBy v2 + +##### Limit pushdown optimization + +Druid pushes down the `limit` spec in groupBy queries to the segments on historicals wherever possible to early prune unnecessary intermediate results and minimize the amount of data transferred to brokers. By default, this technique is applied only when all fields in the `orderBy` spec is a subset of the grouping keys. This is because the `limitPushDown` doesn't guarantee the exact results if the `orderBy` spec includes any fields that are not in the grouping keys. However, you can enable this technique even in such cases if you can sacrifice some accuracy for fast query processing like in topN queries. See `forceLimitPushDown` in [advanced groupBy v2 configurations](#groupby-v2-configurations). + + +##### Optimizing hash table + +The groupBy v2 engine uses an open addressing hash table for aggregation. The hash table is initalized with a given initial bucket number and gradually grows on buffer full. On hash collisions, the linear probing technique is used. + +The default number of initial buckets is 1024 and the default max load factor of the hash table is 0.7. If you can see too many collisions in the hash table, you can adjust these numbers. See `bufferGrouperInitialBuckets` and `bufferGrouperMaxLoadFactor` in [Advanced groupBy v2 configurations](#groupby-v2-configurations). + + +##### Parallel combine + +Once a historical finishes aggregation using the hash table, it sorts aggregates and merge them before sending to the broker for N-way merge aggregation in the broker. By default, historicals use all their available processing threads (configured by `druid.processing.numThreads`) for aggregation, but use a single thread for sorting and merging aggregates which is an http thread to send data to brokers. + +This is to prevent some heavy groupBy queries from blocking other queries. In Druid, the processing threads are shared between all submitted queries and they are _not interruptible_. It means, if a heavy query takes all available processing threads, all other queries might be blocked until the heavy query is finished. GroupBy queries usually take longer time than timeseries or topN queries, they should release processing threads as soon as possible. + +However, you might care about the performance of some really heavy groupBy queries. Usually, the performance bottleneck of heavy groupBy queries is merging sorted aggregates. In such cases, you can use processing threads for it as well. This is called _parallel combine_. To enable parallel combine, see `numParallelCombineThreads` in [Advanced groupBy v2 configurations](#groupby-v2-configurations). Note that parallel combine can be enabled only when data is actually spilled (see [Memory tuning and resource limits](#memory-tuning-and-resource-limits)). + +Once parallel combine is enabled, the groupBy v2 engine can create a combining tree for merging sorted aggregates. Each intermediate node of the tree is a thread merging aggregates from the child nodes. The leaf node threads read and merge aggregates from hash tables including spilled ones. Usually, leaf nodes are slower than intermediate nodes because they need to read data from disk. As a result, less threads are used for intermediate nodes by default. You can change the degree of intermeidate nodes. See `intermediateCombineDegree` in [Advanced groupBy v2 configurations](#groupby-v2-configurations). + + #### Alternatives There are some situations where other query types may be a better choice than groupBy. @@ -208,55 +233,87 @@ indexing mechanism, and runs the outer query on these materialized results. "v2" inner query's results stream with off-heap fact map and on-heap string dictionary that can spill to disk. Both strategy perform the outer query on the broker in a single-threaded fashion. -#### Server configuration +#### Configurations + +This section describes the configurations for groupBy queries. You can set system-wide configurations by adding them to runtime properties or query-specific configurations by adding them to query contexts. All runtime properties are prefixed by `druid.query.groupBy`. + +#### Commonly tuned configurations -When using the "v2" strategy, the following runtime properties apply: +##### Configurations for groupBy v2 + +Supported runtime properties: |Property|Description|Default| |--------|-----------|-------| -|`druid.query.groupBy.defaultStrategy`|Default groupBy query strategy.|v2| -|`druid.query.groupBy.bufferGrouperInitialBuckets`|Initial number of buckets in the off-heap hash table used for grouping results. Set to 0 to use a reasonable default.|0| -|`druid.query.groupBy.bufferGrouperMaxLoadFactor`|Maximum load factor of the off-heap hash table used for grouping results. When the load factor exceeds this size, the table will be grown or spilled to disk. Set to 0 to use a reasonable default.|0| |`druid.query.groupBy.maxMergingDictionarySize`|Maximum amount of heap space (approximately) to use for the string dictionary during merging. When the dictionary exceeds this size, a spill to disk will be triggered.|100000000| |`druid.query.groupBy.maxOnDiskStorage`|Maximum amount of disk space to use, per-query, for spilling result sets to disk when either the merging buffer or the dictionary fills up. Queries that exceed this limit will fail. Set to zero to disable disk spilling.|0 (disabled)| -|`druid.query.groupBy.singleThreaded`|Merge results using a single thread.|false| -This may require allocating more direct memory. The amount of direct memory needed by Druid is at least -`druid.processing.buffer.sizeBytes * (druid.processing.numMergeBuffers + druid.processing.numThreads + 1)`. You can -ensure at least this amount of direct memory is available by providing `-XX:MaxDirectMemorySize=` at the command -line. +Supported query contexts: + +|Key|Description| +|---|-----------| +|`maxMergingDictionarySize`|Can be used to lower the value of `druid.query.groupBy.maxMergingDictionarySize` for this query.| +|`maxOnDiskStorage`|Can be used to lower the value of `druid.query.groupBy.maxOnDiskStorage` for this query.| + + +#### Advanced configurations + +##### Common configuragions for all groupBy strategies -When using the "v1" strategy, the following runtime properties apply: +Supported runtime properties: |Property|Description|Default| |--------|-----------|-------| |`druid.query.groupBy.defaultStrategy`|Default groupBy query strategy.|v2| -|`druid.query.groupBy.maxIntermediateRows`|Maximum number of intermediate rows for the per-segment grouping engine. This is a tuning parameter that does not impose a hard limit; rather, it potentially shifts merging work from the per-segment engine to the overall merging index. Queries that exceed this limit will not fail.|50000| -|`druid.query.groupBy.maxResults`|Maximum number of results. Queries that exceed this limit will fail.|500000| |`druid.query.groupBy.singleThreaded`|Merge results using a single thread.|false| -#### Query context - -When using the "v2" strategy, the following query context parameters apply: +Supported query contexts: -|Property|Description| -|--------|-----------| +|Key|Description| +|---|-----------| |`groupByStrategy`|Overrides the value of `druid.query.groupBy.defaultStrategy` for this query.| |`groupByIsSingleThreaded`|Overrides the value of `druid.query.groupBy.singleThreaded` for this query.| -|`bufferGrouperInitialBuckets`|Overrides the value of `druid.query.groupBy.bufferGrouperInitialBuckets` for this query.| -|`bufferGrouperMaxLoadFactor`|Overrides the value of `druid.query.groupBy.bufferGrouperMaxLoadFactor` for this query.| -|`maxMergingDictionarySize`|Can be used to lower the value of `druid.query.groupBy.maxMergingDictionarySize` for this query.| -|`maxOnDiskStorage`|Can be used to lower the value of `druid.query.groupBy.maxOnDiskStorage` for this query.| -|`sortByDimsFirst`|Sort the results first by dimension values and then by timestamp.| -|`forcePushDownLimit`|When all fields in the orderby are part of the grouping key, the broker will push limit application down to the historical nodes. When the sorting order uses fields that are not in the grouping key, applying this optimization can result in approximate results with unknown accuracy, so this optimization is disabled by default in that case. Enabling this context flag turns on limit push down for limit/orderbys that contain non-grouping key columns.| -|`forceHashAggregation`|Force to use hash-based aggregation.| -When using the "v1" strategy, the following query context parameters apply: -|Property|Description| -|--------|-----------| -|`groupByStrategy`|Overrides the value of `druid.query.groupBy.defaultStrategy` for this query.| -|`groupByIsSingleThreaded`|Overrides the value of `druid.query.groupBy.singleThreaded` for this query.| -|`maxIntermediateRows`|Can be used to lower the value of `druid.query.groupBy.maxIntermediateRows` for this query.| -|`maxResults`|Can be used to lower the value of `druid.query.groupBy.maxResults` for this query.| -|`useOffheap`|Set to true to store aggregations off-heap when merging results.| +##### GroupBy v2 configurations + +Supported runtime properties: + +|Property|Description|Default| +|--------|-----------|-------| +|`druid.query.groupBy.bufferGrouperInitialBuckets`|Initial number of buckets in the off-heap hash table used for grouping results. Set to 0 to use a reasonable default (1024).|0| +|`druid.query.groupBy.bufferGrouperMaxLoadFactor`|Maximum load factor of the off-heap hash table used for grouping results. When the load factor exceeds this size, the table will be grown or spilled to disk. Set to 0 to use a reasonable default (0.7).|0| +|`druid.query.groupBy.forceHashAggregation`|Force to use hash-based aggregation.|false| +|`druid.query.groupBy.intermediateCombineDegree`|Number of intermediate nodes combined together in the combining tree. Higher degrees will need less threads which might be helpful to improve the query performance by reducing the overhead of too many threads if the server has sufficiently powerful cpu cores.|8| +|`druid.query.groupBy.numParallelCombineThreads`|Hint for the number of parallel combining threads. This should be larger than 1 to turn on the parallel combining feature. The actual number of threads used for parallel combining is min(`druid.query.groupBy.numParallelCombineThreads`, `druid.processing.numThreads`).|1 (disabled)| + +Supported query contexts: + +|Key|Description|Default| +|---|-----------|-------| +|`bufferGrouperInitialBuckets`|Overrides the value of `druid.query.groupBy.bufferGrouperInitialBuckets` for this query.|None| +|`bufferGrouperMaxLoadFactor`|Overrides the value of `druid.query.groupBy.bufferGrouperMaxLoadFactor` for this query.|None| +|`forceHashAggregation`|Overrides the value of `druid.query.groupBy.forceHashAggregation`|None| +|`intermediateCombineDegree`|Overrides the value of `druid.query.groupBy.intermediateCombineDegree`|None| +|`numParallelCombineThreads`|Overrides the value of `druid.query.groupBy.numParallelCombineThreads`|None| +|`sortByDimsFirst`|Sort the results first by dimension values and then by timestamp.|false| +|`forceLimitPushDown`|When all fields in the orderby are part of the grouping key, the broker will push limit application down to the historical nodes. When the sorting order uses fields that are not in the grouping key, applying this optimization can result in approximate results with unknown accuracy, so this optimization is disabled by default in that case. Enabling this context flag turns on limit push down for limit/orderbys that contain non-grouping key columns.|false| + + +##### GroupBy v1 configurations + +Supported runtime properties: + +|Property|Description|Default| +|--------|-----------|-------| +|`druid.query.groupBy.maxIntermediateRows`|Maximum number of intermediate rows for the per-segment grouping engine. This is a tuning parameter that does not impose a hard limit; rather, it potentially shifts merging work from the per-segment engine to the overall merging index. Queries that exceed this limit will not fail.|50000| +|`druid.query.groupBy.maxResults`|Maximum number of results. Queries that exceed this limit will fail.|500000| + +Supported query contexts: + +|Key|Description|Default| +|---|-----------|-------| +|`maxIntermediateRows`|Can be used to lower the value of `druid.query.groupBy.maxIntermediateRows` for this query.|None| +|`maxResults`|Can be used to lower the value of `druid.query.groupBy.maxResults` for this query.|None| +|`useOffheap`|Set to true to store aggregations off-heap when merging results.|false| + diff --git a/java-util/src/main/java/io/druid/java/util/common/CloseableIterators.java b/java-util/src/main/java/io/druid/java/util/common/CloseableIterators.java new file mode 100644 index 000000000000..437e4c8c81e9 --- /dev/null +++ b/java-util/src/main/java/io/druid/java/util/common/CloseableIterators.java @@ -0,0 +1,96 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.java.util.common; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterators; +import io.druid.java.util.common.io.Closer; +import io.druid.java.util.common.parsers.CloseableIterator; + +import javax.annotation.Nullable; +import java.io.Closeable; +import java.io.IOException; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +public class CloseableIterators +{ + public static CloseableIterator concat(List> iterators) + { + final Closer closer = Closer.create(); + iterators.forEach(closer::register); + + final Iterator innerIterator = Iterators.concat(iterators.iterator()); + return wrap(innerIterator, closer); + } + + public static CloseableIterator mergeSorted( + List> iterators, + Comparator comparator + ) + { + Preconditions.checkNotNull(comparator); + + final Closer closer = Closer.create(); + iterators.forEach(closer::register); + + final Iterator innerIterator = Iterators.mergeSorted(iterators, comparator); + return wrap(innerIterator, closer); + } + + public static CloseableIterator wrap(Iterator innerIterator, @Nullable Closeable closeable) + { + return new CloseableIterator() + { + private boolean closed; + + @Override + public boolean hasNext() + { + return innerIterator.hasNext(); + } + + @Override + public T next() + { + return innerIterator.next(); + } + + @Override + public void close() throws IOException + { + if (!closed) { + if (closeable != null) { + closeable.close(); + } + closed = true; + } + } + }; + } + + public static CloseableIterator withEmptyBaggage(Iterator innerIterator) + { + return wrap(innerIterator, null); + } + + private CloseableIterators() {} +} diff --git a/processing/src/main/java/io/druid/query/groupby/GroupByQueryConfig.java b/processing/src/main/java/io/druid/query/groupby/GroupByQueryConfig.java index 85d8b5216ca3..9c744dc24aa1 100644 --- a/processing/src/main/java/io/druid/query/groupby/GroupByQueryConfig.java +++ b/processing/src/main/java/io/druid/query/groupby/GroupByQueryConfig.java @@ -38,6 +38,8 @@ public class GroupByQueryConfig private static final String CTX_KEY_MAX_ON_DISK_STORAGE = "maxOnDiskStorage"; private static final String CTX_KEY_MAX_MERGING_DICTIONARY_SIZE = "maxMergingDictionarySize"; private static final String CTX_KEY_FORCE_HASH_AGGREGATION = "forceHashAggregation"; + private static final String CTX_KEY_INTERMEDIATE_COMBINE_DEGREE = "intermediateCombineDegree"; + private static final String CTX_KEY_NUM_PARALLEL_COMBINE_THREADS = "numParallelCombineThreads"; @JsonProperty private String defaultStrategy = GroupByStrategySelector.STRATEGY_V2; @@ -75,6 +77,12 @@ public class GroupByQueryConfig @JsonProperty private boolean forceHashAggregation = false; + @JsonProperty + private int intermediateCombineDegree = 8; + + @JsonProperty + private int numParallelCombineThreads = 1; + public String getDefaultStrategy() { return defaultStrategy; @@ -144,7 +152,17 @@ public boolean isForceHashAggregation() { return forceHashAggregation; } - + + public int getIntermediateCombineDegree() + { + return intermediateCombineDegree; + } + + public int getNumParallelCombineThreads() + { + return numParallelCombineThreads; + } + public GroupByQueryConfig withOverrides(final GroupByQuery query) { final GroupByQueryConfig newConfig = new GroupByQueryConfig(); @@ -180,6 +198,14 @@ public GroupByQueryConfig withOverrides(final GroupByQuery query) ); newConfig.forcePushDownLimit = query.getContextBoolean(CTX_KEY_FORCE_LIMIT_PUSH_DOWN, isForcePushDownLimit()); newConfig.forceHashAggregation = query.getContextBoolean(CTX_KEY_FORCE_HASH_AGGREGATION, isForceHashAggregation()); + newConfig.intermediateCombineDegree = query.getContextValue( + CTX_KEY_INTERMEDIATE_COMBINE_DEGREE, + getIntermediateCombineDegree() + ); + newConfig.numParallelCombineThreads = query.getContextValue( + CTX_KEY_NUM_PARALLEL_COMBINE_THREADS, + getNumParallelCombineThreads() + ); return newConfig; } @@ -198,6 +224,8 @@ public String toString() ", maxOnDiskStorage=" + maxOnDiskStorage + ", forcePushDownLimit=" + forcePushDownLimit + ", forceHashAggregation=" + forceHashAggregation + + ", intermediateCombineDegree=" + intermediateCombineDegree + + ", numParallelCombineThreads=" + numParallelCombineThreads + '}'; } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouper.java index eb5e4d194651..04a7fa5b7c81 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferArrayGrouper.java @@ -21,6 +21,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Supplier; +import io.druid.java.util.common.parsers.CloseableIterator; import io.druid.java.util.common.ISE; import io.druid.java.util.common.logger.Logger; import io.druid.query.aggregation.AggregatorFactory; @@ -28,9 +29,9 @@ import io.druid.query.groupby.epinephelinae.column.GroupByColumnSelectorStrategy; import io.druid.segment.ColumnSelectorFactory; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; -import java.util.Iterator; import java.util.NoSuchElementException; /** @@ -225,13 +226,13 @@ public void close() } @Override - public Iterator> iterator(boolean sorted) + public CloseableIterator> iterator(boolean sorted) { if (sorted) { throw new UnsupportedOperationException("sorted iterator is not supported yet"); } - return new Iterator>() + return new CloseableIterator>() { int cur = -1; boolean findNext = false; @@ -276,6 +277,12 @@ public Entry next() } return new Entry<>(cur - 1, values); } + + @Override + public void close() throws IOException + { + // do nothing + } }; } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferHashGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferHashGrouper.java index c6885d2306d7..c7a25807d81d 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferHashGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferHashGrouper.java @@ -22,16 +22,18 @@ import com.google.common.base.Supplier; import com.google.common.collect.Iterators; import com.google.common.primitives.Ints; +import io.druid.java.util.common.parsers.CloseableIterator; +import io.druid.java.util.common.CloseableIterators; import io.druid.java.util.common.IAE; import io.druid.java.util.common.logger.Logger; import io.druid.query.aggregation.AggregatorFactory; import io.druid.segment.ColumnSelectorFactory; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.AbstractList; import java.util.Collections; import java.util.Comparator; -import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; @@ -167,12 +169,12 @@ public void reset() } @Override - public Iterator> iterator(boolean sorted) + public CloseableIterator> iterator(boolean sorted) { if (!initialized) { // it's possible for iterator() to be called before initialization when // a nested groupBy's subquery has an empty result set (see testEmptySubquery() in GroupByQueryRunnerTest) - return Iterators.>emptyIterator(); + return CloseableIterators.withEmptyBaggage(Iterators.>emptyIterator()); } if (sorted) { @@ -225,7 +227,7 @@ public int compare(Integer lhs, Integer rhs) } ); - return new Iterator>() + return new CloseableIterator>() { int curr = 0; final int size = getSize(); @@ -250,10 +252,16 @@ public void remove() { throw new UnsupportedOperationException(); } + + @Override + public void close() throws IOException + { + // do nothing + } }; } else { // Unsorted iterator - return new Iterator>() + return new CloseableIterator>() { int curr = 0; final int size = getSize(); @@ -282,6 +290,12 @@ public void remove() { throw new UnsupportedOperationException(); } + + @Override + public void close() throws IOException + { + // do nothing + } }; } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/CloseableGrouperIterator.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/CloseableGrouperIterator.java index a377f033f196..c8002c79afc2 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/CloseableGrouperIterator.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/CloseableGrouperIterator.java @@ -19,29 +19,33 @@ package io.druid.query.groupby.epinephelinae; -import com.google.common.base.Function; -import com.google.common.base.Throwables; +import io.druid.java.util.common.io.Closer; +import io.druid.java.util.common.parsers.CloseableIterator; +import io.druid.query.groupby.epinephelinae.Grouper.Entry; import java.io.Closeable; import java.io.IOException; -import java.util.Iterator; +import java.util.function.Function; -public class CloseableGrouperIterator implements Iterator, Closeable +public class CloseableGrouperIterator implements CloseableIterator { - private final Function, T> transformer; - private final Closeable closer; - private final Iterator> iterator; + private final Function, T> transformer; + private final CloseableIterator> iterator; + private final Closer closer; public CloseableGrouperIterator( final Grouper grouper, final boolean sorted, final Function, T> transformer, - final Closeable closer + final Closeable closeable ) { this.transformer = transformer; - this.closer = closer; this.iterator = grouper.iterator(sorted); + this.closer = Closer.create(); + + closer.register(iterator); + closer.register(closeable); } @Override @@ -65,13 +69,11 @@ public void remove() @Override public void close() { - if (closer != null) { - try { - closer.close(); - } - catch (IOException e) { - throw Throwables.propagate(e); - } + try { + closer.close(); + } + catch (IOException e) { + throw new RuntimeException(e); } } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouper.java index 9d8e328b33cc..de223f3894d8 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouper.java @@ -23,21 +23,29 @@ import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; +import io.druid.collections.ResourceHolder; +import io.druid.java.util.common.CloseableIterators; import io.druid.java.util.common.ISE; +import io.druid.java.util.common.parsers.CloseableIterator; import io.druid.query.AbstractPrioritizedCallable; import io.druid.query.QueryInterruptedException; import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.groupby.GroupByQueryConfig; import io.druid.query.groupby.orderby.DefaultLimitSpec; import io.druid.segment.ColumnSelectorFactory; +import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.Comparator; -import java.util.Iterator; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -75,16 +83,64 @@ public class ConcurrentGrouper implements Grouper private final DefaultLimitSpec limitSpec; private final boolean sortHasNonGroupingFields; private final Comparator> keyObjComparator; - private final ListeningExecutorService grouperSorter; + private final ListeningExecutorService executor; private final int priority; private final boolean hasQueryTimeout; private final long queryTimeoutAt; + private final long maxDictionarySizeForCombiner; + @Nullable + private final ParallelCombiner parallelCombiner; private volatile boolean initialized = false; public ConcurrentGrouper( + final GroupByQueryConfig groupByQueryConfig, final Supplier bufferSupplier, + final Supplier> combineBufferSupplier, final KeySerdeFactory keySerdeFactory, + final KeySerdeFactory combineKeySerdeFactory, + final ColumnSelectorFactory columnSelectorFactory, + final AggregatorFactory[] aggregatorFactories, + final LimitedTemporaryStorage temporaryStorage, + final ObjectMapper spillMapper, + final int concurrencyHint, + final DefaultLimitSpec limitSpec, + final boolean sortHasNonGroupingFields, + final ListeningExecutorService executor, + final int priority, + final boolean hasQueryTimeout, + final long queryTimeoutAt + ) + { + this( + bufferSupplier, + combineBufferSupplier, + keySerdeFactory, + combineKeySerdeFactory, + columnSelectorFactory, + aggregatorFactories, + groupByQueryConfig.getBufferGrouperMaxSize(), + groupByQueryConfig.getBufferGrouperMaxLoadFactor(), + groupByQueryConfig.getBufferGrouperInitialBuckets(), + temporaryStorage, + spillMapper, + concurrencyHint, + limitSpec, + sortHasNonGroupingFields, + executor, + priority, + hasQueryTimeout, + queryTimeoutAt, + groupByQueryConfig.getIntermediateCombineDegree(), + groupByQueryConfig.getNumParallelCombineThreads() + ); + } + + ConcurrentGrouper( + final Supplier bufferSupplier, + final Supplier> combineBufferSupplier, + final KeySerdeFactory keySerdeFactory, + final KeySerdeFactory combineKeySerdeFactory, final ColumnSelectorFactory columnSelectorFactory, final AggregatorFactory[] aggregatorFactories, final int bufferGrouperMaxSize, @@ -95,24 +151,24 @@ public ConcurrentGrouper( final int concurrencyHint, final DefaultLimitSpec limitSpec, final boolean sortHasNonGroupingFields, - final ListeningExecutorService grouperSorter, + final ListeningExecutorService executor, final int priority, final boolean hasQueryTimeout, final long queryTimeoutAt, - final int mergeBufferSize + final int intermediateCombineDegree, + final int numParallelCombineThreads ) { Preconditions.checkArgument(concurrencyHint > 0, "concurrencyHint > 0"); + Preconditions.checkArgument( + concurrencyHint >= numParallelCombineThreads, + "numParallelCombineThreads[%s] cannot larger than concurrencyHint[%s]", + numParallelCombineThreads, + concurrencyHint + ); this.groupers = new ArrayList<>(concurrencyHint); - this.threadLocalGrouper = new ThreadLocal>() - { - @Override - protected SpillingGrouper initialValue() - { - return groupers.get(threadNumber.getAndIncrement()); - } - }; + this.threadLocalGrouper = ThreadLocal.withInitial(() -> groupers.get(threadNumber.getAndIncrement())); this.bufferSupplier = bufferSupplier; this.columnSelectorFactory = columnSelectorFactory; @@ -127,10 +183,27 @@ protected SpillingGrouper initialValue() this.limitSpec = limitSpec; this.sortHasNonGroupingFields = sortHasNonGroupingFields; this.keyObjComparator = keySerdeFactory.objectComparator(sortHasNonGroupingFields); - this.grouperSorter = Preconditions.checkNotNull(grouperSorter); + this.executor = Preconditions.checkNotNull(executor); this.priority = priority; this.hasQueryTimeout = hasQueryTimeout; this.queryTimeoutAt = queryTimeoutAt; + this.maxDictionarySizeForCombiner = combineKeySerdeFactory.getMaxDictionarySize(); + + if (numParallelCombineThreads > 1) { + this.parallelCombiner = new ParallelCombiner<>( + combineBufferSupplier, + getCombiningFactories(aggregatorFactories), + combineKeySerdeFactory, + executor, + sortHasNonGroupingFields, + Math.min(numParallelCombineThreads, concurrencyHint), + priority, + queryTimeoutAt, + intermediateCombineDegree + ); + } else { + this.parallelCombiner = null; + } } @Override @@ -143,11 +216,9 @@ public void init() final int sliceSize = (buffer.capacity() / concurrencyHint); for (int i = 0; i < concurrencyHint; i++) { - final ByteBuffer slice = buffer.duplicate(); - slice.position(sliceSize * i); - slice.limit(slice.position() + sliceSize); + final ByteBuffer slice = Groupers.getSlice(buffer, sliceSize, i); final SpillingGrouper grouper = new SpillingGrouper<>( - Suppliers.ofInstance(slice.slice()), + Suppliers.ofInstance(slice), keySerdeFactory, columnSelectorFactory, aggregatorFactories, @@ -222,15 +293,11 @@ public void reset() throw new ISE("Grouper is closed"); } - for (Grouper grouper : groupers) { - synchronized (grouper) { - grouper.reset(); - } - } + groupers.forEach(Grouper::reset); } @Override - public Iterator> iterator(final boolean sorted) + public CloseableIterator> iterator(final boolean sorted) { if (!initialized) { throw new ISE("Grouper is not initialized"); @@ -240,28 +307,43 @@ public Iterator> iterator(final boolean sorted) throw new ISE("Grouper is closed"); } - return Groupers.mergeIterators( - sorted && isParallelSortAvailable() ? parallelSortAndGetGroupersIterator() : getGroupersIterator(sorted), - sorted ? keyObjComparator : null - ); + final List>> sortedIterators = sorted && isParallelizable() ? + parallelSortAndGetGroupersIterator() : + getGroupersIterator(sorted); + + // Parallel combine is used only when data is spilled. This is because ConcurrentGrouper uses two different modes + // depending on data is spilled or not. If data is not spilled, all inputs are completely aggregated and no more + // aggregation is required. + if (sorted && spilling && parallelCombiner != null) { + // First try to merge dictionaries generated by all underlying groupers. If it is merged successfully, the same + // merged dictionary is used for all combining threads + final List dictionary = tryMergeDictionary(); + if (dictionary != null) { + return parallelCombiner.combine(sortedIterators, dictionary); + } + } + + return sorted ? + CloseableIterators.mergeSorted(sortedIterators, keyObjComparator) : + CloseableIterators.concat(sortedIterators); } - private boolean isParallelSortAvailable() + private boolean isParallelizable() { return concurrencyHint > 1; } - private List>> parallelSortAndGetGroupersIterator() + private List>> parallelSortAndGetGroupersIterator() { - // The number of groupers is same with the number of processing threads in grouperSorter - final ListenableFuture>>> future = Futures.allAsList( + // The number of groupers is same with the number of processing threads in the executor + final ListenableFuture>>> future = Futures.allAsList( groupers.stream() .map(grouper -> - grouperSorter.submit( - new AbstractPrioritizedCallable>>(priority) + executor.submit( + new AbstractPrioritizedCallable>>(priority) { @Override - public Iterator> call() throws Exception + public CloseableIterator> call() throws Exception { return grouper.iterator(true); } @@ -287,21 +369,47 @@ public Iterator> call() throws Exception } } - private List>> getGroupersIterator(boolean sorted) + private List>> getGroupersIterator(boolean sorted) { return groupers.stream() .map(grouper -> grouper.iterator(sorted)) .collect(Collectors.toList()); } + /** + * Merge dictionaries of {@link Grouper.KeySerde}s of {@link Grouper}s. The result dictionary contains unique string + * keys. + * + * @return merged dictionary if its size does not exceed max dictionary size. Otherwise null. + */ + @Nullable + private List tryMergeDictionary() + { + final Set mergedDictionary = new HashSet<>(); + long totalDictionarySize = 0L; + + for (SpillingGrouper grouper : groupers) { + final List dictionary = grouper.mergeAndGetDictionary(); + + for (String key : dictionary) { + if (mergedDictionary.add(key)) { + totalDictionarySize += RowBasedGrouperHelper.estimateStringKeySize(key); + if (totalDictionarySize > maxDictionarySizeForCombiner) { + return null; + } + } + } + } + + return ImmutableList.copyOf(mergedDictionary); + } + @Override public void close() { - closed = true; - for (Grouper grouper : groupers) { - synchronized (grouper) { - grouper.close(); - } + if (!closed) { + closed = true; + groupers.forEach(Grouper::close); } } @@ -309,4 +417,11 @@ private int grouperNumberForKeyHash(int keyHash) { return keyHash % groupers.size(); } + + private AggregatorFactory[] getCombiningFactories(AggregatorFactory[] aggregatorFactories) + { + final AggregatorFactory[] combiningFactories = new AggregatorFactory[aggregatorFactories.length]; + Arrays.setAll(combiningFactories, i -> aggregatorFactories[i].getCombiningFactory()); + return combiningFactories; + } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java index 828d1f358e8d..752820a28e2f 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Function; import com.google.common.base.Predicates; +import com.google.common.base.Supplier; import com.google.common.base.Suppliers; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; @@ -33,8 +34,10 @@ import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import io.druid.collections.BlockingPool; +import io.druid.collections.NonBlockingPool; import io.druid.collections.ReferenceCountingResourceHolder; import io.druid.collections.Releaser; +import io.druid.collections.ResourceHolder; import io.druid.data.input.Row; import io.druid.java.util.common.ISE; import io.druid.java.util.common.Pair; @@ -80,6 +83,7 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner private final ListeningExecutorService exec; private final QueryWatcher queryWatcher; private final int concurrencyHint; + private final NonBlockingPool processingBufferPool; private final BlockingPool mergeBufferPool; private final ObjectMapper spillMapper; private final String processingTmpDir; @@ -91,6 +95,7 @@ public GroupByMergingQueryRunnerV2( QueryWatcher queryWatcher, Iterable> queryables, int concurrencyHint, + NonBlockingPool processingBufferPool, BlockingPool mergeBufferPool, int mergeBufferSize, ObjectMapper spillMapper, @@ -102,6 +107,7 @@ public GroupByMergingQueryRunnerV2( this.queryWatcher = queryWatcher; this.queryables = Iterables.unmodifiableIterable(Iterables.filter(queryables, Predicates.notNull())); this.concurrencyHint = concurrencyHint; + this.processingBufferPool = processingBufferPool; this.mergeBufferPool = mergeBufferPool; this.spillMapper = spillMapper; this.processingTmpDir = processingTmpDir; @@ -154,6 +160,22 @@ public Sequence run(final QueryPlus queryPlus, final Map> combineBufferSupplier = new Supplier>() + { + private boolean initialized; + private ResourceHolder buffer; + + @Override + public ResourceHolder get() + { + if (!initialized) { + buffer = processingBufferPool.take(); + initialized = true; + } + return buffer; + } + }; + return new BaseSequence<>( new BaseSequence.IteratorMaker>() { @@ -194,6 +216,7 @@ public CloseableGrouperIterator make() null, config, Suppliers.ofInstance(mergeBufferHolder.get()), + combineBufferSupplier, concurrencyHint, temporaryStorage, spillMapper, diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java index ffa614d4727c..6db22e1efb19 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java @@ -22,6 +22,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import io.druid.collections.NonBlockingPool; import io.druid.collections.ResourceHolder; @@ -732,6 +733,12 @@ public Class keyClazz() return ByteBuffer.class; } + @Override + public List getDictionary() + { + return ImmutableList.of(); + } + @Override public ByteBuffer toByteBuffer(ByteBuffer key) { diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/Grouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/Grouper.java index fe06fdbd280f..7e1a071fb9d5 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/Grouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/Grouper.java @@ -22,13 +22,14 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; +import io.druid.java.util.common.parsers.CloseableIterator; import io.druid.query.aggregation.AggregatorFactory; import java.io.Closeable; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Comparator; -import java.util.Iterator; +import java.util.List; import java.util.function.ToIntFunction; /** @@ -100,10 +101,10 @@ default ToIntFunction hashFunction() /** * Iterate through entries. *

- * Once this method is called, writes are no longer safe. After you are done with the iterator returned by this - * method, you should either call {@link #close()} (if you are done with the Grouper), {@link #reset()} (if you - * want to reuse it), or {@link #iterator(boolean)} again if you want another iterator. This method is not thread-safe - * and must not be called by multiple threads concurrently. + * Some implementations allow writes even after this method is called. After you are done with the iterator + * returned by this method, you should either call {@link #close()} (if you are done with the Grouper) or + * {@link #reset()} (if you want to reuse it). Some implementations allow calling {@link #iterator(boolean)} again if + * you want another iterator. But, this method must not be called by multiple threads concurrently. *

* If "sorted" is true then the iterator will return sorted results. It will use KeyType's natural ordering on * deserialized objects, and will use the {@link KeySerde#comparator()} on serialized objects. Woe be unto you @@ -116,7 +117,7 @@ default ToIntFunction hashFunction() * * @return entry iterator */ - Iterator> iterator(boolean sorted); + CloseableIterator> iterator(boolean sorted); class Entry { @@ -186,10 +187,22 @@ public String toString() interface KeySerdeFactory { /** - * Create a new KeySerde, which may be stateful. + * Return max dictionary size threshold. + * + * @return max dictionary size + */ + long getMaxDictionarySize(); + + /** + * Create a new {@link KeySerde}, which may be stateful. */ KeySerde factorize(); + /** + * Create a new {@link KeySerde} with the given dictionary. + */ + KeySerde factorizeWithDictionary(List dictionary); + /** * Return an object that knows how to compare two serialized key instances. Will be called by the * {@link #iterator(boolean)} method if sorting is enabled. @@ -217,6 +230,11 @@ interface KeySerde */ Class keyClazz(); + /** + * Return the dictionary of this KeySerde. The return value should not be null. + */ + List getDictionary(); + /** * Serialize a key. This will be called by the {@link #aggregate(Comparable)} method. The buffer will not * be retained after the aggregate method returns, so reusing buffers is OK. diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/Groupers.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/Groupers.java index d03be91abe35..41b9c3d9c8e2 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/Groupers.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/Groupers.java @@ -19,10 +19,7 @@ package io.druid.query.groupby.epinephelinae; -import com.google.common.collect.Iterators; - -import java.util.Comparator; -import java.util.Iterator; +import java.nio.ByteBuffer; public class Groupers { @@ -72,25 +69,11 @@ static int getUsedFlag(int keyHash) return keyHash | 0x80000000; } - public static Iterator> mergeIterators( - final Iterable>> iterators, - final Comparator> keyTypeComparator - ) + public static ByteBuffer getSlice(ByteBuffer buffer, int sliceSize, int i) { - if (keyTypeComparator != null) { - return Iterators.mergeSorted( - iterators, - new Comparator>() - { - @Override - public int compare(Grouper.Entry lhs, Grouper.Entry rhs) - { - return keyTypeComparator.compare(lhs, rhs); - } - } - ); - } else { - return Iterators.concat(iterators.iterator()); - } + final ByteBuffer slice = buffer.duplicate(); + slice.position(sliceSize * i); + slice.limit(slice.position() + sliceSize); + return slice.slice(); } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/LimitedBufferHashGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/LimitedBufferHashGrouper.java index b6015fe2edf7..12c50ce940c0 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/LimitedBufferHashGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/LimitedBufferHashGrouper.java @@ -22,16 +22,18 @@ import com.google.common.base.Supplier; import com.google.common.collect.Iterators; import com.google.common.primitives.Ints; +import io.druid.java.util.common.parsers.CloseableIterator; +import io.druid.java.util.common.CloseableIterators; import io.druid.java.util.common.IAE; import io.druid.java.util.common.ISE; import io.druid.query.aggregation.AggregatorFactory; import io.druid.segment.ColumnSelectorFactory; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.AbstractList; import java.util.Collections; import java.util.Comparator; -import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; @@ -199,13 +201,13 @@ public void reset() } @Override - public Iterator> iterator(boolean sorted) + public CloseableIterator> iterator(boolean sorted) { if (!initialized) { // it's possible for iterator() to be called before initialization when // a nested groupBy's subquery has an empty result set (see testEmptySubqueryWithLimitPushDown() // in GroupByQueryRunnerTest) - return Iterators.>emptyIterator(); + return CloseableIterators.withEmptyBaggage(Iterators.>emptyIterator()); } if (sortHasNonGroupingFields) { @@ -251,7 +253,7 @@ public int getHeapIndexForOffset(int bucketOffset) } } - private Iterator> makeDefaultOrderingIterator() + private CloseableIterator> makeDefaultOrderingIterator() { final int size = offsetHeap.getHeapSize(); @@ -299,7 +301,7 @@ public int compare(Integer lhs, Integer rhs) } ); - return new Iterator>() + return new CloseableIterator>() { int curr = 0; @@ -320,13 +322,19 @@ public void remove() { throw new UnsupportedOperationException(); } + + @Override + public void close() throws IOException + { + // do nothing + } }; } - private Iterator> makeHeapIterator() + private CloseableIterator> makeHeapIterator() { final int initialHeapSize = offsetHeap.getHeapSize(); - return new Iterator>() + return new CloseableIterator>() { int curr = 0; @@ -354,6 +362,12 @@ public void remove() { throw new UnsupportedOperationException(); } + + @Override + public void close() throws IOException + { + // do nothing + } }; } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/ParallelCombiner.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ParallelCombiner.java new file mode 100644 index 000000000000..d043a9070d76 --- /dev/null +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ParallelCombiner.java @@ -0,0 +1,489 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.base.Preconditions; +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; +import com.google.common.base.Throwables; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import io.druid.collections.ResourceHolder; +import io.druid.java.util.common.CloseableIterators; +import io.druid.java.util.common.ISE; +import io.druid.java.util.common.Pair; +import io.druid.java.util.common.io.Closer; +import io.druid.java.util.common.parsers.CloseableIterator; +import io.druid.query.AbstractPrioritizedCallable; +import io.druid.query.QueryInterruptedException; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.dimension.DimensionSpec; +import io.druid.query.groupby.epinephelinae.Grouper.Entry; +import io.druid.query.groupby.epinephelinae.Grouper.KeySerdeFactory; +import io.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import io.druid.segment.ColumnSelectorFactory; +import io.druid.segment.ColumnValueSelector; +import io.druid.segment.DimensionSelector; +import io.druid.segment.ObjectColumnSelector; +import io.druid.segment.column.ColumnCapabilities; +import it.unimi.dsi.fastutil.objects.Object2IntArrayMap; +import it.unimi.dsi.fastutil.objects.Object2IntMap; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +/** + * ParallelCombiner builds a combining tree which asynchronously aggregates input entries. Each node of the combining + * tree is a combining task executed in parallel which aggregates inputs from the child nodes. + */ +public class ParallelCombiner +{ + // The combining tree created by this class can have two different degrees for intermediate nodes. + // The "leaf combine degree (LCD)" is the number of leaf nodes combined together, while the "intermediate combine + // degree (ICD)" is the number of non-leaf nodes combined together. The below picture shows an example where LCD = 2 + // and ICD = 4. + // + // o <- non-leaf node + // / / \ \ <- ICD = 4 + // o o o o <- non-leaf nodes + // / \ / \ / \ / \ <- LCD = 2 + // o o o o o o o o <- leaf nodes + // + // The reason why we need two different degrees is to optimize the number of non-leaf nodes which are run by + // different threads at the same time. Note that the leaf nodes are sorted iterators of SpillingGroupers which + // generally returns multiple rows of the same grouping key which in turn should be combined, while the non-leaf nodes + // are iterators of StreamingMergeSortedGroupers and always returns a single row per grouping key. Generally, the + // performance will get better as LCD becomes low while ICD is some value larger than LCD because the amount of work + // each thread has to do can be properly tuned. The optimal values for LCD and ICD may vary with query and data. Here, + // we use a simple heuristic to avoid complex optimization. That is, ICD is fixed as a user-configurable value and the + // minimum LCD satisfying the memory restriction is searched. See findLeafCombineDegreeAndNumBuffers() for more + // details. + private static final int MINIMUM_LEAF_COMBINE_DEGREE = 2; + + private final Supplier> combineBufferSupplier; + private final AggregatorFactory[] combiningFactories; + private final KeySerdeFactory combineKeySerdeFactory; + private final ListeningExecutorService executor; + private final Comparator> keyObjComparator; + private final int concurrencyHint; + private final int priority; + private final long queryTimeoutAt; + + // The default value is 8 which comes from an experiment. A non-leaf node will combine up to intermediateCombineDegree + // rows for the same grouping key. + private final int intermediateCombineDegree; + + public ParallelCombiner( + Supplier> combineBufferSupplier, + AggregatorFactory[] combiningFactories, + KeySerdeFactory combineKeySerdeFactory, + ListeningExecutorService executor, + boolean sortHasNonGroupingFields, + int concurrencyHint, + int priority, + long queryTimeoutAt, + int intermediateCombineDegree + ) + { + this.combineBufferSupplier = combineBufferSupplier; + this.combiningFactories = combiningFactories; + this.combineKeySerdeFactory = combineKeySerdeFactory; + this.executor = executor; + this.keyObjComparator = combineKeySerdeFactory.objectComparator(sortHasNonGroupingFields); + this.concurrencyHint = concurrencyHint; + this.priority = priority; + this.intermediateCombineDegree = intermediateCombineDegree; + + this.queryTimeoutAt = queryTimeoutAt; + } + + /** + * Build a combining tree for the input iterators which combine input entries asynchronously. Each node in the tree + * is a combining task which iterates through child iterators, aggregates the inputs from those iterators, and returns + * an iterator for the result of aggregation. + *

+ * This method is called when data is spilled and thus streaming combine is preferred to avoid too many disk accesses. + * + * @return an iterator of the root grouper of the combining tree + */ + public CloseableIterator> combine( + List>> sortedIterators, + List mergedDictionary + ) + { + // CombineBuffer is initialized when this method is called and closed after the result iterator is done + final ResourceHolder combineBufferHolder = combineBufferSupplier.get(); + final ByteBuffer combineBuffer = combineBufferHolder.get(); + final int minimumRequiredBufferCapacity = StreamingMergeSortedGrouper.requiredBufferCapacity( + combineKeySerdeFactory.factorizeWithDictionary(mergedDictionary), + combiningFactories + ); + // We want to maximize the parallelism while the size of buffer slice is greater than the minimum buffer size + // required by StreamingMergeSortedGrouper. Here, we find the leafCombineDegree of the cominbing tree and the + // required number of buffers maximizing the parallelism. + final Pair degreeAndNumBuffers = findLeafCombineDegreeAndNumBuffers( + combineBuffer, + minimumRequiredBufferCapacity, + concurrencyHint, + sortedIterators.size() + ); + + final int leafCombineDegree = degreeAndNumBuffers.lhs; + final int numBuffers = degreeAndNumBuffers.rhs; + final int sliceSize = combineBuffer.capacity() / numBuffers; + + final Supplier bufferSupplier = createCombineBufferSupplier(combineBuffer, numBuffers, sliceSize); + + final Pair>>, List> combineIteratorAndFutures = buildCombineTree( + sortedIterators, + bufferSupplier, + combiningFactories, + leafCombineDegree, + mergedDictionary + ); + + final CloseableIterator> combineIterator = Iterables.getOnlyElement(combineIteratorAndFutures.lhs); + final List combineFutures = combineIteratorAndFutures.rhs; + + final Closer closer = Closer.create(); + closer.register(combineBufferHolder); + closer.register(() -> checkCombineFutures(combineFutures)); + + return CloseableIterators.wrap(combineIterator, closer); + } + + private static void checkCombineFutures(List combineFutures) + { + for (Future future : combineFutures) { + try { + if (!future.isDone()) { + // Cancel futures if close() for the iterator is called early due to some reason (e.g., test failure) + future.cancel(true); + } else { + future.get(); + } + } + catch (InterruptedException | CancellationException e) { + throw new QueryInterruptedException(e); + } + catch (ExecutionException e) { + throw new RuntimeException(e); + } + } + } + + private static Supplier createCombineBufferSupplier( + ByteBuffer combineBuffer, + int numBuffers, + int sliceSize + ) + { + return new Supplier() + { + private int i = 0; + + @Override + public ByteBuffer get() + { + if (i < numBuffers) { + return Groupers.getSlice(combineBuffer, sliceSize, i++); + } else { + throw new ISE("Requested number[%d] of buffer slices exceeds the planned one[%d]", i++, numBuffers); + } + } + }; + } + + /** + * Find a minimum size of the buffer slice and corresponding leafCombineDegree and number of slices. Note that each + * node in the combining tree is executed by different threads. This method assumes that combining the leaf nodes + * requires threads as many as possible, while combining intermediate nodes is not. See the comment on + * {@link #MINIMUM_LEAF_COMBINE_DEGREE} for more details. + * + * @param combineBuffer entire buffer used for combining tree + * @param requiredMinimumBufferCapacity minimum buffer capacity for {@link StreamingMergeSortedGrouper} + * @param numAvailableThreads number of available threads + * @param numLeafNodes number of leaf nodes of combining tree + * + * @return a pair of leafCombineDegree and number of buffers if found. + */ + private Pair findLeafCombineDegreeAndNumBuffers( + ByteBuffer combineBuffer, + int requiredMinimumBufferCapacity, + int numAvailableThreads, + int numLeafNodes + ) + { + for (int leafCombineDegree = MINIMUM_LEAF_COMBINE_DEGREE; leafCombineDegree <= numLeafNodes; leafCombineDegree++) { + final int requiredBufferNum = computeRequiredBufferNum(numLeafNodes, leafCombineDegree); + if (requiredBufferNum <= numAvailableThreads) { + final int expectedSliceSize = combineBuffer.capacity() / requiredBufferNum; + if (expectedSliceSize >= requiredMinimumBufferCapacity) { + return Pair.of(leafCombineDegree, requiredBufferNum); + } + } + } + + throw new ISE( + "Cannot find a proper leaf combine degree for the combining tree. " + + "Each node of the combining tree requires a buffer of [%d] bytes. " + + "Try increasing druid.processing.buffer.sizeBytes for larger buffer or " + + "druid.query.groupBy.intermediateCombineDegree for a smaller tree", + requiredMinimumBufferCapacity + ); + } + + /** + * Recursively compute the number of required buffers for a combining tree in a bottom-up manner. Since each node of + * the combining tree represents a combining task and each combining task requires one buffer, the number of required + * buffers is the number of nodes of the combining tree. + * + * @param numChildNodes number of child nodes + * @param combineDegree combine degree for the current level + * + * @return minimum number of buffers required for combining tree + * + * @see #buildCombineTree(List, Supplier, AggregatorFactory[], int, List) + */ + private int computeRequiredBufferNum(int numChildNodes, int combineDegree) + { + // numChildrenForLastNode used to determine that the last node is needed for the current level. + // Please see buildCombineTree() for more details. + final int numChildrenForLastNode = numChildNodes % combineDegree; + final int numCurLevelNodes = numChildNodes / combineDegree + (numChildrenForLastNode > 1 ? 1 : 0); + final int numChildOfParentNodes = numCurLevelNodes + (numChildrenForLastNode == 1 ? 1 : 0); + + if (numChildOfParentNodes == 1) { + return numCurLevelNodes; + } else { + return numCurLevelNodes + + computeRequiredBufferNum(numChildOfParentNodes, intermediateCombineDegree); + } + } + + /** + * Recursively build a combining tree in a bottom-up manner. Each node of the tree is a task that combines input + * iterators asynchronously. + * + * @param childIterators all iterators of the child level + * @param bufferSupplier combining buffer supplier + * @param combiningFactories array of combining aggregator factories + * @param combineDegree combining degree for the current level + * @param dictionary merged dictionary + * + * @return a pair of a list of iterators of the current level in the combining tree and a list of futures of all + * executed combining tasks + */ + private Pair>>, List> buildCombineTree( + List>> childIterators, + Supplier bufferSupplier, + AggregatorFactory[] combiningFactories, + int combineDegree, + List dictionary + ) + { + final int numChildLevelIterators = childIterators.size(); + final List>> childIteratorsOfNextLevel = new ArrayList<>(); + final List combineFutures = new ArrayList<>(); + + // The below algorithm creates the combining nodes of the current level. It first checks that the number of children + // to be combined together is 1. If it is, the intermediate combining node for that child is not needed. Instead, it + // can be directly connected to a node of the parent level. Here is an example of generated tree when + // numLeafNodes = 6 and leafCombineDegree = intermediateCombineDegree = 2. See the description of + // MINIMUM_LEAF_COMBINE_DEGREE for more details about leafCombineDegree and intermediateCombineDegree. + // + // o + // / \ + // o \ + // / \ \ + // o o o + // / \ / \ / \ + // o o o o o o + // + // We can expect that the aggregates can be combined as early as possible because the tree is built in a bottom-up + // manner. + + for (int i = 0; i < numChildLevelIterators; i += combineDegree) { + if (i < numChildLevelIterators - 1) { + final List>> subIterators = childIterators.subList( + i, + Math.min(i + combineDegree, numChildLevelIterators) + ); + final Pair>, Future> iteratorAndFuture = runCombiner( + subIterators, + bufferSupplier.get(), + combiningFactories, + dictionary + ); + + childIteratorsOfNextLevel.add(iteratorAndFuture.lhs); + combineFutures.add(iteratorAndFuture.rhs); + } else { + // If there remains one child, it can be directly connected to a node of the parent level. + childIteratorsOfNextLevel.add(childIterators.get(i)); + } + } + + if (childIteratorsOfNextLevel.size() == 1) { + // This is the root + return Pair.of(childIteratorsOfNextLevel, combineFutures); + } else { + // Build the parent level iterators + final Pair>>, List> parentIteratorsAndFutures = + buildCombineTree( + childIteratorsOfNextLevel, + bufferSupplier, + combiningFactories, + intermediateCombineDegree, + dictionary + ); + combineFutures.addAll(parentIteratorsAndFutures.rhs); + return Pair.of(parentIteratorsAndFutures.lhs, combineFutures); + } + } + + private Pair>, Future> runCombiner( + List>> iterators, + ByteBuffer combineBuffer, + AggregatorFactory[] combiningFactories, + List dictionary + ) + { + final SettableColumnSelectorFactory settableColumnSelectorFactory = + new SettableColumnSelectorFactory(combiningFactories); + final StreamingMergeSortedGrouper grouper = new StreamingMergeSortedGrouper<>( + Suppliers.ofInstance(combineBuffer), + combineKeySerdeFactory.factorizeWithDictionary(dictionary), + settableColumnSelectorFactory, + combiningFactories, + queryTimeoutAt + ); + grouper.init(); // init() must be called before iterator(), so cannot be called inside the below callable. + + final ListenableFuture future = executor.submit( + new AbstractPrioritizedCallable(priority) + { + @Override + public Void call() throws Exception + { + try ( + CloseableIterator> mergedIterator = CloseableIterators.mergeSorted( + iterators, + keyObjComparator + ) + ) { + while (mergedIterator.hasNext()) { + final Entry next = mergedIterator.next(); + + settableColumnSelectorFactory.set(next.values); + grouper.aggregate(next.key); // grouper always returns ok or throws an exception + settableColumnSelectorFactory.set(null); + } + } + catch (IOException e) { + throw Throwables.propagate(e); + } + + grouper.finish(); + return null; + } + } + ); + + return new Pair<>(grouper.iterator(), future); + } + + private static class SettableColumnSelectorFactory implements ColumnSelectorFactory + { + private static final int UNKNOWN_COLUMN_INDEX = -1; + private final Object2IntMap columnIndexMap; + + private Object[] values; + + SettableColumnSelectorFactory(AggregatorFactory[] aggregatorFactories) + { + columnIndexMap = new Object2IntArrayMap<>(aggregatorFactories.length); + columnIndexMap.defaultReturnValue(UNKNOWN_COLUMN_INDEX); + for (int i = 0; i < aggregatorFactories.length; i++) { + columnIndexMap.put(aggregatorFactories[i].getName(), i); + } + } + + public void set(Object[] values) + { + this.values = values; + } + + private int checkAndGetColumnIndex(String columnName) + { + final int columnIndex = columnIndexMap.getInt(columnName); + Preconditions.checkState( + columnIndex != UNKNOWN_COLUMN_INDEX, + "Cannot find a proper column index for column[%s]", + columnName + ); + return columnIndex; + } + + @Override + public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec) + { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnValueSelector makeColumnValueSelector(String columnName) + { + return new ObjectColumnSelector() + { + @Override + public void inspectRuntimeShape(RuntimeShapeInspector inspector) + { + // do nothing + } + + @Override + public Class classOfObject() + { + return Object.class; + } + + @Override + public Object getObject() + { + return values[checkAndGetColumnIndex(columnName)]; + } + }; + } + + @Override + public ColumnCapabilities getColumnCapabilities(String column) + { + throw new UnsupportedOperationException(); + } + } +} diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java index 0ff4ad19e005..cadc6b7246c3 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java @@ -22,7 +22,6 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Function; import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.base.Supplier; @@ -34,9 +33,12 @@ import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; import com.google.common.util.concurrent.ListeningExecutorService; +import io.druid.collections.ResourceHolder; +import io.druid.common.utils.IntArrayUtils; import io.druid.data.input.MapBasedRow; import io.druid.data.input.Row; import io.druid.java.util.common.IAE; +import io.druid.java.util.common.ISE; import io.druid.java.util.common.Pair; import io.druid.java.util.common.granularity.AllGranularity; import io.druid.java.util.common.guava.Accumulator; @@ -49,6 +51,7 @@ import io.druid.query.groupby.GroupByQuery; import io.druid.query.groupby.GroupByQueryConfig; import io.druid.query.groupby.RowBasedColumnSelectorFactory; +import io.druid.query.groupby.epinephelinae.Grouper.BufferComparator; import io.druid.query.groupby.orderby.DefaultLimitSpec; import io.druid.query.groupby.orderby.OrderByColumnSpec; import io.druid.query.groupby.strategy.GroupByStrategyV2; @@ -64,6 +67,10 @@ import io.druid.segment.column.ColumnCapabilities; import io.druid.segment.column.ValueType; import io.druid.segment.data.IndexedInts; +import it.unimi.dsi.fastutil.ints.IntArrays; +import it.unimi.dsi.fastutil.ints.IntComparator; +import it.unimi.dsi.fastutil.objects.Object2IntMap; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; import org.joda.time.DateTime; import javax.annotation.Nullable; @@ -76,10 +83,15 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Function; +import java.util.stream.IntStream; // this class contains shared code between GroupByMergingQueryRunnerV2 and GroupByRowProcessor public class RowBasedGrouperHelper { + // Entry in dictionary, node pointer in reverseDictionary, hash + k/v/next pointer in reverseDictionary nodes + private static final int ROUGH_OVERHEAD_PER_DICTIONARY_ENTRY = Longs.BYTES * 5 + Ints.BYTES; + private static final int SINGLE_THREAD_CONCURRENCY_HINT = -1; private static final int UNKNOWN_THREAD_PRIORITY = -1; private static final long UNKNOWN_TIMEOUT = -1L; @@ -105,6 +117,7 @@ public static Pair, Accumulator> crea rawInputRowSignature, config, bufferSupplier, + null, SINGLE_THREAD_CONCURRENCY_HINT, temporaryStorage, spillMapper, @@ -128,6 +141,7 @@ public static Pair, Accumulator> crea final Map rawInputRowSignature, final GroupByQueryConfig config, final Supplier bufferSupplier, + final Supplier> combineBufferSupplier, final int concurrencyHint, final LimitedTemporaryStorage temporaryStorage, final ObjectMapper spillMapper, @@ -193,14 +207,24 @@ public static Pair, Accumulator> crea mergeBufferSize ); } else { + final Grouper.KeySerdeFactory combineKeySerdeFactory = new RowBasedKeySerdeFactory( + includeTimestamp, + query.getContextSortByDimsFirst(), + query.getDimensions(), + querySpecificConfig.getMaxMergingDictionarySize(), // use entire dictionary space for combining key serde + valueTypes, + aggregatorFactories, + limitSpec + ); + grouper = new ConcurrentGrouper<>( + querySpecificConfig, bufferSupplier, + combineBufferSupplier, keySerdeFactory, + combineKeySerdeFactory, columnSelectorFactory, aggregatorFactories, - querySpecificConfig.getBufferGrouperMaxSize(), - querySpecificConfig.getBufferGrouperMaxLoadFactor(), - querySpecificConfig.getBufferGrouperInitialBuckets(), temporaryStorage, spillMapper, concurrencyHint, @@ -209,8 +233,7 @@ public static Pair, Accumulator> crea grouperSorter, priority, hasQueryTimeout, - queryTimeoutAt, - mergeBufferSize + queryTimeoutAt ); } @@ -647,6 +670,12 @@ private static class RowBasedKeySerdeFactory implements Grouper.KeySerdeFactory< this.valueTypes = valueTypes; } + @Override + public long getMaxDictionarySize() + { + return maxDictionarySize; + } + @Override public Grouper.KeySerde factorize() { @@ -656,7 +685,22 @@ public Grouper.KeySerde factorize() dimensions, maxDictionarySize, limitSpec, - valueTypes + valueTypes, + null + ); + } + + @Override + public Grouper.KeySerde factorizeWithDictionary(List dictionary) + { + return new RowBasedKeySerde( + includeTimestamp, + sortByDimsFirst, + dimensions, + maxDictionarySize, + limitSpec, + valueTypes, + dictionary ); } @@ -894,10 +938,15 @@ private static int compareDimsInRowsWithAggs( } } + static long estimateStringKeySize(String key) + { + return (long) key.length() * Chars.BYTES + ROUGH_OVERHEAD_PER_DICTIONARY_ENTRY; + } + private static class RowBasedKeySerde implements Grouper.KeySerde { - // Entry in dictionary, node pointer in reverseDictionary, hash + k/v/next pointer in reverseDictionary nodes - private static final int ROUGH_OVERHEAD_PER_DICTIONARY_ENTRY = Longs.BYTES * 5 + Ints.BYTES; + private static final int DICTIONARY_INITIAL_CAPACITY = 10000; + private static final int UNKNOWN_DICTIONARY_ID = -1; private final boolean includeTimestamp; private final boolean sortByDimsFirst; @@ -905,18 +954,25 @@ private static class RowBasedKeySerde implements Grouper.KeySerde dictionary = Lists.newArrayList(); - private final Map reverseDictionary = Maps.newHashMap(); - private final List serdeHelpers; + private final RowBasedKeySerdeHelper[] serdeHelpers; + private final BufferComparator[] serdeHelperComparators; private final DefaultLimitSpec limitSpec; private final List valueTypes; + private final boolean enableRuntimeDictionaryGeneration; + + private final List dictionary; + private final Object2IntMap reverseDictionary; + // Size limiting for the dictionary, in (roughly estimated) bytes. private final long maxDictionarySize; + private long currentEstimatedSize = 0; - // dictionary id -> its position if it were sorted by dictionary value - private int[] sortableIds = null; + // dictionary id -> rank of the sorted dictionary + // This is initialized in the constructor and bufferComparator() with static dictionary and dynamic dictionary, + // respectively. + private int[] rankOfDictionaryIds = null; RowBasedKeySerde( final boolean includeTimestamp, @@ -924,19 +980,71 @@ private static class RowBasedKeySerde implements Grouper.KeySerde dimensions, final long maxDictionarySize, final DefaultLimitSpec limitSpec, - final List valueTypes + final List valueTypes, + @Nullable final List dictionary ) { this.includeTimestamp = includeTimestamp; this.sortByDimsFirst = sortByDimsFirst; this.dimensions = dimensions; this.dimCount = dimensions.size(); - this.maxDictionarySize = maxDictionarySize; this.valueTypes = valueTypes; this.limitSpec = limitSpec; - this.serdeHelpers = makeSerdeHelpers(); + this.enableRuntimeDictionaryGeneration = dictionary == null; + this.dictionary = enableRuntimeDictionaryGeneration ? new ArrayList<>(DICTIONARY_INITIAL_CAPACITY) : dictionary; + this.reverseDictionary = enableRuntimeDictionaryGeneration ? + new Object2IntOpenHashMap<>(DICTIONARY_INITIAL_CAPACITY) : + new Object2IntOpenHashMap<>(dictionary.size()); + this.reverseDictionary.defaultReturnValue(UNKNOWN_DICTIONARY_ID); + this.maxDictionarySize = maxDictionarySize; + this.serdeHelpers = makeSerdeHelpers(limitSpec != null, enableRuntimeDictionaryGeneration); + this.serdeHelperComparators = new BufferComparator[serdeHelpers.length]; + Arrays.setAll(serdeHelperComparators, i -> serdeHelpers[i].getBufferComparator()); this.keySize = (includeTimestamp ? Longs.BYTES : 0) + getTotalKeySize(); this.keyBuffer = ByteBuffer.allocate(keySize); + + if (!enableRuntimeDictionaryGeneration) { + final long initialDictionarySize = dictionary.stream() + .mapToLong(RowBasedGrouperHelper::estimateStringKeySize) + .sum(); + Preconditions.checkState( + maxDictionarySize >= initialDictionarySize, + "Dictionary size[%s] exceeds threshold[%s]", + initialDictionarySize, + maxDictionarySize + ); + + for (int i = 0; i < dictionary.size(); i++) { + reverseDictionary.put(dictionary.get(i), i); + } + + initializeRankOfDictionaryIds(); + } + } + + private void initializeRankOfDictionaryIds() + { + final int dictionarySize = dictionary.size(); + rankOfDictionaryIds = IntStream.range(0, dictionarySize).toArray(); + IntArrays.quickSort( + rankOfDictionaryIds, + new IntComparator() + { + @Override + public int compare(int i1, int i2) + { + return dictionary.get(i1).compareTo(dictionary.get(i2)); + } + + @Override + public int compare(Integer o1, Integer o2) + { + return compare(o1.intValue(), o2.intValue()); + } + } + ); + + IntArrayUtils.inverse(rankOfDictionaryIds); } @Override @@ -951,6 +1059,12 @@ public Class keyClazz() return RowBasedKey.class; } + @Override + public List getDictionary() + { + return dictionary; + } + @Override public ByteBuffer toByteBuffer(RowBasedKey key) { @@ -964,7 +1078,7 @@ public ByteBuffer toByteBuffer(RowBasedKey key) dimStart = 0; } for (int i = dimStart; i < key.getKey().length; i++) { - if (!serdeHelpers.get(i - dimStart).putToKeyBuffer(key, i)) { + if (!serdeHelpers[i - dimStart].putToKeyBuffer(key, i)) { return null; } } @@ -993,7 +1107,7 @@ public RowBasedKey fromByteBuffer(ByteBuffer buffer, int position) for (int i = dimStart; i < key.length; i++) { // Writes value from buffer to key[i] - serdeHelpers.get(i - dimStart).getFromByteBuffer(buffer, dimsPosition, i, key); + serdeHelpers[i - dimStart].getFromByteBuffer(buffer, dimsPosition, i, key); } return new RowBasedKey(key); @@ -1002,16 +1116,8 @@ public RowBasedKey fromByteBuffer(ByteBuffer buffer, int position) @Override public Grouper.BufferComparator bufferComparator() { - if (sortableIds == null) { - Map sortedMap = Maps.newTreeMap(); - for (int id = 0; id < dictionary.size(); id++) { - sortedMap.put(dictionary.get(id), id); - } - sortableIds = new int[dictionary.size()]; - int index = 0; - for (final Integer id : sortedMap.values()) { - sortableIds[id] = index++; - } + if (rankOfDictionaryIds == null) { + initializeRankOfDictionaryIds(); } if (includeTimestamp) { @@ -1022,9 +1128,7 @@ public Grouper.BufferComparator bufferComparator() public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) { final int cmp = compareDimsInBuffersForNullFudgeTimestamp( - serdeHelpers, - sortableIds, - dimCount, + serdeHelperComparators, lhsBuffer, rhsBuffer, lhsPosition, @@ -1050,9 +1154,7 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, } return compareDimsInBuffersForNullFudgeTimestamp( - serdeHelpers, - sortableIds, - dimCount, + serdeHelperComparators, lhsBuffer, rhsBuffer, lhsPosition, @@ -1068,7 +1170,7 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) { for (int i = 0; i < dimCount; i++) { - final int cmp = serdeHelpers.get(i).compare( + final int cmp = serdeHelperComparators[i].compare( lhsBuffer, rhsBuffer, lhsPosition, @@ -1104,7 +1206,7 @@ public Grouper.BufferComparator bufferComparatorWithAggregators( needsReverse = orderSpec.getDirection() != OrderByColumnSpec.Direction.ASCENDING; int dimIndex = OrderByColumnSpec.getDimIndexForOrderBy(orderSpec, dimensions); if (dimIndex >= 0) { - RowBasedKeySerdeHelper serdeHelper = serdeHelpers.get(dimIndex); + RowBasedKeySerdeHelper serdeHelper = serdeHelpers[dimIndex]; orderByHelpers.add(serdeHelper); orderByIndices.add(dimIndex); needsReverses.add(needsReverse); @@ -1112,35 +1214,19 @@ public Grouper.BufferComparator bufferComparatorWithAggregators( int aggIndex = OrderByColumnSpec.getAggIndexForOrderBy(orderSpec, Arrays.asList(aggregatorFactories)); if (aggIndex >= 0) { final RowBasedKeySerdeHelper serdeHelper; - final StringComparator cmp = orderSpec.getDimensionComparator(); - final boolean cmpIsNumeric = cmp.equals(StringComparators.NUMERIC); + final StringComparator stringComparator = orderSpec.getDimensionComparator(); final String typeName = aggregatorFactories[aggIndex].getTypeName(); final int aggOffset = aggregatorOffsets[aggIndex] - Ints.BYTES; aggCount++; - if (typeName.equals("long")) { - if (cmpIsNumeric) { - serdeHelper = new LongRowBasedKeySerdeHelper(aggOffset); - } else { - serdeHelper = new LimitPushDownLongRowBasedKeySerdeHelper(aggOffset, cmp); - } - } else if (typeName.equals("float")) { - if (cmpIsNumeric) { - serdeHelper = new FloatRowBasedKeySerdeHelper(aggOffset); - } else { - serdeHelper = new LimitPushDownFloatRowBasedKeySerdeHelper(aggOffset, cmp); - } - } else if (typeName.equals("double")) { - if (cmpIsNumeric) { - serdeHelper = new DoubleRowBasedKeySerdeHelper(aggOffset); - } else { - serdeHelper = new LimitPushDownDoubleRowBasedKeySerdeHelper(aggOffset, cmp); - } - } else { + final ValueType valueType = ValueType.fromString(typeName); + if (!ValueType.isNumeric(valueType)) { throw new IAE("Cannot order by a non-numeric aggregator[%s]", orderSpec); } + serdeHelper = makeNumericSerdeHelper(valueType, aggOffset, true, stringComparator); + orderByHelpers.add(serdeHelper); needsReverses.add(needsReverse); } @@ -1149,7 +1235,7 @@ public Grouper.BufferComparator bufferComparatorWithAggregators( for (int i = 0; i < dimCount; i++) { if (!orderByIndices.contains(i)) { - otherDimHelpers.add(serdeHelpers.get(i)); + otherDimHelpers.add(serdeHelpers[i]); needsReverses.add(false); // default to Ascending order if dim is not in an orderby spec } } @@ -1157,6 +1243,9 @@ public Grouper.BufferComparator bufferComparatorWithAggregators( adjustedSerdeHelpers = orderByHelpers; adjustedSerdeHelpers.addAll(otherDimHelpers); + final BufferComparator[] adjustedSerdeHelperComparators = new BufferComparator[adjustedSerdeHelpers.size()]; + Arrays.setAll(adjustedSerdeHelperComparators, i -> adjustedSerdeHelpers.get(i).getBufferComparator()); + final int fieldCount = dimCount + aggCount; if (includeTimestamp) { @@ -1167,7 +1256,7 @@ public Grouper.BufferComparator bufferComparatorWithAggregators( public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) { final int cmp = compareDimsInBuffersForNullFudgeTimestampForPushDown( - adjustedSerdeHelpers, + adjustedSerdeHelperComparators, needsReverses, fieldCount, lhsBuffer, @@ -1195,7 +1284,7 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, } int cmp = compareDimsInBuffersForNullFudgeTimestampForPushDown( - adjustedSerdeHelpers, + adjustedSerdeHelperComparators, needsReverses, fieldCount, lhsBuffer, @@ -1217,14 +1306,14 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, for (int i = 0; i < fieldCount; i++) { final int cmp; if (needsReverses.get(i)) { - cmp = adjustedSerdeHelpers.get(i).compare( + cmp = adjustedSerdeHelperComparators[i].compare( rhsBuffer, lhsBuffer, rhsPosition, lhsPosition ); } else { - cmp = adjustedSerdeHelpers.get(i).compare( + cmp = adjustedSerdeHelperComparators[i].compare( lhsBuffer, rhsBuffer, lhsPosition, @@ -1243,98 +1332,15 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, } } - private static int compareDimsInBuffersForNullFudgeTimestamp( - List serdeHelpers, - int[] sortableIds, - int dimCount, - ByteBuffer lhsBuffer, - ByteBuffer rhsBuffer, - int lhsPosition, - int rhsPosition - ) - { - for (int i = 0; i < dimCount; i++) { - final int cmp = serdeHelpers.get(i).compare( - lhsBuffer, - rhsBuffer, - lhsPosition + Longs.BYTES, - rhsPosition + Longs.BYTES - ); - if (cmp != 0) { - return cmp; - } - } - - return 0; - } - - private static int compareDimsInBuffersForNullFudgeTimestampForPushDown( - List serdeHelpers, - List needsReverses, - int dimCount, - ByteBuffer lhsBuffer, - ByteBuffer rhsBuffer, - int lhsPosition, - int rhsPosition - ) - { - for (int i = 0; i < dimCount; i++) { - final int cmp; - if (needsReverses.get(i)) { - cmp = serdeHelpers.get(i).compare( - rhsBuffer, - lhsBuffer, - rhsPosition + Longs.BYTES, - lhsPosition + Longs.BYTES - ); - } else { - cmp = serdeHelpers.get(i).compare( - lhsBuffer, - rhsBuffer, - lhsPosition + Longs.BYTES, - rhsPosition + Longs.BYTES - ); - } - if (cmp != 0) { - return cmp; - } - } - - return 0; - } - @Override public void reset() { - dictionary.clear(); - reverseDictionary.clear(); - sortableIds = null; - currentEstimatedSize = 0; - } - - /** - * Adds s to the dictionary. If the dictionary's size limit would be exceeded by adding this key, then - * this returns -1. - * - * @param s a string - * - * @return id for this string, or -1 - */ - private int addToDictionary(final String s) - { - Integer idx = reverseDictionary.get(s); - if (idx == null) { - final long additionalEstimatedSize = (long) s.length() * Chars.BYTES + ROUGH_OVERHEAD_PER_DICTIONARY_ENTRY; - if (currentEstimatedSize + additionalEstimatedSize > maxDictionarySize) { - return -1; - } - - idx = dictionary.size(); - reverseDictionary.put(s, idx); - dictionary.add(s); - currentEstimatedSize += additionalEstimatedSize; + if (enableRuntimeDictionaryGeneration) { + dictionary.clear(); + reverseDictionary.clear(); + rankOfDictionaryIds = null; + currentEstimatedSize = 0; } - return idx; } private int getTotalKeySize() @@ -1346,143 +1352,122 @@ private int getTotalKeySize() return size; } - private List makeSerdeHelpers() + private RowBasedKeySerdeHelper[] makeSerdeHelpers( + boolean pushLimitDown, + boolean enableRuntimeDictionaryGeneration + ) { - if (limitSpec != null) { - return makeSerdeHelpersForLimitPushDown(); - } - - List helpers = new ArrayList<>(); + final List helpers = new ArrayList<>(); int keyBufferPosition = 0; - for (ValueType valType : valueTypes) { - RowBasedKeySerdeHelper helper; - switch (valType) { - case STRING: - helper = new StringRowBasedKeySerdeHelper(keyBufferPosition); - break; - case LONG: - helper = new LongRowBasedKeySerdeHelper(keyBufferPosition); - break; - case FLOAT: - helper = new FloatRowBasedKeySerdeHelper(keyBufferPosition); - break; - case DOUBLE: - helper = new DoubleRowBasedKeySerdeHelper(keyBufferPosition); - break; - default: - throw new IAE("invalid type: %s", valType); + + for (int i = 0; i < dimCount; i++) { + final StringComparator stringComparator; + if (limitSpec != null) { + final String dimName = dimensions.get(i).getOutputName(); + stringComparator = DefaultLimitSpec.getComparatorForDimName(limitSpec, dimName); + } else { + stringComparator = null; } + + RowBasedKeySerdeHelper helper = makeSerdeHelper( + valueTypes.get(i), + keyBufferPosition, + pushLimitDown, + stringComparator, + enableRuntimeDictionaryGeneration + ); + keyBufferPosition += helper.getKeyBufferValueSize(); helpers.add(helper); } - return helpers; + + return helpers.toArray(new RowBasedKeySerdeHelper[helpers.size()]); } - private List makeSerdeHelpersForLimitPushDown() + private RowBasedKeySerdeHelper makeSerdeHelper( + ValueType valueType, + int keyBufferPosition, + boolean pushLimitDown, + @Nullable StringComparator stringComparator, + boolean enableRuntimeDictionaryGeneration + ) { - List helpers = new ArrayList<>(); - int keyBufferPosition = 0; - - for (int i = 0; i < valueTypes.size(); i++) { - final ValueType valType = valueTypes.get(i); - final String dimName = dimensions.get(i).getOutputName(); - StringComparator cmp = DefaultLimitSpec.getComparatorForDimName(limitSpec, dimName); - final boolean cmpIsNumeric = cmp != null && cmp.equals(StringComparators.NUMERIC); - - RowBasedKeySerdeHelper helper; - switch (valType) { - case STRING: - if (cmp == null) { - cmp = StringComparators.LEXICOGRAPHIC; - } - helper = new LimitPushDownStringRowBasedKeySerdeHelper(keyBufferPosition, cmp); - break; - case LONG: - if (cmp == null || cmpIsNumeric) { - helper = new LongRowBasedKeySerdeHelper(keyBufferPosition); - } else { - helper = new LimitPushDownLongRowBasedKeySerdeHelper(keyBufferPosition, cmp); - } - break; - case FLOAT: - if (cmp == null || cmpIsNumeric) { - helper = new FloatRowBasedKeySerdeHelper(keyBufferPosition); - } else { - helper = new LimitPushDownFloatRowBasedKeySerdeHelper(keyBufferPosition, cmp); - } - break; - case DOUBLE: - if (cmp == null || cmpIsNumeric) { - helper = new DoubleRowBasedKeySerdeHelper(keyBufferPosition); - } else { - helper = new LimitPushDownDoubleRowBasedKeySerdeHelper(keyBufferPosition, cmp); - } - break; - default: - throw new IAE("invalid type: %s", valType); - } - keyBufferPosition += helper.getKeyBufferValueSize(); - helpers.add(helper); + switch (valueType) { + case STRING: + if (enableRuntimeDictionaryGeneration) { + return new DynamicDictionaryStringRowBasedKeySerdeHelper( + keyBufferPosition, + pushLimitDown, + stringComparator + ); + } else { + return new StaticDictionaryStringRowBasedKeySerdeHelper( + keyBufferPosition, + pushLimitDown, + stringComparator + ); + } + case LONG: + case FLOAT: + case DOUBLE: + return makeNumericSerdeHelper(valueType, keyBufferPosition, pushLimitDown, stringComparator); + default: + throw new IAE("invalid type: %s", valueType); } - return helpers; } - private interface RowBasedKeySerdeHelper + private RowBasedKeySerdeHelper makeNumericSerdeHelper( + ValueType valueType, + int keyBufferPosition, + boolean pushLimitDown, + @Nullable StringComparator stringComparator + ) { - /** - * @return The size in bytes for a value of the column handled by this SerdeHelper. - */ - int getKeyBufferValueSize(); - - /** - * Read a value from RowBasedKey at `idx` and put the value at the current position of RowBasedKeySerde's keyBuffer. - * advancing the position by the size returned by getKeyBufferValueSize(). - * - * If an internal resource limit has been reached and the value could not be added to the keyBuffer, - * (e.g., maximum dictionary size exceeded for Strings), this method returns false. - * - * @param key RowBasedKey containing the grouping key values for a row. - * @param idx Index of the grouping key column within that this SerdeHelper handles - * - * @return true if the value was added to the key, false otherwise - */ - boolean putToKeyBuffer(RowBasedKey key, int idx); - - /** - * Read a value from a ByteBuffer containing a grouping key in the same format as RowBasedKeySerde's keyBuffer and - * put the value in `dimValues` at `dimValIdx`. - * - * The value to be read resides in the buffer at position (`initialOffset` + the SerdeHelper's keyBufferPosition). - * - * @param buffer ByteBuffer containing an array of grouping keys for a row - * @param initialOffset Offset where non-timestamp grouping key columns start, needed because timestamp is not - * always included in the buffer. - * @param dimValIdx Index within dimValues to store the value read from the buffer - * @param dimValues Output array containing grouping key values for a row - */ - void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValIdx, Comparable[] dimValues); + switch (valueType) { + case LONG: + return new LongRowBasedKeySerdeHelper(keyBufferPosition, pushLimitDown, stringComparator); + case FLOAT: + return new FloatRowBasedKeySerdeHelper(keyBufferPosition, pushLimitDown, stringComparator); + case DOUBLE: + return new DoubleRowBasedKeySerdeHelper(keyBufferPosition, pushLimitDown, stringComparator); + default: + throw new IAE("invalid type: %s", valueType); + } + } - /** - * Compare the values at lhsBuffer[lhsPosition] and rhsBuffer[rhsPosition] using the natural ordering - * for this SerdeHelper's value type. - * - * @param lhsBuffer ByteBuffer containing an array of grouping keys for a row - * @param rhsBuffer ByteBuffer containing an array of grouping keys for a row - * @param lhsPosition Position of value within lhsBuffer - * @param rhsPosition Position of value within rhsBuffer - * - * @return Negative number if lhs < rhs, positive if lhs > rhs, 0 if lhs == rhs - */ - int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition); + private static boolean isPrimitiveComparable(boolean pushLimitDown, @Nullable StringComparator stringComparator) + { + return !pushLimitDown || stringComparator == null || stringComparator.equals(StringComparators.NUMERIC); } - private class StringRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper + private abstract class AbstractStringRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper { final int keyBufferPosition; - public StringRowBasedKeySerdeHelper(int keyBufferPosition) + final BufferComparator bufferComparator; + + AbstractStringRowBasedKeySerdeHelper( + int keyBufferPosition, + boolean pushLimitDown, + @Nullable StringComparator stringComparator + ) { this.keyBufferPosition = keyBufferPosition; + if (!pushLimitDown) { + bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> Ints.compare( + rankOfDictionaryIds[lhsBuffer.getInt(lhsPosition + keyBufferPosition)], + rankOfDictionaryIds[rhsBuffer.getInt(rhsPosition + keyBufferPosition)] + ); + } else { + final StringComparator realComparator = stringComparator == null ? + StringComparators.LEXICOGRAPHIC : + stringComparator; + bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> { + String lhsStr = dictionary.get(lhsBuffer.getInt(lhsPosition + keyBufferPosition)); + String rhsStr = dictionary.get(rhsBuffer.getInt(rhsPosition + keyBufferPosition)); + return realComparator.compare(lhsStr, rhsStr); + }; + } } @Override @@ -1491,6 +1476,30 @@ public int getKeyBufferValueSize() return Ints.BYTES; } + @Override + public void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValIdx, Comparable[] dimValues) + { + dimValues[dimValIdx] = dictionary.get(buffer.getInt(initialOffset + keyBufferPosition)); + } + + @Override + public BufferComparator getBufferComparator() + { + return bufferComparator; + } + } + + private class DynamicDictionaryStringRowBasedKeySerdeHelper extends AbstractStringRowBasedKeySerdeHelper + { + DynamicDictionaryStringRowBasedKeySerdeHelper( + int keyBufferPosition, + boolean pushLimitDown, + @Nullable StringComparator stringComparator + ) + { + super(keyBufferPosition, pushLimitDown, stringComparator); + } + @Override public boolean putToKeyBuffer(RowBasedKey key, int idx) { @@ -1502,48 +1511,82 @@ public boolean putToKeyBuffer(RowBasedKey key, int idx) return true; } - @Override - public void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValIdx, Comparable[] dimValues) + /** + * Adds s to the dictionary. If the dictionary's size limit would be exceeded by adding this key, then + * this returns -1. + * + * @param s a string + * + * @return id for this string, or -1 + */ + private int addToDictionary(final String s) { - dimValues[dimValIdx] = dictionary.get(buffer.getInt(initialOffset + keyBufferPosition)); - } + int idx = reverseDictionary.getInt(s); + if (idx == UNKNOWN_DICTIONARY_ID) { + final long additionalEstimatedSize = estimateStringKeySize(s); + if (currentEstimatedSize + additionalEstimatedSize > maxDictionarySize) { + return -1; + } - @Override - public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) - { - return Ints.compare( - sortableIds[lhsBuffer.getInt(lhsPosition + keyBufferPosition)], - sortableIds[rhsBuffer.getInt(rhsPosition + keyBufferPosition)] - ); + idx = dictionary.size(); + reverseDictionary.put(s, idx); + dictionary.add(s); + currentEstimatedSize += additionalEstimatedSize; + } + return idx; } } - private class LimitPushDownStringRowBasedKeySerdeHelper extends StringRowBasedKeySerdeHelper + private class StaticDictionaryStringRowBasedKeySerdeHelper extends AbstractStringRowBasedKeySerdeHelper { - final StringComparator cmp; - - public LimitPushDownStringRowBasedKeySerdeHelper(int keyBufferPosition, StringComparator cmp) + StaticDictionaryStringRowBasedKeySerdeHelper( + int keyBufferPosition, + boolean pushLimitDown, + @Nullable StringComparator stringComparator + ) { - super(keyBufferPosition); - this.cmp = cmp; + super(keyBufferPosition, pushLimitDown, stringComparator); } @Override - public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + public boolean putToKeyBuffer(RowBasedKey key, int idx) { - String lhsStr = dictionary.get(lhsBuffer.getInt(lhsPosition + keyBufferPosition)); - String rhsStr = dictionary.get(rhsBuffer.getInt(rhsPosition + keyBufferPosition)); - return cmp.compare(lhsStr, rhsStr); + final String stringKey = (String) key.getKey()[idx]; + + final int dictIndex = reverseDictionary.getInt(stringKey); + if (dictIndex == UNKNOWN_DICTIONARY_ID) { + throw new ISE("Cannot find key[%s] from dictionary", stringKey); + } + keyBuffer.putInt(dictIndex); + return true; } } private class LongRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper { final int keyBufferPosition; + final BufferComparator bufferComparator; - public LongRowBasedKeySerdeHelper(int keyBufferPosition) + LongRowBasedKeySerdeHelper( + int keyBufferPosition, + boolean pushLimitDown, + @Nullable StringComparator stringComparator + ) { this.keyBufferPosition = keyBufferPosition; + if (isPrimitiveComparable(pushLimitDown, stringComparator)) { + bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> Longs.compare( + lhsBuffer.getLong(lhsPosition + keyBufferPosition), + rhsBuffer.getLong(rhsPosition + keyBufferPosition) + ); + } else { + bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> { + long lhs = lhsBuffer.getLong(lhsPosition + keyBufferPosition); + long rhs = rhsBuffer.getLong(rhsPosition + keyBufferPosition); + + return stringComparator.compare(String.valueOf(lhs), String.valueOf(rhs)); + }; + } } @Override @@ -1566,42 +1609,35 @@ public void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValId } @Override - public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + public BufferComparator getBufferComparator() { - return Longs.compare( - lhsBuffer.getLong(lhsPosition + keyBufferPosition), - rhsBuffer.getLong(rhsPosition + keyBufferPosition) - ); - } - } - - private class LimitPushDownLongRowBasedKeySerdeHelper extends LongRowBasedKeySerdeHelper - { - final StringComparator cmp; - - public LimitPushDownLongRowBasedKeySerdeHelper(int keyBufferPosition, StringComparator cmp) - { - super(keyBufferPosition); - this.cmp = cmp; - } - - @Override - public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) - { - long lhs = lhsBuffer.getLong(lhsPosition + keyBufferPosition); - long rhs = rhsBuffer.getLong(rhsPosition + keyBufferPosition); - - return cmp.compare(String.valueOf(lhs), String.valueOf(rhs)); + return bufferComparator; } } private class FloatRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper { final int keyBufferPosition; + final BufferComparator bufferComparator; - public FloatRowBasedKeySerdeHelper(int keyBufferPosition) + FloatRowBasedKeySerdeHelper( + int keyBufferPosition, + boolean pushLimitDown, + @Nullable StringComparator stringComparator) { this.keyBufferPosition = keyBufferPosition; + if (isPrimitiveComparable(pushLimitDown, stringComparator)) { + bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> Float.compare( + lhsBuffer.getFloat(lhsPosition + keyBufferPosition), + rhsBuffer.getFloat(rhsPosition + keyBufferPosition) + ); + } else { + bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> { + float lhs = lhsBuffer.getFloat(lhsPosition + keyBufferPosition); + float rhs = rhsBuffer.getFloat(rhsPosition + keyBufferPosition); + return stringComparator.compare(String.valueOf(lhs), String.valueOf(rhs)); + }; + } } @Override @@ -1624,41 +1660,36 @@ public void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValId } @Override - public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) - { - return Float.compare( - lhsBuffer.getFloat(lhsPosition + keyBufferPosition), - rhsBuffer.getFloat(rhsPosition + keyBufferPosition) - ); - } - } - - private class LimitPushDownFloatRowBasedKeySerdeHelper extends FloatRowBasedKeySerdeHelper - { - final StringComparator cmp; - - public LimitPushDownFloatRowBasedKeySerdeHelper(int keyBufferPosition, StringComparator cmp) - { - super(keyBufferPosition); - this.cmp = cmp; - } - - @Override - public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + public BufferComparator getBufferComparator() { - float lhs = lhsBuffer.getFloat(lhsPosition + keyBufferPosition); - float rhs = rhsBuffer.getFloat(rhsPosition + keyBufferPosition); - return cmp.compare(String.valueOf(lhs), String.valueOf(rhs)); + return bufferComparator; } } private class DoubleRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper { final int keyBufferPosition; + final BufferComparator bufferComparator; - public DoubleRowBasedKeySerdeHelper(int keyBufferPosition) + DoubleRowBasedKeySerdeHelper( + int keyBufferPosition, + boolean pushLimitDown, + @Nullable StringComparator stringComparator + ) { this.keyBufferPosition = keyBufferPosition; + if (isPrimitiveComparable(pushLimitDown, stringComparator)) { + bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> Double.compare( + lhsBuffer.getDouble(lhsPosition + keyBufferPosition), + rhsBuffer.getDouble(rhsPosition + keyBufferPosition) + ); + } else { + bufferComparator = (lhsBuffer, rhsBuffer, lhsPosition, rhsPosition) -> { + double lhs = lhsBuffer.getDouble(lhsPosition + keyBufferPosition); + double rhs = rhsBuffer.getDouble(rhsPosition + keyBufferPosition); + return stringComparator.compare(String.valueOf(lhs), String.valueOf(rhs)); + }; + } } @Override @@ -1681,32 +1712,68 @@ public void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValId } @Override - public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + public BufferComparator getBufferComparator() { - return Double.compare( - lhsBuffer.getDouble(lhsPosition + keyBufferPosition), - rhsBuffer.getDouble(rhsPosition + keyBufferPosition) - ); + return bufferComparator; } } + } - private class LimitPushDownDoubleRowBasedKeySerdeHelper extends DoubleRowBasedKeySerdeHelper - { - final StringComparator cmp; - - public LimitPushDownDoubleRowBasedKeySerdeHelper(int keyBufferPosition, StringComparator cmp) - { - super(keyBufferPosition); - this.cmp = cmp; + private static int compareDimsInBuffersForNullFudgeTimestamp( + BufferComparator[] serdeHelperComparators, + ByteBuffer lhsBuffer, + ByteBuffer rhsBuffer, + int lhsPosition, + int rhsPosition + ) + { + for (BufferComparator comparator : serdeHelperComparators) { + final int cmp = comparator.compare( + lhsBuffer, + rhsBuffer, + lhsPosition + Longs.BYTES, + rhsPosition + Longs.BYTES + ); + if (cmp != 0) { + return cmp; } + } - @Override - public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) - { - double lhs = lhsBuffer.getDouble(lhsPosition + keyBufferPosition); - double rhs = rhsBuffer.getDouble(rhsPosition + keyBufferPosition); - return cmp.compare(String.valueOf(lhs), String.valueOf(rhs)); + return 0; + } + + private static int compareDimsInBuffersForNullFudgeTimestampForPushDown( + BufferComparator[] serdeHelperComparators, + List needsReverses, + int dimCount, + ByteBuffer lhsBuffer, + ByteBuffer rhsBuffer, + int lhsPosition, + int rhsPosition + ) + { + for (int i = 0; i < dimCount; i++) { + final int cmp; + if (needsReverses.get(i)) { + cmp = serdeHelperComparators[i].compare( + rhsBuffer, + lhsBuffer, + rhsPosition + Longs.BYTES, + lhsPosition + Longs.BYTES + ); + } else { + cmp = serdeHelperComparators[i].compare( + lhsBuffer, + rhsBuffer, + lhsPosition + Longs.BYTES, + rhsPosition + Longs.BYTES + ); + } + if (cmp != 0) { + return cmp; } } + + return 0; } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedKeySerdeHelper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedKeySerdeHelper.java new file mode 100644 index 000000000000..c7e3437ded97 --- /dev/null +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedKeySerdeHelper.java @@ -0,0 +1,66 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import io.druid.query.groupby.epinephelinae.Grouper.BufferComparator; +import io.druid.query.groupby.epinephelinae.RowBasedGrouperHelper.RowBasedKey; + +import java.nio.ByteBuffer; + +interface RowBasedKeySerdeHelper +{ + /** + * @return The size in bytes for a value of the column handled by this SerdeHelper. + */ + int getKeyBufferValueSize(); + + /** + * Read a value from RowBasedKey at `idx` and put the value at the current position of RowBasedKeySerde's keyBuffer. + * advancing the position by the size returned by getKeyBufferValueSize(). + * + * If an internal resource limit has been reached and the value could not be added to the keyBuffer, + * (e.g., maximum dictionary size exceeded for Strings), this method returns false. + * + * @param key RowBasedKey containing the grouping key values for a row. + * @param idx Index of the grouping key column within that this SerdeHelper handles + * + * @return true if the value was added to the key, false otherwise + */ + boolean putToKeyBuffer(RowBasedKey key, int idx); + + /** + * Read a value from a ByteBuffer containing a grouping key in the same format as RowBasedKeySerde's keyBuffer and + * put the value in `dimValues` at `dimValIdx`. + * + * The value to be read resides in the buffer at position (`initialOffset` + the SerdeHelper's keyBufferPosition). + * + * @param buffer ByteBuffer containing an array of grouping keys for a row + * @param initialOffset Offset where non-timestamp grouping key columns start, needed because timestamp is not + * always included in the buffer. + * @param dimValIdx Index within dimValues to store the value read from the buffer + * @param dimValues Output array containing grouping key values for a row + */ + void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValIdx, Comparable[] dimValues); + + /** + * Return a {@link BufferComparator} to compare keys stored in ByteBuffer. + */ + BufferComparator getBufferComparator(); +} diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/SpillingGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/SpillingGrouper.java index 6aa9ee6fc815..4ef2f26681d9 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/SpillingGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/SpillingGrouper.java @@ -27,8 +27,10 @@ import com.google.common.base.Throwables; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; -import io.druid.java.util.common.guava.CloseQuietly; +import io.druid.java.util.common.CloseableIterators; +import io.druid.java.util.common.io.Closer; import io.druid.java.util.common.logger.Logger; +import io.druid.java.util.common.parsers.CloseableIterator; import io.druid.query.BaseQuery; import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.groupby.orderby.DefaultLimitSpec; @@ -36,15 +38,16 @@ import net.jpountz.lz4.LZ4BlockInputStream; import net.jpountz.lz4.LZ4BlockOutputStream; -import java.io.Closeable; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Comparator; +import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.Set; /** * Grouper based around a single underlying {@link BufferHashGrouper}. Not thread-safe. @@ -67,7 +70,7 @@ public class SpillingGrouper implements Grouper private final Comparator> defaultOrderKeyObjComparator; private final List files = Lists.newArrayList(); - private final List closeables = Lists.newArrayList(); + private final List dictionaryFiles = Lists.newArrayList(); private final boolean sortHasNonGroupingFields; private boolean spillingAllowed = false; @@ -164,7 +167,7 @@ public AggregateResult aggregate(KeyType key, int keyHash) { final AggregateResult result = grouper.aggregate(key, keyHash); - if (result.isOk() || temporaryStorage.maxSize() <= 0 || !spillingAllowed) { + if (result.isOk() || !spillingAllowed || temporaryStorage.maxSize() <= 0) { return result; } else { // Warning: this can potentially block up a processing thread for a while. @@ -197,71 +200,115 @@ public void close() deleteFiles(); } + /** + * Returns a dictionary of string keys added to this grouper. Note that the dictionary of keySerde is spilled on + * local storage whenever the inner grouper is spilled. If there are spilled dictionaries, this method loads them + * from disk and returns a merged dictionary. + * + * @return a dictionary which is a list of unique strings + */ + public List mergeAndGetDictionary() + { + final Set mergedDictionary = new HashSet<>(); + mergedDictionary.addAll(keySerde.getDictionary()); + + for (File dictFile : dictionaryFiles) { + try ( + final MappingIterator dictIterator = spillMapper.readValues( + spillMapper.getFactory().createParser(new LZ4BlockInputStream(new FileInputStream(dictFile))), + spillMapper.getTypeFactory().constructType(String.class) + ) + ) { + while (dictIterator.hasNext()) { + mergedDictionary.add(dictIterator.next()); + } + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + return new ArrayList<>(mergedDictionary); + } + public void setSpillingAllowed(final boolean spillingAllowed) { this.spillingAllowed = spillingAllowed; } @Override - public Iterator> iterator(final boolean sorted) + public CloseableIterator> iterator(final boolean sorted) { - final List>> iterators = new ArrayList<>(1 + files.size()); + final List>> iterators = new ArrayList<>(1 + files.size()); iterators.add(grouper.iterator(sorted)); + final Closer closer = Closer.create(); for (final File file : files) { final MappingIterator> fileIterator = read(file, keySerde.keyClazz()); iterators.add( - Iterators.transform( - fileIterator, - new Function, Entry>() - { - @Override - public Entry apply(Entry entry) - { - final Object[] deserializedValues = new Object[entry.getValues().length]; - for (int i = 0; i < deserializedValues.length; i++) { - deserializedValues[i] = aggregatorFactories[i].deserialize(entry.getValues()[i]); - if (deserializedValues[i] instanceof Integer) { - // Hack to satisfy the groupBy unit tests; perhaps we could do better by adjusting Jackson config. - deserializedValues[i] = ((Integer) deserializedValues[i]).longValue(); + CloseableIterators.withEmptyBaggage( + Iterators.transform( + fileIterator, + new Function, Entry>() + { + @Override + public Entry apply(Entry entry) + { + final Object[] deserializedValues = new Object[entry.getValues().length]; + for (int i = 0; i < deserializedValues.length; i++) { + deserializedValues[i] = aggregatorFactories[i].deserialize(entry.getValues()[i]); + if (deserializedValues[i] instanceof Integer) { + // Hack to satisfy the groupBy unit tests; perhaps we could do better by adjusting Jackson config. + deserializedValues[i] = ((Integer) deserializedValues[i]).longValue(); + } + } + return new Entry<>(entry.getKey(), deserializedValues); } } - return new Entry<>(entry.getKey(), deserializedValues); - } - } + ) ) ); - closeables.add(fileIterator); + closer.register(fileIterator); } + final Iterator> baseIterator; if (sortHasNonGroupingFields) { - return Groupers.mergeIterators(iterators, defaultOrderKeyObjComparator); + baseIterator = CloseableIterators.mergeSorted(iterators, defaultOrderKeyObjComparator); } else { - return Groupers.mergeIterators(iterators, sorted ? keyObjComparator : null); + baseIterator = sorted ? + CloseableIterators.mergeSorted(iterators, keyObjComparator) : + CloseableIterators.concat(iterators); } + + return CloseableIterators.wrap(baseIterator, closer); } private void spill() throws IOException { - final File outFile; + try (CloseableIterator> iterator = grouper.iterator(true)) { + files.add(spill(iterator)); + dictionaryFiles.add(spill(keySerde.getDictionary().iterator())); + + grouper.reset(); + } + } + private File spill(Iterator iterator) throws IOException + { try ( final LimitedTemporaryStorage.LimitedOutputStream out = temporaryStorage.createFile(); final LZ4BlockOutputStream compressedOut = new LZ4BlockOutputStream(out); final JsonGenerator jsonGenerator = spillMapper.getFactory().createGenerator(compressedOut) ) { - outFile = out.getFile(); - final Iterator> it = grouper.iterator(true); - while (it.hasNext()) { + while (iterator.hasNext()) { BaseQuery.checkInterrupted(); - jsonGenerator.writeObject(it.next()); + jsonGenerator.writeObject(iterator.next()); } - } - files.add(outFile); - grouper.reset(); + return out.getFile(); + } } private MappingIterator> read(final File file, final Class keyClazz) @@ -279,10 +326,6 @@ private MappingIterator> read(final File file, final Class implements Grouper +{ + private static final Logger LOG = new Logger(StreamingMergeSortedGrouper.class); + private static final long DEFAULT_TIMEOUT_NS = TimeUnit.SECONDS.toNanos(5); // default timeout for spinlock + + // Threashold time for spinlocks in increaseWriteIndex() and increaseReadIndex(). The waiting thread calls + // Thread.yield() after this threadhold time elapses. + private static final long SPIN_FOR_TIMEOUT_THRESHOLD_NS = 1000L; + + private final Supplier bufferSupplier; + private final KeySerde keySerde; + private final BufferAggregator[] aggregators; + private final int[] aggregatorOffsets; + private final int keySize; + private final int recordSize; // size of (key + all aggregates) + + // Timeout for the current query. + // The query must fail with a timeout exception if System.nanoTime() >= queryTimeoutAtNs. This is used in the + // spinlocks to prevent the writing thread from being blocked if the iterator of this grouper is not consumed due to + // some failures which potentially makes the whole system being paused. + private final long queryTimeoutAtNs; + private final boolean hasQueryTimeout; + + // Below variables are initialized when init() is called. + private ByteBuffer buffer; + private int maxNumSlots; + private boolean initialized; + + /** + * Indicate that this grouper consumed the last input or not. The writing thread must set this value to true by + * calling {@link #finish()} when it's done. This variable is always set by the writing thread and read by the + * reading thread. + */ + private volatile boolean finished; + + /** + * Current write index of the array. This points to the array slot where the aggregation is currently performed. Its + * initial value is -1 which means any data are not written yet. Since it's assumed that the input is sorted by the + * grouping key, this variable is moved to the next slot whenever a new grouping key is found. Once it reaches the + * last slot of the array, it moves to the first slot. + * + * This is always moved ahead of {@link #nextReadIndex}. If the array is full, this variable + * cannot be moved until {@link #nextReadIndex} is moved. See {@link #increaseWriteIndex()} for more details. This + * variable is always incremented by the writing thread and read by both the writing and the reading threads. + */ + private volatile int curWriteIndex; + + /** + * Next read index of the array. This points to the array slot which the reading thread will read next. Its initial + * value is -1 which means any data are not read yet. This variable can point an array slot only when the aggregation + * for that slot is finished. Once it reaches the last slot of the array, it moves to the first slot. + * + * This always follows {@link #curWriteIndex}. If the array is empty, this variable cannot be moved until the + * aggregation for at least one grouping key is finished which in turn {@link #curWriteIndex} is moved. See + * {@link #iterator()} for more details. This variable is always incremented by the reading thread and read by both + * the writing and the reading threads. + */ + private volatile int nextReadIndex; + + /** + * Returns the minimum buffer capacity required for this grouper. This grouper keeps track read/write indexes + * and they cannot point the same array slot at the same time. Since the read/write indexes move circularly, one + * extra slot is needed in addition to the read/write slots. Finally, the required minimum buffer capacity is + * 3 * record size. + * + * @return required minimum buffer capacity + */ + public static int requiredBufferCapacity( + KeySerde keySerde, + AggregatorFactory[] aggregatorFactories + ) + { + int recordSize = keySerde.keySize(); + for (AggregatorFactory aggregatorFactory : aggregatorFactories) { + recordSize += aggregatorFactory.getMaxIntermediateSize(); + } + return recordSize * 3; + } + + StreamingMergeSortedGrouper( + final Supplier bufferSupplier, + final KeySerde keySerde, + final ColumnSelectorFactory columnSelectorFactory, + final AggregatorFactory[] aggregatorFactories, + final long queryTimeoutAtMs + ) + { + this.bufferSupplier = bufferSupplier; + this.keySerde = keySerde; + this.aggregators = new BufferAggregator[aggregatorFactories.length]; + this.aggregatorOffsets = new int[aggregatorFactories.length]; + + this.keySize = keySerde.keySize(); + int offset = keySize; + for (int i = 0; i < aggregatorFactories.length; i++) { + aggregators[i] = aggregatorFactories[i].factorizeBuffered(columnSelectorFactory); + aggregatorOffsets[i] = offset; + offset += aggregatorFactories[i].getMaxIntermediateSize(); + } + this.recordSize = offset; + + // queryTimeoutAtMs comes from System.currentTimeMillis(), but we should use System.nanoTime() to check timeout in + // this class. See increaseWriteIndex() and increaseReadIndex(). + this.hasQueryTimeout = queryTimeoutAtMs != QueryContexts.NO_TIMEOUT; + final long timeoutNs = hasQueryTimeout ? + TimeUnit.MILLISECONDS.toNanos(queryTimeoutAtMs - System.currentTimeMillis()) : + QueryContexts.NO_TIMEOUT; + + this.queryTimeoutAtNs = System.nanoTime() + timeoutNs; + } + + @Override + public void init() + { + if (!initialized) { + buffer = bufferSupplier.get(); + maxNumSlots = buffer.capacity() / recordSize; + Preconditions.checkState( + maxNumSlots > 2, + "Buffer[%s] should be large enough to store at least three records[%s]", + buffer.capacity(), + recordSize + ); + + reset(); + initialized = true; + } + } + + @Override + public boolean isInitialized() + { + return initialized; + } + + @Override + public AggregateResult aggregate(KeyType key, int notUsed) + { + return aggregate(key); + } + + @Override + public AggregateResult aggregate(KeyType key) + { + try { + final ByteBuffer keyBuffer = keySerde.toByteBuffer(key); + + if (keyBuffer.remaining() != keySize) { + throw new IAE( + "keySerde.toByteBuffer(key).remaining[%s] != keySerde.keySize[%s], buffer was the wrong size?!", + keyBuffer.remaining(), + keySize + ); + } + + final int prevRecordOffset = curWriteIndex * recordSize; + if (curWriteIndex == -1 || !keyEquals(keyBuffer, buffer, prevRecordOffset)) { + // Initialize a new slot for the new key. This may be potentially blocked if the array is full until at least + // one slot becomes available. + initNewSlot(keyBuffer); + } + + final int curRecordOffset = curWriteIndex * recordSize; + for (int i = 0; i < aggregatorOffsets.length; i++) { + aggregators[i].aggregate(buffer, curRecordOffset + aggregatorOffsets[i]); + } + + return AggregateResult.ok(); + } + catch (RuntimeException e) { + finished = true; + throw e; + } + } + + /** + * Checks two keys contained in the given buffers are same. + * + * @param curKeyBuffer the buffer for the given key from {@link #aggregate(Object)} + * @param buffer the whole array buffer + * @param bufferOffset the key offset of the buffer + * + * @return true if the two buffers are same. + */ + private boolean keyEquals(ByteBuffer curKeyBuffer, ByteBuffer buffer, int bufferOffset) + { + // Since this method is frequently called per each input row, the compare performance matters. + int i = 0; + for (; i + Long.BYTES <= keySize; i += Long.BYTES) { + if (curKeyBuffer.getLong(i) != buffer.getLong(bufferOffset + i)) { + return false; + } + } + + if (i + Integer.BYTES <= keySize) { + // This can be called at most once because we already compared using getLong() in the above. + if (curKeyBuffer.getInt(i) != buffer.getInt(bufferOffset + i)) { + return false; + } + i += Integer.BYTES; + } + + for (; i < keySize; i++) { + if (curKeyBuffer.get(i) != buffer.get(bufferOffset + i)) { + return false; + } + } + + return true; + } + + /** + * Initialize a new slot for a new grouping key. This may be potentially blocked if the array is full until at least + * one slot becomes available. + */ + private void initNewSlot(ByteBuffer newKey) + { + // Wait if the array is full and increase curWriteIndex + increaseWriteIndex(); + + final int recordOffset = recordSize * curWriteIndex; + buffer.position(recordOffset); + buffer.put(newKey); + + for (int i = 0; i < aggregators.length; i++) { + aggregators[i].init(buffer, recordOffset + aggregatorOffsets[i]); + } + } + + /** + * Wait for {@link #nextReadIndex} to be moved if necessary and move {@link #curWriteIndex}. + */ + private void increaseWriteIndex() + { + final long startAtNs = System.nanoTime(); + final long queryTimeoutAtNs = getQueryTimeoutAtNs(startAtNs); + final long spinTimeoutAtNs = startAtNs + SPIN_FOR_TIMEOUT_THRESHOLD_NS; + long timeoutNs = queryTimeoutAtNs - startAtNs; + long spinTimeoutNs = SPIN_FOR_TIMEOUT_THRESHOLD_NS; + + // In the below, we check that the array is full and wait for at least one slot to become available. + // + // nextReadIndex is a volatile variable and the changes on it are continuously checked until they are seen in + // the while loop. See the following links. + // * http://docs.oracle.com/javase/specs/jls/se7/html/jls-8.html#jls-8.3.1.4 + // * http://docs.oracle.com/javase/specs/jls/se7/html/jls-17.html#jls-17.4.5 + // * https://stackoverflow.com/questions/11761552/detailed-semantics-of-volatile-regarding-timeliness-of-visibility + + if (curWriteIndex == maxNumSlots - 1) { + // We additionally check that nextReadIndex is -1 here because the writing thread should wait for the reading + // thread to start reading only when the writing thread tries to overwrite the first slot for the first time. + + // The below condition is checked in a while loop instead of using a lock to avoid frequent thread park. + while ((nextReadIndex == -1 || nextReadIndex == 0) && !Thread.currentThread().isInterrupted()) { + if (timeoutNs <= 0L) { + throw new RuntimeException(new TimeoutException()); + } + // Thread.yield() should not be called from the very beginning + if (spinTimeoutNs <= 0L) { + Thread.yield(); + } + long now = System.nanoTime(); + timeoutNs = queryTimeoutAtNs - now; + spinTimeoutNs = spinTimeoutAtNs - now; + } + + // Changes on nextReadIndex happens-before changing curWriteIndex. + curWriteIndex = 0; + } else { + final int nextWriteIndex = curWriteIndex + 1; + + // The below condition is checked in a while loop instead of using a lock to avoid frequent thread park. + while ((nextWriteIndex == nextReadIndex) && !Thread.currentThread().isInterrupted()) { + if (timeoutNs <= 0L) { + throw new RuntimeException(new TimeoutException()); + } + // Thread.yield() should not be called from the very beginning + if (spinTimeoutNs <= 0L) { + Thread.yield(); + } + long now = System.nanoTime(); + timeoutNs = queryTimeoutAtNs - now; + spinTimeoutNs = spinTimeoutAtNs - now; + } + + // Changes on nextReadIndex happens-before changing curWriteIndex. + curWriteIndex = nextWriteIndex; + } + } + + @Override + public void reset() + { + curWriteIndex = -1; + nextReadIndex = -1; + finished = false; + } + + @Override + public void close() + { + for (BufferAggregator aggregator : aggregators) { + try { + aggregator.close(); + } + catch (Exception e) { + LOG.warn(e, "Could not close aggregator [%s], skipping.", aggregator); + } + } + } + + /** + * Signal that no more inputs are added. Must be called after {@link #aggregate(Object)} is called for the last input. + */ + public void finish() + { + increaseWriteIndex(); + // Once finished is set, curWriteIndex must not be changed. This guarantees that the remaining number of items in + // the array is always decreased as the reading thread proceeds. See hasNext() and remaining() below. + finished = true; + } + + /** + * Return a sorted iterator. This method can be called safely while writing, and the iterating thread and the writing + * thread can be different. The result iterator always returns sorted results. This method should be called only one + * time per grouper. + * + * @return a sorted iterator + */ + public CloseableIterator> iterator() + { + if (!initialized) { + throw new ISE("Grouper should be initialized first"); + } + + return new CloseableIterator>() + { + { + // Wait for some data to be ready and initialize nextReadIndex. + increaseReadIndexTo(0); + } + + @Override + public boolean hasNext() + { + // If setting finished happens-before the below check, curWriteIndex isn't changed anymore and thus remainig() + // can be computed safely because nextReadIndex is changed only by the reading thread. + // Otherwise, hasNext() always returns true. + // + // The below line can be executed between increasing curWriteIndex and setting finished in + // StreamingMergeSortedGrouper.finish(), but it is also a valid case because there should be at least one slot + // which is not read yet before finished is set. + return !finished || remaining() > 0; + } + + /** + * Calculate the number of remaining items in the array. Must be called only when + * {@link StreamingMergeSortedGrouper#finished} is true. + * + * @return the number of remaining items + */ + private int remaining() + { + if (curWriteIndex >= nextReadIndex) { + return curWriteIndex - nextReadIndex; + } else { + return (maxNumSlots - nextReadIndex) + curWriteIndex; + } + } + + @Override + public Entry next() + { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + // Here, nextReadIndex should be valid which means: + // - a valid array index which should be >= 0 and < maxNumSlots + // - an index of the array slot where the aggregation for the corresponding grouping key is done + // - an index of the array slot which is not read yet + final int recordOffset = recordSize * nextReadIndex; + final KeyType key = keySerde.fromByteBuffer(buffer, recordOffset); + + final Object[] values = new Object[aggregators.length]; + for (int i = 0; i < aggregators.length; i++) { + values[i] = aggregators[i].get(buffer, recordOffset + aggregatorOffsets[i]); + } + + final int targetIndex = nextReadIndex == maxNumSlots - 1 ? 0 : nextReadIndex + 1; + // Wait if the array is empty until at least one slot becomes available for read, and then increase + // nextReadIndex. + increaseReadIndexTo(targetIndex); + + return new Entry<>(key, values); + } + + /** + * Wait for {@link StreamingMergeSortedGrouper#curWriteIndex} to be moved if necessary and move + * {@link StreamingMergeSortedGrouper#nextReadIndex}. + * + * @param target the target index {@link StreamingMergeSortedGrouper#nextReadIndex} will move to + */ + private void increaseReadIndexTo(int target) + { + // Check that the array is empty and wait for at least one slot to become available. + // + // curWriteIndex is a volatile variable and the changes on it are continuously checked until they are seen in + // the while loop. See the following links. + // * http://docs.oracle.com/javase/specs/jls/se7/html/jls-8.html#jls-8.3.1.4 + // * http://docs.oracle.com/javase/specs/jls/se7/html/jls-17.html#jls-17.4.5 + // * https://stackoverflow.com/questions/11761552/detailed-semantics-of-volatile-regarding-timeliness-of-visibility + + final long startAtNs = System.nanoTime(); + final long queryTimeoutAtNs = getQueryTimeoutAtNs(startAtNs); + final long spinTimeoutAtNs = startAtNs + SPIN_FOR_TIMEOUT_THRESHOLD_NS; + long timeoutNs = queryTimeoutAtNs - startAtNs; + long spinTimeoutNs = SPIN_FOR_TIMEOUT_THRESHOLD_NS; + + // The below condition is checked in a while loop instead of using a lock to avoid frequent thread park. + while ((curWriteIndex == -1 || target == curWriteIndex) && + !finished && !Thread.currentThread().isInterrupted()) { + if (timeoutNs <= 0L) { + throw new RuntimeException(new TimeoutException()); + } + // Thread.yield() should not be called from the very beginning + if (spinTimeoutNs <= 0L) { + Thread.yield(); + } + long now = System.nanoTime(); + timeoutNs = queryTimeoutAtNs - now; + spinTimeoutNs = spinTimeoutAtNs - now; + } + + // Changes on curWriteIndex happens-before changing nextReadIndex. + nextReadIndex = target; + } + + @Override + public void close() throws IOException + { + // do nothing + } + }; + } + + private long getQueryTimeoutAtNs(long startAtNs) + { + return hasQueryTimeout ? queryTimeoutAtNs : startAtNs + DEFAULT_TIMEOUT_NS; + } + + /** + * Return a sorted iterator. This method can be called safely while writing and iterating thread and writing thread + * can be different. The result iterator always returns sorted results. This method should be called only one time + * per grouper. + * + * @param sorted not used + * + * @return a sorted iterator + */ + @Override + public CloseableIterator> iterator(boolean sorted) + { + return iterator(); + } +} diff --git a/processing/src/main/java/io/druid/query/groupby/strategy/GroupByStrategyV2.java b/processing/src/main/java/io/druid/query/groupby/strategy/GroupByStrategyV2.java index 40a902cea271..5ad75004b760 100644 --- a/processing/src/main/java/io/druid/query/groupby/strategy/GroupByStrategyV2.java +++ b/processing/src/main/java/io/druid/query/groupby/strategy/GroupByStrategyV2.java @@ -338,6 +338,7 @@ public QueryRunner mergeRunners( queryWatcher, queryRunners, processingConfig.getNumThreads(), + bufferPool, mergeBufferPool, processingConfig.intermediateComputeSizeBytes(), spillMapper, diff --git a/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java b/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java index dcb64932e9f5..b68c2c954735 100644 --- a/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java +++ b/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java @@ -296,6 +296,26 @@ public String toString() return "v2SmallDictionary"; } }; + final GroupByQueryConfig v2ParallelCombineConfig = new GroupByQueryConfig() + { + @Override + public String getDefaultStrategy() + { + return GroupByStrategySelector.STRATEGY_V2; + } + + @Override + public int getNumParallelCombineThreads() + { + return DEFAULT_PROCESSING_CONFIG.getNumThreads(); + } + + @Override + public String toString() + { + return "v2ParallelCombine"; + } + }; v1Config.setMaxIntermediateRows(10000); v1SingleThreadedConfig.setMaxIntermediateRows(10000); @@ -305,7 +325,8 @@ public String toString() v1SingleThreadedConfig, v2Config, v2SmallBufferConfig, - v2SmallDictionaryConfig + v2SmallDictionaryConfig, + v2ParallelCombineConfig ); } diff --git a/processing/src/test/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouperTest.java b/processing/src/test/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouperTest.java index 9d3e5f2dc9db..0d0137b2e5f3 100644 --- a/processing/src/test/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouperTest.java +++ b/processing/src/test/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouperTest.java @@ -20,14 +20,19 @@ package io.druid.query.groupby.epinephelinae; import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import com.google.common.primitives.Longs; import com.google.common.util.concurrent.MoreExecutors; -import io.druid.java.util.common.concurrent.Execs; +import io.druid.collections.ResourceHolder; +import io.druid.jackson.DefaultObjectMapper; import io.druid.java.util.common.IAE; +import io.druid.java.util.common.parsers.CloseableIterator; import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.CountAggregatorFactory; import io.druid.query.dimension.DimensionSpec; import io.druid.query.groupby.epinephelinae.Grouper.BufferComparator; +import io.druid.query.groupby.epinephelinae.Grouper.Entry; import io.druid.query.groupby.epinephelinae.Grouper.KeySerde; import io.druid.query.groupby.epinephelinae.Grouper.KeySerdeFactory; import io.druid.segment.ColumnSelectorFactory; @@ -35,44 +40,171 @@ import io.druid.segment.DimensionSelector; import io.druid.segment.column.ColumnCapabilities; import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; import java.util.Comparator; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicBoolean; +@RunWith(Parameterized.class) public class ConcurrentGrouperTest { - private static final ExecutorService service = Executors.newFixedThreadPool(8); - private static final int BYTE_BUFFER_SIZE = 192; + private static final ExecutorService SERVICE = Executors.newFixedThreadPool(8); + private static final TestResourceHolder TEST_RESOURCE_HOLDER = new TestResourceHolder(256); + private static final KeySerdeFactory KEY_SERDE_FACTORY = new TestKeySerdeFactory(); + private static final Supplier> COMBINE_BUFFER_SUPPLIER = new TestBufferSupplier(); + private static final ColumnSelectorFactory NULL_FACTORY = new TestColumnSelectorFactory(); + + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private Supplier bufferSupplier; + + @Parameters(name = "bufferSize={0}") + public static Collection constructorFeeder() + { + return ImmutableList.of( + new Object[]{1024 * 32}, + new Object[]{1024 * 1024} + ); + } @AfterClass public static void teardown() { - service.shutdown(); + SERVICE.shutdown(); } - private static final Supplier bufferSupplier = new Supplier() + public ConcurrentGrouperTest(int bufferSize) { - private final AtomicBoolean called = new AtomicBoolean(false); + bufferSupplier = new Supplier() + { + private final AtomicBoolean called = new AtomicBoolean(false); + private ByteBuffer buffer; + + @Override + public ByteBuffer get() + { + if (called.compareAndSet(false, true)) { + buffer = ByteBuffer.allocate(bufferSize); + } + + return buffer; + } + }; + } + + @Test() + public void testAggregate() throws InterruptedException, ExecutionException, IOException + { + final ConcurrentGrouper grouper = new ConcurrentGrouper<>( + bufferSupplier, + COMBINE_BUFFER_SUPPLIER, + KEY_SERDE_FACTORY, + KEY_SERDE_FACTORY, + NULL_FACTORY, + new AggregatorFactory[]{new CountAggregatorFactory("cnt")}, + 1024, + 0.7f, + 1, + new LimitedTemporaryStorage(temporaryFolder.newFolder(), 1024 * 1024), + new DefaultObjectMapper(), + 8, + null, + false, + MoreExecutors.listeningDecorator(SERVICE), + 0, + false, + 0, + 4, + 8 + ); + grouper.init(); + + final int numRows = 1000; + + Future[] futures = new Future[8]; + + for (int i = 0; i < 8; i++) { + futures[i] = SERVICE.submit(new Runnable() + { + @Override + public void run() + { + for (long i = 0; i < numRows; i++) { + grouper.aggregate(i); + } + } + }); + } + + for (Future eachFuture : futures) { + eachFuture.get(); + } + + final CloseableIterator> iterator = grouper.iterator(true); + final List> actual = Lists.newArrayList(iterator); + iterator.close(); + + Assert.assertTrue(!TEST_RESOURCE_HOLDER.taken || TEST_RESOURCE_HOLDER.closed); + + final List> expected = new ArrayList<>(); + for (long i = 0; i < numRows; i++) { + expected.add(new Entry<>(i, new Object[]{8L})); + } + + Assert.assertEquals(expected, actual); + + grouper.close(); + } + + static class TestResourceHolder implements ResourceHolder + { + private boolean taken; + private boolean closed; + private ByteBuffer buffer; + + TestResourceHolder(int bufferSize) + { + buffer = ByteBuffer.allocate(bufferSize); + } @Override public ByteBuffer get() { - if (called.compareAndSet(false, true)) { - return ByteBuffer.allocate(BYTE_BUFFER_SIZE); - } else { - throw new IAE("should be called once"); - } + taken = true; + return buffer; + } + + @Override + public void close() + { + closed = true; } - }; + } - private static final KeySerdeFactory keySerdeFactory = new KeySerdeFactory() + static class TestKeySerdeFactory implements KeySerdeFactory { + @Override + public long getMaxDictionarySize() + { + return 0; + } + @Override public KeySerde factorize() { @@ -92,6 +224,12 @@ public Class keyClazz() return Long.class; } + @Override + public List getDictionary() + { + return ImmutableList.of(); + } + @Override public ByteBuffer toByteBuffer(Long key) { @@ -134,6 +272,12 @@ public void reset() {} }; } + @Override + public KeySerde factorizeWithDictionary(List dictionary) + { + return factorize(); + } + @Override public Comparator> objectComparator(boolean forceDefaultOrder) { @@ -146,9 +290,24 @@ public int compare(Grouper.Entry o1, Grouper.Entry o2) } }; } - }; + } + + private static class TestBufferSupplier implements Supplier> + { + private final AtomicBoolean called = new AtomicBoolean(false); + + @Override + public ResourceHolder get() + { + if (called.compareAndSet(false, true)) { + return TEST_RESOURCE_HOLDER; + } else { + throw new IAE("should be called once"); + } + } + } - private static final ColumnSelectorFactory null_factory = new ColumnSelectorFactory() + private static class TestColumnSelectorFactory implements ColumnSelectorFactory { @Override public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec) @@ -167,51 +326,5 @@ public ColumnCapabilities getColumnCapabilities(String columnName) { return null; } - }; - - @Test(timeout = 5000L) - public void testAggregate() throws InterruptedException, ExecutionException - { - final ConcurrentGrouper grouper = new ConcurrentGrouper<>( - bufferSupplier, - keySerdeFactory, - null_factory, - new AggregatorFactory[]{new CountAggregatorFactory("cnt")}, - 24, - 0.7f, - 1, - null, - null, - 8, - null, - false, - MoreExecutors.listeningDecorator(Execs.multiThreaded(4, "concurrent-grouper-test-%d")), - 0, - false, - 0, - BYTE_BUFFER_SIZE - ); - - Future[] futures = new Future[8]; - - for (int i = 0; i < 8; i++) { - futures[i] = service.submit(new Runnable() - { - @Override - public void run() - { - grouper.init(); - for (long i = 0; i < 100; i++) { - grouper.aggregate(0L); - } - } - }); - } - - for (Future eachFuture : futures) { - eachFuture.get(); - } - - grouper.close(); } } diff --git a/processing/src/test/java/io/druid/query/groupby/epinephelinae/IntKeySerde.java b/processing/src/test/java/io/druid/query/groupby/epinephelinae/IntKeySerde.java index 8f017caff953..387e14d8406a 100644 --- a/processing/src/test/java/io/druid/query/groupby/epinephelinae/IntKeySerde.java +++ b/processing/src/test/java/io/druid/query/groupby/epinephelinae/IntKeySerde.java @@ -19,11 +19,13 @@ package io.druid.query.groupby.epinephelinae; +import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import io.druid.query.aggregation.AggregatorFactory; import java.nio.ByteBuffer; import java.util.Comparator; +import java.util.List; public class IntKeySerde implements Grouper.KeySerde { @@ -66,6 +68,12 @@ public Class keyClazz() return Integer.class; } + @Override + public List getDictionary() + { + return ImmutableList.of(); + } + @Override public ByteBuffer toByteBuffer(Integer key) { diff --git a/processing/src/test/java/io/druid/query/groupby/epinephelinae/ParallelCombinerTest.java b/processing/src/test/java/io/druid/query/groupby/epinephelinae/ParallelCombinerTest.java new file mode 100644 index 000000000000..9e7f00376598 --- /dev/null +++ b/processing/src/test/java/io/druid/query/groupby/epinephelinae/ParallelCombinerTest.java @@ -0,0 +1,147 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.base.Supplier; +import com.google.common.util.concurrent.MoreExecutors; +import io.druid.collections.ResourceHolder; +import io.druid.java.util.common.IAE; +import io.druid.java.util.common.concurrent.Execs; +import io.druid.java.util.common.parsers.CloseableIterator; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.CountAggregatorFactory; +import io.druid.query.groupby.epinephelinae.ConcurrentGrouperTest.TestKeySerdeFactory; +import io.druid.query.groupby.epinephelinae.ConcurrentGrouperTest.TestResourceHolder; +import io.druid.query.groupby.epinephelinae.Grouper.Entry; +import io.druid.query.groupby.epinephelinae.Grouper.KeySerdeFactory; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; + +public class ParallelCombinerTest +{ + private static final int THREAD_NUM = 8; + private static final ExecutorService SERVICE = Execs.multiThreaded(THREAD_NUM, "parallel-combiner-test-%d"); + private static final TestResourceHolder TEST_RESOURCE_HOLDER = new TestResourceHolder(512); + private static final KeySerdeFactory KEY_SERDE_FACTORY = new TestKeySerdeFactory(); + + private static final Supplier> COMBINE_BUFFER_SUPPLIER = + new Supplier>() + { + private final AtomicBoolean called = new AtomicBoolean(false); + + @Override + public ResourceHolder get() + { + if (called.compareAndSet(false, true)) { + return TEST_RESOURCE_HOLDER; + } else { + throw new IAE("should be called once"); + } + } + }; + + private static final class TestIterator implements CloseableIterator> + { + private final Iterator> innerIterator; + private boolean closed; + + TestIterator(Iterator> innerIterator) + { + this.innerIterator = innerIterator; + } + + @Override + public boolean hasNext() + { + return innerIterator.hasNext(); + } + + @Override + public Entry next() + { + return innerIterator.next(); + } + + public boolean isClosed() + { + return closed; + } + + @Override + public void close() throws IOException + { + if (!closed) { + closed = true; + } + } + } + + @AfterClass + public static void teardown() + { + SERVICE.shutdownNow(); + } + + @Test + public void testCombine() throws IOException + { + final ParallelCombiner combiner = new ParallelCombiner<>( + COMBINE_BUFFER_SUPPLIER, + new AggregatorFactory[]{new CountAggregatorFactory("cnt").getCombiningFactory()}, + KEY_SERDE_FACTORY, + MoreExecutors.listeningDecorator(SERVICE), + false, + THREAD_NUM, + 0, // default priority + 0, // default timeout + 4 + ); + + final int numRows = 1000; + final List> baseIterator = new ArrayList<>(numRows); + for (long i = 0; i < numRows; i++) { + baseIterator.add(new Entry<>(i, new Object[]{i * 10})); + } + + final int leafNum = 8; + final List iterators = new ArrayList<>(leafNum); + for (int i = 0; i < leafNum; i++) { + iterators.add(new TestIterator(baseIterator.iterator())); + } + + try (final CloseableIterator> iterator = combiner.combine(iterators, new ArrayList<>())) { + long expectedKey = 0; + while (iterator.hasNext()) { + Assert.assertEquals(new Entry<>(expectedKey, new Object[]{expectedKey++ * leafNum * 10}), iterator.next()); + } + } + + iterators.forEach(it -> Assert.assertTrue(it.isClosed())); + } +} diff --git a/processing/src/test/java/io/druid/query/groupby/epinephelinae/StreamingMergeSortedGrouperTest.java b/processing/src/test/java/io/druid/query/groupby/epinephelinae/StreamingMergeSortedGrouperTest.java new file mode 100644 index 000000000000..e66cd77b0ac5 --- /dev/null +++ b/processing/src/test/java/io/druid/query/groupby/epinephelinae/StreamingMergeSortedGrouperTest.java @@ -0,0 +1,188 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.collect.Ordering; +import com.google.common.primitives.Ints; +import io.druid.data.input.MapBasedRow; +import io.druid.java.util.common.concurrent.Execs; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.CountAggregatorFactory; +import io.druid.query.aggregation.LongSumAggregatorFactory; +import io.druid.query.groupby.epinephelinae.Grouper.Entry; +import org.hamcrest.CoreMatchers; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeoutException; + +public class StreamingMergeSortedGrouperTest +{ + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testAggregate() + { + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + final StreamingMergeSortedGrouper grouper = newGrouper(columnSelectorFactory, 1024); + + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("value", 10L))); + grouper.aggregate(6); + grouper.aggregate(6); + grouper.aggregate(6); + grouper.aggregate(10); + grouper.aggregate(12); + grouper.aggregate(12); + + grouper.finish(); + + final List> expected = ImmutableList.of( + new Grouper.Entry<>(6, new Object[]{30L, 3L}), + new Grouper.Entry<>(10, new Object[]{10L, 1L}), + new Grouper.Entry<>(12, new Object[]{20L, 2L}) + ); + final List> unsortedEntries = Lists.newArrayList(grouper.iterator(true)); + + Assert.assertEquals( + expected, + unsortedEntries + ); + } + + @Test(timeout = 5000L) + public void testEmptyIterator() + { + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + final StreamingMergeSortedGrouper grouper = newGrouper(columnSelectorFactory, 1024); + + grouper.finish(); + + Assert.assertTrue(!grouper.iterator(true).hasNext()); + } + + @Test(timeout = 5000L) + public void testStreamingAggregateWithLargeBuffer() throws ExecutionException, InterruptedException + { + testStreamingAggregate(1024); + } + + @Test(timeout = 5000L) + public void testStreamingAggregateWithMinimumBuffer() throws ExecutionException, InterruptedException + { + testStreamingAggregate(60); + } + + private void testStreamingAggregate(int bufferSize) throws ExecutionException, InterruptedException + { + final ExecutorService exec = Execs.multiThreaded(2, "merge-sorted-grouper-test-%d"); + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + final StreamingMergeSortedGrouper grouper = newGrouper(columnSelectorFactory, bufferSize); + + final List> expected = new ArrayList<>(1024); + for (int i = 0; i < 1024; i++) { + expected.add(new Entry<>(i, new Object[]{100L, 10L})); + } + + try { + final Future future = exec.submit(() -> { + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("value", 10L))); + + for (int i = 0; i < 1024; i++) { + for (int j = 0; j < 10; j++) { + grouper.aggregate(i); + } + } + + grouper.finish(); + }); + + final List> unsortedEntries = Lists.newArrayList(grouper.iterator(true)); + final List> actual = Ordering.from((Comparator>) (o1, o2) -> Ints.compare(o1.getKey(), o2.getKey())) + .sortedCopy(unsortedEntries); + + if (!actual.equals(expected)) { + future.get(); // Check there is an exception occured + Assert.fail(); + } + } + finally { + exec.shutdownNow(); + } + } + + @Test + public void testNotEnoughBuffer() + { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Buffer[50] should be large enough to store at least three records[20]"); + + newGrouper(GrouperTestUtil.newColumnSelectorFactory(), 50); + } + + @Test + public void testTimeout() + { + expectedException.expect(RuntimeException.class); + expectedException.expectCause(CoreMatchers.instanceOf(TimeoutException.class)); + + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + final StreamingMergeSortedGrouper grouper = newGrouper(columnSelectorFactory, 60); + + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("value", 10L))); + grouper.aggregate(6); + + grouper.iterator(); + } + + private StreamingMergeSortedGrouper newGrouper( + TestColumnSelectorFactory columnSelectorFactory, + int bufferSize + ) + { + final ByteBuffer buffer = ByteBuffer.allocate(bufferSize); + + final StreamingMergeSortedGrouper grouper = new StreamingMergeSortedGrouper<>( + Suppliers.ofInstance(buffer), + GrouperTestUtil.intKeySerde(), + columnSelectorFactory, + new AggregatorFactory[]{ + new LongSumAggregatorFactory("valueSum", "value"), + new CountAggregatorFactory("count") + }, + System.currentTimeMillis() + 1000L + ); + grouper.init(); + return grouper; + } +}