From 068001e5598dfc14005be08f36081d5f5ccd16d4 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 6 Nov 2017 15:42:21 -0800 Subject: [PATCH] SQL: Add rule to prune unused aggregations. --- .../io/druid/sql/calcite/planner/Rules.java | 2 + .../ProjectAggregatePruneUnusedCallRule.java | 129 ++++++++++++++++++ .../druid/sql/calcite/CalciteQueryTest.java | 83 ++++++++++- 3 files changed, 213 insertions(+), 1 deletion(-) create mode 100644 sql/src/main/java/io/druid/sql/calcite/rule/ProjectAggregatePruneUnusedCallRule.java diff --git a/sql/src/main/java/io/druid/sql/calcite/planner/Rules.java b/sql/src/main/java/io/druid/sql/calcite/planner/Rules.java index 4e91f265e263..4e98a4919f55 100644 --- a/sql/src/main/java/io/druid/sql/calcite/planner/Rules.java +++ b/sql/src/main/java/io/druid/sql/calcite/planner/Rules.java @@ -27,6 +27,7 @@ import io.druid.sql.calcite.rule.DruidRules; import io.druid.sql.calcite.rule.DruidSemiJoinRule; import io.druid.sql.calcite.rule.DruidTableScanRule; +import io.druid.sql.calcite.rule.ProjectAggregatePruneUnusedCallRule; import io.druid.sql.calcite.rule.SortCollapseRule; import org.apache.calcite.interpreter.Bindables; import org.apache.calcite.plan.RelOptLattice; @@ -239,6 +240,7 @@ private static List baseRuleSet( rules.add(SortCollapseRule.instance()); rules.add(CaseFilteredAggregatorRule.instance()); + rules.add(ProjectAggregatePruneUnusedCallRule.instance()); // Druid-specific rules. rules.add(new DruidTableScanRule(queryMaker)); diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/ProjectAggregatePruneUnusedCallRule.java b/sql/src/main/java/io/druid/sql/calcite/rule/ProjectAggregatePruneUnusedCallRule.java new file mode 100644 index 000000000000..819d9a0dad8f --- /dev/null +++ b/sql/src/main/java/io/druid/sql/calcite/rule/ProjectAggregatePruneUnusedCallRule.java @@ -0,0 +1,129 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.sql.calcite.rule; + +import io.druid.java.util.common.ISE; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.util.ImmutableBitSet; + +import java.util.ArrayList; +import java.util.List; + +/** + * Rule that prunes unused aggregators after a projection. + */ +public class ProjectAggregatePruneUnusedCallRule extends RelOptRule +{ + private static final ProjectAggregatePruneUnusedCallRule INSTANCE = new ProjectAggregatePruneUnusedCallRule(); + + private ProjectAggregatePruneUnusedCallRule() + { + super(operand(Project.class, operand(Aggregate.class, any()))); + } + + public static ProjectAggregatePruneUnusedCallRule instance() + { + return INSTANCE; + } + + @Override + public boolean matches(final RelOptRuleCall call) + { + final Aggregate aggregate = call.rel(1); + return !aggregate.indicator && aggregate.getGroupSets().size() == 1; + } + + @Override + public void onMatch(final RelOptRuleCall call) + { + final Project project = call.rel(0); + final Aggregate aggregate = call.rel(1); + + final ImmutableBitSet projectBits = RelOptUtil.InputFinder.bits(project.getChildExps(), null); + + final int fieldCount = aggregate.getGroupCount() + aggregate.getAggCallList().size(); + if (fieldCount != aggregate.getRowType().getFieldCount()) { + throw new ISE( + "WTF, expected[%s] to have[%s] fields but it had[%s]", + aggregate, + fieldCount, + aggregate.getRowType().getFieldCount() + ); + } + + final ImmutableBitSet callsToKeep = projectBits.intersect( + ImmutableBitSet.range(aggregate.getGroupCount(), fieldCount) + ); + + if (callsToKeep.cardinality() < aggregate.getAggCallList().size()) { + // There are some aggregate calls to prune. + final List newAggregateCalls = new ArrayList<>(); + + for (int i : callsToKeep) { + newAggregateCalls.add(aggregate.getAggCallList().get(i - aggregate.getGroupCount())); + } + + final Aggregate newAggregate = aggregate.copy( + aggregate.getTraitSet(), + aggregate.getInput(), + aggregate.indicator, + aggregate.getGroupSet(), + aggregate.getGroupSets(), + newAggregateCalls + ); + + // Project that will match the old Aggregate in its row type, so we can layer the original "project" on top. + final List fixUpProjects = new ArrayList<>(); + final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); + + // Project the group unchanged. + for (int i = 0; i < aggregate.getGroupCount(); i++) { + fixUpProjects.add(rexBuilder.makeInputRef(newAggregate, i)); + } + + // Replace pruned-out aggregators with NULLs. + int j = aggregate.getGroupCount(); + for (int i = aggregate.getGroupCount(); i < fieldCount; i++) { + if (callsToKeep.get(i)) { + fixUpProjects.add(rexBuilder.makeInputRef(newAggregate, j++)); + } else { + fixUpProjects.add(rexBuilder.makeNullLiteral(aggregate.getRowType().getFieldList().get(i).getType())); + } + } + + call.transformTo( + call.builder() + .push(newAggregate) + .project(fixUpProjects) + .project(project.getChildExps()) + .build() + ); + + call.getPlanner().setImportance(project, 0.0); + } + } +} diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index d10d81296ee1..d8bf3e1e249a 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -1467,6 +1467,87 @@ public void testTopNWithSelectAndOrderByProjections() throws Exception ); } + @Test + public void testPruneDeadAggregators() throws Exception + { + // Test for ProjectAggregatePruneUnusedCallRule. + + testQuery( + "SELECT\n" + + " CASE 'foo'\n" + + " WHEN 'bar' THEN SUM(cnt)\n" + + " WHEN 'foo' THEN SUM(m1)\n" + + " WHEN 'baz' THEN SUM(m2)\n" + + " END\n" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(QSS(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(AGGS(new DoubleSumAggregatorFactory("a0", "m1"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{21.0}) + ); + } + + @Test + public void testPruneDeadAggregatorsThroughPostProjection() throws Exception + { + // Test for ProjectAggregatePruneUnusedCallRule. + + testQuery( + "SELECT\n" + + " CASE 'foo'\n" + + " WHEN 'bar' THEN SUM(cnt) / 10\n" + + " WHEN 'foo' THEN SUM(m1) / 10\n" + + " WHEN 'baz' THEN SUM(m2) / 10\n" + + " END\n" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(QSS(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(AGGS(new DoubleSumAggregatorFactory("a0", "m1"))) + .postAggregators(ImmutableList.of(EXPRESSION_POST_AGG("p0", "(\"a0\" / 10)"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{2.1}) + ); + } + + @Test + public void testPruneDeadAggregatorsThroughHaving() throws Exception + { + // Test for ProjectAggregatePruneUnusedCallRule. + + testQuery( + "SELECT\n" + + " CASE 'foo'\n" + + " WHEN 'bar' THEN SUM(cnt)\n" + + " WHEN 'foo' THEN SUM(m1)\n" + + " WHEN 'baz' THEN SUM(m2)\n" + + " END AS theCase\n" + + "FROM foo\n" + + "HAVING theCase = 21", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(QSS(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(AGGS(new DoubleSumAggregatorFactory("a0", "m1"))) + .setHavingSpec(HAVING(NUMERIC_SELECTOR("a0", "21", null))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{21.0}) + ); + } + @Test public void testGroupByCaseWhen() throws Exception { @@ -3736,7 +3817,7 @@ public void testExplainDoubleNestedGroupBy() throws Exception { final String explanation = "DruidOuterQueryRel(query=[{\"queryType\":\"timeseries\",\"dataSource\":{\"type\":\"table\",\"name\":\"__subquery__\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"descending\":false,\"virtualColumns\":[],\"filter\":null,\"granularity\":{\"type\":\"all\"},\"aggregations\":[{\"type\":\"longSum\",\"name\":\"a0\",\"fieldName\":\"cnt\",\"expression\":null},{\"type\":\"count\",\"name\":\"a1\"}],\"postAggregations\":[],\"context\":{\"defaultTimeout\":300000,\"maxScatterGatherBytes\":9223372036854775807,\"skipEmptyBuckets\":true,\"sqlCurrentTimestamp\":\"2000-01-01T00:00:00Z\"}}], signature=[{a0:LONG, a1:LONG}])\n" - + " DruidOuterQueryRel(query=[{\"queryType\":\"groupBy\",\"dataSource\":{\"type\":\"table\",\"name\":\"__subquery__\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"filter\":null,\"granularity\":{\"type\":\"all\"},\"dimensions\":[{\"type\":\"default\",\"dimension\":\"dim2\",\"outputName\":\"d0\",\"outputType\":\"STRING\"}],\"aggregations\":[{\"type\":\"longSum\",\"name\":\"a0\",\"fieldName\":\"cnt\",\"expression\":null}],\"postAggregations\":[],\"having\":null,\"limitSpec\":{\"type\":\"NoopLimitSpec\"},\"context\":{\"defaultTimeout\":300000,\"maxScatterGatherBytes\":9223372036854775807,\"sqlCurrentTimestamp\":\"2000-01-01T00:00:00Z\"},\"descending\":false}], signature=[{a0:LONG}])\n" + + " DruidOuterQueryRel(query=[{\"queryType\":\"groupBy\",\"dataSource\":{\"type\":\"table\",\"name\":\"__subquery__\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"filter\":null,\"granularity\":{\"type\":\"all\"},\"dimensions\":[{\"type\":\"default\",\"dimension\":\"dim2\",\"outputName\":\"d0\",\"outputType\":\"STRING\"}],\"aggregations\":[{\"type\":\"longSum\",\"name\":\"a0\",\"fieldName\":\"cnt\",\"expression\":null}],\"postAggregations\":[],\"having\":null,\"limitSpec\":{\"type\":\"NoopLimitSpec\"},\"context\":{\"defaultTimeout\":300000,\"maxScatterGatherBytes\":9223372036854775807,\"sqlCurrentTimestamp\":\"2000-01-01T00:00:00Z\"},\"descending\":false}], signature=[{d0:STRING, a0:LONG}])\n" + " DruidQueryRel(query=[{\"queryType\":\"groupBy\",\"dataSource\":{\"type\":\"table\",\"name\":\"foo\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"filter\":null,\"granularity\":{\"type\":\"all\"},\"dimensions\":[{\"type\":\"default\",\"dimension\":\"dim1\",\"outputName\":\"d0\",\"outputType\":\"STRING\"},{\"type\":\"default\",\"dimension\":\"dim2\",\"outputName\":\"d1\",\"outputType\":\"STRING\"}],\"aggregations\":[{\"type\":\"count\",\"name\":\"a0\"}],\"postAggregations\":[],\"having\":null,\"limitSpec\":{\"type\":\"NoopLimitSpec\"},\"context\":{\"defaultTimeout\":300000,\"maxScatterGatherBytes\":9223372036854775807,\"sqlCurrentTimestamp\":\"2000-01-01T00:00:00Z\"},\"descending\":false}], signature=[{d0:STRING, d1:STRING, a0:LONG}])\n"; testQuery(