Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,19 @@
import com.yahoo.sketches.theta.Union;
import io.druid.query.aggregation.BufferAggregator;
import io.druid.segment.ObjectColumnSelector;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;

import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.IdentityHashMap;

public class SketchBufferAggregator implements BufferAggregator
{
private final ObjectColumnSelector selector;
private final int size;
private final int maxIntermediateSize;

private NativeMemory nm;

private final Map<Integer, Union> unions = new HashMap<>(); //position in BB -> Union Object
private final IdentityHashMap<ByteBuffer, Int2ObjectMap<Union>> unions = new IdentityHashMap<>();
private final IdentityHashMap<ByteBuffer, NativeMemory> nmCache = new IdentityHashMap<>();

public SketchBufferAggregator(ObjectColumnSelector selector, int size, int maxIntermediateSize)
{
Expand All @@ -52,12 +51,7 @@ public SketchBufferAggregator(ObjectColumnSelector selector, int size, int maxIn
@Override
public void init(ByteBuffer buf, int position)
{
if (nm == null) {
nm = new NativeMemory(buf);
}

Memory mem = new MemoryRegion(nm, position, maxIntermediateSize);
unions.put(position, (Union) SetOperation.builder().initMemory(mem).build(size, Family.UNION));
createNewUnion(buf, position, false);
}

@Override
Expand Down Expand Up @@ -86,12 +80,27 @@ public Object get(ByteBuffer buf, int position)
//Note that this is not threadsafe and I don't think it needs to be
private Union getUnion(ByteBuffer buf, int position)
{
Union union = unions.get(position);
if (union == null) {
Memory mem = new MemoryRegion(nm, position, maxIntermediateSize);
union = (Union) SetOperation.wrap(mem);
unions.put(position, union);
Int2ObjectMap<Union> unionMap = unions.get(buf);
Union union = unionMap != null ? unionMap.get(position) : null;
if (union != null) {
return union;
}
return createNewUnion(buf, position, true);
}

private Union createNewUnion(ByteBuffer buf, int position, boolean isWrapped)
{
NativeMemory nm = getNativeMemory(buf);
Memory mem = new MemoryRegion(nm, position, maxIntermediateSize);
Union union = isWrapped
? (Union) SetOperation.wrap(mem)
: (Union) SetOperation.builder().initMemory(mem).build(size, Family.UNION);
Int2ObjectMap<Union> unionMap = unions.get(buf);
if (unionMap == null) {
unionMap = new Int2ObjectOpenHashMap<>();
unions.put(buf, unionMap);
}
unionMap.put(position, union);
return union;
}

Expand All @@ -113,4 +122,28 @@ public void close()
unions.clear();
}

@Override
public void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
createNewUnion(newBuffer, newPosition, true);
Int2ObjectMap<Union> unionMap = unions.get(oldBuffer);
if (unionMap != null) {
unionMap.remove(oldPosition);
if (unionMap.isEmpty()) {
unions.remove(oldBuffer);
nmCache.remove(oldBuffer);
}
}
}

private NativeMemory getNativeMemory(ByteBuffer buffer)
{
NativeMemory nm = nmCache.get(buffer);
if (nm == null) {
nm = new NativeMemory(buffer);
nmCache.put(buffer, nm);
}
return nm;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -290,4 +290,16 @@ public static SketchHolder sketchSetOperation(Func func, int sketchSize, Object.
throw new IllegalArgumentException("Unknown sketch operation " + func);
}
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
return this.getSketch().equals(((SketchHolder) o).getSketch());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package io.druid.query.aggregation.datasketches.theta;

import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.UpdateSketch;
import io.druid.data.input.MapBasedRow;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.groupby.epinephelinae.BufferGrouper;
import io.druid.query.groupby.epinephelinae.Grouper;
import io.druid.query.groupby.epinephelinae.GrouperTestUtil;
import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory;
import org.junit.Assert;
import org.junit.Test;

import java.nio.ByteBuffer;

public class BufferGrouperUsingSketchMergeAggregatorFactoryTest
{
private static BufferGrouper<Integer> makeGrouper(
TestColumnSelectorFactory columnSelectorFactory,
int bufferSize,
int initialBuckets
)
{
final BufferGrouper<Integer> grouper = new BufferGrouper<>(
Suppliers.ofInstance(ByteBuffer.allocate(bufferSize)),
GrouperTestUtil.intKeySerde(),
columnSelectorFactory,
new AggregatorFactory[]{
new SketchMergeAggregatorFactory("sketch", "sketch", 16, false, true, 2),
new CountAggregatorFactory("count")
},
Integer.MAX_VALUE,
0.75f,
initialBuckets
);
grouper.init();
return grouper;
}

@Test
public void testGrowingBufferGrouper()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
final Grouper<Integer> grouper = makeGrouper(columnSelectorFactory, 100000, 2);
try {
final int expectedMaxSize = 5;

SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16));
UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch();
updateSketch.update(1);

columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));

for (int i = 0; i < expectedMaxSize; i++) {
Assert.assertTrue(String.valueOf(i), grouper.aggregate(i));
}

updateSketch.update(3);
columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));

for (int i = 0; i < expectedMaxSize; i++) {
Assert.assertTrue(String.valueOf(i), grouper.aggregate(i));
}

Object[] holders = Lists.newArrayList(grouper.iterator(true)).get(0).getValues();

