diff --git a/processing/src/main/java/io/druid/query/aggregation/post/ExpressionPostAggregator.java b/processing/src/main/java/io/druid/query/aggregation/post/ExpressionPostAggregator.java index 090363d7a2aa..3e37c5d38b4e 100644 --- a/processing/src/main/java/io/druid/query/aggregation/post/ExpressionPostAggregator.java +++ b/processing/src/main/java/io/druid/query/aggregation/post/ExpressionPostAggregator.java @@ -22,8 +22,11 @@ import com.fasterxml.jackson.annotation.JacksonInject; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Function; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; import io.druid.java.util.common.guava.Comparators; import io.druid.math.expr.Expr; import io.druid.math.expr.ExprMacroTable; @@ -32,10 +35,12 @@ import io.druid.query.aggregation.PostAggregator; import io.druid.query.cache.CacheKeyBuilder; +import javax.annotation.Nullable; import java.util.Comparator; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; public class ExpressionPostAggregator implements PostAggregator { @@ -55,6 +60,7 @@ public class ExpressionPostAggregator implements PostAggregator private final Comparator comparator; private final String ordering; private final ExprMacroTable macroTable; + private final Map> finalizers; private final Expr parsed; private final Set dependentFields; @@ -69,6 +75,20 @@ public ExpressionPostAggregator( @JsonProperty("ordering") String ordering, @JacksonInject ExprMacroTable macroTable ) + { + this(name, expression, ordering, macroTable, ImmutableMap.of()); + } + + /** + * Constructor for {@link #decorate(Map)}. + */ + private ExpressionPostAggregator( + final String name, + final String expression, + @Nullable final String ordering, + final ExprMacroTable macroTable, + final Map> finalizers + ) { Preconditions.checkArgument(expression != null, "expression cannot be null"); @@ -77,15 +97,12 @@ public ExpressionPostAggregator( this.ordering = ordering; this.comparator = ordering == null ? DEFAULT_COMPARATOR : Ordering.valueOf(ordering); this.macroTable = macroTable; + this.finalizers = finalizers; + this.parsed = Parser.parse(expression, macroTable); this.dependentFields = ImmutableSet.copyOf(Parser.findRequiredBindings(parsed)); } - public ExpressionPostAggregator(String name, String fnName) - { - this(name, fnName, null, ExprMacroTable.nil()); - } - @Override public Set getDependentFields() { @@ -101,7 +118,16 @@ public Comparator getComparator() @Override public Object compute(Map values) { - return parsed.eval(Parser.withMap(values)).value(); + // Maps.transformEntries is lazy, will only finalize values we actually read. + final Map finalizedValues = Maps.transformEntries( + values, + (String k, Object v) -> { + final Function finalizer = finalizers.get(k); + return finalizer != null ? finalizer.apply(v) : v; + } + ); + + return parsed.eval(Parser.withMap(finalizedValues)).value(); } @Override @@ -112,9 +138,20 @@ public String getName() } @Override - public ExpressionPostAggregator decorate(Map aggregators) + public ExpressionPostAggregator decorate(final Map aggregators) { - return this; + return new ExpressionPostAggregator( + name, + expression, + ordering, + macroTable, + aggregators.entrySet().stream().collect( + Collectors.toMap( + entry -> entry.getKey(), + entry -> entry.getValue()::finalizeComputation + ) + ) + ); } @JsonProperty("expression") diff --git a/processing/src/test/java/io/druid/query/topn/TopNQueryRunnerTest.java b/processing/src/test/java/io/druid/query/topn/TopNQueryRunnerTest.java index bc3f4f3243eb..c686e08c9a1a 100644 --- a/processing/src/test/java/io/druid/query/topn/TopNQueryRunnerTest.java +++ b/processing/src/test/java/io/druid/query/topn/TopNQueryRunnerTest.java @@ -640,6 +640,56 @@ public void testTopNOverHyperUniqueFinalizingPostAggregator() assertExpectedResults(expectedResults, query); } + @Test + public void testTopNOverHyperUniqueExpression() + { + TopNQuery query = new TopNQueryBuilder() + .dataSource(QueryRunnerTestHelper.dataSource) + .granularity(QueryRunnerTestHelper.allGran) + .dimension(QueryRunnerTestHelper.marketDimension) + .metric(QueryRunnerTestHelper.hyperUniqueFinalizingPostAggMetric) + .threshold(3) + .intervals(QueryRunnerTestHelper.fullOnInterval) + .aggregators( + Arrays.asList(QueryRunnerTestHelper.qualityUniques) + ) + .postAggregators( + Collections.singletonList(new ExpressionPostAggregator( + QueryRunnerTestHelper.hyperUniqueFinalizingPostAggMetric, + "uniques + 1", + null, + TestExprMacroTable.INSTANCE + )) + ) + .build(); + + List> expectedResults = Arrays.asList( + new Result<>( + new DateTime("2011-01-12T00:00:00.000Z"), + new TopNResultValue( + Arrays.>asList( + ImmutableMap.builder() + .put("market", "spot") + .put(QueryRunnerTestHelper.uniqueMetric, QueryRunnerTestHelper.UNIQUES_9) + .put(QueryRunnerTestHelper.hyperUniqueFinalizingPostAggMetric, QueryRunnerTestHelper.UNIQUES_9 + 1) + .build(), + ImmutableMap.builder() + .put("market", "total_market") + .put(QueryRunnerTestHelper.uniqueMetric, QueryRunnerTestHelper.UNIQUES_2) + .put(QueryRunnerTestHelper.hyperUniqueFinalizingPostAggMetric, QueryRunnerTestHelper.UNIQUES_2 + 1) + .build(), + ImmutableMap.builder() + .put("market", "upfront") + .put(QueryRunnerTestHelper.uniqueMetric, QueryRunnerTestHelper.UNIQUES_2) + .put(QueryRunnerTestHelper.hyperUniqueFinalizingPostAggMetric, QueryRunnerTestHelper.UNIQUES_2 + 1) + .build() + ) + ) + ) + ); + assertExpectedResults(expectedResults, query); + } + @Test public void testTopNOverFirstLastAggregator() { diff --git a/sql/src/main/java/io/druid/sql/calcite/expression/Expressions.java b/sql/src/main/java/io/druid/sql/calcite/expression/Expressions.java index 0874f3cfb975..a6bcc94d2881 100644 --- a/sql/src/main/java/io/druid/sql/calcite/expression/Expressions.java +++ b/sql/src/main/java/io/druid/sql/calcite/expression/Expressions.java @@ -210,7 +210,8 @@ public static PostAggregator toPostAggregator( final String name, final List rowOrder, final List finalizingPostAggregatorFactories, - final RexNode expression + final RexNode expression, + final PlannerContext plannerContext ) { final PostAggregator retVal; @@ -226,7 +227,7 @@ public static PostAggregator toPostAggregator( // types internally and there isn't much we can do to respect // TODO(gianm): Probably not a good idea to ignore CAST like this. final RexNode operand = ((RexCall) expression).getOperands().get(0); - retVal = toPostAggregator(name, rowOrder, finalizingPostAggregatorFactories, operand); + retVal = toPostAggregator(name, rowOrder, finalizingPostAggregatorFactories, operand, plannerContext); } else if (expression.getKind() == SqlKind.LITERAL && SqlTypeName.NUMERIC_TYPES.contains(expression.getType().getSqlTypeName())) { retVal = new ConstantPostAggregator(name, (Number) RexLiteral.value(expression)); @@ -246,7 +247,8 @@ public static PostAggregator toPostAggregator( null, rowOrder, finalizingPostAggregatorFactories, - operand + operand, + plannerContext ); if (translatedOperand == null) { return null; @@ -260,7 +262,7 @@ public static PostAggregator toPostAggregator( if (mathExpression == null) { retVal = null; } else { - retVal = new ExpressionPostAggregator(name, mathExpression); + retVal = new ExpressionPostAggregator(name, mathExpression, null, plannerContext.getExprMacroTable()); } } diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java index f5dafc3471ec..ba791141939e 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java @@ -540,7 +540,8 @@ private static DruidRel applyPostAggregation(final DruidRel druidRel, final Proj postAggregatorName, rowOrder, finalizingPostAggregatorFactories, - projectExpression + projectExpression, + druidRel.getPlannerContext() ); if (postAggregator != null) { newAggregations.add(Aggregation.create(postAggregator)); diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index 5438f229e711..6b82125ea64c 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -1529,7 +1529,7 @@ public void testExpressionAggregations() throws Exception new DoubleSumAggregatorFactory("a2", "m1", null, macroTable) )) .postAggregators(ImmutableList.of( - new ExpressionPostAggregator("a3", "log((\"a1\" + \"a2\"))"), + EXPRESSION_POST_AGG("a3", "log((\"a1\" + \"a2\"))"), new ArithmeticPostAggregator("a4", "quotient", ImmutableList.of( new FieldAccessPostAggregator(null, "a1"), new ConstantPostAggregator(null, 0.25) @@ -4416,4 +4416,9 @@ private static List AGGS(final AggregatorFactory... aggregato { return Arrays.asList(aggregators); } + + private static ExpressionPostAggregator EXPRESSION_POST_AGG(final String name, final String expression) + { + return new ExpressionPostAggregator(name, expression, null, CalciteTests.createExprMacroTable()); + } }