From 7038855cd12add60bc477a9cad0e92dab7cce00c Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Fri, 24 Mar 2017 00:21:49 -0700 Subject: [PATCH 1/2] SQL: Support for another form of filtered aggregator. --- .../io/druid/sql/calcite/planner/Calcites.java | 15 +++++++++++++++ .../io/druid/sql/calcite/rule/GroupByRules.java | 17 +++++++++-------- .../io/druid/sql/calcite/CalciteQueryTest.java | 9 +++++++-- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/sql/src/main/java/io/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/io/druid/sql/calcite/planner/Calcites.java index 6f33cb6c6e5d..b5bda9870ea9 100644 --- a/sql/src/main/java/io/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/io/druid/sql/calcite/planner/Calcites.java @@ -25,6 +25,8 @@ import io.druid.sql.calcite.schema.DruidSchema; import io.druid.sql.calcite.schema.InformationSchema; import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.sql.type.SqlTypeName; @@ -164,4 +166,17 @@ public static DateTime calciteDateToJoda(final int date, final DateTimeZone time { return new DateTime(0L, DateTimeZone.UTC).plusDays(date).withZoneRetainFields(timeZone); } + + /** + * Checks if a RexNode is a literal int or not. If this returns true, then {@code RexLiteral.intValue(literal)} can be + * used to get the value of the literal. + * + * @param rexNode the node + * + * @return true if this is an int + */ + public static boolean isIntLiteral(final RexNode rexNode) + { + return rexNode instanceof RexLiteral && SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName()); + } } diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java index 05acca45fafe..fbeb3049fda7 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java @@ -807,8 +807,9 @@ private static Aggregation translateAggregateCall( input = foe; } else if (rexNode.getKind() == SqlKind.CASE && ((RexCall) rexNode).getOperands().size() == 3) { // Possibly a CASE-style filtered aggregation. Styles supported: - // A: SUM(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null) - // B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0) + // A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null) + // A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM + // B: AGG(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0) // C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null) // If the null and non-null args are switched, "flip" is set, which negates the filter. @@ -839,15 +840,15 @@ private static Aggregation translateAggregateCall( forceCount = true; input = null; } else if (call.getAggregation().getKind() == SqlKind.SUM - && arg1 instanceof RexLiteral - && ((Number) RexLiteral.value(arg1)).intValue() == 1 - && arg2 instanceof RexLiteral - && ((Number) RexLiteral.value(arg2)).intValue() == 0) { + && Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1 + && Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) { // Case B forceCount = true; input = null; - } else if (RexLiteral.isNullLiteral(arg2)) { - // Maybe case A + } else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */ + || (kind == SqlKind.SUM + && Calcites.isIntLiteral(arg2) + && RexLiteral.intValue(arg2) == 0) /* Case A2 */) { input = FieldOrExpression.fromRexNode(operatorTable, plannerContext, rowOrder, arg1); if (input == null) { return null; diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index b441d618bd72..85e19198aa0c 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -1430,7 +1430,8 @@ public void testFilteredAggregations() throws Exception + "COUNT(CASE WHEN dim1 <> '1' THEN 'dummy' END), " + "SUM(CASE WHEN dim1 <> '1' THEN 1 ELSE 0 END), " + "SUM(cnt) filter(WHERE dim2 = 'a'), " - + "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a') " + + "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), " + + "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END) " + "FROM druid.foo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -1472,13 +1473,17 @@ public void testFilteredAggregations() throws Exception SELECTOR("dim2", "a", null), NOT(SELECTOR("dim1", "1", null)) ) + ), + new FilteredAggregatorFactory( + new LongSumAggregatorFactory("a8", "cnt"), + NOT(SELECTOR("dim1", "1", null)) ) )) .context(TIMESERIES_CONTEXT_DEFAULT) .build() ), ImmutableList.of( - new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L} + new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L, 5L} ) ); } From ed7cca62d2731acde2f976d62c34b7e10e0d24ff Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Fri, 24 Mar 2017 01:10:00 -0700 Subject: [PATCH 2/2] Fix comment, add test for MAX too. --- .../java/io/druid/sql/calcite/rule/GroupByRules.java | 2 +- .../test/java/io/druid/sql/calcite/CalciteQueryTest.java | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java index fbeb3049fda7..c3455bf5252a 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java +++ b/sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java @@ -809,7 +809,7 @@ private static Aggregation translateAggregateCall( // Possibly a CASE-style filtered aggregation. Styles supported: // A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null) // A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM - // B: AGG(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0) + // B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0) // C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null) // If the null and non-null args are switched, "flip" is set, which negates the filter. diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index 85e19198aa0c..dafdcdb8c5f2 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -1431,7 +1431,8 @@ public void testFilteredAggregations() throws Exception + "SUM(CASE WHEN dim1 <> '1' THEN 1 ELSE 0 END), " + "SUM(cnt) filter(WHERE dim2 = 'a'), " + "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), " - + "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END) " + + "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), " + + "MAX(CASE WHEN dim1 <> '1' THEN cnt END) " + "FROM druid.foo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -1477,13 +1478,17 @@ public void testFilteredAggregations() throws Exception new FilteredAggregatorFactory( new LongSumAggregatorFactory("a8", "cnt"), NOT(SELECTOR("dim1", "1", null)) + ), + new FilteredAggregatorFactory( + new LongMaxAggregatorFactory("a9", "cnt"), + NOT(SELECTOR("dim1", "1", null)) ) )) .context(TIMESERIES_CONTEXT_DEFAULT) .build() ), ImmutableList.of( - new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L, 5L} + new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L, 5L, 1L} ) ); }