diff --git a/processing/src/main/java/io/druid/segment/VirtualColumns.java b/processing/src/main/java/io/druid/segment/VirtualColumns.java index 409e3d693d42..189f3cd44585 100644 --- a/processing/src/main/java/io/druid/segment/VirtualColumns.java +++ b/processing/src/main/java/io/druid/segment/VirtualColumns.java @@ -25,6 +25,7 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import io.druid.java.util.common.Cacheable; @@ -70,6 +71,11 @@ public static Pair splitColumnName(String columnName) } } + public static VirtualColumns create(VirtualColumn...virtualColumns) + { + return create(Lists.newArrayList(virtualColumns)); + } + @JsonCreator public static VirtualColumns create(List virtualColumns) { diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java index 5503f50adf98..0d0532e5ff21 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidQuery.java @@ -778,6 +778,12 @@ public TimeseriesQuery toTimeseriesQuery() } final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature); + + final List postAggregators = new ArrayList<>(grouping.getPostAggregators()); + if (sortProject != null) { + postAggregators.addAll(sortProject.getPostAggregators()); + } + final Map theContext = Maps.newHashMap(); theContext.put("skipEmptyBuckets", true); theContext.putAll(plannerContext.getQueryContext()); @@ -790,7 +796,7 @@ public TimeseriesQuery toTimeseriesQuery() filtration.getDimFilter(), queryGranularity, grouping.getAggregatorFactories(), - grouping.getPostAggregators(), + postAggregators, ImmutableSortedMap.copyOf(theContext) ); } @@ -849,6 +855,11 @@ public TopNQuery toTopNQuery() final Filtration filtration = Filtration.create(filter).optimize(sourceRowSignature); + final List postAggregators = new ArrayList<>(grouping.getPostAggregators()); + if (sortProject != null) { + postAggregators.addAll(sortProject.getPostAggregators()); + } + return new TopNQuery( dataSource, getVirtualColumns(plannerContext.getExprMacroTable(), true), @@ -859,7 +870,7 @@ public TopNQuery toTopNQuery() filtration.getDimFilter(), Granularities.ALL, grouping.getAggregatorFactories(), - grouping.getPostAggregators(), + postAggregators, ImmutableSortedMap.copyOf(plannerContext.getQueryContext()) ); } diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index 3bd269028302..d558f777b425 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -82,6 +82,7 @@ import io.druid.query.topn.InvertedTopNMetricSpec; import io.druid.query.topn.NumericTopNMetricSpec; import io.druid.query.topn.TopNQueryBuilder; +import io.druid.segment.VirtualColumns; import io.druid.segment.column.Column; import io.druid.segment.column.ValueType; import io.druid.segment.virtual.ExpressionVirtualColumn; @@ -6721,6 +6722,104 @@ public void testSortProjectAfterNestedGroupBy() throws Exception ); } + @Test + public void testPostAggWithTimeseries() throws Exception + { + testQuery( + "SELECT " + + " FLOOR(__time TO YEAR), " + + " SUM(m1), " + + " SUM(m1) + SUM(m2) " + + "FROM " + + " druid.foo " + + "WHERE " + + " dim2 = 'a' " + + "GROUP BY FLOOR(__time TO YEAR) " + + "ORDER BY FLOOR(__time TO YEAR) desc", + Collections.singletonList( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(QSS(Filtration.eternity())) + .filters(SELECTOR("dim2", "a", null)) + .granularity(Granularities.YEAR) + .aggregators( + AGGS( + new DoubleSumAggregatorFactory("a0", "m1"), + new DoubleSumAggregatorFactory("a1", "m2") + ) + ) + .postAggregators( + EXPRESSION_POST_AGG("p0", "(\"a0\" + \"a1\")") + ) + .descending(true) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{978307200000L, 4.0, 8.0}, + new Object[]{946684800000L, 1.0, 2.0} + ) + ); + } + + @Test + public void testPostAggWithTopN() throws Exception + { + testQuery( + "SELECT " + + " FLOOR(__time TO SECOND), " + + " AVG(m2), " + + " SUM(m1) + SUM(m2) " + + "FROM " + + " druid.foo " + + "WHERE " + + " dim2 = 'a' " + + "GROUP BY FLOOR(__time TO SECOND) " + + "ORDER BY FLOOR(__time TO SECOND) " + + "LIMIT 5", + Collections.singletonList( + new TopNQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(QSS(Filtration.eternity())) + .granularity(Granularities.ALL) + .dimension(new DefaultDimensionSpec("d0:v", "d0", ValueType.LONG)) + .virtualColumns( + VirtualColumns.create( + EXPRESSION_VIRTUAL_COLUMN("d0:v", "timestamp_floor(\"__time\",'PT1S','','UTC')", ValueType.LONG) + ) + ) + .filters("dim2", "a") + .aggregators(AGGS( + new DoubleSumAggregatorFactory("a0:sum", "m2"), + new CountAggregatorFactory("a0:count"), + new DoubleSumAggregatorFactory("a1", "m1"), + new DoubleSumAggregatorFactory("a2", "m2") + )) + .postAggregators( + ImmutableList.of( + new ArithmeticPostAggregator( + "a0", + "quotient", + ImmutableList.of( + new FieldAccessPostAggregator(null, "a0:sum"), + new FieldAccessPostAggregator(null, "a0:count") + ) + ), + EXPRESSION_POST_AGG("p0", "(\"a1\" + \"a2\")") + ) + ) + .metric(new DimensionTopNMetricSpec(null, StringComparators.NUMERIC)) + .threshold(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{946684800000L, 1.0, 2.0}, + new Object[]{978307200000L, 4.0, 8.0} + ) + ); + } + private void testQuery( final String sql, final List expectedQueries,