Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions processing/src/main/java/org/apache/druid/query/CacheStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.base.Function;
import org.apache.druid.guice.annotations.ExtensionPoint;
import org.apache.druid.query.aggregation.AggregatorFactory;

import java.util.Iterator;
import java.util.concurrent.ExecutorService;
import java.util.function.BiFunction;

/**
*/
Expand Down Expand Up @@ -98,4 +101,31 @@ default Function<CacheType, T> pullFromSegmentLevelCache()
{
return pullFromCache(false);
}

/**
* Helper function used by TopN, GroupBy, Timeseries queries in {@link #pullFromCache(boolean)}.
* When using the result level cache, the agg values seen here are
* finalized values generated by AggregatorFactory.finalizeComputation().
* These finalized values are deserialized from the cache as generic Objects, which will
* later be reserialized and returned to the user without further modification.
* Because the agg values are deserialized as generic Objects, the values are subject to the same
* type consistency issues handled by DimensionHandlerUtils.convertObjectToType() in the pullFromCache implementations
* for dimension values (e.g., a Float would become Double).
*/
static void fetchAggregatorsFromCache(
Iterator<AggregatorFactory> aggIter,
Iterator<Object> resultIter,
boolean isResultLevelCache,
BiFunction<String, Object, Void> addToResultFunction
)
{
while (aggIter.hasNext() && resultIter.hasNext()) {
final AggregatorFactory factory = aggIter.next();
if (isResultLevelCache) {
addToResultFunction.apply(factory.getName(), resultIter.next());
} else {
addToResultFunction.apply(factory.getName(), factory.deserialize(resultIter.next()));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ public Row apply(Object input)

DateTime timestamp = granularity.toDateTime(((Number) results.next()).longValue());

Map<String, Object> event = Maps.newLinkedHashMap();
final Map<String, Object> event = Maps.newLinkedHashMap();
Iterator<DimensionSpec> dimsIter = dims.iterator();
while (dimsIter.hasNext() && results.hasNext()) {
final DimensionSpec dimensionSpec = dimsIter.next();
Expand All @@ -566,12 +566,18 @@ public Row apply(Object input)
DimensionHandlerUtils.convertObjectToType(results.next(), dimensionSpec.getOutputType())
);
}

Iterator<AggregatorFactory> aggsIter = aggs.iterator();
while (aggsIter.hasNext() && results.hasNext()) {
final AggregatorFactory factory = aggsIter.next();
event.put(factory.getName(), factory.deserialize(results.next()));
}

CacheStrategy.fetchAggregatorsFromCache(
aggsIter,
results,
isResultLevelCache,
(aggName, aggValueObject) -> {
event.put(aggName, aggValueObject);
return null;
}
);

if (isResultLevelCache) {
Iterator<PostAggregator> postItr = query.getPostAggregatorSpecs().iterator();
while (postItr.hasNext() && results.hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,23 @@ public Function<Object, Result<TimeseriesResultValue>> pullFromCache(boolean isR
public Result<TimeseriesResultValue> apply(@Nullable Object input)
{
List<Object> results = (List<Object>) input;
Map<String, Object> retVal = Maps.newLinkedHashMap();
final Map<String, Object> retVal = Maps.newLinkedHashMap();

Iterator<AggregatorFactory> aggsIter = aggs.iterator();
Iterator<Object> resultIter = results.iterator();

DateTime timestamp = granularity.toDateTime(((Number) resultIter.next()).longValue());

while (aggsIter.hasNext() && resultIter.hasNext()) {
final AggregatorFactory factory = aggsIter.next();
retVal.put(factory.getName(), factory.deserialize(resultIter.next()));
}
CacheStrategy.fetchAggregatorsFromCache(
aggsIter,
resultIter,
isResultLevelCache,
(aggName, aggValueObject) -> {
retVal.put(aggName, aggValueObject);
return null;
}
);

if (isResultLevelCache) {
Iterator<PostAggregator> postItr = query.getPostAggregatorSpecs().iterator();
while (postItr.hasNext() && resultIter.hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ public Result<TopNResultValue> apply(Object input)

while (inputIter.hasNext()) {
List<Object> result = (List<Object>) inputIter.next();
Map<String, Object> vals = Maps.newLinkedHashMap();
final Map<String, Object> vals = Maps.newLinkedHashMap();

Iterator<AggregatorFactory> aggIter = aggs.iterator();
Iterator<Object> resultIter = result.iterator();
Expand All @@ -409,10 +409,15 @@ public Result<TopNResultValue> apply(Object input)
DimensionHandlerUtils.convertObjectToType(resultIter.next(), query.getDimensionSpec().getOutputType())
);

while (aggIter.hasNext() && resultIter.hasNext()) {
final AggregatorFactory factory = aggIter.next();
vals.put(factory.getName(), factory.deserialize(resultIter.next()));
}
CacheStrategy.fetchAggregatorsFromCache(
aggIter,
resultIter,
isResultLevelCache,
(aggName, aggValueObject) -> {
vals.put(aggName, aggValueObject);
return null;
}
);

for (PostAggregator postAgg : postAggs) {
vals.put(postAgg.getName(), postAgg.compute(vals));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,26 @@

package org.apache.druid.query.groupby;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.data.input.MapBasedRow;
import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.QueryRunnerTestHelper;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FloatSumAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.last.DoubleLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.FloatLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory;
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.expression.TestExprMacroTable;
Expand All @@ -46,10 +57,14 @@
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ValueType;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class GroupByQueryQueryToolChestTest
Expand Down Expand Up @@ -483,4 +498,143 @@ public void testResultLevelCacheKeyWithSubTotalsSpec()
));
}

@Test
public void testCacheStrategy() throws Exception
{
doTestCacheStrategy(ValueType.STRING, "val1");
doTestCacheStrategy(ValueType.FLOAT, 2.1f);
doTestCacheStrategy(ValueType.DOUBLE, 2.1d);
doTestCacheStrategy(ValueType.LONG, 2L);
}

private AggregatorFactory getComplexAggregatorFactoryForValueType(final ValueType valueType)
{
switch (valueType) {
case LONG:
return new LongLastAggregatorFactory("complexMetric", "test");
case DOUBLE:
return new DoubleLastAggregatorFactory("complexMetric", "test");
case FLOAT:
return new FloatLastAggregatorFactory("complexMetric", "test");
case STRING:
return new StringLastAggregatorFactory("complexMetric", "test", null);
default:
throw new IllegalArgumentException("bad valueType: " + valueType);
}
}

private SerializablePair getIntermediateComplexValue(final ValueType valueType, final Object dimValue)
{
switch (valueType) {
case LONG:
case DOUBLE:
case FLOAT:
return new SerializablePair<>(123L, dimValue);
case STRING:
return new SerializablePairLongString(123L, (String) dimValue);
default:
throw new IllegalArgumentException("bad valueType: " + valueType);
}
}

private void doTestCacheStrategy(final ValueType valueType, final Object dimValue) throws IOException
{
final GroupByQuery query1 = GroupByQuery
.builder()
.setDataSource(QueryRunnerTestHelper.dataSource)
.setQuerySegmentSpec(QueryRunnerTestHelper.firstToThird)
.setDimensions(Collections.singletonList(
new DefaultDimensionSpec("test", "test", valueType)
))
.setAggregatorSpecs(
Arrays.asList(
QueryRunnerTestHelper.rowsCount,
getComplexAggregatorFactoryForValueType(valueType)
)
)
.setPostAggregatorSpecs(
ImmutableList.of(new ConstantPostAggregator("post", 10))
)
.setGranularity(QueryRunnerTestHelper.dayGran)
.build();

CacheStrategy<Row, Object, GroupByQuery> strategy =
new GroupByQueryQueryToolChest(null, null).getCacheStrategy(
query1
);

final Row result1 = new MapBasedRow(
// test timestamps that result in integer size millis
DateTimes.utc(123L),
ImmutableMap.of(
"test", dimValue,
"rows", 1,
"complexMetric", getIntermediateComplexValue(valueType, dimValue)
)
);

Object preparedValue = strategy.prepareForSegmentLevelCache().apply(
result1
);

ObjectMapper objectMapper = TestHelper.makeJsonMapper();
Object fromCacheValue = objectMapper.readValue(
objectMapper.writeValueAsBytes(preparedValue),
strategy.getCacheObjectClazz()
);

Row fromCacheResult = strategy.pullFromSegmentLevelCache().apply(fromCacheValue);

Assert.assertEquals(result1, fromCacheResult);

final Row result2 = new MapBasedRow(
// test timestamps that result in integer size millis
DateTimes.utc(123L),
ImmutableMap.of(
"test", dimValue,
"rows", 1,
"complexMetric", dimValue,
"post", 10
)
);

// Please see the comments on aggregator serde and type handling in CacheStrategy.fetchAggregatorsFromCache()
final Row typeAdjustedResult2;
if (valueType == ValueType.FLOAT) {
typeAdjustedResult2 = new MapBasedRow(
DateTimes.utc(123L),
ImmutableMap.of(
"test", dimValue,
"rows", 1,
"complexMetric", 2.1d,
"post", 10
)
);
} else if (valueType == ValueType.LONG) {
typeAdjustedResult2 = new MapBasedRow(
DateTimes.utc(123L),
ImmutableMap.of(
"test", dimValue,
"rows", 1,
"complexMetric", 2,
"post", 10
)
);
} else {
typeAdjustedResult2 = result2;
}


Object preparedResultCacheValue = strategy.prepareForCache(true).apply(
result2
);

Object fromResultCacheValue = objectMapper.readValue(
objectMapper.writeValueAsBytes(preparedResultCacheValue),
strategy.getCacheObjectClazz()
);

Row fromResultCacheResult = strategy.pullFromCache(true).apply(fromResultCacheValue);
Assert.assertEquals(typeAdjustedResult2, fromResultCacheResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
Expand Down Expand Up @@ -77,7 +79,8 @@ public void testCacheStrategy() throws Exception
Granularities.ALL,
ImmutableList.of(
new CountAggregatorFactory("metric1"),
new LongSumAggregatorFactory("metric0", "metric0")
new LongSumAggregatorFactory("metric0", "metric0"),
new StringLastAggregatorFactory("complexMetric", "test", null)
),
ImmutableList.of(new ConstantPostAggregator("post", 10)),
0,
Expand All @@ -89,7 +92,11 @@ public void testCacheStrategy() throws Exception
// test timestamps that result in integer size millis
DateTimes.utc(123L),
new TimeseriesResultValue(
ImmutableMap.of("metric1", 2, "metric0", 3)
ImmutableMap.of(
"metric1", 2,
"metric0", 3,
"complexMetric", new SerializablePairLongString(123L, "val1")
)
)
);

Expand All @@ -109,7 +116,12 @@ public void testCacheStrategy() throws Exception
// test timestamps that result in integer size millis
DateTimes.utc(123L),
new TimeseriesResultValue(
ImmutableMap.of("metric1", 2, "metric0", 3, "post", 10)
ImmutableMap.of(
"metric1", 2,
"metric0", 3,
"complexMetric", "val1",
"post", 10
)
)
);

Expand Down
Loading