From c001230554091914823cadd9514941b7b20872db Mon Sep 17 00:00:00 2001 From: Nishant Date: Tue, 21 Feb 2017 23:08:08 +0530 Subject: [PATCH 1/2] Thread safe reads for aggregators in IncrementalIndex Aggregators are *NOT* thread safe and if two threads concurrently try to read/write to the aggregator the reader may read absurd values since the aggregate method is not *atomic*. In case of IncrementalIndex the writes are protected by a sync block but the reads are unprotected, so its possible for the queries to read absurd values in aggregator.get(). This PR adds a test that can reproduce that behavior by wrapping Aggregators inside a ThreadSafetyAssertionAggregator. TODO: test any performance impacts. --- .../incremental/OffheapIncrementalIndex.java | 16 +- .../incremental/OnheapIncrementalIndex.java | 19 +- .../segment/data/IncrementalIndexTest.java | 15 +- ...hreadSafetyAssertingAggregatorFactory.java | 216 ++++++++++++++++++ 4 files changed, 256 insertions(+), 10 deletions(-) create mode 100644 processing/src/test/java/io/druid/segment/data/ThreadSafetyAssertingAggregatorFactory.java diff --git a/processing/src/main/java/io/druid/segment/incremental/OffheapIncrementalIndex.java b/processing/src/main/java/io/druid/segment/incremental/OffheapIncrementalIndex.java index 4655d7c27623..cd0bc54be092 100644 --- a/processing/src/main/java/io/druid/segment/incremental/OffheapIncrementalIndex.java +++ b/processing/src/main/java/io/druid/segment/incremental/OffheapIncrementalIndex.java @@ -303,7 +303,9 @@ protected Object getAggVal(BufferAggregator agg, int rowOffset, int aggPosition) { int[] indexAndOffset = indexAndOffsets.get(rowOffset); ByteBuffer bb = aggBuffers.get(indexAndOffset[0]).get(); - return agg.get(bb, indexAndOffset[1] + aggOffsetInBuffer[aggPosition]); + synchronized (agg) { + return agg.get(bb, indexAndOffset[1] + aggOffsetInBuffer[aggPosition]); + } } @Override @@ -312,7 +314,9 @@ public float getMetricFloatValue(int rowOffset, int aggOffset) BufferAggregator agg = getAggs()[aggOffset]; int[] indexAndOffset = indexAndOffsets.get(rowOffset); ByteBuffer bb = aggBuffers.get(indexAndOffset[0]).get(); - return agg.getFloat(bb, indexAndOffset[1] + aggOffsetInBuffer[aggOffset]); + synchronized (agg) { + return agg.getFloat(bb, indexAndOffset[1] + aggOffsetInBuffer[aggOffset]); + } } @Override @@ -321,7 +325,9 @@ public long getMetricLongValue(int rowOffset, int aggOffset) BufferAggregator agg = getAggs()[aggOffset]; int[] indexAndOffset = indexAndOffsets.get(rowOffset); ByteBuffer bb = aggBuffers.get(indexAndOffset[0]).get(); - return agg.getLong(bb, indexAndOffset[1] + aggOffsetInBuffer[aggOffset]); + synchronized (agg) { + return agg.getLong(bb, indexAndOffset[1] + aggOffsetInBuffer[aggOffset]); + } } @Override @@ -330,7 +336,9 @@ public Object getMetricObjectValue(int rowOffset, int aggOffset) BufferAggregator agg = getAggs()[aggOffset]; int[] indexAndOffset = indexAndOffsets.get(rowOffset); ByteBuffer bb = aggBuffers.get(indexAndOffset[0]).get(); - return agg.get(bb, indexAndOffset[1] + aggOffsetInBuffer[aggOffset]); + synchronized (agg) { + return agg.get(bb, indexAndOffset[1] + aggOffsetInBuffer[aggOffset]); + } } /** diff --git a/processing/src/main/java/io/druid/segment/incremental/OnheapIncrementalIndex.java b/processing/src/main/java/io/druid/segment/incremental/OnheapIncrementalIndex.java index 7d9514d85f0d..f38dc6c0f36b 100644 --- a/processing/src/main/java/io/druid/segment/incremental/OnheapIncrementalIndex.java +++ b/processing/src/main/java/io/druid/segment/incremental/OnheapIncrementalIndex.java @@ -318,25 +318,36 @@ protected Aggregator[] getAggsForRow(int rowOffset) @Override protected Object getAggVal(Aggregator agg, int rowOffset, int aggPosition) { - return agg.get(); + synchronized (agg) { + return agg.get(); + } } @Override public float getMetricFloatValue(int rowOffset, int aggOffset) { - return concurrentGet(rowOffset)[aggOffset].getFloat(); + Aggregator aggregator = concurrentGet(rowOffset)[aggOffset]; + synchronized (aggregator) { + return aggregator.getFloat(); + } } @Override public long getMetricLongValue(int rowOffset, int aggOffset) { - return concurrentGet(rowOffset)[aggOffset].getLong(); + Aggregator aggregator = concurrentGet(rowOffset)[aggOffset]; + synchronized (aggregator) { + return aggregator.getLong(); + } } @Override public Object getMetricObjectValue(int rowOffset, int aggOffset) { - return concurrentGet(rowOffset)[aggOffset].get(); + Aggregator aggregator = concurrentGet(rowOffset)[aggOffset]; + synchronized (aggregator) { + return aggregator.get(); + } } /** diff --git a/processing/src/test/java/io/druid/segment/data/IncrementalIndexTest.java b/processing/src/test/java/io/druid/segment/data/IncrementalIndexTest.java index 5069fe50b481..b80cd0ffa413 100644 --- a/processing/src/test/java/io/druid/segment/data/IncrementalIndexTest.java +++ b/processing/src/test/java/io/druid/segment/data/IncrementalIndexTest.java @@ -19,6 +19,7 @@ package io.druid.segment.data; +import com.google.common.base.Function; import com.google.common.base.Supplier; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; @@ -142,7 +143,7 @@ public IncrementalIndex createIndex(AggregatorFactory[] factories) @Override public ByteBuffer get() { - return ByteBuffer.allocate(256 * 1024); + return ByteBuffer.allocate(512 * 1024); } } ) @@ -527,7 +528,17 @@ public void testConcurrentAddRead() throws InterruptedException, ExecutionExcept final IncrementalIndex index = closer.closeLater( - indexCreator.createIndex(ingestAggregatorFactories.toArray(new AggregatorFactory[dimensionCount])) + indexCreator.createIndex(Lists.transform( + ingestAggregatorFactories, + new Function() + { + @Override + public AggregatorFactory apply(AggregatorFactory input) + { + return new ThreadSafetyAssertingAggregatorFactory(input); + } + } + ).toArray(new AggregatorFactory[dimensionCount])) ); final int concurrentThreads = 2; final int elementsPerThread = 10_000; diff --git a/processing/src/test/java/io/druid/segment/data/ThreadSafetyAssertingAggregatorFactory.java b/processing/src/test/java/io/druid/segment/data/ThreadSafetyAssertingAggregatorFactory.java new file mode 100644 index 000000000000..9f3b5fb7a28c --- /dev/null +++ b/processing/src/test/java/io/druid/segment/data/ThreadSafetyAssertingAggregatorFactory.java @@ -0,0 +1,216 @@ +package io.druid.segment.data; + +import io.druid.query.aggregation.Aggregator; +import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.BufferAggregator; +import io.druid.segment.ColumnSelectorFactory; +import org.junit.Assert; + +import java.nio.ByteBuffer; +import java.util.Comparator; +import java.util.List; + +/** + * An AggregatorFactory that asserts thread safety. + * If the delegate aggregator factory is accessed in a thread unsafe manner throws AssertionError during read. + */ +public class ThreadSafetyAssertingAggregatorFactory extends AggregatorFactory +{ + private final AggregatorFactory delegate; + + public ThreadSafetyAssertingAggregatorFactory(AggregatorFactory delegate) + { + this.delegate = delegate; + } + + + @Override + public Aggregator factorize(ColumnSelectorFactory metricFactory) + { + final Aggregator delegate1 = delegate.factorize(metricFactory); + final Aggregator delegate2 = delegate.factorize(metricFactory); + return new Aggregator() + { + @Override + public void aggregate() + { + delegate1.aggregate(); + Thread.yield(); + delegate2.aggregate(); + } + + @Override + public void reset() + { + delegate1.reset(); + Thread.yield(); + delegate2.reset(); + } + + @Override + public Object get() + { + Object o1 = delegate1.get(); + Thread.yield(); + Object o2 = delegate2.get(); + Assert.assertEquals("Unsafe Call to aggregator.get()", o1, o2); + return o1; + } + + @Override + public float getFloat() + { + float o1 = delegate1.getFloat(); + Thread.yield(); + float o2 = delegate2.getFloat(); + Assert.assertTrue("Unsafe Call to aggregator.get()", o1 == o2); + return o1; + } + + @Override + public void close() + { + delegate1.close(); + delegate2.close(); + } + + @Override + public long getLong() + { + long o1 = delegate1.getLong(); + Thread.yield(); + long o2 = delegate2.getLong(); + Assert.assertEquals("Unsafe Call to aggregator.get()", o1, o2); + return o1; + } + }; + } + + @Override + public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) + { + final BufferAggregator delegate1 = delegate.factorizeBuffered(metricFactory); + final BufferAggregator delegate2 = delegate.factorizeBuffered(metricFactory); + final int intermediateSize = delegate.getMaxIntermediateSize(); + return new BufferAggregator() + { + @Override + public void init(ByteBuffer buf, int position) + { + delegate1.init(buf, position); + delegate2.init(buf, position + intermediateSize); + } + + @Override + public void aggregate(ByteBuffer buf, int position) + { + delegate1.aggregate(buf, position); + Thread.yield(); + delegate2.aggregate(buf, position + intermediateSize); + } + + @Override + public Object get(ByteBuffer buf, int position) + { + Object o1 = delegate1.get(buf, position); + Thread.yield(); + Object o2 = delegate2.get(buf, position + intermediateSize); + Assert.assertEquals("Unsafe Call to aggregator.get()", o1, o2); + return o1; + } + + @Override + public float getFloat(ByteBuffer buf, int position) + { + float o1 = delegate1.getFloat(buf, position); + Thread.yield(); + float o2 = delegate2.getFloat(buf, position + intermediateSize); + Assert.assertTrue("Unsafe Call to aggregator.get()", o1 == o2); + return o1; + } + + @Override + public void close() + { + delegate1.close(); + delegate2.close(); + } + + @Override + public long getLong(ByteBuffer buf, int position) + { + long o1 = delegate1.getLong(buf, position); + Thread.yield(); + long o2 = delegate2.getLong(buf, position + intermediateSize); + Assert.assertEquals("Unsafe Call to aggregator.get()", o1, o2); + return o1; + } + }; + } + + @Override + public Comparator getComparator() + { + return delegate.getComparator(); + } + + @Override + public Object combine(Object lhs, Object rhs) + { + return delegate.combine(lhs, rhs); + } + + @Override + public AggregatorFactory getCombiningFactory() + { + return delegate.getCombiningFactory(); + } + + @Override + public List getRequiredColumns() + { + return delegate.getRequiredColumns(); + } + + @Override + public Object deserialize(Object object) + { + return delegate.deserialize(object); + } + + @Override + public Object finalizeComputation(Object object) + { + return delegate.finalizeComputation(object); + } + + @Override + public String getName() + { + return delegate.getName(); + } + + @Override + public List requiredFields() + { + return delegate.requiredFields(); + } + + @Override + public String getTypeName() + { + return delegate.getTypeName(); + } + + @Override + public int getMaxIntermediateSize() + { + return 2 * delegate.getMaxIntermediateSize(); + } + + @Override + public byte[] getCacheKey() + { + return delegate.getCacheKey(); + } +} From 8e2e67002d90afc406a6dacaf36f8074296e7f9d Mon Sep 17 00:00:00 2001 From: Nishant Date: Tue, 21 Feb 2017 23:15:26 +0530 Subject: [PATCH 2/2] Add license --- ...hreadSafetyAssertingAggregatorFactory.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/processing/src/test/java/io/druid/segment/data/ThreadSafetyAssertingAggregatorFactory.java b/processing/src/test/java/io/druid/segment/data/ThreadSafetyAssertingAggregatorFactory.java index 9f3b5fb7a28c..c750a3b862a6 100644 --- a/processing/src/test/java/io/druid/segment/data/ThreadSafetyAssertingAggregatorFactory.java +++ b/processing/src/test/java/io/druid/segment/data/ThreadSafetyAssertingAggregatorFactory.java @@ -1,3 +1,22 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package io.druid.segment.data; import io.druid.query.aggregation.Aggregator;