diff --git a/core/src/main/java/org/apache/druid/memory/BufferHolder.java b/core/src/main/java/org/apache/druid/memory/BufferHolder.java new file mode 100644 index 000000000000..e3844047abc3 --- /dev/null +++ b/core/src/main/java/org/apache/druid/memory/BufferHolder.java @@ -0,0 +1,10 @@ +package org.apache.druid.memory; + +import java.nio.ByteBuffer; + +public interface BufferHolder +{ + int position(); + int capacity(); + ByteBuffer get(); +} diff --git a/core/src/main/java/org/apache/druid/memory/MemoryAllocator.java b/core/src/main/java/org/apache/druid/memory/MemoryAllocator.java new file mode 100644 index 000000000000..3f4fcf1b64d6 --- /dev/null +++ b/core/src/main/java/org/apache/druid/memory/MemoryAllocator.java @@ -0,0 +1,9 @@ +package org.apache.druid.memory; + +import java.nio.ByteBuffer; + +public interface MemoryAllocator +{ + BufferHolder allocate(int capacity); + void free(BufferHolder bh); +} diff --git a/core/src/main/java/org/apache/druid/memory/SimpleOnHeapMemoryAllocator.java b/core/src/main/java/org/apache/druid/memory/SimpleOnHeapMemoryAllocator.java new file mode 100644 index 000000000000..e9448fef4425 --- /dev/null +++ b/core/src/main/java/org/apache/druid/memory/SimpleOnHeapMemoryAllocator.java @@ -0,0 +1,49 @@ +package org.apache.druid.memory; + +import com.google.common.base.Suppliers; + +import java.nio.ByteBuffer; +import java.util.function.Supplier; + +public class SimpleOnHeapMemoryAllocator implements MemoryAllocator +{ + @Override + public BufferHolder allocate(int capacity) + { + return new SimplerBufferHolder(ByteBuffer.allocate(capacity)); + } + + @Override + public void free(BufferHolder ignored) + { + + } + + private static class SimplerBufferHolder implements BufferHolder + { + private final ByteBuffer bb; + + public SimplerBufferHolder(ByteBuffer bb) + { + this.bb = bb; + } + + @Override + public int position() + { + return 0; + } + + @Override + public int capacity() + { + return bb.capacity(); + } + + @Override + public ByteBuffer get() + { + return bb; + } + } +} diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorFactory.java index ced087bd7de7..0b38cd488233 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/AggregatorFactory.java @@ -28,6 +28,7 @@ import org.apache.druid.segment.vector.VectorColumnSelectorFactory; import javax.annotation.Nullable; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Comparator; import java.util.LinkedHashMap; @@ -212,6 +213,26 @@ public AggregatorFactory getMergingFactory(AggregatorFactory other) throws Aggre */ public abstract int getMaxIntermediateSize(); + /** + * Does BufferAggregator support handling of varying ByteBuffer sizes by overriding + * {@link BufferAggregator#aggregate(ByteBuffer, int, int)} + * @return + */ + public boolean isDynamicallyResizable() + { + return getMinIntermediateSize() < getMaxIntermediateSize(); + } + + /** + * Start size of ByteBuffer to be used with BufferAggregator. + * @return + */ + public int getMinIntermediateSize() + { + return getMaxIntermediateSize(); + } + + /** * Returns the maximum size that this aggregator will require in bytes for intermediate storage of results. * Implementations of {@link AggregatorFactory} which need to Support Nullable Aggregations are encouraged diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/BufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/BufferAggregator.java index 98608546ccc4..57b5c1b6ceac 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/BufferAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/BufferAggregator.java @@ -20,6 +20,8 @@ package org.apache.druid.query.aggregation; import org.apache.druid.guice.annotations.ExtensionPoint; +import org.apache.druid.memory.BufferHolder; +import org.apache.druid.memory.MemoryAllocator; import org.apache.druid.query.monomorphicprocessing.CalledFromHotLoop; import org.apache.druid.query.monomorphicprocessing.HotLoopCallee; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; @@ -209,4 +211,52 @@ default boolean isNull(ByteBuffer buf, int position) return false; } + /** + * Returns false if aggregation requires a bigger buffer than capacity arg or true. + * Return status must be used exclusively to signal "low memory in buffer" condition and + * nothing else. + */ + default boolean aggregate(ByteBuffer buff, int position, int capacity) + { + aggregate(buff, position); + return true; + } + + // Following methods are equivalent of old methods with same name except they provide access to capacity + // of ByteBuffer which is assumed to be AggregatorFactory.getMaxIntermediateSize() by older methods. + + default void init(ByteBuffer buff, int position, int capacity) + { + init(buff, position); + } + + default void relocate(ByteBuffer oldBuff, int oldPosition, int oldCapacity, ByteBuffer newwBuff, int newwPosition, int newwCapacity) + { + relocate(oldPosition, newwPosition, oldBuff, newwBuff); + } + + default Object get(ByteBuffer buff, int position, int capacity) + { + return get(buff, position); + } + + default float getFloat(ByteBuffer buff, int position, int capacity) + { + return getFloat(buff, position); + } + + default double getDouble(ByteBuffer buff, int position, int capacity) + { + return getDouble(buff, position); + } + + default long getLong(ByteBuffer buff, int position, int capacity) + { + return getLong(buff, position); + } + + default boolean isNull(ByteBuffer buff, int position, int capacity) + { + return isNull(buff, position); + } } diff --git a/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java b/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java index 515c47571f42..dc08d56b14cd 100644 --- a/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java +++ b/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java @@ -489,10 +489,6 @@ protected abstract AddToFactsResult addToFacts( public abstract int getLastRowIndex(); - protected abstract AggregatorType[] getAggsForRow(int rowOffset); - - protected abstract Object getAggVal(AggregatorType agg, int rowOffset, int aggPosition); - protected abstract float getMetricFloatValue(int rowOffset, int aggOffset); protected abstract long getMetricLongValue(int rowOffset, int aggOffset); @@ -1029,9 +1025,8 @@ public Iterator iterator() theVals.put(dimensionName, rowVals); } - AggregatorType[] aggs = getAggsForRow(rowOffset); - for (int i = 0; i < aggs.length; ++i) { - theVals.put(metrics[i].getName(), getAggVal(aggs[i], rowOffset, i)); + for (int i = 0; i < metrics.length; ++i) { + theVals.put(metrics[i].getName(), getMetricObjectValue(rowOffset, i)); } if (postAggs != null) { diff --git a/processing/src/main/java/org/apache/druid/segment/incremental/OffheapIncrementalIndex.java b/processing/src/main/java/org/apache/druid/segment/incremental/OffheapIncrementalIndex.java index 95c88fc9606a..b8ecfc24bdd6 100644 --- a/processing/src/main/java/org/apache/druid/segment/incremental/OffheapIncrementalIndex.java +++ b/processing/src/main/java/org/apache/druid/segment/incremental/OffheapIncrementalIndex.java @@ -262,20 +262,6 @@ public String getOutOfRowsReason() return outOfRowsReason; } - @Override - protected BufferAggregator[] getAggsForRow(int rowOffset) - { - return getAggs(); - } - - @Override - protected Object getAggVal(BufferAggregator agg, int rowOffset, int aggPosition) - { - int[] indexAndOffset = indexAndOffsets.get(rowOffset); - ByteBuffer bb = aggBuffers.get(indexAndOffset[0]).get(); - return agg.get(bb, indexAndOffset[1] + aggOffsetInBuffer[aggPosition]); - } - @Override public float getMetricFloatValue(int rowOffset, int aggOffset) { diff --git a/processing/src/main/java/org/apache/druid/segment/incremental/OnheapIncrementalIndex.java b/processing/src/main/java/org/apache/druid/segment/incremental/OnheapIncrementalIndex.java index 80e21a08c493..6b6bef82bebd 100644 --- a/processing/src/main/java/org/apache/druid/segment/incremental/OnheapIncrementalIndex.java +++ b/processing/src/main/java/org/apache/druid/segment/incremental/OnheapIncrementalIndex.java @@ -25,8 +25,11 @@ import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.parsers.ParseException; -import org.apache.druid.query.aggregation.Aggregator; +import org.apache.druid.memory.BufferHolder; +import org.apache.druid.memory.MemoryAllocator; +import org.apache.druid.memory.SimpleOnHeapMemoryAllocator; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.BufferAggregator; import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnValueSelector; @@ -47,20 +50,20 @@ /** * */ -public class OnheapIncrementalIndex extends IncrementalIndex +public class OnheapIncrementalIndex extends IncrementalIndex { private static final Logger log = new Logger(OnheapIncrementalIndex.class); /** * overhead per {@link ConcurrentHashMap.Node} or {@link java.util.concurrent.ConcurrentSkipListMap.Node} object */ private static final int ROUGH_OVERHEAD_PER_MAP_ENTRY = Long.BYTES * 5 + Integer.BYTES; - private final ConcurrentHashMap aggregators = new ConcurrentHashMap<>(); + private final ConcurrentHashMap aggregators = new ConcurrentHashMap<>(); private final FactsHolder facts; private final AtomicInteger indexIncrement = new AtomicInteger(0); - private final long maxBytesPerRowForAggregators; + private final long initialBytesPerRowForAggregators; protected final int maxRowCount; protected final long maxBytesInMemory; - private volatile Map selectors; + private final MemoryAllocator memoryAllocator = new SimpleOnHeapMemoryAllocator(); private String outOfRowsReason = null; @@ -79,7 +82,7 @@ public class OnheapIncrementalIndex extends IncrementalIndex this.maxBytesInMemory = maxBytesInMemory == 0 ? Long.MAX_VALUE : maxBytesInMemory; this.facts = incrementalIndexSchema.isRollup() ? new RollupFactsHolder(sortFacts, dimsComparator(), getDimensions()) : new PlainFactsHolder(sortFacts, dimsComparator()); - maxBytesPerRowForAggregators = getMaxBytesPerRowForAggregators(incrementalIndexSchema); + initialBytesPerRowForAggregators = getInitialBytesPerRowForAggregators(incrementalIndexSchema); } /** @@ -101,14 +104,14 @@ public class OnheapIncrementalIndex extends IncrementalIndex * * @return long max aggregator size in bytes */ - private static long getMaxBytesPerRowForAggregators(IncrementalIndexSchema incrementalIndexSchema) + private static long getInitialBytesPerRowForAggregators(IncrementalIndexSchema incrementalIndexSchema) { - long maxAggregatorIntermediateSize = Integer.BYTES * incrementalIndexSchema.getMetrics().length; - maxAggregatorIntermediateSize += Arrays.stream(incrementalIndexSchema.getMetrics()) - .mapToLong(aggregator -> aggregator.getMaxIntermediateSizeWithNulls() + long initialSize = Integer.BYTES * incrementalIndexSchema.getMetrics().length; + initialSize += Arrays.stream(incrementalIndexSchema.getMetrics()) + .mapToLong(aggregator -> aggregator.getMinIntermediateSize() + Long.BYTES * 2) .sum(); - return maxAggregatorIntermediateSize; + return initialSize; } @Override @@ -118,25 +121,22 @@ public FactsHolder getFacts() } @Override - protected Aggregator[] initAggs( + protected BufferAggregator[] initAggs( final AggregatorFactory[] metrics, final Supplier rowSupplier, final boolean deserializeComplexMetrics, final boolean concurrentEventAdd ) { - selectors = new HashMap<>(); - for (AggregatorFactory agg : metrics) { - selectors.put( - agg.getName(), - new CachingColumnSelectorFactory( - makeColumnSelectorFactory(agg, rowSupplier, deserializeComplexMetrics), - concurrentEventAdd - ) - ); + BufferAggregator[] aggs = new BufferAggregator[metrics.length]; + for (int i = 0; i < metrics.length; i++) { + aggs[i] = metrics[i].factorizeBuffered(new CachingColumnSelectorFactory( + makeColumnSelectorFactory(metrics[i], rowSupplier, deserializeComplexMetrics), + concurrentEventAdd + )); } - return new Aggregator[metrics.length]; + return aggs; } @Override @@ -151,20 +151,19 @@ protected AddToFactsResult addToFacts( List parseExceptionMessages; final int priorIndex = facts.getPriorIndex(key); - Aggregator[] aggs; + BufferHolder[] bufferHolders; final AggregatorFactory[] metrics = getMetrics(); final AtomicInteger numEntries = getNumEntries(); final AtomicLong sizeInBytes = getBytesInMemory(); if (IncrementalIndexRow.EMPTY_ROW_INDEX != priorIndex) { - aggs = concurrentGet(priorIndex); - parseExceptionMessages = doAggregate(metrics, aggs, rowContainer, row); + bufferHolders = concurrentGet(priorIndex); + parseExceptionMessages = doAggregate(metrics, getAggs(), bufferHolders, rowContainer, row); } else { - aggs = new Aggregator[metrics.length]; - factorizeAggs(metrics, aggs, rowContainer, row); - parseExceptionMessages = doAggregate(metrics, aggs, rowContainer, row); + bufferHolders = factorizeAggs(metrics, getAggs(), rowContainer, row); + parseExceptionMessages = doAggregate(metrics, getAggs(), bufferHolders, rowContainer, row); final int rowIndex = indexIncrement.getAndIncrement(); - concurrentSet(rowIndex, aggs); + concurrentSet(rowIndex, bufferHolders); // Last ditch sanity checks if ((numEntries.get() >= maxRowCount || sizeInBytes.get() >= maxBytesInMemory) @@ -179,12 +178,12 @@ protected AddToFactsResult addToFacts( final int prev = facts.putIfAbsent(key, rowIndex); if (IncrementalIndexRow.EMPTY_ROW_INDEX == prev) { numEntries.incrementAndGet(); - long estimatedRowSize = estimateRowSizeInBytes(key, maxBytesPerRowForAggregators); + long estimatedRowSize = estimateRowSizeInBytes(key, initialBytesPerRowForAggregators); sizeInBytes.addAndGet(estimatedRowSize); } else { // We lost a race - aggs = concurrentGet(prev); - parseExceptionMessages = doAggregate(metrics, aggs, rowContainer, row); + bufferHolders = concurrentGet(prev); + parseExceptionMessages = doAggregate(metrics, getAggs(), bufferHolders, rowContainer, row); // Free up the misfire concurrentRemove(rowIndex); // This is expected to occur ~80% of the time in the worst scenarios @@ -203,13 +202,13 @@ protected AddToFactsResult addToFacts( * * * @param key TimeAndDims key - * @param maxBytesPerRowForAggregators max size per aggregator + * @param initialBytesPerRowForAggregators initial size for aggregators per row * * @return estimated size of row */ - private long estimateRowSizeInBytes(IncrementalIndexRow key, long maxBytesPerRowForAggregators) + private long estimateRowSizeInBytes(IncrementalIndexRow key, long initialBytesPerRowForAggregators) { - return ROUGH_OVERHEAD_PER_MAP_ENTRY + key.estimateBytesInMemory() + maxBytesPerRowForAggregators; + return ROUGH_OVERHEAD_PER_MAP_ENTRY + key.estimateBytesInMemory() + initialBytesPerRowForAggregators; } @Override @@ -218,24 +217,28 @@ public int getLastRowIndex() return indexIncrement.get() - 1; } - private void factorizeAggs( + private BufferHolder[] factorizeAggs( AggregatorFactory[] metrics, - Aggregator[] aggs, + BufferAggregator[] aggs, ThreadLocal rowContainer, InputRow row ) { rowContainer.set(row); + + BufferHolder[] buffs = new BufferHolder[metrics.length]; for (int i = 0; i < metrics.length; i++) { - final AggregatorFactory agg = metrics[i]; - aggs[i] = agg.factorize(selectors.get(agg.getName())); + buffs[i] = memoryAllocator.allocate(metrics[i].getMinIntermediateSize()); + aggs[i].init(buffs[i].get(), buffs[i].position(), buffs[i].capacity()); } rowContainer.set(null); + return buffs; } private List doAggregate( AggregatorFactory[] metrics, - Aggregator[] aggs, + BufferAggregator[] aggs, + BufferHolder[] bufferHolders, ThreadLocal rowContainer, InputRow row ) @@ -244,10 +247,15 @@ private List doAggregate( rowContainer.set(row); for (int i = 0; i < aggs.length; i++) { - final Aggregator agg = aggs[i]; - synchronized (agg) { + synchronized (aggs[i]) { try { - agg.aggregate(); + while (!aggs[i].aggregate(bufferHolders[i].get(), bufferHolders[i].position(), bufferHolders[i].capacity())) { + BufferHolder old = bufferHolders[i]; + BufferHolder bigger = memoryAllocator.allocate(2*old.capacity()); + aggs[i].relocate(old.get(), old.position(), old.capacity(), bigger.get(), bigger.position(), bigger.capacity()); + getBytesInMemory().addAndGet(bigger.capacity() - old.capacity()); + memoryAllocator.free(old); + } } catch (ParseException e) { // "aggregate" can throw ParseExceptions if a selector expects something but gets something else. @@ -264,9 +272,9 @@ private List doAggregate( private void closeAggregators() { Closer closer = Closer.create(); - for (Aggregator[] aggs : aggregators.values()) { - for (Aggregator agg : aggs) { - closer.register(agg); + for (BufferHolder[] bufferHolders : aggregators.values()) { + for (BufferHolder bh : bufferHolders) { + closer.register(() -> memoryAllocator.free(bh)); } } @@ -278,13 +286,13 @@ private void closeAggregators() } } - protected Aggregator[] concurrentGet(int offset) + protected BufferHolder[] concurrentGet(int offset) { // All get operations should be fine return aggregators.get(offset); } - protected void concurrentSet(int offset, Aggregator[] value) + protected void concurrentSet(int offset, BufferHolder[] value) { aggregators.put(offset, value); } @@ -324,46 +332,39 @@ public String getOutOfRowsReason() return outOfRowsReason; } - @Override - protected Aggregator[] getAggsForRow(int rowOffset) - { - return concurrentGet(rowOffset); - } - - @Override - protected Object getAggVal(Aggregator agg, int rowOffset, int aggPosition) - { - return agg.get(); - } - @Override public float getMetricFloatValue(int rowOffset, int aggOffset) { - return concurrentGet(rowOffset)[aggOffset].getFloat(); + BufferHolder bh = concurrentGet(rowOffset)[aggOffset]; + return getAggs()[aggOffset].getFloat(bh.get(), bh.position(), bh.capacity()); } @Override public long getMetricLongValue(int rowOffset, int aggOffset) { - return concurrentGet(rowOffset)[aggOffset].getLong(); + BufferHolder bh = concurrentGet(rowOffset)[aggOffset]; + return getAggs()[aggOffset].getLong(bh.get(), bh.position(), bh.capacity()); } @Override public Object getMetricObjectValue(int rowOffset, int aggOffset) { - return concurrentGet(rowOffset)[aggOffset].get(); + BufferHolder bh = concurrentGet(rowOffset)[aggOffset]; + return getAggs()[aggOffset].get(bh.get(), bh.position(), bh.capacity()); } @Override protected double getMetricDoubleValue(int rowOffset, int aggOffset) { - return concurrentGet(rowOffset)[aggOffset].getDouble(); + BufferHolder bh = concurrentGet(rowOffset)[aggOffset]; + return getAggs()[aggOffset].getDouble(bh.get(), bh.position(), bh.capacity()); } @Override public boolean isNull(int rowOffset, int aggOffset) { - return concurrentGet(rowOffset)[aggOffset].isNull(); + BufferHolder bh = concurrentGet(rowOffset)[aggOffset]; + return getAggs()[aggOffset].isNull(bh.get(), bh.position(), bh.capacity()); } /** @@ -377,9 +378,6 @@ public void close() closeAggregators(); aggregators.clear(); facts.clear(); - if (selectors != null) { - selectors.clear(); - } } /**