diff --git a/docs/content/querying/groupbyquery.md b/docs/content/querying/groupbyquery.md index 82f2d81a39e1..32bbbf42492c 100644 --- a/docs/content/querying/groupbyquery.md +++ b/docs/content/querying/groupbyquery.md @@ -245,6 +245,7 @@ When using the "v2" strategy, the following query context parameters apply: |`maxMergingDictionarySize`|Can be used to lower the value of `druid.query.groupBy.maxMergingDictionarySize` for this query.| |`maxOnDiskStorage`|Can be used to lower the value of `druid.query.groupBy.maxOnDiskStorage` for this query.| |`sortByDimsFirst`|Sort the results first by dimension values and then by timestamp.| +|`forcePushDownLimit`|When all fields in the orderby are part of the grouping key, the broker will push limit application down to the historical nodes. When the sorting order uses fields that are not in the grouping key, applying this optimization can result in approximate results with unknown accuracy, so this optimization is disabled by default in that case. Enabling this context flag turns on limit push down for limit/orderbys that contain non-grouping key columns.| When using the "v1" strategy, the following query context parameters apply: diff --git a/processing/src/main/java/io/druid/query/groupby/GroupByQuery.java b/processing/src/main/java/io/druid/query/groupby/GroupByQuery.java index ba77ef9f6e2b..893bdd13d26f 100644 --- a/processing/src/main/java/io/druid/query/groupby/GroupByQuery.java +++ b/processing/src/main/java/io/druid/query/groupby/GroupByQuery.java @@ -25,6 +25,7 @@ import com.google.common.base.Function; import com.google.common.base.Functions; import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Ordering; @@ -53,6 +54,9 @@ import io.druid.query.groupby.orderby.LimitSpec; import io.druid.query.groupby.orderby.NoopLimitSpec; import io.druid.query.groupby.orderby.OrderByColumnSpec; +import io.druid.query.groupby.strategy.GroupByStrategyV2; +import io.druid.query.ordering.StringComparator; +import io.druid.query.ordering.StringComparators; import io.druid.query.spec.LegacySegmentSpec; import io.druid.query.spec.QuerySegmentSpec; import io.druid.segment.VirtualColumn; @@ -65,6 +69,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -97,6 +102,8 @@ public static Builder builder() private final List aggregatorSpecs; private final List postAggregatorSpecs; + private final Function, Sequence> limitFn; + private final boolean applyLimitPushDown; private final Function, Sequence> postProcessingFn; @JsonCreator @@ -190,6 +197,45 @@ private GroupByQuery( verifyOutputNames(this.dimensions, this.aggregatorSpecs, this.postAggregatorSpecs); this.postProcessingFn = postProcessingFn != null ? postProcessingFn : makePostProcessingFn(); + + // Check if limit push down configuration is valid and check if limit push down will be applied + this.applyLimitPushDown = determineApplyLimitPushDown(); + + // On an inner query, we may sometimes get a LimitSpec so that row orderings can be determined for limit push down + // However, it's not necessary to build the real limitFn from it at this stage. + Function, Sequence> postProcFn; + if (getContextBoolean(GroupByStrategyV2.CTX_KEY_OUTERMOST, true)) { + postProcFn = this.limitSpec.build(this.dimensions, this.aggregatorSpecs, this.postAggregatorSpecs); + } else { + postProcFn = NoopLimitSpec.INSTANCE.build(this.dimensions, this.aggregatorSpecs, this.postAggregatorSpecs); + } + + if (havingSpec != null) { + postProcFn = Functions.compose( + postProcFn, + new Function, Sequence>() + { + @Override + public Sequence apply(Sequence input) + { + GroupByQuery.this.havingSpec.setRowSignature(GroupByQueryHelper.rowSignatureFor(GroupByQuery.this)); + return Sequences.filter( + input, + new Predicate() + { + @Override + public boolean apply(Row input) + { + return GroupByQuery.this.havingSpec.eval(input); + } + } + ); + } + } + ); + } + + limitFn = postProcFn; } @JsonProperty @@ -264,6 +310,12 @@ public boolean getContextSortByDimsFirst() return getContextBoolean(CTX_KEY_SORT_BY_DIMS_FIRST, false); } + @JsonIgnore + public boolean isApplyLimitPushDown() + { + return applyLimitPushDown; + } + @Override public Ordering getResultOrdering() { @@ -281,10 +333,177 @@ public Ordering getResultOrdering() ); } - public Ordering getRowOrdering(final boolean granular) + private boolean validateAndGetForceLimitPushDown() + { + final boolean forcePushDown = getContextBoolean(GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, false); + if (forcePushDown) { + if (!(limitSpec instanceof DefaultLimitSpec)) { + throw new IAE("When forcing limit push down, a limit spec must be provided."); + } + + if (((DefaultLimitSpec) limitSpec).getLimit() == Integer.MAX_VALUE) { + throw new IAE("When forcing limit push down, the provided limit spec must have a limit."); + } + + for (OrderByColumnSpec orderBySpec : ((DefaultLimitSpec) limitSpec).getColumns()) { + if (OrderByColumnSpec.getPostAggIndexForOrderBy(orderBySpec, postAggregatorSpecs) > -1) { + throw new UnsupportedOperationException("Limit push down when sorting by a post aggregator is not supported."); + } + } + } + return forcePushDown; + } + + public boolean determineApplyLimitPushDown() + { + final boolean forceLimitPushDown = validateAndGetForceLimitPushDown(); + + if (limitSpec instanceof DefaultLimitSpec) { + DefaultLimitSpec defaultLimitSpec = (DefaultLimitSpec) limitSpec; + + // If only applying an orderby without a limit, don't try to push down + if (defaultLimitSpec.getLimit() == Integer.MAX_VALUE) { + return false; + } + + if (forceLimitPushDown) { + return true; + } + + // If the sorting order only uses columns in the grouping key, we can always push the limit down + // to the buffer grouper without affecting result accuracy + boolean sortHasNonGroupingFields = DefaultLimitSpec.sortingOrderHasNonGroupingFields( + (DefaultLimitSpec) limitSpec, + getDimensions() + ); + + return !sortHasNonGroupingFields; + } + + return false; + } + + /** + * When limit push down is applied, the partial results would be sorted by the ordering specified by the + * limit/order spec (unlike non-push down case where the results always use the default natural ascending order), + * so when merging these partial result streams, the merge needs to use the same ordering to get correct results. + */ + private Ordering getRowOrderingForPushDown( + final boolean granular, + final DefaultLimitSpec limitSpec + ) { final boolean sortByDimsFirst = getContextSortByDimsFirst(); + final List orderedFieldNames = new ArrayList<>(); + final Set dimsInOrderBy = new HashSet<>(); + final List needsReverseList = new ArrayList<>(); + final List isNumericField = new ArrayList<>(); + final List comparators = new ArrayList<>(); + + for (OrderByColumnSpec orderSpec : limitSpec.getColumns()) { + boolean needsReverse = orderSpec.getDirection() != OrderByColumnSpec.Direction.ASCENDING; + int dimIndex = OrderByColumnSpec.getDimIndexForOrderBy(orderSpec, dimensions); + if (dimIndex >= 0) { + DimensionSpec dim = dimensions.get(dimIndex); + orderedFieldNames.add(dim.getOutputName()); + dimsInOrderBy.add(dimIndex); + needsReverseList.add(needsReverse); + final ValueType type = dimensions.get(dimIndex).getOutputType(); + isNumericField.add(type == ValueType.LONG || type == ValueType.FLOAT); + comparators.add(orderSpec.getDimensionComparator()); + } + } + + for (int i = 0; i < dimensions.size(); i++) { + if (!dimsInOrderBy.contains(i)) { + orderedFieldNames.add(dimensions.get(i).getOutputName()); + needsReverseList.add(false); + final ValueType type = dimensions.get(i).getOutputType(); + isNumericField.add(type == ValueType.LONG || type == ValueType.FLOAT); + comparators.add(StringComparators.LEXICOGRAPHIC); + } + } + + final Comparator timeComparator = getTimeComparator(granular); + + if (timeComparator == null) { + return Ordering.from( + new Comparator() + { + @Override + public int compare(Row lhs, Row rhs) + { + return compareDimsForLimitPushDown( + orderedFieldNames, + needsReverseList, + isNumericField, + comparators, + lhs, + rhs + ); + } + } + ); + } else if (sortByDimsFirst) { + return Ordering.from( + new Comparator() + { + @Override + public int compare(Row lhs, Row rhs) + { + final int cmp = compareDimsForLimitPushDown( + orderedFieldNames, + needsReverseList, + isNumericField, + comparators, + lhs, + rhs + ); + if (cmp != 0) { + return cmp; + } + + return timeComparator.compare(lhs, rhs); + } + } + ); + } else { + return Ordering.from( + new Comparator() + { + @Override + public int compare(Row lhs, Row rhs) + { + final int timeCompare = timeComparator.compare(lhs, rhs); + + if (timeCompare != 0) { + return timeCompare; + } + + return compareDimsForLimitPushDown( + orderedFieldNames, + needsReverseList, + isNumericField, + comparators, + lhs, + rhs + ); + } + } + ); + } + } + + public Ordering getRowOrdering(final boolean granular) + { + if (applyLimitPushDown) { + if (!DefaultLimitSpec.sortingOrderHasNonGroupingFields((DefaultLimitSpec) limitSpec, dimensions)) { + return getRowOrderingForPushDown(granular, (DefaultLimitSpec) limitSpec); + } + } + + final boolean sortByDimsFirst = getContextSortByDimsFirst(); final Comparator timeComparator = getTimeComparator(granular); if (timeComparator == null) { @@ -357,6 +576,51 @@ private static int compareDims(List dimensions, Row lhs, Row rhs) return 0; } + private static int compareDimsForLimitPushDown( + final List fields, + final List needsReverseList, + final List isNumericField, + final List comparators, + Row lhs, + Row rhs + ) + { + for (int i = 0; i < fields.size(); i++) { + final String fieldName = fields.get(i); + final StringComparator comparator = comparators.get(i); + + final int dimCompare; + + Object lhsObj; + Object rhsObj; + if (needsReverseList.get(i)) { + lhsObj = rhs.getRaw(fieldName); + rhsObj = lhs.getRaw(fieldName); + } else { + lhsObj = lhs.getRaw(fieldName); + rhsObj = rhs.getRaw(fieldName); + } + + if (isNumericField.get(i)) { + if (comparator == StringComparators.NUMERIC) { + dimCompare = NATURAL_NULLS_FIRST.compare( + rhs.getRaw(fieldName), + lhs.getRaw(fieldName) + ); + } else { + dimCompare = comparator.compare(String.valueOf(lhsObj), String.valueOf(rhsObj)); + } + } else { + dimCompare = comparator.compare((String) lhsObj, (String) rhsObj); + } + + if (dimCompare != 0) { + return dimCompare; + } + } + return 0; + } + /** * Apply the havingSpec and limitSpec. Because havingSpecs are not thread safe, and because they are applied during * accumulation of the returned sequence, callers must take care to avoid accumulating two different Sequences diff --git a/processing/src/main/java/io/druid/query/groupby/GroupByQueryConfig.java b/processing/src/main/java/io/druid/query/groupby/GroupByQueryConfig.java index 70090210ae6f..6f55b9a2abc7 100644 --- a/processing/src/main/java/io/druid/query/groupby/GroupByQueryConfig.java +++ b/processing/src/main/java/io/druid/query/groupby/GroupByQueryConfig.java @@ -27,6 +27,7 @@ public class GroupByQueryConfig { public static final String CTX_KEY_STRATEGY = "groupByStrategy"; + public static final String CTX_KEY_FORCE_LIMIT_PUSH_DOWN = "forceLimitPushDown"; private static final String CTX_KEY_IS_SINGLE_THREADED = "groupByIsSingleThreaded"; private static final String CTX_KEY_MAX_INTERMEDIATE_ROWS = "maxIntermediateRows"; private static final String CTX_KEY_MAX_RESULTS = "maxResults"; @@ -66,6 +67,12 @@ public class GroupByQueryConfig // Max on-disk temporary storage, per-query; when exceeded, the query fails private long maxOnDiskStorage = 0L; + @JsonProperty + private boolean forcePushDownLimit = false; + + @JsonProperty + private Class queryMetricsFactory; + public String getDefaultStrategy() { return defaultStrategy; @@ -126,6 +133,21 @@ public long getMaxOnDiskStorage() return maxOnDiskStorage; } + public boolean isForcePushDownLimit() + { + return forcePushDownLimit; + } + + public Class getQueryMetricsFactory() + { + return queryMetricsFactory != null ? queryMetricsFactory : DefaultGroupByQueryMetricsFactory.class; + } + + public void setQueryMetricsFactory(Class queryMetricsFactory) + { + this.queryMetricsFactory = queryMetricsFactory; + } + public GroupByQueryConfig withOverrides(final GroupByQuery query) { final GroupByQueryConfig newConfig = new GroupByQueryConfig(); @@ -159,6 +181,7 @@ public GroupByQueryConfig withOverrides(final GroupByQuery query) ((Number) query.getContextValue(CTX_KEY_MAX_MERGING_DICTIONARY_SIZE, getMaxMergingDictionarySize())).longValue(), getMaxMergingDictionarySize() ); + newConfig.forcePushDownLimit = query.getContextBoolean(CTX_KEY_FORCE_LIMIT_PUSH_DOWN, isForcePushDownLimit()); return newConfig; } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/AbstractBufferGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/AbstractBufferGrouper.java new file mode 100644 index 000000000000..a0b5e8d10cef --- /dev/null +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/AbstractBufferGrouper.java @@ -0,0 +1,214 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.base.Supplier; +import com.google.common.primitives.Ints; +import io.druid.java.util.common.IAE; +import io.druid.java.util.common.logger.Logger; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.BufferAggregator; + +import java.nio.ByteBuffer; + +public abstract class AbstractBufferGrouper implements Grouper +{ + private static final AggregateResult DICTIONARY_FULL = AggregateResult.failure( + "Not enough dictionary space to execute this query. Try increasing " + + "druid.query.groupBy.maxMergingDictionarySize or enable disk spilling by setting " + + "druid.query.groupBy.maxOnDiskStorage to a positive number." + ); + private static final AggregateResult HASHTABLE_FULL = AggregateResult.failure( + "Not enough aggregation table space to execute this query. Try increasing " + + "druid.processing.buffer.sizeBytes or enable disk spilling by setting " + + "druid.query.groupBy.maxOnDiskStorage to a positive number." + ); + + protected static final int HASH_SIZE = Ints.BYTES; + protected static final Logger log = new Logger(AbstractBufferGrouper.class); + + protected final Supplier bufferSupplier; + protected final KeySerde keySerde; + protected final int keySize; + protected final BufferAggregator[] aggregators; + protected final int[] aggregatorOffsets; + protected final int bufferGrouperMaxSize; // Integer.MAX_VALUE in production, only used for unit tests + + // The load factor and bucket configurations are not final, to allow subclasses to set their own values + protected float maxLoadFactor; + protected int initialBuckets; + protected int bucketSize; + + // The hashTable and its buffer are not final, these are set during init() for buffer management purposes + // See PR 3863 for details: https://github.com/druid-io/druid/pull/3863 + protected ByteBufferHashTable hashTable; + protected ByteBuffer hashTableBuffer; // buffer for the entire hash table (total space, not individual growth) + + public AbstractBufferGrouper( + final Supplier bufferSupplier, + final KeySerde keySerde, + final AggregatorFactory[] aggregatorFactories, + final int bufferGrouperMaxSize + ) + { + this.bufferSupplier = bufferSupplier; + this.keySerde = keySerde; + this.keySize = keySerde.keySize(); + this.aggregators = new BufferAggregator[aggregatorFactories.length]; + this.aggregatorOffsets = new int[aggregatorFactories.length]; + this.bufferGrouperMaxSize = bufferGrouperMaxSize; + } + + /** + * Called when a new bucket is used for an entry in the hash table. An implementing BufferGrouper class + * can use this to update its own state, e.g. tracking bucket offsets in a structure outside of the hash table. + * + * @param bucketOffset offset of the new bucket, within the buffer returned by hashTable.getTableBuffer() + */ + public abstract void newBucketHook(int bucketOffset); + + /** + * Called to check if it's possible to skip aggregation for a row. + * + * @param bucketWasUsed Was the row a new entry in the hash table? + * @param bucketOffset Offset of the bucket containing this row's entry in the hash table, + * within the buffer returned by hashTable.getTableBuffer() + * @return true if aggregation can be skipped, false otherwise. + */ + public abstract boolean canSkipAggregate(boolean bucketWasUsed, int bucketOffset); + + /** + * Called after a row is aggregated. An implementing BufferGrouper class can use this to update + * its own state, e.g. reading the new aggregated values for the row's key and acting on that information. + * + * @param bucketOffset Offset of the bucket containing the row that was aggregated, + * within the buffer returned by hashTable.getTableBuffer() + */ + public abstract void afterAggregateHook(int bucketOffset); + + // how many times the hash table's buffer has filled/readjusted (through adjustTableWhenFull()) + public int getGrowthCount() + { + return hashTable.getGrowthCount(); + } + + // Number of elements in the table right now + public int getSize() + { + return hashTable.getSize(); + } + + // Current number of available/used buckets in the table + public int getBuckets() + { + return hashTable.getMaxBuckets(); + } + + // Maximum number of elements in the table before it must be resized + public int getMaxSize() + { + return hashTable.getRegrowthThreshold(); + } + + @Override + public AggregateResult aggregate(KeyType key, int keyHash) + { + final ByteBuffer keyBuffer = keySerde.toByteBuffer(key); + if (keyBuffer == null) { + // This may just trigger a spill and get ignored, which is ok. If it bubbles up to the user, the message will + // be correct. + return DICTIONARY_FULL; + } + + if (keyBuffer.remaining() != keySize) { + throw new IAE( + "keySerde.toByteBuffer(key).remaining[%s] != keySerde.keySize[%s], buffer was the wrong size?!", + keyBuffer.remaining(), + keySize + ); + } + + // find and try to expand if table is full and find again + int bucket = hashTable.findBucketWithAutoGrowth(keyBuffer, keyHash); + if (bucket < 0) { + // This may just trigger a spill and get ignored, which is ok. If it bubbles up to the user, the message will + // be correct. + return HASHTABLE_FULL; + } + + final int bucketStartOffset = hashTable.getOffsetForBucket(bucket); + final boolean bucketWasUsed = hashTable.isBucketUsed(bucket); + final ByteBuffer tableBuffer = hashTable.getTableBuffer(); + + // Set up key and initialize the aggs if this is a new bucket. + if (!bucketWasUsed) { + hashTable.initializeNewBucketKey(bucket, keyBuffer, keyHash); + for (int i = 0; i < aggregators.length; i++) { + aggregators[i].init(tableBuffer, bucketStartOffset + aggregatorOffsets[i]); + } + + newBucketHook(bucketStartOffset); + } + + if (canSkipAggregate(bucketWasUsed, bucketStartOffset)) { + return AggregateResult.ok(); + } + + // Aggregate the current row. + for (int i = 0; i < aggregators.length; i++) { + aggregators[i].aggregate(tableBuffer, bucketStartOffset + aggregatorOffsets[i]); + } + + afterAggregateHook(bucketStartOffset); + + return AggregateResult.ok(); + } + + @Override + public AggregateResult aggregate(final KeyType key) + { + return aggregate(key, Groupers.hash(key)); + } + + @Override + public void close() + { + for (BufferAggregator aggregator : aggregators) { + try { + aggregator.close(); + } + catch (Exception e) { + log.warn(e, "Could not close aggregator, skipping.", aggregator); + } + } + } + + protected Entry bucketEntryForOffset(final int bucketOffset) + { + final ByteBuffer tableBuffer = hashTable.getTableBuffer(); + final KeyType key = keySerde.fromByteBuffer(tableBuffer, bucketOffset + HASH_SIZE); + final Object[] values = new Object[aggregators.length]; + for (int i = 0; i < aggregators.length; i++) { + values[i] = aggregators[i].get(tableBuffer, bucketOffset + aggregatorOffsets[i]); + } + + return new Entry<>(key, values); + } +} diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferGrouper.java index 58d04de6e688..cb738ed4d0cd 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/BufferGrouper.java @@ -20,12 +20,11 @@ package io.druid.query.groupby.epinephelinae; import com.google.common.base.Supplier; +import com.google.common.collect.Iterators; import com.google.common.primitives.Ints; import io.druid.java.util.common.IAE; -import io.druid.java.util.common.ISE; import io.druid.java.util.common.logger.Logger; import io.druid.query.aggregation.AggregatorFactory; -import io.druid.query.aggregation.BufferAggregator; import io.druid.segment.ColumnSelectorFactory; import java.nio.ByteBuffer; @@ -34,77 +33,24 @@ import java.util.Comparator; import java.util.Iterator; import java.util.List; +import java.util.NoSuchElementException; -/** - * Grouper based around a hash table and companion array in a single ByteBuffer. Not thread-safe. - * - * The buffer has two parts: a table arena (offset 0 to tableArenaSize) and an array containing pointers objects in - * the table (tableArenaSize until the end of the buffer). - * - * The table uses open addressing with linear probing on collisions. Each bucket contains the key hash (with the high - * bit set to signify the bucket is used), the serialized key (which are a fixed size) and scratch space for - * BufferAggregators (which is also fixed size). The actual table is represented by "tableBuffer", which points to the - * same memory as positions "tableStart" through "tableStart + buckets * bucketSize" of "buffer". Everything else in - * the table arena is potentially junk. - * - * The array of pointers starts out ordered by insertion order, but might be sorted on calls to - * {@link #iterator(boolean)}. This sorting is done in-place to avoid materializing the full array of pointers. The - * first "size" pointers in the array of pointers are valid; everything else is potentially junk. - * - * The table is periodically grown to accommodate more keys. Even though starting small is not necessary to control - * memory use (we already have the entire buffer allocated) or iteration speed (iteration is fast due to the array - * of pointers) it still helps significantly on initialization times. Otherwise, we'd need to clear the used bits of - * each bucket in the entire buffer, which is a lot of writes if the buckets are small. - */ -public class BufferGrouper implements Grouper +public class BufferGrouper extends AbstractBufferGrouper { private static final Logger log = new Logger(BufferGrouper.class); - private static final AggregateResult DICTIONARY_FULL = AggregateResult.failure( - "Not enough dictionary space to execute this query. Try increasing " - + "druid.query.groupBy.maxMergingDictionarySize or enable disk spilling by setting " - + "druid.query.groupBy.maxOnDiskStorage to a positive number." - ); - private static final AggregateResult HASHTABLE_FULL = AggregateResult.failure( - "Not enough aggregation table space to execute this query. Try increasing " - + "druid.processing.buffer.sizeBytes or enable disk spilling by setting " - + "druid.query.groupBy.maxOnDiskStorage to a positive number." - ); - private static final int MIN_INITIAL_BUCKETS = 4; private static final int DEFAULT_INITIAL_BUCKETS = 1024; private static final float DEFAULT_MAX_LOAD_FACTOR = 0.7f; - private static final int HASH_SIZE = Ints.BYTES; - - private final Supplier bufferSupplier; - private final KeySerde keySerde; - private final int keySize; - private final BufferAggregator[] aggregators; - private final int[] aggregatorOffsets; - private final int initialBuckets; - private final int bucketSize; - private final int bufferGrouperMaxSize; // Integer.MAX_VALUE in production, only used for unit tests - private final float maxLoadFactor; private ByteBuffer buffer; - private int tableArenaSize = -1; - - // Buffer pointing to the current table (it moves around as the table grows) - private ByteBuffer tableBuffer; - - // Offset of tableBuffer within the larger buffer - private int tableStart; - - // Current number of buckets in the table - private int buckets; - - // Number of elements in the table right now - private int size; - - // Maximum number of elements in the table before it must be resized - private int maxSize; - private boolean initialized = false; + // Track the offsets of used buckets using this list. + // When a new bucket is initialized by initializeNewBucketKey(), an offset is added to this list. + // When expanding the table, the list is reset() and filled with the new offsets of the copied buckets. + private ByteBuffer offsetListBuffer; + private ByteBufferIntList offsetList; + public BufferGrouper( final Supplier bufferSupplier, final KeySerde keySerde, @@ -115,12 +61,8 @@ public BufferGrouper( final int initialBuckets ) { - this.bufferSupplier = bufferSupplier; - this.keySerde = keySerde; - this.keySize = keySerde.keySize(); - this.aggregators = new BufferAggregator[aggregatorFactories.length]; - this.aggregatorOffsets = new int[aggregatorFactories.length]; - this.bufferGrouperMaxSize = bufferGrouperMaxSize; + super(bufferSupplier, keySerde, aggregatorFactories, bufferGrouperMaxSize); + this.maxLoadFactor = maxLoadFactor > 0 ? maxLoadFactor : DEFAULT_MAX_LOAD_FACTOR; this.initialBuckets = initialBuckets > 0 ? Math.max(MIN_INITIAL_BUCKETS, initialBuckets) : DEFAULT_INITIAL_BUCKETS; @@ -143,7 +85,38 @@ public void init() { if (!initialized) { this.buffer = bufferSupplier.get(); - this.tableArenaSize = (buffer.capacity() / (bucketSize + Ints.BYTES)) * bucketSize; + + int hashTableSize = ByteBufferHashTable.calculateTableArenaSizeWithPerBucketAdditionalSize( + buffer.capacity(), + bucketSize, + Ints.BYTES + ); + + hashTableBuffer = buffer.duplicate(); + hashTableBuffer.position(0); + hashTableBuffer.limit(hashTableSize); + hashTableBuffer = hashTableBuffer.slice(); + + offsetListBuffer = buffer.duplicate(); + offsetListBuffer.position(hashTableSize); + offsetListBuffer.limit(buffer.capacity()); + offsetListBuffer = offsetListBuffer.slice(); + + this.offsetList = new ByteBufferIntList( + offsetListBuffer, + offsetListBuffer.capacity() / Ints.BYTES + ); + + this.hashTable = new ByteBufferHashTable( + maxLoadFactor, + initialBuckets, + bucketSize, + hashTableBuffer, + keySize, + bufferGrouperMaxSize, + new BufferGrouperBucketUpdateHandler() + ); + reset(); initialized = true; } @@ -156,149 +129,64 @@ public boolean isInitialized() } @Override - public AggregateResult aggregate(KeyType key, int keyHash) + public void newBucketHook(int bucketOffset) { - final ByteBuffer keyBuffer = keySerde.toByteBuffer(key); - if (keyBuffer == null) { - // This may just trigger a spill and get ignored, which is ok. If it bubbles up to the user, the message will - // be correct. - return DICTIONARY_FULL; - } - - if (keyBuffer.remaining() != keySize) { - throw new IAE( - "keySerde.toByteBuffer(key).remaining[%s] != keySerde.keySize[%s], buffer was the wrong size?!", - keyBuffer.remaining(), - keySize - ); - } - - int bucket = findBucket( - tableBuffer, - buckets, - bucketSize, - size < Math.min(maxSize, bufferGrouperMaxSize), - keyBuffer, - keySize, - keyHash - ); - - if (bucket < 0) { - if (size < bufferGrouperMaxSize) { - growIfPossible(); - bucket = findBucket(tableBuffer, buckets, bucketSize, size < maxSize, keyBuffer, keySize, keyHash); - } - - if (bucket < 0) { - // This may just trigger a spill and get ignored, which is ok. If it bubbles up to the user, the message will - // be correct. - return HASHTABLE_FULL; - } - } - - final int offset = bucket * bucketSize; - - // Set up key if this is a new bucket. - if (!isUsed(bucket)) { - tableBuffer.position(offset); - tableBuffer.putInt(keyHash | 0x80000000); - tableBuffer.put(keyBuffer); - - for (int i = 0; i < aggregators.length; i++) { - aggregators[i].init(tableBuffer, offset + aggregatorOffsets[i]); - } - - buffer.putInt(tableArenaSize + size * Ints.BYTES, offset); - size++; - } - - // Aggregate the current row. - for (int i = 0; i < aggregators.length; i++) { - aggregators[i].aggregate(tableBuffer, offset + aggregatorOffsets[i]); - } - - return AggregateResult.ok(); } @Override - public AggregateResult aggregate(final KeyType key) + public boolean canSkipAggregate(boolean bucketWasUsed, int bucketOffset) { - return aggregate(key, Groupers.hash(key)); + return false; } @Override - public void reset() + public void afterAggregateHook(int bucketOffset) { - size = 0; - buckets = Math.min(tableArenaSize / bucketSize, initialBuckets); - maxSize = maxSizeForBuckets(buckets); - - if (buckets < 1) { - throw new IAE( - "Not enough capacity for even one row! Need[%,d] but have[%,d].", - bucketSize + Ints.BYTES, - buffer.capacity() - ); - } - - // Start table part-way through the buffer so the last growth can start from zero and thereby use more space. - tableStart = tableArenaSize - buckets * bucketSize; - int nextBuckets = buckets * 2; - while (true) { - final int nextTableStart = tableStart - nextBuckets * bucketSize; - if (nextTableStart > tableArenaSize / 2) { - tableStart = nextTableStart; - nextBuckets = nextBuckets * 2; - } else { - break; - } - } - - if (tableStart < tableArenaSize / 2) { - tableStart = 0; - } - final ByteBuffer bufferDup = buffer.duplicate(); - bufferDup.position(tableStart); - bufferDup.limit(tableStart + buckets * bucketSize); - tableBuffer = bufferDup.slice(); - - // Clear used bits of new table - for (int i = 0; i < buckets; i++) { - tableBuffer.put(i * bucketSize, (byte) 0); - } + } + @Override + public void reset() + { + offsetList.reset(); + hashTable.reset(); keySerde.reset(); } @Override - public Iterator> iterator(final boolean sorted) + public Iterator> iterator(boolean sorted) { + if (!initialized) { + // it's possible for iterator() to be called before initialization when + // a nested groupBy's subquery has an empty result set (see testEmptySubquery() in GroupByQueryRunnerTest) + return Iterators.>emptyIterator(); + } + if (sorted) { final List wrappedOffsets = new AbstractList() { @Override public Integer get(int index) { - return buffer.getInt(tableArenaSize + index * Ints.BYTES); + return offsetList.get(index); } @Override public Integer set(int index, Integer element) { final Integer oldValue = get(index); - buffer.putInt(tableArenaSize + index * Ints.BYTES, element); + offsetList.set(index, element); return oldValue; } @Override public int size() { - return size; + return hashTable.getSize(); } }; - final KeyComparator comparator = keySerde.bufferComparator(); + final BufferComparator comparator = keySerde.bufferComparator(); // Sort offsets in-place. Collections.sort( @@ -308,6 +196,7 @@ public int size() @Override public int compare(Integer lhs, Integer rhs) { + final ByteBuffer tableBuffer = hashTable.getTableBuffer(); return comparator.compare( tableBuffer, tableBuffer, @@ -321,6 +210,7 @@ public int compare(Integer lhs, Integer rhs) return new Iterator>() { int curr = 0; + final int size = getSize(); @Override public boolean hasNext() @@ -331,6 +221,9 @@ public boolean hasNext() @Override public Entry next() { + if (curr >= size) { + throw new NoSuchElementException(); + } return bucketEntryForOffset(wrappedOffsets.get(curr++)); } @@ -345,6 +238,7 @@ public void remove() return new Iterator>() { int curr = 0; + final int size = getSize(); @Override public boolean hasNext() @@ -355,7 +249,10 @@ public boolean hasNext() @Override public Entry next() { - final int offset = buffer.getInt(tableArenaSize + curr * Ints.BYTES); + if (curr >= size) { + throw new NoSuchElementException(); + } + final int offset = offsetList.get(curr); final Entry entry = bucketEntryForOffset(offset); curr++; @@ -371,174 +268,36 @@ public void remove() } } - @Override - public void close() - { - for (BufferAggregator aggregator : aggregators) { - try { - aggregator.close(); - } - catch (Exception e) { - log.warn(e, "Could not close aggregator, skipping.", aggregator); - } - } - } - - private boolean isUsed(final int bucket) - { - return (tableBuffer.get(bucket * bucketSize) & 0x80) == 0x80; - } - - private Entry bucketEntryForOffset(final int bucketOffset) - { - final KeyType key = keySerde.fromByteBuffer(tableBuffer, bucketOffset + HASH_SIZE); - final Object[] values = new Object[aggregators.length]; - for (int i = 0; i < aggregators.length; i++) { - values[i] = aggregators[i].get(tableBuffer, bucketOffset + aggregatorOffsets[i]); - } - - return new Entry<>(key, values); - } - - private void growIfPossible() + private class BufferGrouperBucketUpdateHandler implements ByteBufferHashTable.BucketUpdateHandler { - if (tableStart == 0) { - // tableStart = 0 is the last growth; no further growing is possible. - return; - } - - final int newBuckets; - final int newMaxSize; - final int newTableStart; - - if ((long) buckets * 3 * bucketSize > (long) tableArenaSize - tableStart) { - // Not enough space to grow upwards, start back from zero - newTableStart = 0; - newBuckets = tableStart / bucketSize; - newMaxSize = maxSizeForBuckets(newBuckets); - } else { - newTableStart = tableStart + tableBuffer.limit(); - newBuckets = buckets * 2; - newMaxSize = maxSizeForBuckets(newBuckets); - } - - if (newBuckets < buckets) { - throw new ISE("WTF?! newBuckets[%,d] < buckets[%,d]", newBuckets, buckets); - } - - ByteBuffer newTableBuffer = buffer.duplicate(); - newTableBuffer.position(newTableStart); - newTableBuffer.limit(newTableStart + newBuckets * bucketSize); - newTableBuffer = newTableBuffer.slice(); - - int newSize = 0; - - // Clear used bits of new table - for (int i = 0; i < newBuckets; i++) { - newTableBuffer.put(i * bucketSize, (byte) 0); - } - - // Loop over old buckets and copy to new table - final ByteBuffer entryBuffer = tableBuffer.duplicate(); - final ByteBuffer keyBuffer = tableBuffer.duplicate(); - - for (int oldBucket = 0; oldBucket < buckets; oldBucket++) { - if (isUsed(oldBucket)) { - int oldPosition = oldBucket * bucketSize; - entryBuffer.limit((oldBucket + 1) * bucketSize); - entryBuffer.position(oldPosition); - keyBuffer.limit(entryBuffer.position() + HASH_SIZE + keySize); - keyBuffer.position(entryBuffer.position() + HASH_SIZE); - - final int keyHash = entryBuffer.getInt(entryBuffer.position()) & 0x7fffffff; - final int newBucket = findBucket(newTableBuffer, newBuckets, bucketSize, true, keyBuffer, keySize, keyHash); - - if (newBucket < 0) { - throw new ISE("WTF?! Couldn't find a bucket while resizing?!"); - } - - int newPosition = newBucket * bucketSize; - newTableBuffer.position(newPosition); - newTableBuffer.put(entryBuffer); - - for (int i = 0; i < aggregators.length; i++) { - aggregators[i].relocate( - oldPosition + aggregatorOffsets[i], - newPosition + aggregatorOffsets[i], - tableBuffer, - newTableBuffer - ); - } - - buffer.putInt(tableArenaSize + newSize * Ints.BYTES, newBucket * bucketSize); - newSize++; - } + @Override + public void handleNewBucket(int bucketOffset) + { + offsetList.add(bucketOffset); } - buckets = newBuckets; - maxSize = newMaxSize; - tableBuffer = newTableBuffer; - tableStart = newTableStart; - - if (size != newSize) { - throw new ISE("WTF?! size[%,d] != newSize[%,d] after resizing?!", size, maxSize); + @Override + public void handlePreTableSwap() + { + offsetList.reset(); } - } - - private int maxSizeForBuckets(int buckets) - { - return Math.max(1, (int) (buckets * maxLoadFactor)); - } - - /** - * Finds the bucket into which we should insert a key. - * - * @param keyBuffer key, must have exactly keySize bytes remaining. Will not be modified. - * - * @return bucket index for this key, or -1 if no bucket is available due to being full - */ - private static int findBucket( - final ByteBuffer tableBuffer, - final int buckets, - final int bucketSize, - final boolean allowNewBucket, - final ByteBuffer keyBuffer, - final int keySize, - final int keyHash - ) - { - // startBucket will never be negative since keyHash is always positive (see Groupers.hash) - final int startBucket = keyHash % buckets; - int bucket = startBucket; - -outer: - while (true) { - final int bucketOffset = bucket * bucketSize; - - if ((tableBuffer.get(bucketOffset) & 0x80) == 0) { - // Found unused bucket before finding our key - return allowNewBucket ? bucket : -1; - } - - for (int i = bucketOffset + HASH_SIZE, j = keyBuffer.position(); j < keyBuffer.position() + keySize; i++, j++) { - if (tableBuffer.get(i) != keyBuffer.get(j)) { - bucket += 1; - if (bucket == buckets) { - bucket = 0; - } - if (bucket == startBucket) { - // Came back around to the start without finding a free slot, that was a long trip! - // Should never happen unless buckets == maxSize. - return -1; - } - - continue outer; - } + @Override + public void handleBucketMove( + int oldBucketOffset, int newBucketOffset, ByteBuffer oldBuffer, ByteBuffer newBuffer + ) + { + // relocate aggregators (see https://github.com/druid-io/druid/pull/4071) + for (int i = 0; i < aggregators.length; i++) { + aggregators[i].relocate( + oldBucketOffset + aggregatorOffsets[i], + newBucketOffset + aggregatorOffsets[i], + oldBuffer, + newBuffer + ); } - // Found our key in a used bucket - return bucket; + offsetList.add(newBucketOffset); } } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/ByteBufferHashTable.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ByteBufferHashTable.java new file mode 100644 index 000000000000..cd83b229c355 --- /dev/null +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ByteBufferHashTable.java @@ -0,0 +1,382 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.primitives.Ints; +import io.druid.java.util.common.IAE; +import io.druid.java.util.common.ISE; + +import java.nio.ByteBuffer; + +public class ByteBufferHashTable +{ + public static int calculateTableArenaSizeWithPerBucketAdditionalSize( + int bufferCapacity, + int bucketSize, + int perBucketAdditionalSize + ) + { + return (bufferCapacity / (bucketSize + perBucketAdditionalSize)) * bucketSize; + } + + public static int calculateTableArenaSizeWithFixedAdditionalSize( + int bufferCapacity, + int bucketSize, + int fixedAdditionalSize + ) + { + return ((bufferCapacity - fixedAdditionalSize) / bucketSize) * bucketSize; + } + + protected final int maxSizeForTesting; // Integer.MAX_VALUE in production, only used for unit tests + + protected static final int HASH_SIZE = Ints.BYTES; + + protected final float maxLoadFactor; + protected final int initialBuckets; + protected final ByteBuffer buffer; + protected final int bucketSizeWithHash; + protected final int tableArenaSize; + protected final int keySize; + + protected int tableStart; + + // Buffer pointing to the current table (it moves around as the table grows) + protected ByteBuffer tableBuffer; + + // Number of elements in the table right now + protected int size; + + // Maximum number of elements in the table before it must be resized + // This value changes when the table is resized. + protected int regrowthThreshold; + + // current number of available/used buckets in the table + // This value changes when the table is resized. + protected int maxBuckets; + + // how many times the table buffer has filled/readjusted (through adjustTableWhenFull()) + protected int growthCount; + + + + protected BucketUpdateHandler bucketUpdateHandler; + + public ByteBufferHashTable( + float maxLoadFactor, + int initialBuckets, + int bucketSizeWithHash, + ByteBuffer buffer, + int keySize, + int maxSizeForTesting, + BucketUpdateHandler bucketUpdateHandler + ) + { + this.maxLoadFactor = maxLoadFactor; + this.initialBuckets = initialBuckets; + this.bucketSizeWithHash = bucketSizeWithHash; + this.buffer = buffer; + this.keySize = keySize; + this.maxSizeForTesting = maxSizeForTesting; + this.tableArenaSize = buffer.capacity(); + this.bucketUpdateHandler = bucketUpdateHandler; + } + + public void reset() + { + size = 0; + + maxBuckets = Math.min(tableArenaSize / bucketSizeWithHash, initialBuckets); + regrowthThreshold = maxSizeForBuckets(maxBuckets); + + if (maxBuckets < 1) { + throw new IAE( + "Not enough capacity for even one row! Need[%,d] but have[%,d].", + bucketSizeWithHash + Ints.BYTES, + buffer.capacity() + ); + } + + // Start table part-way through the buffer so the last growth can start from zero and thereby use more space. + tableStart = tableArenaSize - maxBuckets * bucketSizeWithHash; + int nextBuckets = maxBuckets * 2; + while (true) { + final int nextTableStart = tableStart - nextBuckets * bucketSizeWithHash; + if (nextTableStart > tableArenaSize / 2) { + tableStart = nextTableStart; + nextBuckets = nextBuckets * 2; + } else { + break; + } + } + + if (tableStart < tableArenaSize / 2) { + tableStart = 0; + } + + final ByteBuffer bufferDup = buffer.duplicate(); + bufferDup.position(tableStart); + bufferDup.limit(tableStart + maxBuckets * bucketSizeWithHash); + tableBuffer = bufferDup.slice(); + + // Clear used bits of new table + for (int i = 0; i < maxBuckets; i++) { + tableBuffer.put(i * bucketSizeWithHash, (byte) 0); + } + } + + public void adjustTableWhenFull() + { + if (tableStart == 0) { + // tableStart = 0 is the last growth; no further growing is possible. + return; + } + + final int newBuckets; + final int newMaxSize; + final int newTableStart; + + if (((long) maxBuckets * 3 * bucketSizeWithHash) > (long) tableArenaSize - tableStart) { + // Not enough space to grow upwards, start back from zero + newTableStart = 0; + newBuckets = tableStart / bucketSizeWithHash; + newMaxSize = maxSizeForBuckets(newBuckets); + } else { + newTableStart = tableStart + tableBuffer.limit(); + newBuckets = maxBuckets * 2; + newMaxSize = maxSizeForBuckets(newBuckets); + } + + if (newBuckets < maxBuckets) { + throw new ISE("WTF?! newBuckets[%,d] < maxBuckets[%,d]", newBuckets, maxBuckets); + } + + ByteBuffer newTableBuffer = buffer.duplicate(); + newTableBuffer.position(newTableStart); + newTableBuffer.limit(newTableStart + newBuckets * bucketSizeWithHash); + newTableBuffer = newTableBuffer.slice(); + + int newSize = 0; + + // Clear used bits of new table + for (int i = 0; i < newBuckets; i++) { + newTableBuffer.put(i * bucketSizeWithHash, (byte) 0); + } + + // Loop over old buckets and copy to new table + final ByteBuffer entryBuffer = tableBuffer.duplicate(); + final ByteBuffer keyBuffer = tableBuffer.duplicate(); + + int oldBuckets = maxBuckets; + + if (bucketUpdateHandler != null) { + bucketUpdateHandler.handlePreTableSwap(); + } + + for (int oldBucket = 0; oldBucket < oldBuckets; oldBucket++) { + if (isBucketUsed(oldBucket)) { + int oldBucketOffset = oldBucket * bucketSizeWithHash; + entryBuffer.limit((oldBucket + 1) * bucketSizeWithHash); + entryBuffer.position(oldBucketOffset); + keyBuffer.limit(entryBuffer.position() + HASH_SIZE + keySize); + keyBuffer.position(entryBuffer.position() + HASH_SIZE); + + final int keyHash = entryBuffer.getInt(entryBuffer.position()) & 0x7fffffff; + final int newBucket = findBucket(true, newBuckets, newTableBuffer, keyBuffer, keyHash); + + if (newBucket < 0) { + throw new ISE("WTF?! Couldn't find a bucket while resizing?!"); + } + + final int newBucketOffset = newBucket * bucketSizeWithHash; + + newTableBuffer.position(newBucketOffset); + newTableBuffer.put(entryBuffer); + + newSize++; + + if (bucketUpdateHandler != null) { + bucketUpdateHandler.handleBucketMove(oldBucketOffset, newBucketOffset, tableBuffer, newTableBuffer); + } + } + } + + maxBuckets = newBuckets; + regrowthThreshold = newMaxSize; + tableBuffer = newTableBuffer; + tableStart = newTableStart; + + growthCount++; + + if (size != newSize) { + throw new ISE("WTF?! size[%,d] != newSize[%,d] after resizing?!", size, newSize); + } + } + + protected void initializeNewBucketKey( + final int bucket, + final ByteBuffer keyBuffer, + final int keyHash + ) + { + int offset = bucket * bucketSizeWithHash; + tableBuffer.position(offset); + tableBuffer.putInt(keyHash | 0x80000000); + tableBuffer.put(keyBuffer); + size++; + + if (bucketUpdateHandler != null) { + bucketUpdateHandler.handleNewBucket(offset); + } + } + + /** + * Find a bucket for a key, attempting to resize the table with adjustTableWhenFull() if possible. + * + * @param keyBuffer buffer containing the key + * @param keyHash hash of the key + * @return bucket number of the found bucket or -1 if a bucket could not be allocated after resizing. + */ + protected int findBucketWithAutoGrowth( + final ByteBuffer keyBuffer, + final int keyHash + ) + { + int bucket = findBucket(canAllowNewBucket(), maxBuckets, tableBuffer, keyBuffer, keyHash); + + if (bucket < 0) { + if (size < maxSizeForTesting) { + adjustTableWhenFull(); + bucket = findBucket(size < regrowthThreshold, maxBuckets, tableBuffer, keyBuffer, keyHash); + } + } + + return bucket; + } + + /** + * Finds the bucket into which we should insert a key. + * + * @param keyBuffer key, must have exactly keySize bytes remaining. Will not be modified. + * @param targetTableBuffer Need selectable buffer, since when resizing hash table, + * findBucket() is used on the newly allocated table buffer + * + * @return bucket index for this key, or -1 if no bucket is available due to being full + */ + protected int findBucket( + final boolean allowNewBucket, + final int buckets, + final ByteBuffer targetTableBuffer, + final ByteBuffer keyBuffer, + final int keyHash + ) + { + // startBucket will never be negative since keyHash is always positive (see Groupers.hash) + final int startBucket = keyHash % buckets; + int bucket = startBucket; + +outer: + while (true) { + final int bucketOffset = bucket * bucketSizeWithHash; + + if ((targetTableBuffer.get(bucketOffset) & 0x80) == 0) { + // Found unused bucket before finding our key + return allowNewBucket ? bucket : -1; + } + + for (int i = bucketOffset + HASH_SIZE, j = keyBuffer.position(); j < keyBuffer.position() + keySize; i++, j++) { + if (targetTableBuffer.get(i) != keyBuffer.get(j)) { + bucket += 1; + if (bucket == buckets) { + bucket = 0; + } + + if (bucket == startBucket) { + // Came back around to the start without finding a free slot, that was a long trip! + // Should never happen unless buckets == regrowthThreshold. + return -1; + } + + continue outer; + } + } + + // Found our key in a used bucket + return bucket; + } + } + + protected boolean canAllowNewBucket() + { + return size < Math.min(regrowthThreshold, maxSizeForTesting); + } + + protected int getOffsetForBucket(int bucket) + { + return bucket * bucketSizeWithHash; + } + + protected int maxSizeForBuckets(int buckets) + { + return Math.max(1, (int) (buckets * maxLoadFactor)); + } + + protected boolean isBucketUsed(final int bucket) + { + return (tableBuffer.get(bucket * bucketSizeWithHash) & 0x80) == 0x80; + } + + protected boolean isOffsetUsed(final int bucketOffset) + { + return (tableBuffer.get(bucketOffset) & 0x80) == 0x80; + } + + public ByteBuffer getTableBuffer() + { + return tableBuffer; + } + + public int getSize() + { + return size; + } + + public int getRegrowthThreshold() + { + return regrowthThreshold; + } + + public int getMaxBuckets() + { + return maxBuckets; + } + + public int getGrowthCount() + { + return growthCount; + } + + public interface BucketUpdateHandler + { + void handleNewBucket(int bucketOffset); + void handlePreTableSwap(); + void handleBucketMove(int oldBucketOffset, int newBucketOffset, ByteBuffer oldBuffer, ByteBuffer newBuffer); + } +} diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/ByteBufferIntList.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ByteBufferIntList.java new file mode 100644 index 000000000000..2fe1706b2a24 --- /dev/null +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ByteBufferIntList.java @@ -0,0 +1,78 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.primitives.Ints; +import com.metamx.common.IAE; + +import java.nio.ByteBuffer; + +public class ByteBufferIntList +{ + private final ByteBuffer buffer; + private final int maxElements; + private int numElements; + + public ByteBufferIntList( + ByteBuffer buffer, + int maxElements + ) + { + this.buffer = buffer; + this.maxElements = maxElements; + this.numElements = 0; + + if (buffer.capacity() < (maxElements * Ints.BYTES)) { + throw new IAE( + "buffer for list is too small, was [%s] bytes, but need [%s] bytes.", + buffer.capacity(), + maxElements * Ints.BYTES + ); + } + } + + public void add(int val) + { + if (numElements == maxElements) { + throw new IndexOutOfBoundsException(String.format("List is full with %s elements.", maxElements)); + } + buffer.putInt(numElements * Ints.BYTES, val); + numElements++; + } + + public void set(int index, int val) + { + buffer.putInt(index * Ints.BYTES, val); + } + + public int get(int index) { + return buffer.getInt(index * Ints.BYTES); + } + + public int getNumElements() + { + return numElements; + } + + public void reset() + { + numElements = 0; + } +} diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/ByteBufferMinMaxOffsetHeap.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ByteBufferMinMaxOffsetHeap.java new file mode 100644 index 000000000000..ea203a660514 --- /dev/null +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ByteBufferMinMaxOffsetHeap.java @@ -0,0 +1,493 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Ordering; +import com.google.common.primitives.Ints; +import io.druid.java.util.common.ISE; + +import java.nio.ByteBuffer; +import java.util.Comparator; + +/** + * ByteBuffer-based implementation of the min-max heap developed by Atkinson, et al. + * (http://portal.acm.org/citation.cfm?id=6621), with some utility functions from + * Guava's MinMaxPriorityQueue. + */ +public class ByteBufferMinMaxOffsetHeap +{ + private static final int EVEN_POWERS_OF_TWO = 0x55555555; + private static final int ODD_POWERS_OF_TWO = 0xaaaaaaaa; + + private final Comparator minComparator; + private final Comparator maxComparator; + private final ByteBuffer buf; + private final int limit; + private final LimitedBufferGrouper.BufferGrouperOffsetHeapIndexUpdater heapIndexUpdater; + + private int heapSize; + + public ByteBufferMinMaxOffsetHeap( + ByteBuffer buf, + int limit, + Comparator minComparator, + LimitedBufferGrouper.BufferGrouperOffsetHeapIndexUpdater heapIndexUpdater + ) + { + this.buf = buf; + this.limit = limit; + this.heapSize = 0; + this.minComparator = minComparator; + this.maxComparator = Ordering.from(minComparator).reverse(); + this.heapIndexUpdater = heapIndexUpdater; + } + + public void reset() + { + heapSize = 0; + } + + public int addOffset(int offset) + { + int pos = heapSize; + buf.putInt(pos * Ints.BYTES, offset); + heapSize++; + + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(offset, pos); + } + + bubbleUp(pos); + + if (heapSize > limit) { + return removeMax(); + } else { + return -1; + } + } + + public int removeMin() { + if (heapSize < 1) { + throw new ISE("Empty heap"); + } + int minOffset = buf.getInt(0); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(minOffset, -1); + } + + if (heapSize == 1) { + heapSize--; + return minOffset; + } + + int lastIndex = heapSize - 1; + int lastOffset = buf.getInt(lastIndex * Ints.BYTES); + heapSize--; + buf.putInt(0, lastOffset); + + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(lastOffset, 0); + } + + Comparator comparator = isEvenLevel(0) ? minComparator : maxComparator; + siftDown(comparator, 0); + + return minOffset; + } + + public int removeMax() { + int maxOffset; + if (heapSize < 1) { + throw new ISE("Empty heap"); + } + if (heapSize == 1) { + heapSize--; + maxOffset = buf.getInt(0); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(maxOffset, -1); + } + return maxOffset; + } + + // index of max must be 1, just remove it and shrink the heap + if (heapSize == 2) { + heapSize--; + maxOffset = buf.getInt(Ints.BYTES); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(maxOffset, -1); + } + return maxOffset; + } + + int maxIndex = findMaxElementIndex(); + maxOffset = buf.getInt(maxIndex * Ints.BYTES); + + int lastIndex = heapSize - 1; + int lastOffset = buf.getInt(lastIndex * Ints.BYTES); + heapSize--; + buf.putInt(maxIndex * Ints.BYTES, lastOffset); + + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(maxOffset, -1); + heapIndexUpdater.updateHeapIndexForOffset(lastOffset, maxIndex); + } + + Comparator comparator = isEvenLevel(maxIndex) ? minComparator : maxComparator; + siftDown(comparator, maxIndex); + + return maxOffset; + } + + public int removeAt(int deletedIndex) { + if (heapSize < 1) { + throw new ISE("Empty heap"); + } + int deletedOffset = buf.getInt(deletedIndex * Ints.BYTES); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(deletedOffset, -1); + } + + int lastIndex = heapSize - 1; + heapSize--; + if (lastIndex == deletedIndex) { + return deletedOffset; + } + int lastOffset = buf.getInt(lastIndex * Ints.BYTES); + buf.putInt(deletedIndex * Ints.BYTES, lastOffset); + + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(lastOffset, deletedIndex); + } + + Comparator comparator = isEvenLevel(deletedIndex) ? minComparator : maxComparator; + + bubbleUp(deletedIndex); + siftDown(comparator, deletedIndex); + + return deletedOffset; + } + + public void setAt(int index, int newVal) { + buf.putInt(index * Ints.BYTES, newVal); + } + + public int getAt(int index) { + return buf.getInt(index * Ints.BYTES); + } + + public int indexOf(int offset) { + for (int i = 0; i < heapSize; i++) { + int curOffset = buf.getInt(i * Ints.BYTES); + if (curOffset == offset) { + return i; + } + } + return -1; + } + + public void removeOffset(int offset) { + int index = indexOf(offset); + if (index > -1) { + removeAt(index); + } + } + + public int getHeapSize() { + return heapSize; + } + + private void bubbleUp(int pos) + { + if (isEvenLevel(pos)) { + int parentIndex = getParentIndex(pos); + if (parentIndex > -1) { + int parentOffset = buf.getInt(parentIndex * Ints.BYTES); + int offset = buf.getInt(pos * Ints.BYTES); + if (minComparator.compare(offset, parentOffset) > 0) { + buf.putInt(parentIndex * Ints.BYTES, offset); + buf.putInt(pos * Ints.BYTES, parentOffset); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(offset, parentIndex); + heapIndexUpdater.updateHeapIndexForOffset(parentOffset, pos); + } + bubbleUpDirectional(maxComparator, parentIndex); + } else { + bubbleUpDirectional(minComparator, pos); + } + } else { + bubbleUpDirectional(minComparator, pos); + } + } else { + int parentIndex = getParentIndex(pos); + if (parentIndex > -1) { + int parentOffset = buf.getInt(parentIndex * Ints.BYTES); + int offset = buf.getInt(pos * Ints.BYTES); + if (minComparator.compare(offset, parentOffset) < 0) { + buf.putInt(parentIndex * Ints.BYTES, offset); + buf.putInt(pos * Ints.BYTES, parentOffset); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(offset, parentIndex); + heapIndexUpdater.updateHeapIndexForOffset(parentOffset, pos); + } + bubbleUpDirectional(minComparator, parentIndex); + } else { + bubbleUpDirectional(maxComparator, pos); + } + } else { + bubbleUpDirectional(maxComparator, pos); + } + } + } + + private void bubbleUpDirectional(Comparator comparator, int pos) + { + int grandparent = getGrandparentIndex(pos); + while (grandparent > -1) { + int offset = buf.getInt(pos * Ints.BYTES); + int gpOffset = buf.getInt(grandparent * Ints.BYTES); + + if (comparator.compare(offset, gpOffset) < 0) { + buf.putInt(pos * Ints.BYTES, gpOffset); + buf.putInt(grandparent * Ints.BYTES, offset); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(gpOffset, pos); + heapIndexUpdater.updateHeapIndexForOffset(offset, grandparent); + } + } + pos = grandparent; + grandparent = getGrandparentIndex(pos); + } + } + + private void siftDown(Comparator comparator, int pos) + { + int minChild = findMinChild(comparator, pos); + int minGrandchild; + int minIndex; + while (minChild > -1) { + minGrandchild = findMinGrandChild(comparator, pos); + if (minGrandchild > -1) { + int minChildOffset = buf.getInt(minChild * Ints.BYTES); + int minGcOffset = buf.getInt(minGrandchild * Ints.BYTES); + int cmp = comparator.compare(minChildOffset, minGcOffset); + minIndex = (cmp > 0) ? minGrandchild : minChild; + } else if (minChild > -1) { + minIndex = minChild; + } else { + break; + } + if (minIndex == minGrandchild) { + int offset = buf.getInt(pos * Ints.BYTES); + int minOffset = buf.getInt(minIndex * Ints.BYTES); + + if (comparator.compare(minOffset, offset) < 0) { + buf.putInt(pos * Ints.BYTES, minOffset); + buf.putInt(minIndex * Ints.BYTES, offset); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(minOffset, pos); + heapIndexUpdater.updateHeapIndexForOffset(offset, minIndex); + } + + int parent = getParentIndex(minIndex); + int parentOffset = buf.getInt(parent * Ints.BYTES); + + if (comparator.compare(offset, parentOffset) > 0) { + buf.putInt(minIndex * Ints.BYTES, parentOffset); + buf.putInt(parent * Ints.BYTES, offset); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(offset, parent); + heapIndexUpdater.updateHeapIndexForOffset(parentOffset, minIndex); + } + } + minChild = findMinChild(comparator, minIndex); + } + pos = minIndex; + } else { + int offset = buf.getInt(pos * Ints.BYTES); + int minOffset = buf.getInt(minIndex * Ints.BYTES); + if (comparator.compare(minOffset, offset) < 0) { + buf.putInt(pos * Ints.BYTES, minOffset); + buf.putInt(minIndex * Ints.BYTES, offset); + if (heapIndexUpdater != null) { + heapIndexUpdater.updateHeapIndexForOffset(offset, minIndex); + heapIndexUpdater.updateHeapIndexForOffset(minOffset, pos); + } + } + break; + } + } + } + + private boolean isEvenLevel(int index) { + int oneBased = index + 1; + return (oneBased & EVEN_POWERS_OF_TWO) > (oneBased & ODD_POWERS_OF_TWO); + } + + /** + * Returns the index of minimum value between {@code index} and + * {@code index + len}, or {@code -1} if {@code index} is greater than + * {@code size}. + */ + private int findMin(Comparator comparator, int index, int len) { + if (index >= heapSize) { + return -1; + } + int limit = Math.min(index, heapSize - len) + len; + int minIndex = index; + for (int i = index + 1; i < limit; i++) { + if (comparator.compare(buf.getInt(i * Ints.BYTES), buf.getInt(minIndex * Ints.BYTES)) < 0) { + minIndex = i; + } + } + return minIndex; + } + + /** + * Returns the minimum child or {@code -1} if no child exists. + */ + private int findMinChild(Comparator comparator, int index) { + return findMin(comparator, getLeftChildIndex(index), 2); + } + + /** + * Returns the minimum grand child or -1 if no grand child exists. + */ + private int findMinGrandChild(Comparator comparator, int index) { + int leftChildIndex = getLeftChildIndex(index); + if (leftChildIndex < 0) { + return -1; + } + return findMin(comparator, getLeftChildIndex(leftChildIndex), 4); + } + + private int getLeftChildIndex(int i) { + return i * 2 + 1; + } + + private int getRightChildIndex(int i) { + return i * 2 + 2; + } + + private int getParentIndex(int i) { + if (i == 0) { + return -1; + } + return (i - 1) / 2; + } + + private int getGrandparentIndex(int i) { + if (i < 3) { + return -1; + } + return (i - 3) / 4; + } + + /** + * Returns the index of the max element. + */ + private int findMaxElementIndex() { + switch (heapSize) { + case 1: + return 0; // The lone element in the queue is the maximum. + case 2: + return 1; // The lone element in the maxHeap is the maximum. + default: + // The max element must sit on the first level of the maxHeap. It is + // actually the *lesser* of the two from the maxHeap's perspective. + int offset1 = buf.getInt(1 * Ints.BYTES); + int offset2 = buf.getInt(2 * Ints.BYTES); + return maxComparator.compare(offset1, offset2) <= 0 ? 1 : 2; + } + } + + @VisibleForTesting + boolean isIntact() { + for (int i = 0; i < heapSize; i++) { + if (!verifyIndex(i)) { + return false; + } + } + return true; + } + + private boolean verifyIndex(int i) + { + Comparator comparator = isEvenLevel(i) ? minComparator : maxComparator; + int offset = buf.getInt(i * Ints.BYTES); + + int lcIdx = getLeftChildIndex(i); + if (lcIdx < heapSize) { + int leftChildOffset = buf.getInt(lcIdx * Ints.BYTES); + if (comparator.compare(offset, leftChildOffset) > 0) { + throw new ISE("Left child val[%d] at idx[%d] is less than val[%d] at idx[%d]", + leftChildOffset, lcIdx, offset, i); + } + } + + int rcIdx = getRightChildIndex(i); + if (rcIdx < heapSize) { + int rightChildOffset = buf.getInt(rcIdx * Ints.BYTES); + if (comparator.compare(offset, rightChildOffset) > 0) { + throw new ISE("Right child val[%d] at idx[%d] is less than val[%d] at idx[%d]", + rightChildOffset, rcIdx, offset, i); + } + } + + if (i > 0) { + int parentIdx = getParentIndex(i); + int parentOffset = buf.getInt(parentIdx * Ints.BYTES); + if (comparator.compare(offset, parentOffset) > 0) { + throw new ISE("Parent val[%d] at idx[%d] is less than val[%d] at idx[%d]", + parentOffset, parentIdx, offset, i); + } + } + + if (i > 2) { + int gpIdx = getGrandparentIndex(i); + int gpOffset = buf.getInt(gpIdx * Ints.BYTES); + if (comparator.compare(gpOffset, offset) > 0) { + throw new ISE("Grandparent val[%d] at idx[%d] is less than val[%d] at idx[%d]", + gpOffset, gpIdx, offset, i); + } + } + + return true; + } + + @Override + public String toString() + { + if (heapSize == 0) { + return "[]"; + } + + String ret = "["; + for (int i = 0; i < heapSize; i++) { + ret += buf.getInt(i * Ints.BYTES); + if (i < heapSize - 1) { + ret += ", "; + } + } + + ret += "]"; + return ret; + } +} diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouper.java index aa4ee9a73c70..71d3ff6f1543 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouper.java @@ -25,6 +25,7 @@ import com.google.common.base.Suppliers; import io.druid.java.util.common.ISE; import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.groupby.orderby.DefaultLimitSpec; import io.druid.segment.ColumnSelectorFactory; import java.nio.ByteBuffer; @@ -50,7 +51,7 @@ public class ConcurrentGrouper implements Grouper private final AtomicInteger threadNumber = new AtomicInteger(); private volatile boolean spilling = false; private volatile boolean closed = false; - private final Comparator keyObjComparator; + private final Comparator> keyObjComparator; private final Supplier bufferSupplier; private final ColumnSelectorFactory columnSelectorFactory; @@ -62,6 +63,8 @@ public class ConcurrentGrouper implements Grouper private final ObjectMapper spillMapper; private final int concurrencyHint; private final KeySerdeFactory keySerdeFactory; + private final DefaultLimitSpec limitSpec; + private final boolean sortHasNonGroupingFields; private volatile boolean initialized = false; @@ -75,7 +78,9 @@ public ConcurrentGrouper( final int bufferGrouperInitialBuckets, final LimitedTemporaryStorage temporaryStorage, final ObjectMapper spillMapper, - final int concurrencyHint + final int concurrencyHint, + final DefaultLimitSpec limitSpec, + final boolean sortHasNonGroupingFields ) { Preconditions.checkArgument(concurrencyHint > 0, "concurrencyHint > 0"); @@ -100,7 +105,9 @@ protected SpillingGrouper initialValue() this.spillMapper = spillMapper; this.concurrencyHint = concurrencyHint; this.keySerdeFactory = keySerdeFactory; - this.keyObjComparator = keySerdeFactory.objectComparator(); + this.limitSpec = limitSpec; + this.sortHasNonGroupingFields = sortHasNonGroupingFields; + this.keyObjComparator = keySerdeFactory.objectComparator(sortHasNonGroupingFields); } @Override @@ -126,7 +133,9 @@ public void init() bufferGrouperInitialBuckets, temporaryStorage, spillMapper, - false + false, + limitSpec, + sortHasNonGroupingFields ); grouper.init(); groupers.add(grouper); diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java index a0f82b996304..932136f07513 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/GroupByQueryEngineV2.java @@ -441,12 +441,21 @@ public ByteBuffer fromByteBuffer(ByteBuffer buffer, int position) } @Override - public Grouper.KeyComparator bufferComparator() + public Grouper.BufferComparator bufferComparator() { // No sorting, let mergeRunners handle that throw new UnsupportedOperationException(); } + @Override + public Grouper.BufferComparator bufferComparatorWithAggregators( + AggregatorFactory[] aggregatorFactories, int[] aggregatorOffsets + ) + { + // not called on this + throw new UnsupportedOperationException(); + } + @Override public void reset() { diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/Grouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/Grouper.java index 1971c2b86b27..2f6f795fd091 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/Grouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/Grouper.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import io.druid.query.aggregation.AggregatorFactory; import java.io.Closeable; import java.nio.ByteBuffer; @@ -178,9 +179,12 @@ interface KeySerdeFactory * Return an object that knows how to compare two serialized key instances. Will be called by the * {@link #iterator(boolean)} method if sorting is enabled. * + * @param forceDefaultOrder Return a comparator that sorts by the key in default lexicographic ascending order, + * regardless of any other conditions (e.g., presence of OrderBySpecs). + * * @return comparator for key objects. */ - Comparator objectComparator(); + Comparator> objectComparator(boolean forceDefaultOrder); } /** @@ -228,7 +232,19 @@ interface KeySerde * * @return comparator for keys */ - KeyComparator bufferComparator(); + BufferComparator bufferComparator(); + + /** + * When pushing down limits, it may also be necessary to compare aggregated values along with the key + * using the bufferComparator. + * + * @param aggregatorFactories Array of aggregators from a GroupByQuery + * @param aggregatorOffsets Offsets for each aggregator in aggregatorFactories pointing to their location + * within the grouping key + aggs buffer. + * + * @return comparator for keys + aggs + */ + BufferComparator bufferComparatorWithAggregators(AggregatorFactory[] aggregatorFactories, int[] aggregatorOffsets); /** * Reset the keySerde to its initial state. After this method is called, {@link #fromByteBuffer(ByteBuffer, int)} @@ -237,7 +253,7 @@ interface KeySerde void reset(); } - interface KeyComparator + interface BufferComparator { int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition); } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/Groupers.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/Groupers.java index 92400427cfdc..2324f0f3779d 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/Groupers.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/Groupers.java @@ -57,7 +57,7 @@ public static int hash(final Object obj) public static Iterator> mergeIterators( final Iterable>> iterators, - final Comparator keyTypeComparator + final Comparator> keyTypeComparator ) { if (keyTypeComparator != null) { @@ -68,7 +68,7 @@ public static Iterator> mergeIterators( @Override public int compare(Grouper.Entry lhs, Grouper.Entry rhs) { - return keyTypeComparator.compare(lhs.getKey(), rhs.getKey()); + return keyTypeComparator.compare(lhs, rhs); } } ); diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/LimitedBufferGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/LimitedBufferGrouper.java new file mode 100644 index 000000000000..dd6442835bab --- /dev/null +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/LimitedBufferGrouper.java @@ -0,0 +1,519 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.base.Supplier; +import com.google.common.collect.Iterators; +import com.google.common.primitives.Ints; +import io.druid.java.util.common.IAE; +import io.druid.java.util.common.ISE; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.segment.ColumnSelectorFactory; + +import java.nio.ByteBuffer; +import java.util.AbstractList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +public class LimitedBufferGrouper extends AbstractBufferGrouper +{ + private static final int MIN_INITIAL_BUCKETS = 4; + private static final int DEFAULT_INITIAL_BUCKETS = 1024; + private static final float DEFAULT_MAX_LOAD_FACTOR = 0.7f; + + private final AggregatorFactory[] aggregatorFactories; + + // Limit to apply to results. + private int limit; + + // Indicates if the sorting order has fields not in the grouping key, used when pushing down limit/sorting. + // In this case, grouping key comparisons need to also compare on aggregators. + // Additionally, results must be resorted by grouping key to allow results to merge correctly. + private boolean sortHasNonGroupingFields; + + // Min-max heap, used for storing offsets when applying limits/sorting in the BufferGrouper + private ByteBufferMinMaxOffsetHeap offsetHeap; + + // ByteBuffer slices used by the grouper + private ByteBuffer totalBuffer; + private ByteBuffer hashTableBuffer; + private ByteBuffer offsetHeapBuffer; + + // Updates the heap index field for buckets, created passed to the heap when + // pushing down limit and the sort order includes aggregators + private BufferGrouperOffsetHeapIndexUpdater heapIndexUpdater; + private boolean initialized = false; + + public LimitedBufferGrouper( + final Supplier bufferSupplier, + final Grouper.KeySerde keySerde, + final ColumnSelectorFactory columnSelectorFactory, + final AggregatorFactory[] aggregatorFactories, + final int bufferGrouperMaxSize, + final float maxLoadFactor, + final int initialBuckets, + final int limit, + final boolean sortHasNonGroupingFields + ) + { + super(bufferSupplier, keySerde, aggregatorFactories, bufferGrouperMaxSize); + this.maxLoadFactor = maxLoadFactor > 0 ? maxLoadFactor : DEFAULT_MAX_LOAD_FACTOR; + this.initialBuckets = initialBuckets > 0 ? Math.max(MIN_INITIAL_BUCKETS, initialBuckets) : DEFAULT_INITIAL_BUCKETS; + this.limit = limit; + this.sortHasNonGroupingFields = sortHasNonGroupingFields; + + if (this.maxLoadFactor >= 1.0f) { + throw new IAE("Invalid maxLoadFactor[%f], must be < 1.0", maxLoadFactor); + } + + int offset = HASH_SIZE + keySize; + this.aggregatorFactories = aggregatorFactories; + for (int i = 0; i < aggregatorFactories.length; i++) { + aggregators[i] = aggregatorFactories[i].factorizeBuffered(columnSelectorFactory); + aggregatorOffsets[i] = offset; + offset += aggregatorFactories[i].getMaxIntermediateSize(); + } + + // For each bucket, store an extra field indicating the bucket's current index within the heap when + // pushing down limits + offset += Ints.BYTES; + this.bucketSize = offset; + } + + @Override + public void init() + { + if (initialized) { + return; + } + this.totalBuffer = bufferSupplier.get(); + + validateBufferCapacity( + limit, + maxLoadFactor, + totalBuffer, + bucketSize + ); + + //only store offsets up to `limit` + 1 instead of up to # of buckets, we only keep the top results + int heapByteSize = (limit + 1) * Ints.BYTES; + + int hashTableSize = ByteBufferHashTable.calculateTableArenaSizeWithFixedAdditionalSize( + totalBuffer.capacity(), + bucketSize, + heapByteSize + ); + + hashTableBuffer = totalBuffer.duplicate(); + hashTableBuffer.position(0); + hashTableBuffer.limit(hashTableSize); + hashTableBuffer = hashTableBuffer.slice(); + + offsetHeapBuffer = totalBuffer.duplicate(); + offsetHeapBuffer.position(hashTableSize); + offsetHeapBuffer = offsetHeapBuffer.slice(); + offsetHeapBuffer.limit(totalBuffer.capacity() - hashTableSize); + + this.hashTable = new AlternatingByteBufferHashTable( + maxLoadFactor, + initialBuckets, + bucketSize, + hashTableBuffer, + keySize, + bufferGrouperMaxSize + ); + this.heapIndexUpdater = new BufferGrouperOffsetHeapIndexUpdater(totalBuffer, bucketSize - Ints.BYTES); + this.offsetHeap = new ByteBufferMinMaxOffsetHeap(offsetHeapBuffer, limit, makeHeapComparator(), heapIndexUpdater); + + reset(); + + initialized = true; + } + + @Override + public boolean isInitialized() + { + return initialized; + } + + @Override + public void newBucketHook(int bucketOffset) + { + heapIndexUpdater.updateHeapIndexForOffset(bucketOffset, -1); + } + + @Override + public boolean canSkipAggregate(boolean bucketWasUsed, int bucketOffset) + { + if (bucketWasUsed) { + if (!sortHasNonGroupingFields) { + if (heapIndexUpdater.getHeapIndexForOffset(bucketOffset) < 0) { + return true; + } + } + } + return false; + } + + @Override + public void afterAggregateHook(int bucketOffset) + { + int heapIndex = heapIndexUpdater.getHeapIndexForOffset(bucketOffset); + if (heapIndex < 0) { + // not in the heap, add it + offsetHeap.addOffset(bucketOffset); + } else if (sortHasNonGroupingFields) { + // Since the sorting columns contain at least one aggregator, we need to remove and reinsert + // the entries after aggregating to maintain proper ordering + offsetHeap.removeAt(heapIndex); + offsetHeap.addOffset(bucketOffset); + } + } + + @Override + public void reset() + { + hashTable.reset(); + keySerde.reset(); + offsetHeap.reset(); + heapIndexUpdater.setHashTableBuffer(hashTable.getTableBuffer()); + } + + @Override + public Iterator> iterator(boolean sorted) + { + if (!initialized) { + // it's possible for iterator() to be called before initialization when + // a nested groupBy's subquery has an empty result set (see testEmptySubqueryWithLimitPushDown() + // in GroupByQueryRunnerTest) + return Iterators.>emptyIterator(); + } + + if (sortHasNonGroupingFields) { + // re-sort the heap in place, it's also an array of offsets in the totalBuffer + return makeDefaultOrderingIterator(); + } else { + return makeHeapIterator(); + } + } + + public int getLimit() + { + return limit; + } + + public static class BufferGrouperOffsetHeapIndexUpdater + { + private ByteBuffer hashTableBuffer; + private final int indexPosition; + + public BufferGrouperOffsetHeapIndexUpdater( + ByteBuffer hashTableBuffer, + int indexPosition + ) + { + this.hashTableBuffer = hashTableBuffer; + this.indexPosition = indexPosition; + } + + public void setHashTableBuffer(ByteBuffer newTableBuffer) { + hashTableBuffer = newTableBuffer; + } + + public void updateHeapIndexForOffset(int bucketOffset, int newHeapIndex) + { + hashTableBuffer.putInt(bucketOffset + indexPosition, newHeapIndex); + } + + public int getHeapIndexForOffset(int bucketOffset) + { + return hashTableBuffer.getInt(bucketOffset + indexPosition); + } + } + + private Iterator> makeDefaultOrderingIterator() + { + final int size = offsetHeap.getHeapSize(); + + final List wrappedOffsets = new AbstractList() + { + @Override + public Integer get(int index) + { + return offsetHeap.getAt(index); + } + + @Override + public Integer set(int index, Integer element) + { + final Integer oldValue = get(index); + offsetHeap.setAt(index, element); + return oldValue; + } + + @Override + public int size() + { + return size; + } + }; + + final BufferComparator comparator = keySerde.bufferComparator(); + + // Sort offsets in-place. + Collections.sort( + wrappedOffsets, + new Comparator() + { + @Override + public int compare(Integer lhs, Integer rhs) + { + final ByteBuffer curHashTableBuffer = hashTable.getTableBuffer(); + return comparator.compare( + curHashTableBuffer, + curHashTableBuffer, + lhs + HASH_SIZE, + rhs + HASH_SIZE + ); + } + } + ); + + return new Iterator>() + { + int curr = 0; + + @Override + public boolean hasNext() + { + return curr < size; + } + + @Override + public Grouper.Entry next() + { + return bucketEntryForOffset(wrappedOffsets.get(curr++)); + } + + @Override + public void remove() + { + throw new UnsupportedOperationException(); + } + }; + } + + private Iterator> makeHeapIterator() + { + final int initialHeapSize = offsetHeap.getHeapSize(); + return new Iterator>() + { + int curr = 0; + + @Override + public boolean hasNext() + { + return curr < initialHeapSize; + } + + @Override + public Grouper.Entry next() + { + if (curr >= initialHeapSize) { + throw new NoSuchElementException(); + } + final int offset = offsetHeap.removeMin(); + final Grouper.Entry entry = bucketEntryForOffset(offset); + curr++; + + return entry; + } + + @Override + public void remove() + { + throw new UnsupportedOperationException(); + } + }; + } + + private Comparator makeHeapComparator() + { + return new Comparator() + { + final BufferComparator bufferComparator = keySerde.bufferComparatorWithAggregators( + aggregatorFactories, + aggregatorOffsets + ); + @Override + public int compare(Integer o1, Integer o2) + { + final ByteBuffer tableBuffer = hashTable.getTableBuffer(); + return bufferComparator.compare(tableBuffer, tableBuffer, o1 + HASH_SIZE, o2 + HASH_SIZE); + } + }; + } + + + private void validateBufferCapacity( + int limit, + float maxLoadFactor, + ByteBuffer buffer, + int bucketSize + ) + { + int numBucketsNeeded = (int) Math.ceil((limit + 1) / maxLoadFactor); + int targetTableArenaSize = numBucketsNeeded * bucketSize * 2; + int heapSize = (limit + 1) * (Ints.BYTES); + int requiredSize = targetTableArenaSize + heapSize; + + if (buffer.capacity() < requiredSize) { + throw new IAE( + "Buffer capacity [%d] is too small for limit[%d] with load factor[%f], minimum bytes needed: [%d]", + buffer.capacity(), + limit, + maxLoadFactor, + requiredSize + ); + } + } + + private class AlternatingByteBufferHashTable extends ByteBufferHashTable + { + // The base buffer is split into two alternating halves, with one sub-buffer in use at a given time. + // When the current sub-buffer fills, the used bits of the other sub-buffer are cleared, entries up to the limit + // are copied from the current full sub-buffer to the new buffer, and the active buffer (referenced by tableBuffer) + // is swapped to the new buffer. + private ByteBuffer[] subHashTableBuffers; + + public AlternatingByteBufferHashTable( + float maxLoadFactor, + int initialBuckets, + int bucketSizeWithHash, + ByteBuffer totalHashTableBuffer, + int keySize, + int maxSizeForTesting + ) + { + super( + maxLoadFactor, + initialBuckets, + bucketSizeWithHash, + totalHashTableBuffer, + keySize, + maxSizeForTesting, + null + ); + + this.growthCount = 0; + + int subHashTableSize = tableArenaSize / 2; + maxBuckets = subHashTableSize / bucketSizeWithHash; + regrowthThreshold = maxSizeForBuckets(maxBuckets); + + // split the hashtable into 2 sub tables that we rotate between + ByteBuffer subHashTable1Buffer = totalHashTableBuffer.duplicate(); + subHashTable1Buffer.position(0); + subHashTable1Buffer.limit(subHashTableSize); + subHashTable1Buffer = subHashTable1Buffer.slice(); + + ByteBuffer subHashTable2Buffer = totalHashTableBuffer.duplicate(); + subHashTable2Buffer.position(subHashTableSize); + subHashTable2Buffer.limit(tableArenaSize); + subHashTable2Buffer = subHashTable2Buffer.slice(); + + subHashTableBuffers = new ByteBuffer[] {subHashTable1Buffer, subHashTable2Buffer}; + } + + @Override + public void reset() + { + size = 0; + growthCount = 0; + // clear the used bits of the first buffer + for (int i = 0; i < maxBuckets; i++) { + subHashTableBuffers[0].put(i * bucketSizeWithHash, (byte) 0); + } + tableBuffer = subHashTableBuffers[0]; + } + + @Override + public void adjustTableWhenFull() + { + int newTableIdx = (growthCount % 2 == 0) ? 1 : 0; + ByteBuffer newTableBuffer = subHashTableBuffers[newTableIdx]; + + // clear the used bits of the buffer we're swapping to + for (int i = 0; i < maxBuckets; i++) { + newTableBuffer.put(i * bucketSizeWithHash, (byte) 0); + } + + // Get the offsets of the top N buckets from the heap and copy the buckets to new table + final ByteBuffer entryBuffer = tableBuffer.duplicate(); + final ByteBuffer keyBuffer = tableBuffer.duplicate(); + + int numCopied = 0; + for (int i = 0; i < offsetHeap.getHeapSize(); i++) { + final int oldBucketOffset = offsetHeap.getAt(i); + + if (isOffsetUsed(oldBucketOffset)) { + // Read the entry from the old hash table + entryBuffer.limit(oldBucketOffset + bucketSizeWithHash); + entryBuffer.position(oldBucketOffset); + keyBuffer.limit(entryBuffer.position() + HASH_SIZE + keySize); + keyBuffer.position(entryBuffer.position() + HASH_SIZE); + + // Put the entry in the new hash table + final int keyHash = entryBuffer.getInt(entryBuffer.position()) & 0x7fffffff; + final int newBucket = findBucket(true, maxBuckets, newTableBuffer, keyBuffer, keyHash); + + if (newBucket < 0) { + throw new ISE("WTF?! Couldn't find a bucket while resizing?!"); + } + + final int newBucketOffset = newBucket * bucketSizeWithHash; + newTableBuffer.position(newBucketOffset); + newTableBuffer.put(entryBuffer); + numCopied++; + + // Update the heap with the copied bucket's new offset in the new table + offsetHeap.setAt(i, newBucketOffset); + + // relocate aggregators (see https://github.com/druid-io/druid/pull/4071) + for (int j = 0; j < aggregators.length; j++) { + aggregators[j].relocate( + oldBucketOffset + aggregatorOffsets[j], + newBucketOffset + aggregatorOffsets[j], + tableBuffer, + newTableBuffer + ); + } + } + } + + size = numCopied; + tableBuffer = newTableBuffer; + growthCount++; + } + } +} diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java index c866405ff0b3..6ce1f85f9f97 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/RowBasedGrouperHelper.java @@ -29,6 +29,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.primitives.Chars; +import com.google.common.primitives.Doubles; import com.google.common.primitives.Floats; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; @@ -47,7 +48,11 @@ import io.druid.query.groupby.GroupByQuery; import io.druid.query.groupby.GroupByQueryConfig; import io.druid.query.groupby.RowBasedColumnSelectorFactory; +import io.druid.query.groupby.orderby.DefaultLimitSpec; +import io.druid.query.groupby.orderby.OrderByColumnSpec; import io.druid.query.groupby.strategy.GroupByStrategyV2; +import io.druid.query.ordering.StringComparator; +import io.druid.query.ordering.StringComparators; import io.druid.segment.ColumnSelectorFactory; import io.druid.segment.ColumnValueSelector; import io.druid.segment.DimensionHandlerUtils; @@ -65,8 +70,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; // this class contains shared code between GroupByMergingQueryRunnerV2 and GroupByRowProcessor public class RowBasedGrouperHelper @@ -95,13 +102,7 @@ public static Pair, Accumulator> crea final GroupByQueryConfig querySpecificConfig = config.withOverrides(query); final boolean includeTimestamp = GroupByStrategyV2.getUniversalTimestamp(query) == null; - final Grouper.KeySerdeFactory keySerdeFactory = new RowBasedKeySerdeFactory( - includeTimestamp, - query.getContextSortByDimsFirst(), - query.getDimensions().size(), - querySpecificConfig.getMaxMergingDictionarySize() / (concurrencyHint == -1 ? 1 : concurrencyHint), - valueTypes - ); + final ThreadLocal columnSelectorRow = new ThreadLocal<>(); final ColumnSelectorFactory columnSelectorFactory = query.getVirtualColumns().wrap( RowBasedColumnSelectorFactory.create( @@ -109,6 +110,27 @@ public static Pair, Accumulator> crea rawInputRowSignature ) ); + + final boolean willApplyLimitPushDown = query.isApplyLimitPushDown(); + final DefaultLimitSpec limitSpec = willApplyLimitPushDown ? (DefaultLimitSpec) query.getLimitSpec() : null; + boolean sortHasNonGroupingFields = false; + if (willApplyLimitPushDown) { + sortHasNonGroupingFields = DefaultLimitSpec.sortingOrderHasNonGroupingFields( + limitSpec, + query.getDimensions() + ); + } + + final Grouper.KeySerdeFactory keySerdeFactory = new RowBasedKeySerdeFactory( + includeTimestamp, + query.getContextSortByDimsFirst(), + query.getDimensions(), + querySpecificConfig.getMaxMergingDictionarySize() / (concurrencyHint == -1 ? 1 : concurrencyHint), + valueTypes, + aggregatorFactories, + limitSpec + ); + final Grouper grouper; if (concurrencyHint == -1) { grouper = new SpillingGrouper<>( @@ -121,7 +143,9 @@ public static Pair, Accumulator> crea querySpecificConfig.getBufferGrouperInitialBuckets(), temporaryStorage, spillMapper, - true + true, + limitSpec, + sortHasNonGroupingFields ); } else { grouper = new ConcurrentGrouper<>( @@ -134,7 +158,9 @@ public static Pair, Accumulator> crea querySpecificConfig.getBufferGrouperInitialBuckets(), temporaryStorage, spillMapper, - concurrencyHint + concurrencyHint, + limitSpec, + sortHasNonGroupingFields ); } @@ -586,70 +612,207 @@ private static class RowBasedKeySerdeFactory implements Grouper.KeySerdeFactory< private final boolean sortByDimsFirst; private final int dimCount; private final long maxDictionarySize; + private final DefaultLimitSpec limitSpec; + private final List dimensions; + final AggregatorFactory[] aggregatorFactories; private final List valueTypes; RowBasedKeySerdeFactory( boolean includeTimestamp, boolean sortByDimsFirst, - int dimCount, + List dimensions, long maxDictionarySize, - List valueTypes + List valueTypes, + final AggregatorFactory[] aggregatorFactories, + DefaultLimitSpec limitSpec ) { this.includeTimestamp = includeTimestamp; this.sortByDimsFirst = sortByDimsFirst; - this.dimCount = dimCount; + this.dimensions = dimensions; + this.dimCount = dimensions.size(); this.maxDictionarySize = maxDictionarySize; + this.limitSpec = limitSpec; + this.aggregatorFactories = aggregatorFactories; this.valueTypes = valueTypes; } @Override public Grouper.KeySerde factorize() { - return new RowBasedKeySerde(includeTimestamp, sortByDimsFirst, dimCount, maxDictionarySize, valueTypes); + return new RowBasedKeySerde( + includeTimestamp, + sortByDimsFirst, + dimensions, + maxDictionarySize, + limitSpec, + valueTypes + ); } @Override - public Comparator objectComparator() + public Comparator> objectComparator(boolean forceDefaultOrder) + { + if (limitSpec != null && !forceDefaultOrder) { + return objectComparatorWithAggs(); + } + + if (includeTimestamp) { + if (sortByDimsFirst) { + return new Comparator>() + { + @Override + public int compare(Grouper.Entry entry1, Grouper.Entry entry2) + { + final int cmp = compareDimsInRows(entry1.getKey(), entry2.getKey(), 1); + if (cmp != 0) { + return cmp; + } + + return Longs.compare((long) entry1.getKey().getKey()[0], (long) entry2.getKey().getKey()[0]); + } + }; + } else { + return new Comparator>() + { + @Override + public int compare(Grouper.Entry entry1, Grouper.Entry entry2) + { + final int timeCompare = Longs.compare( + (long) entry1.getKey().getKey()[0], + (long) entry2.getKey().getKey()[0] + ); + + if (timeCompare != 0) { + return timeCompare; + } + + return compareDimsInRows(entry1.getKey(), entry2.getKey(), 1); + } + }; + } + } else { + return new Comparator>() + { + @Override + public int compare(Grouper.Entry entry1, Grouper.Entry entry2) + { + return compareDimsInRows(entry1.getKey(), entry2.getKey(), 0); + } + }; + } + } + + private Comparator> objectComparatorWithAggs() { + // use the actual sort order from the limitspec if pushing down to merge partial results correctly + final List needsReverses = Lists.newArrayList(); + final List aggFlags = Lists.newArrayList(); + final List isNumericField = Lists.newArrayList(); + final List comparators = Lists.newArrayList(); + final List fieldIndices = Lists.newArrayList(); + final Set orderByIndices = new HashSet<>(); + + for (OrderByColumnSpec orderSpec : limitSpec.getColumns()) { + final boolean needsReverse = orderSpec.getDirection() != OrderByColumnSpec.Direction.ASCENDING; + int dimIndex = OrderByColumnSpec.getDimIndexForOrderBy(orderSpec, dimensions); + if (dimIndex >= 0) { + fieldIndices.add(dimIndex); + orderByIndices.add(dimIndex); + needsReverses.add(needsReverse); + aggFlags.add(false); + final ValueType type = dimensions.get(dimIndex).getOutputType(); + isNumericField.add(type == ValueType.LONG || type == ValueType.FLOAT); + comparators.add(orderSpec.getDimensionComparator()); + } else { + int aggIndex = OrderByColumnSpec.getAggIndexForOrderBy(orderSpec, Arrays.asList(aggregatorFactories)); + if (aggIndex >= 0) { + fieldIndices.add(aggIndex); + needsReverses.add(needsReverse); + aggFlags.add(true); + final String typeName = aggregatorFactories[aggIndex].getTypeName(); + isNumericField.add(typeName.equals("long") || typeName.equals("float")); + comparators.add(orderSpec.getDimensionComparator()); + } + } + } + + for (int i = 0; i < dimCount; i++) { + if (!orderByIndices.contains(i)) { + fieldIndices.add(i); + aggFlags.add(false); + needsReverses.add(false); + final ValueType type = dimensions.get(i).getOutputType(); + isNumericField.add(type == ValueType.LONG || type == ValueType.FLOAT); + comparators.add(StringComparators.LEXICOGRAPHIC); + } + } + if (includeTimestamp) { if (sortByDimsFirst) { - return new Comparator() + return new Comparator>() { @Override - public int compare(RowBasedKey key1, RowBasedKey key2) + public int compare(Grouper.Entry entry1, Grouper.Entry entry2) { - final int cmp = compareDimsInRows(key1, key2, 1); + final int cmp = compareDimsInRowsWithAggs( + entry1, + entry2, + 1, + needsReverses, + aggFlags, + fieldIndices, + isNumericField, + comparators + ); if (cmp != 0) { return cmp; } - return Longs.compare((long) key1.getKey()[0], (long) key2.getKey()[0]); + return Longs.compare((long) entry1.getKey().getKey()[0], (long) entry2.getKey().getKey()[0]); } }; } else { - return new Comparator() + return new Comparator>() { @Override - public int compare(RowBasedKey key1, RowBasedKey key2) + public int compare(Grouper.Entry entry1, Grouper.Entry entry2) { - final int timeCompare = Longs.compare((long) key1.getKey()[0], (long) key2.getKey()[0]); + final int timeCompare = Longs.compare((long) entry1.getKey().getKey()[0], (long) entry2.getKey().getKey()[0]); if (timeCompare != 0) { return timeCompare; } - return compareDimsInRows(key1, key2, 1); + return compareDimsInRowsWithAggs( + entry1, + entry2, + 1, + needsReverses, + aggFlags, + fieldIndices, + isNumericField, + comparators + ); } }; } } else { - return new Comparator() + return new Comparator>() { @Override - public int compare(RowBasedKey key1, RowBasedKey key2) + public int compare(Grouper.Entry entry1, Grouper.Entry entry2) { - return compareDimsInRows(key1, key2, 0); + return compareDimsInRowsWithAggs( + entry1, + entry2, + 0, + needsReverses, + aggFlags, + fieldIndices, + isNumericField, + comparators + ); } }; } @@ -666,22 +829,77 @@ private static int compareDimsInRows(RowBasedKey key1, RowBasedKey key2, int dim return 0; } + + private static int compareDimsInRowsWithAggs( + Grouper.Entry entry1, + Grouper.Entry entry2, + int dimStart, + final List needsReverses, + final List aggFlags, + final List fieldIndices, + final List isNumericField, + final List comparators + ) + { + for (int i = 0; i < fieldIndices.size(); i++) { + final int fieldIndex = fieldIndices.get(i); + final boolean needsReverse = needsReverses.get(i); + final int cmp; + final Comparable lhs; + final Comparable rhs; + + if (aggFlags.get(i)) { + if (needsReverse) { + lhs = (Comparable) entry2.getValues()[fieldIndex]; + rhs = (Comparable) entry1.getValues()[fieldIndex]; + } else { + lhs = (Comparable) entry1.getValues()[fieldIndex]; + rhs = (Comparable) entry2.getValues()[fieldIndex]; + } + } else { + if (needsReverse) { + lhs = (Comparable) entry2.getKey().getKey()[fieldIndex + dimStart]; + rhs = (Comparable) entry1.getKey().getKey()[fieldIndex + dimStart]; + } else { + lhs = (Comparable) entry1.getKey().getKey()[fieldIndex + dimStart]; + rhs = (Comparable) entry2.getKey().getKey()[fieldIndex + dimStart]; + } + } + + final StringComparator comparator = comparators.get(i); + + if (isNumericField.get(i) && comparator == StringComparators.NUMERIC) { + // use natural comparison + cmp = lhs.compareTo(rhs); + } else { + cmp = comparator.compare(lhs.toString(), rhs.toString()); + } + + if (cmp != 0) { + return cmp; + } + } + + return 0; + } } - private static class RowBasedKeySerde implements Grouper.KeySerde + private static class RowBasedKeySerde implements Grouper.KeySerde { // Entry in dictionary, node pointer in reverseDictionary, hash + k/v/next pointer in reverseDictionary nodes private static final int ROUGH_OVERHEAD_PER_DICTIONARY_ENTRY = Longs.BYTES * 5 + Ints.BYTES; private final boolean includeTimestamp; private final boolean sortByDimsFirst; + private final List dimensions; private final int dimCount; private final int keySize; private final ByteBuffer keyBuffer; private final List dictionary = Lists.newArrayList(); private final Map reverseDictionary = Maps.newHashMap(); - private final List valueTypes; private final List serdeHelpers; + private final DefaultLimitSpec limitSpec; + private final List valueTypes; // Size limiting for the dictionary, in (roughly estimated) bytes. private final long maxDictionarySize; @@ -693,16 +911,19 @@ private static class RowBasedKeySerde implements Grouper.KeySerde RowBasedKeySerde( final boolean includeTimestamp, final boolean sortByDimsFirst, - final int dimCount, + final List dimensions, final long maxDictionarySize, + final DefaultLimitSpec limitSpec, final List valueTypes ) { this.includeTimestamp = includeTimestamp; this.sortByDimsFirst = sortByDimsFirst; - this.dimCount = dimCount; + this.dimensions = dimensions; + this.dimCount = dimensions.size(); this.maxDictionarySize = maxDictionarySize; this.valueTypes = valueTypes; + this.limitSpec = limitSpec; this.serdeHelpers = makeSerdeHelpers(); this.keySize = (includeTimestamp ? Longs.BYTES : 0) + getTotalKeySize(); this.keyBuffer = ByteBuffer.allocate(keySize); @@ -769,7 +990,7 @@ public RowBasedKey fromByteBuffer(ByteBuffer buffer, int position) } @Override - public Grouper.KeyComparator bufferComparator() + public Grouper.BufferComparator bufferComparator() { if (sortableIds == null) { Map sortedMap = Maps.newTreeMap(); @@ -785,7 +1006,7 @@ public Grouper.KeyComparator bufferComparator() if (includeTimestamp) { if (sortByDimsFirst) { - return new Grouper.KeyComparator() + return new Grouper.BufferComparator() { @Override public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) @@ -807,7 +1028,7 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, } }; } else { - return new Grouper.KeyComparator() + return new Grouper.BufferComparator() { @Override public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) @@ -831,7 +1052,7 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, }; } } else { - return new Grouper.KeyComparator() + return new Grouper.BufferComparator() { @Override public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) @@ -855,6 +1076,158 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, } } + @Override + public Grouper.BufferComparator bufferComparatorWithAggregators( + AggregatorFactory[] aggregatorFactories, + int[] aggregatorOffsets + ) + { + final List adjustedSerdeHelpers; + final List needsReverses = Lists.newArrayList(); + List orderByHelpers = new ArrayList<>(); + List otherDimHelpers = new ArrayList<>(); + Set orderByIndices = new HashSet<>(); + + int aggCount = 0; + boolean needsReverse; + for (OrderByColumnSpec orderSpec : limitSpec.getColumns()) { + needsReverse = orderSpec.getDirection() != OrderByColumnSpec.Direction.ASCENDING; + int dimIndex = OrderByColumnSpec.getDimIndexForOrderBy(orderSpec, dimensions); + if (dimIndex >= 0) { + RowBasedKeySerdeHelper serdeHelper = serdeHelpers.get(dimIndex); + orderByHelpers.add(serdeHelper); + orderByIndices.add(dimIndex); + needsReverses.add(needsReverse); + } else { + int aggIndex = OrderByColumnSpec.getAggIndexForOrderBy(orderSpec, Arrays.asList(aggregatorFactories)); + if (aggIndex >= 0) { + final RowBasedKeySerdeHelper serdeHelper; + final StringComparator cmp = orderSpec.getDimensionComparator(); + final boolean cmpIsNumeric = cmp == StringComparators.NUMERIC; + final String typeName = aggregatorFactories[aggIndex].getTypeName(); + final int aggOffset = aggregatorOffsets[aggIndex] - Ints.BYTES; + + aggCount++; + + if (typeName.equals("long")) { + if (cmpIsNumeric) { + serdeHelper = new LongRowBasedKeySerdeHelper(aggOffset); + } else { + serdeHelper = new LimitPushDownLongRowBasedKeySerdeHelper(aggOffset, cmp); + } + } else if (typeName.equals("float")) { + // called "float", but the aggs really return doubles + if (cmpIsNumeric) { + serdeHelper = new DoubleRowBasedKeySerdeHelper(aggOffset); + } else { + serdeHelper = new LimitPushDownDoubleRowBasedKeySerdeHelper(aggOffset, cmp); + } + } else { + throw new IAE("Cannot order by a non-numeric aggregator[%s]", orderSpec); + } + + orderByHelpers.add(serdeHelper); + needsReverses.add(needsReverse); + } + } + } + + for (int i = 0; i < dimCount; i++) { + if (!orderByIndices.contains(i)) { + otherDimHelpers.add(serdeHelpers.get(i)); + needsReverses.add(false); // default to Ascending order if dim is not in an orderby spec + } + } + + adjustedSerdeHelpers = orderByHelpers; + adjustedSerdeHelpers.addAll(otherDimHelpers); + + final int fieldCount = dimCount + aggCount; + + if (includeTimestamp) { + if (sortByDimsFirst) { + return new Grouper.BufferComparator() + { + @Override + public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + { + final int cmp = compareDimsInBuffersForNullFudgeTimestampForPushDown( + adjustedSerdeHelpers, + needsReverses, + fieldCount, + lhsBuffer, + rhsBuffer, + lhsPosition, + rhsPosition + ); + if (cmp != 0) { + return cmp; + } + + return Longs.compare(lhsBuffer.getLong(lhsPosition), rhsBuffer.getLong(rhsPosition)); + } + }; + } else { + return new Grouper.BufferComparator() + { + @Override + public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + { + final int timeCompare = Longs.compare(lhsBuffer.getLong(lhsPosition), rhsBuffer.getLong(rhsPosition)); + + if (timeCompare != 0) { + return timeCompare; + } + + int cmp = compareDimsInBuffersForNullFudgeTimestampForPushDown( + adjustedSerdeHelpers, + needsReverses, + fieldCount, + lhsBuffer, + rhsBuffer, + lhsPosition, + rhsPosition + ); + + return cmp; + } + }; + } + } else { + return new Grouper.BufferComparator() + { + @Override + public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + { + for (int i = 0; i < fieldCount; i++) { + final int cmp; + if (needsReverses.get(i)) { + cmp = adjustedSerdeHelpers.get(i).compare( + rhsBuffer, + lhsBuffer, + rhsPosition, + lhsPosition + ); + } else { + cmp = adjustedSerdeHelpers.get(i).compare( + lhsBuffer, + rhsBuffer, + lhsPosition, + rhsPosition + ); + } + + if (cmp != 0) { + return cmp; + } + } + + return 0; + } + }; + } + } + private static int compareDimsInBuffersForNullFudgeTimestamp( List serdeHelpers, int[] sortableIds, @@ -880,6 +1253,41 @@ private static int compareDimsInBuffersForNullFudgeTimestamp( return 0; } + private static int compareDimsInBuffersForNullFudgeTimestampForPushDown( + List serdeHelpers, + List needsReverses, + int dimCount, + ByteBuffer lhsBuffer, + ByteBuffer rhsBuffer, + int lhsPosition, + int rhsPosition + ) + { + for (int i = 0; i < dimCount; i++) { + final int cmp; + if (needsReverses.get(i)) { + cmp = serdeHelpers.get(i).compare( + rhsBuffer, + lhsBuffer, + rhsPosition + Longs.BYTES, + lhsPosition + Longs.BYTES + ); + } else { + cmp = serdeHelpers.get(i).compare( + lhsBuffer, + rhsBuffer, + lhsPosition + Longs.BYTES, + rhsPosition + Longs.BYTES + ); + } + if (cmp != 0) { + return cmp; + } + } + + return 0; + } + @Override public void reset() { @@ -925,6 +1333,10 @@ private int getTotalKeySize() private List makeSerdeHelpers() { + if (limitSpec != null) { + return makeSerdeHelpersForLimitPushDown(); + } + List helpers = new ArrayList<>(); int keyBufferPosition = 0; for (ValueType valType : valueTypes) { @@ -948,6 +1360,48 @@ private List makeSerdeHelpers() return helpers; } + private List makeSerdeHelpersForLimitPushDown() + { + List helpers = new ArrayList<>(); + int keyBufferPosition = 0; + + for (int i = 0; i < valueTypes.size(); i++) { + final ValueType valType = valueTypes.get(i); + final String dimName = dimensions.get(i).getOutputName(); + StringComparator cmp = DefaultLimitSpec.getComparatorForDimName(limitSpec, dimName); + final boolean cmpIsNumeric = cmp == StringComparators.NUMERIC; + + RowBasedKeySerdeHelper helper; + switch (valType) { + case STRING: + if (cmp == null) { + cmp = StringComparators.LEXICOGRAPHIC; + } + helper = new LimitPushDownStringRowBasedKeySerdeHelper(keyBufferPosition, cmp); + break; + case LONG: + if (cmp == null || cmpIsNumeric) { + helper = new LongRowBasedKeySerdeHelper(keyBufferPosition); + } else { + helper = new LimitPushDownLongRowBasedKeySerdeHelper(keyBufferPosition, cmp); + } + break; + case FLOAT: + if (cmp == null || cmpIsNumeric) { + helper = new FloatRowBasedKeySerdeHelper(keyBufferPosition); + } else { + helper = new LimitPushDownFloatRowBasedKeySerdeHelper(keyBufferPosition, cmp); + } + break; + default: + throw new IAE("invalid type: %s", valType); + } + keyBufferPosition += helper.getKeyBufferValueSize(); + helpers.add(helper); + } + return helpers; + } + private interface RowBasedKeySerdeHelper { /** @@ -1039,6 +1493,25 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, } } + private class LimitPushDownStringRowBasedKeySerdeHelper extends StringRowBasedKeySerdeHelper + { + final StringComparator cmp; + + public LimitPushDownStringRowBasedKeySerdeHelper(int keyBufferPosition, StringComparator cmp) + { + super(keyBufferPosition); + this.cmp = cmp; + } + + @Override + public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + { + String lhsStr = dictionary.get(lhsBuffer.getInt(lhsPosition + keyBufferPosition)); + String rhsStr = dictionary.get(rhsBuffer.getInt(rhsPosition + keyBufferPosition)); + return cmp.compare(lhsStr, rhsStr); + } + } + private class LongRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper { final int keyBufferPosition; @@ -1077,6 +1550,26 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, } } + private class LimitPushDownLongRowBasedKeySerdeHelper extends LongRowBasedKeySerdeHelper + { + final StringComparator cmp; + + public LimitPushDownLongRowBasedKeySerdeHelper(int keyBufferPosition, StringComparator cmp) + { + super(keyBufferPosition); + this.cmp = cmp; + } + + @Override + public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + { + long lhs = lhsBuffer.getLong(lhsPosition + keyBufferPosition); + long rhs = rhsBuffer.getLong(rhsPosition + keyBufferPosition); + + return cmp.compare(String.valueOf(lhs), String.valueOf(rhs)); + } + } + private class FloatRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper { final int keyBufferPosition; @@ -1114,5 +1607,81 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, ); } } + + private class LimitPushDownFloatRowBasedKeySerdeHelper extends FloatRowBasedKeySerdeHelper + { + final StringComparator cmp; + + public LimitPushDownFloatRowBasedKeySerdeHelper(int keyBufferPosition, StringComparator cmp) + { + super(keyBufferPosition); + this.cmp = cmp; + } + + @Override + public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + { + float lhs = lhsBuffer.getFloat(lhsPosition + keyBufferPosition); + float rhs = rhsBuffer.getFloat(rhsPosition + keyBufferPosition); + return cmp.compare(String.valueOf(lhs), String.valueOf(rhs)); + } + } + + private class DoubleRowBasedKeySerdeHelper implements RowBasedKeySerdeHelper + { + final int keyBufferPosition; + + public DoubleRowBasedKeySerdeHelper(int keyBufferPosition) + { + this.keyBufferPosition = keyBufferPosition; + } + + @Override + public int getKeyBufferValueSize() + { + return Doubles.BYTES; + } + + @Override + public boolean putToKeyBuffer(RowBasedKey key, int idx) + { + keyBuffer.putDouble((Double) key.getKey()[idx]); + return true; + } + + @Override + public void getFromByteBuffer(ByteBuffer buffer, int initialOffset, int dimValIdx, Comparable[] dimValues) + { + dimValues[dimValIdx] = buffer.getDouble(initialOffset + keyBufferPosition); + } + + @Override + public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + { + return Double.compare( + lhsBuffer.getDouble(lhsPosition + keyBufferPosition), + rhsBuffer.getDouble(rhsPosition + keyBufferPosition) + ); + } + } + + private class LimitPushDownDoubleRowBasedKeySerdeHelper extends DoubleRowBasedKeySerdeHelper + { + final StringComparator cmp; + + public LimitPushDownDoubleRowBasedKeySerdeHelper(int keyBufferPosition, StringComparator cmp) + { + super(keyBufferPosition); + this.cmp = cmp; + } + + @Override + public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) + { + double lhs = lhsBuffer.getDouble(lhsPosition + keyBufferPosition); + double rhs = rhsBuffer.getDouble(rhsPosition + keyBufferPosition); + return cmp.compare(String.valueOf(lhs), String.valueOf(rhs)); + } + } } } diff --git a/processing/src/main/java/io/druid/query/groupby/epinephelinae/SpillingGrouper.java b/processing/src/main/java/io/druid/query/groupby/epinephelinae/SpillingGrouper.java index 21490173cbac..41d8bdba178b 100644 --- a/processing/src/main/java/io/druid/query/groupby/epinephelinae/SpillingGrouper.java +++ b/processing/src/main/java/io/druid/query/groupby/epinephelinae/SpillingGrouper.java @@ -30,6 +30,7 @@ import io.druid.java.util.common.guava.CloseQuietly; import io.druid.query.BaseQuery; import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.groupby.orderby.DefaultLimitSpec; import io.druid.segment.ColumnSelectorFactory; import net.jpountz.lz4.LZ4BlockInputStream; import net.jpountz.lz4.LZ4BlockOutputStream; @@ -51,19 +52,20 @@ */ public class SpillingGrouper implements Grouper { + private final Grouper grouper; private static final AggregateResult DISK_FULL = AggregateResult.failure( "Not enough disk space to execute this query. Try raising druid.query.groupBy.maxOnDiskStorage." ); - - private final BufferGrouper grouper; private final KeySerde keySerde; private final LimitedTemporaryStorage temporaryStorage; private final ObjectMapper spillMapper; private final AggregatorFactory[] aggregatorFactories; - private final Comparator keyObjComparator; + private final Comparator> keyObjComparator; + private final Comparator> defaultOrderKeyObjComparator; private final List files = Lists.newArrayList(); private final List closeables = Lists.newArrayList(); + private final boolean sortHasNonGroupingFields; private boolean spillingAllowed = false; @@ -77,24 +79,42 @@ public SpillingGrouper( final int bufferGrouperInitialBuckets, final LimitedTemporaryStorage temporaryStorage, final ObjectMapper spillMapper, - final boolean spillingAllowed + final boolean spillingAllowed, + final DefaultLimitSpec limitSpec, + final boolean sortHasNonGroupingFields ) { this.keySerde = keySerdeFactory.factorize(); - this.keyObjComparator = keySerdeFactory.objectComparator(); - this.grouper = new BufferGrouper<>( - bufferSupplier, - keySerde, - columnSelectorFactory, - aggregatorFactories, - bufferGrouperMaxSize, - bufferGrouperMaxLoadFactor, - bufferGrouperInitialBuckets - ); + this.keyObjComparator = keySerdeFactory.objectComparator(false); + this.defaultOrderKeyObjComparator = keySerdeFactory.objectComparator(true); + if (limitSpec != null) { + this.grouper = new LimitedBufferGrouper<>( + bufferSupplier, + keySerde, + columnSelectorFactory, + aggregatorFactories, + bufferGrouperMaxSize, + bufferGrouperMaxLoadFactor, + bufferGrouperInitialBuckets, + limitSpec.getLimit(), + sortHasNonGroupingFields + ); + } else { + this.grouper = new BufferGrouper<>( + bufferSupplier, + keySerde, + columnSelectorFactory, + aggregatorFactories, + bufferGrouperMaxSize, + bufferGrouperMaxLoadFactor, + bufferGrouperInitialBuckets + ); + } this.aggregatorFactories = aggregatorFactories; this.temporaryStorage = temporaryStorage; this.spillMapper = spillMapper; this.spillingAllowed = spillingAllowed; + this.sortHasNonGroupingFields = sortHasNonGroupingFields; } @Override @@ -191,7 +211,11 @@ public Entry apply(Entry entry) closeables.add(fileIterator); } - return Groupers.mergeIterators(iterators, sorted ? keyObjComparator : null); + if (sortHasNonGroupingFields) { + return Groupers.mergeIterators(iterators, defaultOrderKeyObjComparator); + } else { + return Groupers.mergeIterators(iterators, sorted ? keyObjComparator : null); + } } private void spill() throws IOException diff --git a/processing/src/main/java/io/druid/query/groupby/orderby/DefaultLimitSpec.java b/processing/src/main/java/io/druid/query/groupby/orderby/DefaultLimitSpec.java index eac0c41273fa..fb1c1d433e14 100644 --- a/processing/src/main/java/io/druid/query/groupby/orderby/DefaultLimitSpec.java +++ b/processing/src/main/java/io/druid/query/groupby/orderby/DefaultLimitSpec.java @@ -57,6 +57,34 @@ public class DefaultLimitSpec implements LimitSpec private final List columns; private final int limit; + /** + * Check if a limitSpec has columns in the sorting order that are not part of the grouping fields represented + * by `dimensions`. + * + * @param limitSpec LimitSpec, assumed to be non-null + * @param dimensions Grouping fields for a groupBy query + * @return True if limitSpec has sorting columns not contained in dimensions + */ + public static boolean sortingOrderHasNonGroupingFields(DefaultLimitSpec limitSpec, List dimensions) + { + for (OrderByColumnSpec orderSpec : limitSpec.getColumns()) { + int dimIndex = OrderByColumnSpec.getDimIndexForOrderBy(orderSpec, dimensions); + if (dimIndex < 0) { + return true; + } + } + return false; + } + + public static StringComparator getComparatorForDimName(DefaultLimitSpec limitSpec, String dimName) { + final OrderByColumnSpec orderBy = OrderByColumnSpec.getOrderByForDimName(limitSpec.getColumns(), dimName); + if (orderBy == null) { + return null; + } + + return orderBy.getDimensionComparator(); + } + @JsonCreator public DefaultLimitSpec( @JsonProperty("columns") List columns, diff --git a/processing/src/main/java/io/druid/query/groupby/orderby/OrderByColumnSpec.java b/processing/src/main/java/io/druid/query/groupby/orderby/OrderByColumnSpec.java index 10bc732a883d..c4a5d1c4665c 100644 --- a/processing/src/main/java/io/druid/query/groupby/orderby/OrderByColumnSpec.java +++ b/processing/src/main/java/io/druid/query/groupby/orderby/OrderByColumnSpec.java @@ -27,6 +27,9 @@ import com.google.common.collect.Lists; import io.druid.java.util.common.ISE; import io.druid.java.util.common.StringUtils; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.PostAggregator; +import io.druid.query.dimension.DimensionSpec; import io.druid.query.ordering.StringComparator; import io.druid.query.ordering.StringComparators; @@ -150,6 +153,49 @@ public OrderByColumnSpec apply(@Nullable String input) ); } + public static OrderByColumnSpec getOrderByForDimName(List orderBys, String dimName) + { + for (OrderByColumnSpec orderBy : orderBys) { + if (orderBy.dimension.equals(dimName)) { + return orderBy; + } + } + return null; + } + + public static int getDimIndexForOrderBy(OrderByColumnSpec orderSpec, List dimensions) { + int i = 0; + for (DimensionSpec dimSpec : dimensions) { + if (orderSpec.getDimension().equals((dimSpec.getOutputName()))) { + return i; + } + i++; + } + return -1; + } + + public static int getAggIndexForOrderBy(OrderByColumnSpec orderSpec, List aggregatorFactories) { + int i = 0; + for (AggregatorFactory agg : aggregatorFactories) { + if (orderSpec.getDimension().equals((agg.getName()))) { + return i; + } + i++; + } + return -1; + } + + public static int getPostAggIndexForOrderBy(OrderByColumnSpec orderSpec, List postAggs) { + int i = 0; + for (PostAggregator postAgg : postAggs) { + if (orderSpec.getDimension().equals((postAgg.getName()))) { + return i; + } + i++; + } + return -1; + } + public OrderByColumnSpec( String dimension, Direction direction diff --git a/processing/src/main/java/io/druid/query/groupby/strategy/GroupByStrategyV2.java b/processing/src/main/java/io/druid/query/groupby/strategy/GroupByStrategyV2.java index b23f3cd50a90..ff9a43032973 100644 --- a/processing/src/main/java/io/druid/query/groupby/strategy/GroupByStrategyV2.java +++ b/processing/src/main/java/io/druid/query/groupby/strategy/GroupByStrategyV2.java @@ -22,7 +22,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Function; import com.google.common.base.Supplier; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.common.collect.Ordering; @@ -62,7 +61,6 @@ import io.druid.query.groupby.epinephelinae.GroupByMergingQueryRunnerV2; import io.druid.query.groupby.epinephelinae.GroupByQueryEngineV2; import io.druid.query.groupby.epinephelinae.GroupByRowProcessor; -import io.druid.query.groupby.orderby.NoopLimitSpec; import io.druid.query.groupby.resource.GroupByQueryResource; import io.druid.segment.StorageAdapter; import org.joda.time.DateTime; @@ -212,7 +210,6 @@ public Sequence mergeResults( { // Merge streams using ResultMergeQueryRunner, then apply postaggregators, then apply limit (which may // involve materialization) - final ResultMergeQueryRunner mergingQueryRunner = new ResultMergeQueryRunner(baseRunner) { @Override @@ -231,60 +228,71 @@ protected BinaryFn createMergeFn(Query queryParam) // Fudge timestamp, maybe. final DateTime fudgeTimestamp = getUniversalTimestamp(query); - return query.postProcess( - Sequences.map( - mergingQueryRunner.run( - QueryPlus.wrap( - new GroupByQuery.Builder(query) - // Don't do post aggs until the end of this method. - .setPostAggregatorSpecs(ImmutableList.of()) - // Don't do "having" clause until the end of this method. - .setHavingSpec(null) - .setLimitSpec(NoopLimitSpec.instance()) - .overrideContext( - ImmutableMap.of( - "finalize", false, - GroupByQueryConfig.CTX_KEY_STRATEGY, GroupByStrategySelector.STRATEGY_V2, - CTX_KEY_FUDGE_TIMESTAMP, fudgeTimestamp == null ? "" : String.valueOf(fudgeTimestamp.getMillis()), - CTX_KEY_OUTERMOST, false - ) - ) - .build() - ), - responseContext - ), - new Function() - { - @Override - public Row apply(final Row row) - { - // Apply postAggregators and fudgeTimestamp if present and if this is the outermost mergeResults. + final GroupByQuery newQuery = new GroupByQuery( + query.getDataSource(), + query.getQuerySegmentSpec(), + query.getVirtualColumns(), + query.getDimFilter(), + query.getGranularity(), + query.getDimensions(), + query.getAggregatorSpecs(), + query.getPostAggregatorSpecs(), + // Don't do "having" clause until the end of this method. + null, + query.getLimitSpec(), + query.getContext() + ).withOverriddenContext( + ImmutableMap.of( + "finalize", false, + GroupByQueryConfig.CTX_KEY_STRATEGY, GroupByStrategySelector.STRATEGY_V2, + CTX_KEY_FUDGE_TIMESTAMP, fudgeTimestamp == null ? "" : String.valueOf(fudgeTimestamp.getMillis()), + CTX_KEY_OUTERMOST, false + ) + ); - if (!query.getContextBoolean(CTX_KEY_OUTERMOST, true)) { - return row; - } + Sequence rowSequence = Sequences.map( + mergingQueryRunner.run( + QueryPlus.wrap(newQuery), + responseContext + ), + new Function() + { + @Override + public Row apply(final Row row) + { + // Apply postAggregators and fudgeTimestamp if present and if this is the outermost mergeResults. - if (query.getPostAggregatorSpecs().isEmpty() && fudgeTimestamp == null) { - return row; - } + if (!query.getContextBoolean(CTX_KEY_OUTERMOST, true)) { + return row; + } - final Map newMap; + if (query.getPostAggregatorSpecs().isEmpty() && fudgeTimestamp == null) { + return row; + } - if (query.getPostAggregatorSpecs().isEmpty()) { - newMap = ((MapBasedRow) row).getEvent(); - } else { - newMap = Maps.newLinkedHashMap(((MapBasedRow) row).getEvent()); + final Map newMap; - for (PostAggregator postAggregator : query.getPostAggregatorSpecs()) { - newMap.put(postAggregator.getName(), postAggregator.compute(newMap)); - } - } + if (query.getPostAggregatorSpecs().isEmpty()) { + newMap = ((MapBasedRow) row).getEvent(); + } else { + newMap = Maps.newLinkedHashMap(((MapBasedRow) row).getEvent()); - return new MapBasedRow(fudgeTimestamp != null ? fudgeTimestamp : row.getTimestamp(), newMap); + for (PostAggregator postAggregator : query.getPostAggregatorSpecs()) { + newMap.put(postAggregator.getName(), postAggregator.compute(newMap)); } } - ) + + return new MapBasedRow(fudgeTimestamp != null ? fudgeTimestamp : row.getTimestamp(), newMap); + } + } ); + + // Don't apply limit here for inner results, that will be pushed down to the BufferGrouper + if (query.getContextBoolean(CTX_KEY_OUTERMOST, true)) { + return query.postProcess(rowSequence); + } else { + return rowSequence; + } } @Override diff --git a/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java b/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java index a38f4710d697..ffdcc648d459 100644 --- a/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java +++ b/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java @@ -51,6 +51,7 @@ import io.druid.query.DruidProcessingConfig; import io.druid.query.Druids; import io.druid.query.FinalizeResultsQueryRunner; +import io.druid.query.Query; import io.druid.query.QueryContexts; import io.druid.query.QueryDataSource; import io.druid.query.QueryPlus; @@ -8378,4 +8379,620 @@ public void testGroupByNestedDoubleTimeExtractionFnWithLongOutputTypes() Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, outerQuery); TestHelper.assertExpectedObjects(expectedResults, results, ""); } + + @Test + public void testGroupByLimitPushDown() + { + if (!config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V2)) { + return; + } + GroupByQuery query = new GroupByQuery.Builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setGranularity(QueryRunnerTestHelper.allGran) + .setDimensions( + Arrays.asList( + new DefaultDimensionSpec( + QueryRunnerTestHelper.marketDimension, + "marketalias" + ) + ) + ) + .setInterval(QueryRunnerTestHelper.fullOnInterval) + .setLimitSpec( + new DefaultLimitSpec( + Lists.newArrayList( + new OrderByColumnSpec( + "marketalias", + OrderByColumnSpec.Direction.DESCENDING + ) + ), + 2 + ) + ) + .setAggregatorSpecs( + Lists.newArrayList( + QueryRunnerTestHelper.rowsCount + ) + ) + .setContext( + ImmutableMap.of( + GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, + true + ) + ) + .build(); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow( + "1970-01-01T00:00:00.000Z", + "marketalias", + "upfront", + "rows", + 186L + ), + GroupByQueryRunnerTestHelper.createExpectedRow( + "1970-01-01T00:00:00.000Z", + "marketalias", + "total_market", + "rows", + 186L + ) + ); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, "order-limit"); + } + + @Test + public void testMergeResultsWithLimitPushDown() + { + if (!config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V2)) { + return; + } + GroupByQuery.Builder builder = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setInterval("2011-04-02/2011-04-04") + .setDimensions(Lists.newArrayList(new DefaultDimensionSpec("quality", "alias"))) + .setAggregatorSpecs( + Arrays.asList( + QueryRunnerTestHelper.rowsCount, + new LongSumAggregatorFactory("idx", "index") + ) + ) + .setLimitSpec( + new DefaultLimitSpec( + Lists.newArrayList( + new OrderByColumnSpec( + "alias", + OrderByColumnSpec.Direction.DESCENDING + ) + ), + 5 + ) + ) + .setContext( + ImmutableMap.of( + GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, + true + ) + ) + .setGranularity(Granularities.ALL); + + final GroupByQuery allGranQuery = builder.build(); + + QueryRunner mergedRunner = factory.getToolchest().mergeResults( + new QueryRunner() + { + @Override + public Sequence run( + Query query, Map responseContext + ) + { + // simulate two daily segments + final Query query1 = query.withQuerySegmentSpec( + new MultipleIntervalSegmentSpec(Lists.newArrayList(new Interval("2011-04-02/2011-04-03"))) + ); + final Query query2 = query.withQuerySegmentSpec( + new MultipleIntervalSegmentSpec(Lists.newArrayList(new Interval("2011-04-03/2011-04-04"))) + ); + + return factory.getToolchest().mergeResults( + new QueryRunner() + { + @Override + public Sequence run(Query query, Map responseContext) + { + return new MergeSequence( + query.getResultOrdering(), + Sequences.simple( + Arrays.asList(runner.run(query1, responseContext), runner.run(query2, responseContext)) + ) + ); + } + } + ).run(query, responseContext); + } + } + ); + Map context = Maps.newHashMap(); + List allGranExpectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "travel", "rows", 2L, "idx", 243L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "technology", "rows", 2L, "idx", 177L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "premium", "rows", 6L, "idx", 4416L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "news", "rows", 2L, "idx", 221L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "mezzanine", "rows", 6L, "idx", 4420L) + ); + + TestHelper.assertExpectedObjects(allGranExpectedResults, mergedRunner.run(allGranQuery, context), "merged"); + } + + @Test + public void testMergeResultsWithLimitPushDownSortByAgg() + { + if (!config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V2)) { + return; + } + GroupByQuery.Builder builder = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setInterval("2011-04-02/2011-04-04") + .setDimensions(Lists.newArrayList(new DefaultDimensionSpec("quality", "alias"))) + .setAggregatorSpecs( + Arrays.asList( + QueryRunnerTestHelper.rowsCount, + new LongSumAggregatorFactory("idx", "index") + ) + ) + .setLimitSpec( + new DefaultLimitSpec( + Lists.newArrayList( + new OrderByColumnSpec( + "idx", + OrderByColumnSpec.Direction.DESCENDING + ) + ), + 5 + ) + ) + .setContext( + ImmutableMap.of( + GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, + true + ) + ) + .setGranularity(Granularities.ALL); + + final GroupByQuery allGranQuery = builder.build(); + + QueryRunner mergedRunner = factory.getToolchest().mergeResults( + new QueryRunner() + { + @Override + public Sequence run( + Query query, Map responseContext + ) + { + // simulate two daily segments + final Query query1 = query.withQuerySegmentSpec( + new MultipleIntervalSegmentSpec(Lists.newArrayList(new Interval("2011-04-02/2011-04-03"))) + ); + final Query query2 = query.withQuerySegmentSpec( + new MultipleIntervalSegmentSpec(Lists.newArrayList(new Interval("2011-04-03/2011-04-04"))) + ); + + return factory.getToolchest().mergeResults( + new QueryRunner() + { + @Override + public Sequence run(Query query, Map responseContext) + { + return new MergeSequence( + query.getResultOrdering(), + Sequences.simple( + Arrays.asList(runner.run(query1, responseContext), runner.run(query2, responseContext)) + ) + ); + } + } + ).run(query, responseContext); + } + } + ); + Map context = Maps.newHashMap(); + + List allGranExpectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "mezzanine", "rows", 6L, "idx", 4420L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "premium", "rows", 6L, "idx", 4416L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "entertainment", "rows", 2L, "idx", 319L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "automotive", "rows", 2L, "idx", 269L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "travel", "rows", 2L, "idx", 243L) + ); + + Iterable results = Sequences.toList(mergedRunner.run(allGranQuery, context), Lists.newArrayList()); + TestHelper.assertExpectedObjects(allGranExpectedResults, results, "merged"); + } + + @Test + public void testMergeResultsWithLimitPushDownSortByDimDim() + { + if (!config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V2)) { + return; + } + GroupByQuery.Builder builder = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setInterval("2011-04-02/2011-04-04") + .setDimensions(Lists.newArrayList( + new DefaultDimensionSpec("quality", "alias"), + new DefaultDimensionSpec("market", "market") + ) + ) + .setAggregatorSpecs( + Arrays.asList( + QueryRunnerTestHelper.rowsCount, + new LongSumAggregatorFactory("idx", "index") + ) + ) + .setLimitSpec( + new DefaultLimitSpec( + Lists.newArrayList( + new OrderByColumnSpec( + "alias", + OrderByColumnSpec.Direction.DESCENDING + ), + new OrderByColumnSpec( + "market", + OrderByColumnSpec.Direction.DESCENDING + ) + ), + 5 + ) + ) + .setContext( + ImmutableMap.of( + GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, + true + ) + ) + .setGranularity(Granularities.ALL); + + final GroupByQuery allGranQuery = builder.build(); + + QueryRunner mergedRunner = factory.getToolchest().mergeResults( + new QueryRunner() + { + @Override + public Sequence run( + Query query, Map responseContext + ) + { + // simulate two daily segments + final Query query1 = query.withQuerySegmentSpec( + new MultipleIntervalSegmentSpec(Lists.newArrayList(new Interval("2011-04-02/2011-04-03"))) + ); + final Query query2 = query.withQuerySegmentSpec( + new MultipleIntervalSegmentSpec(Lists.newArrayList(new Interval("2011-04-03/2011-04-04"))) + ); + + return factory.getToolchest().mergeResults( + new QueryRunner() + { + @Override + public Sequence run(Query query, Map responseContext) + { + return new MergeSequence( + query.getResultOrdering(), + Sequences.simple( + Arrays.asList(runner.run(query1, responseContext), runner.run(query2, responseContext)) + ) + ); + } + } + ).run(query, responseContext); + } + } + ); + Map context = Maps.newHashMap(); + + List allGranExpectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "travel", "market", "spot", "rows", 2L, "idx", 243L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "technology", "market", "spot", "rows", 2L, "idx", 177L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "premium", "market", "upfront", "rows", 2L, "idx", 1817L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "premium", "market", "total_market", "rows", 2L, "idx", 2342L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "premium", "market", "spot", "rows", 2L, "idx", 257L) + ); + + Iterable results = Sequences.toList(mergedRunner.run(allGranQuery, context), Lists.newArrayList()); + TestHelper.assertExpectedObjects(allGranExpectedResults, results, "merged"); + } + + @Test + public void testMergeResultsWithLimitPushDownSortByDimAggDim() + { + if (!config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V2)) { + return; + } + GroupByQuery.Builder builder = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setInterval("2011-04-02/2011-04-04") + .setDimensions(Lists.newArrayList( + new DefaultDimensionSpec("quality", "alias"), + new DefaultDimensionSpec("market", "market") + ) + ) + .setAggregatorSpecs( + Arrays.asList( + QueryRunnerTestHelper.rowsCount, + new LongSumAggregatorFactory("idx", "index") + ) + ) + .setLimitSpec( + new DefaultLimitSpec( + Lists.newArrayList( + new OrderByColumnSpec( + "alias", + OrderByColumnSpec.Direction.DESCENDING + ), + new OrderByColumnSpec( + "idx", + OrderByColumnSpec.Direction.DESCENDING + ), + new OrderByColumnSpec( + "market", + OrderByColumnSpec.Direction.DESCENDING + ) + ), + 5 + ) + ) + .setContext( + ImmutableMap.of( + GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, + true + ) + ) + .setGranularity(Granularities.ALL); + + final GroupByQuery allGranQuery = builder.build(); + + QueryRunner mergedRunner = factory.getToolchest().mergeResults( + new QueryRunner() + { + @Override + public Sequence run( + Query query, Map responseContext + ) + { + // simulate two daily segments + final Query query1 = query.withQuerySegmentSpec( + new MultipleIntervalSegmentSpec(Lists.newArrayList(new Interval("2011-04-02/2011-04-03"))) + ); + final Query query2 = query.withQuerySegmentSpec( + new MultipleIntervalSegmentSpec(Lists.newArrayList(new Interval("2011-04-03/2011-04-04"))) + ); + + return factory.getToolchest().mergeResults( + new QueryRunner() + { + @Override + public Sequence run(Query query, Map responseContext) + { + return new MergeSequence( + query.getResultOrdering(), + Sequences.simple( + Arrays.asList(runner.run(query1, responseContext), runner.run(query2, responseContext)) + ) + ); + } + } + ).run(query, responseContext); + } + } + ); + Map context = Maps.newHashMap(); + + List allGranExpectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "travel", "market", "spot", "rows", 2L, "idx", 243L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "technology", "market", "spot", "rows", 2L, "idx", 177L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "premium", "market", "total_market", "rows", 2L, "idx", 2342L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "premium", "market", "upfront", "rows", 2L, "idx", 1817L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "premium", "market", "spot", "rows", 2L, "idx", 257L) + ); + + Iterable results = Sequences.toList(mergedRunner.run(allGranQuery, context), Lists.newArrayList()); + TestHelper.assertExpectedObjects(allGranExpectedResults, results, "merged"); + } + + @Test + public void testGroupByLimitPushDownPostAggNotSupported() + { + //if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V2)) { + expectedException.expect(UnsupportedOperationException.class); + expectedException.expectMessage("Limit push down when sorting by a post aggregator is not supported."); + //} + + GroupByQuery query = new GroupByQuery.Builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setGranularity(QueryRunnerTestHelper.allGran) + .setDimensions( + Arrays.asList( + new DefaultDimensionSpec( + QueryRunnerTestHelper.marketDimension, + "marketalias" + ) + ) + ) + .setInterval(QueryRunnerTestHelper.fullOnInterval) + .setLimitSpec( + new DefaultLimitSpec( + Lists.newArrayList( + new OrderByColumnSpec( + "constant", + OrderByColumnSpec.Direction.DESCENDING + ) + ), + 2 + ) + ) + .setAggregatorSpecs( + Lists.newArrayList( + QueryRunnerTestHelper.rowsCount + ) + ) + .setPostAggregatorSpecs( + Lists.newArrayList( + new ConstantPostAggregator("constant", 1) + ) + ) + .setContext( + ImmutableMap.of( + GroupByQueryConfig.CTX_KEY_FORCE_LIMIT_PUSH_DOWN, + true + ) + ) + .build(); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + } + + @Test + public void testEmptySubqueryWithLimitPushDown() + { + GroupByQuery subquery = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setQuerySegmentSpec(QueryRunnerTestHelper.emptyInterval) + .setDimensions(Lists.newArrayList(new DefaultDimensionSpec("quality", "alias"))) + .setAggregatorSpecs( + Arrays.asList( + QueryRunnerTestHelper.rowsCount, + new LongSumAggregatorFactory("idx", "index") + ) + ) + .setLimitSpec( + new DefaultLimitSpec( + Lists.newArrayList( + new OrderByColumnSpec( + "alias", + OrderByColumnSpec.Direction.DESCENDING + ) + ), + 5 + ) + ) + .setGranularity(QueryRunnerTestHelper.dayGran) + .build(); + + GroupByQuery query = GroupByQuery + .builder() + .setDataSource(subquery) + .setQuerySegmentSpec(QueryRunnerTestHelper.firstToThird) + .setAggregatorSpecs( + Arrays.asList( + new DoubleMaxAggregatorFactory("idx", "idx") + ) + ) + .setLimitSpec( + new DefaultLimitSpec( + null, + 5 + ) + ) + .setGranularity(QueryRunnerTestHelper.dayGran) + .build(); + + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + Assert.assertFalse(results.iterator().hasNext()); + } + + + @Test + public void testSubqueryWithMultipleIntervalsInOuterQueryWithLimitPushDown() + { + GroupByQuery subquery = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.dataSource) + .setQuerySegmentSpec(QueryRunnerTestHelper.firstToThird) + .setDimensions(Lists.newArrayList(new DefaultDimensionSpec("quality", "alias"))) + .setDimFilter(new JavaScriptDimFilter( + "quality", + "function(dim){ return true; }", + null, + JavaScriptConfig.getEnabledInstance() + )) + .setLimitSpec( + new DefaultLimitSpec( + Lists.newArrayList( + new OrderByColumnSpec( + "alias", + OrderByColumnSpec.Direction.DESCENDING + ) + ), + 12 + ) + ) + .setAggregatorSpecs( + Arrays.asList( + QueryRunnerTestHelper.rowsCount, + new LongSumAggregatorFactory("idx", "index"), + new LongSumAggregatorFactory("indexMaxPlusTen", "indexMaxPlusTen") + ) + ) + .setGranularity(QueryRunnerTestHelper.dayGran) + .build(); + + GroupByQuery query = GroupByQuery + .builder() + .setDataSource(subquery) + .setQuerySegmentSpec( + new MultipleIntervalSegmentSpec( + ImmutableList.of( + new Interval("2011-04-01T00:00:00.000Z/2011-04-01T23:58:00.000Z"), + new Interval("2011-04-02T00:00:00.000Z/2011-04-03T00:00:00.000Z") + ) + ) + ) + .setDimensions(Lists.newArrayList(new DefaultDimensionSpec("alias", "alias"))) + .setLimitSpec( + new DefaultLimitSpec( + Lists.newArrayList( + new OrderByColumnSpec( + "alias", + OrderByColumnSpec.Direction.DESCENDING + ) + ), + 15 + ) + ) + .setAggregatorSpecs( + Arrays.asList( + new LongSumAggregatorFactory("rows", "rows"), + new LongSumAggregatorFactory("idx", "idx") + ) + ) + .setGranularity(QueryRunnerTestHelper.dayGran) + .build(); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "travel", "rows", 1L, "idx", 119L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "technology", "rows", 1L, "idx", 78L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "premium", "rows", 3L, "idx", 2900L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "news", "rows", 1L, "idx", 121L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "mezzanine", "rows", 3L, "idx", 2870L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "health", "rows", 1L, "idx", 120L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "entertainment", "rows", 1L, "idx", 158L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "business", "rows", 1L, "idx", 118L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-01", "alias", "automotive", "rows", 1L, "idx", 135L), + + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "travel", "rows", 1L, "idx", 126L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "technology", "rows", 1L, "idx", 97L), + GroupByQueryRunnerTestHelper.createExpectedRow("2011-04-02", "alias", "premium", "rows", 3L, "idx", 2505L) + ); + + // Subqueries are handled by the ToolChest + Iterable results = GroupByQueryRunnerTestHelper.runQuery(factory, runner, query); + TestHelper.assertExpectedObjects(expectedResults, results, ""); + } } diff --git a/processing/src/test/java/io/druid/query/groupby/epinephelinae/ByteBufferMinMaxOffsetHeapTest.java b/processing/src/test/java/io/druid/query/groupby/epinephelinae/ByteBufferMinMaxOffsetHeapTest.java new file mode 100644 index 000000000000..336c1b2fc16a --- /dev/null +++ b/processing/src/test/java/io/druid/query/groupby/epinephelinae/ByteBufferMinMaxOffsetHeapTest.java @@ -0,0 +1,261 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.collect.Lists; +import com.google.common.collect.Ordering; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +public class ByteBufferMinMaxOffsetHeapTest +{ + @Test + public void testSimple() + { + int limit = 15; + ByteBuffer myBuffer = ByteBuffer.allocate(1000000); + ByteBufferMinMaxOffsetHeap heap = new ByteBufferMinMaxOffsetHeap(myBuffer, limit, Ordering.natural(), null); + + ArrayList values = Lists.newArrayList( + 30, 45, 81, 92, 68, 54, 66, 33, 89, 98, + 87, 62, 84, 39, 13, 32, 67, 50, 21, 53, + 93, 18, 86, 41, 14, 56, 51, 69, 91, 60, + 6, 2, 79, 4, 35, 17, 71, 22, 29, 76, + 57, 97, 73, 24, 94, 77, 80, 15, 52, 88, + 95, 96, 9, 3, 48, 58, 75, 82, 90, 65, + 36, 85, 20, 34, 37, 72, 11, 78, 28, 43, + 27, 12, 83, 38, 59, 19, 31, 46, 40, 63, + 23, 70, 26, 8, 64, 16, 10, 74, 7, 25, + 5, 42, 47, 44, 1, 49, 99 + ); + + for (int i = 0; i < values.size(); i++){ + heap.addOffset(values.get(i)); + } + + int x = heap.removeAt(8); + heap.addOffset(x); + + x = heap.removeAt(2); + heap.addOffset(x); + + Collections.sort(values); + List expected = values.subList(0, limit); + + List actual = Lists.newArrayList(); + for (int i = 0; i < limit; i++) { + int min = heap.removeMin(); + actual.add(min); + } + + Assert.assertEquals(expected, actual); + } + + + @Test + public void testRandom() + { + int limit = 20; + + Random rng = new Random(999); + + ArrayList values = Lists.newArrayList(); + for (int i = 0; i < 100000; i++) { + values.add(rng.nextInt(1000000)); + } + ArrayList deletedValues = Lists.newArrayList(); + + ByteBuffer myBuffer = ByteBuffer.allocate(1000000); + ByteBufferMinMaxOffsetHeap heap = new ByteBufferMinMaxOffsetHeap(myBuffer, limit, Ordering.natural(), null); + + for (int i = 0; i < values.size(); i++){ + int droppedOffset = heap.addOffset(values.get(i)); + Assert.assertTrue(heap.isIntact()); + + if (droppedOffset > 0) { + deletedValues.add(droppedOffset); + } + + // 15% chance to delete a random value for every two values added when heap is > 50% full + if (heap.getHeapSize() > (limit / 2) && i % 2 == 1) { + double deleteRoll = rng.nextDouble(); + if (deleteRoll > 0.15) { + int indexToRemove = rng.nextInt(heap.getHeapSize()); + int deadOffset = heap.removeAt(indexToRemove); + Assert.assertTrue(heap.isIntact()); + deletedValues.add(deadOffset); + } + } + } + + Collections.sort(values); + Collections.sort(deletedValues); + + for (int deletedValue : deletedValues) { + int idx = values.indexOf(deletedValue); + values.remove(idx); + } + + Assert.assertTrue(heap.getHeapSize() <= limit); + List expected = values.subList(0, heap.getHeapSize()); + + List actual = Lists.newArrayList(); + int initialHeapSize = heap.getHeapSize(); + for (int i = 0; i < initialHeapSize; i++){ + int min = heap.removeMin(); + actual.add(min); + } + + Assert.assertEquals(expected, actual); + } + + @Test + public void testRandom2() + { + int limit = 20000; + + Random rng = new Random(9999); + + ArrayList values = Lists.newArrayList(); + for (int i = 0; i < 100000; i++) { + values.add(rng.nextInt(1000000)); + } + ArrayList deletedValues = Lists.newArrayList(); + + ByteBuffer myBuffer = ByteBuffer.allocate(1000000); + ByteBufferMinMaxOffsetHeap heap = new ByteBufferMinMaxOffsetHeap(myBuffer, limit, Ordering.natural(), null); + + for (int i = 0; i < values.size(); i++){ + int droppedOffset = heap.addOffset(values.get(i)); + Assert.assertTrue(heap.isIntact()); + + if (droppedOffset > 0) { + deletedValues.add(droppedOffset); + } + + // 15% chance to delete a random value for every two values added when heap is > 50% full + if (heap.getHeapSize() > (limit / 2) && i % 2 == 1) { + double deleteRoll = rng.nextDouble(); + if (deleteRoll > 0.15) { + int indexToRemove = rng.nextInt(heap.getHeapSize()); + int deadOffset = heap.removeAt(indexToRemove); + Assert.assertTrue(heap.isIntact()); + deletedValues.add(deadOffset); + } + } + } + + Collections.sort(values); + Collections.sort(deletedValues); + + for (int deletedValue : deletedValues) { + int idx = values.indexOf(deletedValue); + values.remove(idx); + } + + Assert.assertTrue(heap.getHeapSize() <= limit); + List expected = values.subList(0, heap.getHeapSize()); + + List actual = Lists.newArrayList(); + int initialHeapSize = heap.getHeapSize(); + for (int i = 0; i < initialHeapSize; i++){ + int min = heap.removeMin(); + actual.add(min); + } + + Assert.assertEquals(expected, actual); + } + + + @Test + public void testRemove() + { + int limit = 100; + + ArrayList values = Lists.newArrayList( + 1, 20, 1000, 2, 3, 30, 40, 10, 11, 12, 13, 300, 400, 500, 600 + ); + + ByteBuffer myBuffer = ByteBuffer.allocate(1000000); + ByteBufferMinMaxOffsetHeap heap = new ByteBufferMinMaxOffsetHeap(myBuffer, limit, Ordering.natural(), null); + + for (int i = 0; i < values.size(); i++){ + heap.addOffset(values.get(i)); + Assert.assertTrue(heap.isIntact()); + } + + heap.removeOffset(12); + + Assert.assertTrue(heap.isIntact()); + + Collections.sort(values); + values.remove((Number) 12); + + List actual = Lists.newArrayList(); + for (int i = 0; i < values.size(); i++){ + int min = heap.removeMin(); + actual.add(min); + } + + Assert.assertEquals(values, actual); + } + + @Test + public void testRemove2() + { + int limit = 100; + + ArrayList values = Lists.newArrayList( + 1, 20, 1000, 2, 3, 30, 40, 10, 11, 12, 13, 300, 400, 500, 600, 4, 5, + 6, 7, 8, 9, 4, 5, 200, 250 + ); + + ByteBuffer myBuffer = ByteBuffer.allocate(1000000); + ByteBufferMinMaxOffsetHeap heap = new ByteBufferMinMaxOffsetHeap(myBuffer, limit, Ordering.natural(), null); + + for (int i = 0; i < values.size(); i++){ + heap.addOffset(values.get(i)); + } + Assert.assertTrue(heap.isIntact()); + + heap.removeOffset(2); + Assert.assertTrue(heap.isIntact()); + + Collections.sort(values); + values.remove((Number) 2); + Assert.assertTrue(heap.isIntact()); + + List actual = Lists.newArrayList(); + for (int i = 0; i < values.size(); i++){ + int min = heap.removeMin(); + actual.add(min); + } + + Assert.assertTrue(heap.isIntact()); + + Assert.assertEquals(values, actual); + } +} diff --git a/processing/src/test/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouperTest.java b/processing/src/test/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouperTest.java index cdd064a2dd3a..c5fe90ad5550 100644 --- a/processing/src/test/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouperTest.java +++ b/processing/src/test/java/io/druid/query/groupby/epinephelinae/ConcurrentGrouperTest.java @@ -25,7 +25,7 @@ import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.CountAggregatorFactory; import io.druid.query.dimension.DimensionSpec; -import io.druid.query.groupby.epinephelinae.Grouper.KeyComparator; +import io.druid.query.groupby.epinephelinae.Grouper.BufferComparator; import io.druid.query.groupby.epinephelinae.Grouper.KeySerde; import io.druid.query.groupby.epinephelinae.Grouper.KeySerdeFactory; import io.druid.segment.ColumnSelectorFactory; @@ -107,9 +107,9 @@ public Long fromByteBuffer(ByteBuffer buffer, int position) } @Override - public KeyComparator bufferComparator() + public BufferComparator bufferComparator() { - return new KeyComparator() + return new BufferComparator() { @Override public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) @@ -119,20 +119,29 @@ public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, }; } + @Override + public BufferComparator bufferComparatorWithAggregators( + AggregatorFactory[] aggregatorFactories, + int[] aggregatorOffsets + ) + { + return null; + } + @Override public void reset() {} }; } @Override - public Comparator objectComparator() + public Comparator> objectComparator(boolean forceDefaultOrder) { - return new Comparator() + return new Comparator>() { @Override - public int compare(Long o1, Long o2) + public int compare(Grouper.Entry o1, Grouper.Entry o2) { - return o1.compareTo(o2); + return o1.getKey().compareTo(o2.getKey()); } }; } @@ -184,7 +193,9 @@ public void testAggregate() throws InterruptedException, ExecutionException 1, null, null, - 8 + 8, + null, + false ); Future[] futures = new Future[8]; diff --git a/processing/src/test/java/io/druid/query/groupby/epinephelinae/IntKeySerde.java b/processing/src/test/java/io/druid/query/groupby/epinephelinae/IntKeySerde.java index 5a3b4e829fbb..8f017caff953 100644 --- a/processing/src/test/java/io/druid/query/groupby/epinephelinae/IntKeySerde.java +++ b/processing/src/test/java/io/druid/query/groupby/epinephelinae/IntKeySerde.java @@ -20,6 +20,7 @@ package io.druid.query.groupby.epinephelinae; import com.google.common.primitives.Ints; +import io.druid.query.aggregation.AggregatorFactory; import java.nio.ByteBuffer; import java.util.Comparator; @@ -33,7 +34,7 @@ private IntKeySerde() // No instantiation } - private static final Grouper.KeyComparator KEY_COMPARATOR = new Grouper.KeyComparator() + private static final Grouper.BufferComparator KEY_COMPARATOR = new Grouper.BufferComparator() { @Override public int compare(ByteBuffer lhsBuffer, ByteBuffer rhsBuffer, int lhsPosition, int rhsPosition) @@ -80,7 +81,15 @@ public Integer fromByteBuffer(ByteBuffer buffer, int position) } @Override - public Grouper.KeyComparator bufferComparator() + public Grouper.BufferComparator bufferComparator() + { + return KEY_COMPARATOR; + } + + @Override + public Grouper.BufferComparator bufferComparatorWithAggregators( + AggregatorFactory[] aggregatorFactories, int[] aggregatorOffsets + ) { return KEY_COMPARATOR; } diff --git a/processing/src/test/java/io/druid/query/groupby/epinephelinae/LimitedBufferGrouperTest.java b/processing/src/test/java/io/druid/query/groupby/epinephelinae/LimitedBufferGrouperTest.java new file mode 100644 index 000000000000..7cd746379ee0 --- /dev/null +++ b/processing/src/test/java/io/druid/query/groupby/epinephelinae/LimitedBufferGrouperTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.query.groupby.epinephelinae; + +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import io.druid.data.input.MapBasedRow; +import io.druid.java.util.common.IAE; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.CountAggregatorFactory; +import io.druid.query.aggregation.LongSumAggregatorFactory; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.nio.ByteBuffer; +import java.util.List; + +public class LimitedBufferGrouperTest +{ + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void testLimitAndBufferSwapping() + { + final int limit = 100; + final int keyBase = 100000; + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + final LimitedBufferGrouper grouper = makeGrouper(columnSelectorFactory, 20000, 2, limit); + final int numRows = 1000; + + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("value", 10L))); + for (int i = 0; i < numRows; i++) { + Assert.assertTrue(String.valueOf(i + keyBase), grouper.aggregate(i + keyBase).isOk()); + } + + // bucket size is hash(int) + key(int) + aggs(2 longs) + heap offset(int) = 28 bytes + // limit is 100 so heap occupies 101 * 4 bytes = 404 bytes + // buffer is 20000 bytes, so table arena size is 20000 - 404 = 19596 bytes + // table arena is split in halves when doing push down, so each half is 9798 bytes + // each table arena half can hold 9798 / 28 = 349 buckets, with load factor of 0.5 max buckets per half is 174 + // First buffer swap occurs when we hit 174 buckets + // Subsequent buffer swaps occur after every 74 buckets, since we keep 100 buckets due to the limit + // With 1000 keys inserted, this results in one swap at the first 174 buckets, then 11 swaps afterwards. + // After the last swap, we have 100 keys + 12 new keys inserted. + Assert.assertEquals(12, grouper.getGrowthCount()); + Assert.assertEquals(112, grouper.getSize()); + Assert.assertEquals(349, grouper.getBuckets()); + Assert.assertEquals(174, grouper.getMaxSize()); + Assert.assertEquals(100, grouper.getLimit()); + + // Aggregate slightly different row + // Since these keys are smaller, they will evict the previous 100 top entries + // First 100 of these new rows will be the expected results. + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("value", 11L))); + for (int i = 0; i < numRows; i++) { + Assert.assertTrue(String.valueOf(i), grouper.aggregate(i).isOk()); + } + + // we added another 1000 unique keys + // previous size is 112, so next swap occurs after 62 rows + // after that, there are 1000 - 62 = 938 rows, 938 / 74 = 12 additional swaps after the first, + // with 50 keys being added after the final swap. + Assert.assertEquals(25, grouper.getGrowthCount()); + Assert.assertEquals(150, grouper.getSize()); + Assert.assertEquals(349, grouper.getBuckets()); + Assert.assertEquals(174, grouper.getMaxSize()); + Assert.assertEquals(100, grouper.getLimit()); + + final List> expected = Lists.newArrayList(); + for (int i = 0; i < limit; i++) { + expected.add(new Grouper.Entry<>(i, new Object[]{11L, 1L})); + } + + Assert.assertEquals(expected, Lists.newArrayList(grouper.iterator(true))); + } + + @Test + public void testBufferTooSmall() + { + expectedException.expect(IAE.class); + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + final LimitedBufferGrouper grouper = makeGrouper(columnSelectorFactory, 10, 2, 100); + } + + @Test + public void testMinBufferSize() + { + final int limit = 100; + final int keyBase = 100000; + final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); + final LimitedBufferGrouper grouper = makeGrouper(columnSelectorFactory, 11716, 2, limit); + final int numRows = 1000; + + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("value", 10L))); + for (int i = 0; i < numRows; i++) { + Assert.assertTrue(String.valueOf(i + keyBase), grouper.aggregate(i + keyBase).isOk()); + } + + // With minimum buffer size, after the first swap, every new key added will result in a swap + Assert.assertEquals(899, grouper.getGrowthCount()); + Assert.assertEquals(101, grouper.getSize()); + Assert.assertEquals(202, grouper.getBuckets()); + Assert.assertEquals(101, grouper.getMaxSize()); + Assert.assertEquals(100, grouper.getLimit()); + + // Aggregate slightly different row + // Since these keys are smaller, they will evict the previous 100 top entries + // First 100 of these new rows will be the expected results. + columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("value", 11L))); + for (int i = 0; i < numRows; i++) { + Assert.assertTrue(String.valueOf(i), grouper.aggregate(i).isOk()); + } + + Assert.assertEquals(1899, grouper.getGrowthCount()); + Assert.assertEquals(101, grouper.getSize()); + Assert.assertEquals(202, grouper.getBuckets()); + Assert.assertEquals(101, grouper.getMaxSize()); + Assert.assertEquals(100, grouper.getLimit()); + + final List> expected = Lists.newArrayList(); + for (int i = 0; i < limit; i++) { + expected.add(new Grouper.Entry<>(i, new Object[]{11L, 1L})); + } + + Assert.assertEquals(expected, Lists.newArrayList(grouper.iterator(true))); + } + + private static LimitedBufferGrouper makeGrouper( + TestColumnSelectorFactory columnSelectorFactory, + int bufferSize, + int initialBuckets, + int limit + ) + { + LimitedBufferGrouper grouper = new LimitedBufferGrouper<>( + Suppliers.ofInstance(ByteBuffer.allocate(bufferSize)), + GrouperTestUtil.intKeySerde(), + columnSelectorFactory, + new AggregatorFactory[]{ + new LongSumAggregatorFactory("valueSum", "value"), + new CountAggregatorFactory("count") + }, + Integer.MAX_VALUE, + 0.5f, + initialBuckets, + limit, + false + ); + + grouper.init(); + return grouper; + } +}