Assert.assertEquals(2.0d, ((SketchHolder) holders[0]).getEstimate(), 0);
}
finally {
grouper.close();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.yahoo.sketches.theta.Sketch;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.Union;
import com.yahoo.sketches.theta.UpdateSketch;
import io.druid.data.input.MapBasedRow;
import io.druid.data.input.Row;
import io.druid.java.util.common.granularity.Granularities;
Expand All @@ -39,6 +40,8 @@
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import io.druid.query.groupby.GroupByQueryConfig;
import io.druid.query.groupby.GroupByQueryRunnerTest;
import io.druid.query.groupby.epinephelinae.GrouperTestUtil;
import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory;
import org.joda.time.DateTime;
import org.junit.Assert;
import org.junit.Rule;
Expand Down Expand Up @@ -389,6 +392,23 @@ public void testSketchAggregatorFactoryComparator()
Assert.assertEquals(1, comparator.compare(SketchHolder.of(union2), SketchHolder.of(sketch1)));
}

@Test
public void testRelocation()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16));
UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch();
updateSketch.update(1);

columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));
SketchHolder[] holders = helper.runRelocateVerificationTest(
new SketchMergeAggregatorFactory("sketch", "sketch", 16, false, true, 2),
columnSelectorFactory,
SketchHolder.class
);
Assert.assertEquals(holders[0].getEstimate(), holders[1].getEstimate(), 0);
}

private void assertPostAggregatorSerde(PostAggregator agg) throws Exception
{
Assert.assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import com.yahoo.sketches.theta.Sketches;
import com.yahoo.sketches.theta.UpdateSketch;
import io.druid.data.input.MapBasedRow;
import io.druid.java.util.common.granularity.Granularities;
import io.druid.java.util.common.guava.Sequence;
import io.druid.java.util.common.guava.Sequences;
import io.druid.query.aggregation.AggregationTestHelper;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.datasketches.theta.SketchHolder;
import io.druid.query.aggregation.post.FieldAccessPostAggregator;
import io.druid.query.groupby.GroupByQueryConfig;
import io.druid.query.groupby.GroupByQueryRunnerTest;
import io.druid.query.groupby.epinephelinae.GrouperTestUtil;
import io.druid.query.groupby.epinephelinae.TestColumnSelectorFactory;
import org.joda.time.DateTime;
import org.junit.Assert;
import org.junit.Rule;
Expand Down Expand Up @@ -194,6 +199,23 @@ public void testSketchSetPostAggregatorSerde() throws Exception
);
}

@Test
public void testRelocation()
{
final TestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory();
SketchHolder sketchHolder = SketchHolder.of(Sketches.updateSketchBuilder().build(16));
UpdateSketch updateSketch = (UpdateSketch) sketchHolder.getSketch();
updateSketch.update(1);

columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.<String, Object>of("sketch", sketchHolder)));
SketchHolder[] holders = helper.runRelocateVerificationTest(
new OldSketchMergeAggregatorFactory("sketch", "sketch", 16, false),
columnSelectorFactory,
SketchHolder.class
);
Assert.assertEquals(holders[0].getEstimate(), holders[1].getEstimate(), 0);
}

private void assertPostAggregatorSerde(PostAggregator agg) throws Exception
{
Assert.assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,27 @@ public interface BufferAggregator
* Release any resources used by the aggregator
*/
void close();

/*
* Relocates any cached objects.
* If underlying ByteBuffer used for aggregation buffer relocates to a new ByteBuffer, positional caches(if any)
* built on top of old ByteBuffer can not be used for further {@link BufferAggregator#aggregate(ByteBuffer, int)}
* calls. This method tells the BufferAggregator that the cached objects at a certain location has been relocated to
* a different location.
*
* Only used if there is any positional caches/objects in the BufferAggregator implementation.
*
* If relocate happens to be across multiple new ByteBuffers (say n ByteBuffers), this method should be called
* multiple times(n times) given all the new positions/old positions should exist in newBuffer/OldBuffer.
*
* <b>Implementations must not change the position, limit or mark of the given buffer</b>
*
* @param oldPosition old position of a cached object before aggregation buffer relocates to a new ByteBuffer.
* @param newPosition new position of a cached object after aggregation buffer relocates to a new ByteBuffer.
* @param oldBuffer old aggregation buffer.
* @param newBuffer new aggregation buffer.
*/
default void relocate(int oldPosition, int newPosition, ByteBuffer oldBuffer, ByteBuffer newBuffer)
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,9 @@ private void growIfPossible()

for (int oldBucket = 0; oldBucket < buckets; oldBucket++) {
if (isUsed(oldBucket)) {
int oldPosition = oldBucket * bucketSize;
entryBuffer.limit((oldBucket + 1) * bucketSize);
entryBuffer.position(oldBucket * bucketSize);
entryBuffer.position(oldPosition);
keyBuffer.limit(entryBuffer.position() + HASH_SIZE + keySize);
keyBuffer.position(entryBuffer.position() + HASH_SIZE);

Expand All @@ -442,9 +443,19 @@ private void growIfPossible()
throw new ISE("WTF?! Couldn't find a bucket while resizing?!");
}

newTableBuffer.position(newBucket * bucketSize);
int newPosition = newBucket * bucketSize;
newTableBuffer.position(newPosition);
newTableBuffer.put(entryBuffer);

for (int i = 0; i < aggregators.length; i++) {
aggregators[i].relocate(
oldPosition + aggregatorOffsets[i],
newPosition + aggregatorOffsets[i],
tableBuffer,
newTableBuffer
);
}

buffer.putInt(tableArenaSize + newSize * Ints.BYTES, newBucket * bucketSize);
newSize++;
}
Expand Down
Loading