From de973eecb023911a98a03f6cc5a629ef2acff43d Mon Sep 17 00:00:00 2001 From: Jihoon Son Date: Fri, 29 Jun 2018 10:36:55 -0700 Subject: [PATCH] [SQL] Fix missing postAggregations for Timeseries and TopN (#5912) * [SQL] Fix missing postAggregations for Timeseries and TopN * fix build * fix test --- .../java/io/druid/segment/VirtualColumns.java | 6 ++ .../io/druid/sql/calcite/rel/DruidQuery.java | 15 ++- .../druid/sql/calcite/CalciteQueryTest.java | 99 +++++++++++++++++++ 3 files changed, 118 insertions(+), 2 deletions(-) diff --git a/processing/src/main/java/io/druid/segment/VirtualColumns.java b/processing/src/main/java/io/druid/segment/VirtualColumns.java index 409e3d693d42..189f3cd44585 100644 --- a/processing/src/main/java/io/druid/segment/VirtualColumns.java +++ b/processing/src/main/java/io/druid/segment/VirtualColumns.java @@ -25,6 +25,7 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import io.druid.java.util.common.Cacheable; @@ -70,6 +71,11 @@ public static Pair splitColumnName(String columnName) } } + public static VirtualColumns create(VirtualColumn...virtualColumns) + { + return create(Lists.newArrayList(virtualColumns)); + } + @JsonCreator public static VirtualColumns create(List virtualColumns) { diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java index 2f6fde564358..d10751010bdd 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java @@ -786,6 +786,12 @@ public TimeseriesQuery toTimeseriesQuery() } final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature); + + final List postAggregators = new ArrayList<>(grouping.getPostAggregators()); + if (sortProject != null) { + postAggregators.addAll(sortProject.getPostAggregators()); + } + final Map theContext = Maps.newHashMap(); theContext.put("skipEmptyBuckets", true); theContext.putAll(plannerContext.getQueryContext()); @@ -798,7 +804,7 @@ public TimeseriesQuery toTimeseriesQuery() filtration.getDimFilter(), queryGranularity, grouping.getAggregatorFactories(), - grouping.getPostAggregators(), + postAggregators, ImmutableSortedMap.copyOf(theContext) ); } @@ -857,6 +863,11 @@ public TopNQuery toTopNQuery() final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature); + final List postAggregators = new ArrayList<>(grouping.getPostAggregators()); + if (sortProject != null) { + postAggregators.addAll(sortProject.getPostAggregators()); + } + return new TopNQuery( dataSource, getVirtualColumns(plannerContext.getExprMacroTable(), true), @@ -867,7 +878,7 @@ public TopNQuery toTopNQuery() filtration.getDimFilter(), Granularities.ALL, grouping.getAggregatorFactories(), - grouping.getPostAggregators(), + postAggregators, ImmutableSortedMap.copyOf(plannerContext.getQueryContext()) ); } diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index dccdc5f3ef08..837fd6e83c81 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -85,6 +85,7 @@ import io.druid.query.topn.InvertedTopNMetricSpec; import io.druid.query.topn.NumericTopNMetricSpec; import io.druid.query.topn.TopNQueryBuilder; +import io.druid.segment.VirtualColumns; import io.druid.segment.column.Column; import io.druid.segment.column.ValueType; import io.druid.segment.virtual.ExpressionVirtualColumn; @@ -6635,6 +6636,104 @@ public void testSortProjectAfterNestedGroupBy() throws Exception ); } + @Test + public void testPostAggWithTimeseries() throws Exception + { + testQuery( + "SELECT " + + " FLOOR(__time TO YEAR), " + + " SUM(m1), " + + " SUM(m1) + SUM(m2) " + + "FROM " + + " druid.foo " + + "WHERE " + + " dim2 = 'a' " + + "GROUP BY FLOOR(__time TO YEAR) " + + "ORDER BY FLOOR(__time TO YEAR) desc", + Collections.singletonList( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(QSS(Filtration.eternity())) + .filters(SELECTOR("dim2", "a", null)) + .granularity(Granularities.YEAR) + .aggregators( + AGGS( + new DoubleSumAggregatorFactory("a0", "m1"), + new DoubleSumAggregatorFactory("a1", "m2") + ) + ) + .postAggregators( + EXPRESSION_POST_AGG("p0", "(\"a0\" + \"a1\")") + ) + .descending(true) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{978307200000L, 4.0, 8.0}, + new Object[]{946684800000L, 1.0, 2.0} + ) + ); + } + + @Test + public void testPostAggWithTopN() throws Exception + { + testQuery( + "SELECT " + + " FLOOR(__time TO SECOND), " + + " AVG(m2), " + + " SUM(m1) + SUM(m2) " + + "FROM " + + " druid.foo " + + "WHERE " + + " dim2 = 'a' " + + "GROUP BY FLOOR(__time TO SECOND) " + + "ORDER BY FLOOR(__time TO SECOND) " + + "LIMIT 5", + Collections.singletonList( + new TopNQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(QSS(Filtration.eternity())) + .granularity(Granularities.ALL) + .dimension(new DefaultDimensionSpec("d0:v", "d0", ValueType.LONG)) + .virtualColumns( + VirtualColumns.create( + EXPRESSION_VIRTUAL_COLUMN("d0:v", "timestamp_floor(\"__time\",'PT1S','','UTC')", ValueType.LONG) + ) + ) + .filters("dim2", "a") + .aggregators(AGGS( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new CountAggregatorFactory("a0:count"), + new DoubleSumAggregatorFactory("a1", "m1"), + new DoubleSumAggregatorFactory("a2", "m2") + )) + .postAggregators( + ImmutableList.of( + new ArithmeticPostAggregator( + "a0", + "quotient", + ImmutableList.of( + new FieldAccessPostAggregator(null, "a0:sum"), + new FieldAccessPostAggregator(null, "a0:count") + ) + ), + EXPRESSION_POST_AGG("p0", "(\"a1\" + \"a2\")") + ) + ) + .metric(new DimensionTopNMetricSpec(null, StringComparators.NUMERIC)) + .threshold(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{946684800000L, 1.0, 2.0}, + new Object[]{978307200000L, 4.0, 8.0} + ) + ); + } + private void testQuery( final String sql, final List expectedQueries,