Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions sql/src/main/java/io/druid/sql/calcite/planner/Calcites.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import io.druid.sql.calcite.schema.DruidSchema;
import io.druid.sql.calcite.schema.InformationSchema;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.type.SqlTypeName;
Expand Down Expand Up @@ -164,4 +166,17 @@ public static DateTime calciteDateToJoda(final int date, final DateTimeZone time
{
return new DateTime(0L, DateTimeZone.UTC).plusDays(date).withZoneRetainFields(timeZone);
}

/**
* Checks if a RexNode is a literal int or not. If this returns true, then {@code RexLiteral.intValue(literal)} can be
* used to get the value of the literal.
*
* @param rexNode the node
*
* @return true if this is an int
*/
public static boolean isIntLiteral(final RexNode rexNode)
{
return rexNode instanceof RexLiteral && SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName());
}
}
15 changes: 8 additions & 7 deletions sql/src/main/java/io/druid/sql/calcite/rule/GroupByRules.java
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,8 @@ private static Aggregation translateAggregateCall(
input = foe;
} else if (rexNode.getKind() == SqlKind.CASE && ((RexCall) rexNode).getOperands().size() == 3) {
// Possibly a CASE-style filtered aggregation. Styles supported:
// A: SUM(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
// A1: AGG(CASE WHEN x = 'foo' THEN cnt END) => operands (x = 'foo', cnt, null)
// A2: SUM(CASE WHEN x = 'foo' THEN cnt ELSE 0 END) => operands (x = 'foo', cnt, 0); must be SUM
// B: SUM(CASE WHEN x = 'foo' THEN 1 ELSE 0 END) => operands (x = 'foo', 1, 0)
// C: COUNT(CASE WHEN x = 'foo' THEN 'dummy' END) => operands (x = 'foo', 'dummy', null)
// If the null and non-null args are switched, "flip" is set, which negates the filter.
Expand Down Expand Up @@ -839,15 +840,15 @@ private static Aggregation translateAggregateCall(
forceCount = true;
input = null;
} else if (call.getAggregation().getKind() == SqlKind.SUM
&& arg1 instanceof RexLiteral
&& ((Number) RexLiteral.value(arg1)).intValue() == 1
&& arg2 instanceof RexLiteral
&& ((Number) RexLiteral.value(arg2)).intValue() == 0) {
&& Calcites.isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
&& Calcites.isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
// Case B
forceCount = true;
input = null;
} else if (RexLiteral.isNullLiteral(arg2)) {
// Maybe case A
} else if (RexLiteral.isNullLiteral(arg2) /* Case A1 */
|| (kind == SqlKind.SUM
&& Calcites.isIntLiteral(arg2)
&& RexLiteral.intValue(arg2) == 0) /* Case A2 */) {
input = FieldOrExpression.fromRexNode(operatorTable, plannerContext, rowOrder, arg1);
if (input == null) {
return null;
Expand Down
14 changes: 12 additions & 2 deletions sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1430,7 +1430,9 @@ public void testFilteredAggregations() throws Exception
+ "COUNT(CASE WHEN dim1 <> '1' THEN 'dummy' END), "
+ "SUM(CASE WHEN dim1 <> '1' THEN 1 ELSE 0 END), "
+ "SUM(cnt) filter(WHERE dim2 = 'a'), "
+ "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a') "
+ "SUM(case when dim1 <> '1' then cnt end) filter(WHERE dim2 = 'a'), "
+ "SUM(CASE WHEN dim1 <> '1' THEN cnt ELSE 0 END), "
+ "MAX(CASE WHEN dim1 <> '1' THEN cnt END) "
+ "FROM druid.foo",
ImmutableList.<Query>of(
Druids.newTimeseriesQueryBuilder()
Expand Down Expand Up @@ -1472,13 +1474,21 @@ public void testFilteredAggregations() throws Exception
SELECTOR("dim2", "a", null),
NOT(SELECTOR("dim1", "1", null))
)
),
new FilteredAggregatorFactory(
new LongSumAggregatorFactory("a8", "cnt"),
NOT(SELECTOR("dim1", "1", null))
),
new FilteredAggregatorFactory(
new LongMaxAggregatorFactory("a9", "cnt"),
NOT(SELECTOR("dim1", "1", null))
)
))
.context(TIMESERIES_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L}
new Object[]{1L, 5L, 1L, 5L, 5L, 5, 2L, 1L, 5L, 1L}
)
);
}
Expand Down