From 499c5e2b03a37ab3dd87289e62537e8e71da0f24 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 7 Mar 2023 00:47:19 -0800 Subject: [PATCH 1/3] use Calcites.getColumnTypeForRelDataType for SQL CAST operator conversion --- .../builtin/CastOperatorConversion.java | 67 ++++----------- .../druid/sql/calcite/planner/Calcites.java | 4 +- .../CalciteMultiValueStringQueryTest.java | 86 +++++++++++++++++++ .../calcite/expression/ExpressionsTest.java | 6 +- .../expression/GreatestExpressionTest.java | 8 +- .../expression/LeastExpressionTest.java | 8 +- 6 files changed, 114 insertions(+), 65 deletions(-) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java index 9062a32d0baf..e1c0f5964d1b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java @@ -20,7 +20,6 @@ package org.apache.druid.sql.calcite.expression.builtin; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlOperator; @@ -29,7 +28,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.PeriodGranularity; -import org.apache.druid.math.expr.ExprType; +import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; @@ -39,46 +38,10 @@ import org.apache.druid.sql.calcite.planner.PlannerContext; import org.joda.time.Period; -import java.util.Map; import java.util.function.Function; public class CastOperatorConversion implements SqlOperatorConversion { - private static final Map EXPRESSION_TYPES; - - static { - final ImmutableMap.Builder builder = ImmutableMap.builder(); - - for (SqlTypeName type : SqlTypeName.FRACTIONAL_TYPES) { - builder.put(type, ExprType.DOUBLE); - } - - for (SqlTypeName type : SqlTypeName.INT_TYPES) { - builder.put(type, ExprType.LONG); - } - - for (SqlTypeName type : SqlTypeName.STRING_TYPES) { - builder.put(type, ExprType.STRING); - } - - // Booleans are treated as longs in Druid expressions, using two-value logic (positive = true, nonpositive = false). - builder.put(SqlTypeName.BOOLEAN, ExprType.LONG); - - // Timestamps are treated as longs (millis since the epoch) in Druid expressions. - builder.put(SqlTypeName.TIMESTAMP, ExprType.LONG); - builder.put(SqlTypeName.DATE, ExprType.LONG); - - for (SqlTypeName type : SqlTypeName.DAY_INTERVAL_TYPES) { - builder.put(type, ExprType.LONG); - } - - for (SqlTypeName type : SqlTypeName.YEAR_INTERVAL_TYPES) { - builder.put(type, ExprType.LONG); - } - - EXPRESSION_TYPES = builder.build(); - } - @Override public SqlOperator calciteOperator() { @@ -118,28 +81,32 @@ public DruidExpression toDruidExpression( } else { // Handle other casts. If either type is ANY, use the other type instead. If both are ANY, this means nulls // downstream, Druid will try its best - final ExprType fromExprType = SqlTypeName.ANY.equals(fromType) - ? EXPRESSION_TYPES.get(toType) - : EXPRESSION_TYPES.get(fromType); - final ExprType toExprType = SqlTypeName.ANY.equals(toType) - ? EXPRESSION_TYPES.get(fromType) - : EXPRESSION_TYPES.get(toType); - - if (fromExprType == null || toExprType == null) { + final ColumnType fromDruidType = Calcites.getColumnTypeForRelDataType(operand.getType()); + final ColumnType toDruidType = Calcites.getColumnTypeForRelDataType(rexNode.getType()); + + final ExpressionType fromExpressionType = SqlTypeName.ANY.equals(fromType) + ? ExpressionType.fromColumnType(toDruidType) + : ExpressionType.fromColumnType(fromDruidType); + final ExpressionType toExpressionType = SqlTypeName.ANY.equals(toType) + ? ExpressionType.fromColumnType(fromDruidType) + : ExpressionType.fromColumnType(toDruidType); + + if (fromExpressionType == null || toExpressionType == null) { // We have no runtime type for these SQL types. return null; } final DruidExpression typeCastExpression; - if (fromExprType != toExprType) { + if (fromExpressionType.equals(toExpressionType)) { + // Ignore casts for simple extractions since it is ok in many cases. + typeCastExpression = operandExpression; + } else { // Ignore casts for simple extractions (use Function.identity) since it is ok in many cases. typeCastExpression = operandExpression.map( Function.identity(), - expression -> StringUtils.format("CAST(%s, '%s')", expression, toExprType.toString()) + expression -> StringUtils.format("CAST(%s, '%s')", expression, toExpressionType.asTypeString()) ); - } else { - typeCastExpression = operandExpression; } if (toType == SqlTypeName.DATE) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index aa95fee28716..157c51b8059f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -188,7 +188,9 @@ public static boolean isLongType(SqlTypeName sqlTypeName) return SqlTypeName.TIMESTAMP == sqlTypeName || SqlTypeName.DATE == sqlTypeName || SqlTypeName.BOOLEAN == sqlTypeName || - SqlTypeName.INT_TYPES.contains(sqlTypeName); + SqlTypeName.INT_TYPES.contains(sqlTypeName) || + SqlTypeName.DAY_INTERVAL_TYPES.contains(sqlTypeName) || + SqlTypeName.YEAR_INTERVAL_TYPES.contains(sqlTypeName); } public static StringComparator getStringComparatorForRelDataType(RelDataType dataType) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java index 534f334b8345..449a6ad75505 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java @@ -29,8 +29,10 @@ import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.filter.AndDimFilter; +import org.apache.druid.query.filter.ExpressionDimFilter; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; +import org.apache.druid.query.filter.OrDimFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQueryConfig; @@ -39,6 +41,7 @@ import org.apache.druid.query.ordering.StringComparators; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.segment.virtual.ListFilteredVirtualColumn; import org.apache.druid.sql.SqlPlanningException; import org.apache.druid.sql.calcite.filtration.Filtration; @@ -1847,4 +1850,87 @@ public void testMultiValueToArrayArgsWithArray() exception -> exception.expect(RuntimeException.class) ); } + + @Test + public void testMultiValueStringOverlapFilterCoalesceNvl() + { + testQuery( + "SELECT COALESCE(dim3, 'other') FROM druid.numfoo " + + "WHERE MV_OVERLAP(COALESCE(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) OR " + + "MV_OVERLAP(NVL(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .eternityInterval() + .virtualColumns( + new ExpressionVirtualColumn( + "v0", + "case_searched(notnull(\"dim3\"),\"dim3\",'other')", + ColumnType.STRING, + queryFramework().macroTable() + ) + ) + .filters( + new OrDimFilter( + new ExpressionDimFilter( + "case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)", + null, + queryFramework().macroTable() + ), + new ExpressionDimFilter( + "case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)", + null, + queryFramework().macroTable() + ) + ) + ) + .columns("v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.replaceWithDefault() + ? ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"b\",\"c\"]"}, + new Object[]{"other"}, + new Object[]{"other"}, + new Object[]{"other"} + ) + : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"b\",\"c\"]"}, + new Object[]{"other"}, + new Object[]{"other"} + ) + ); + } + + @Test + public void testMultiValueStringOverlapFilterInconsistentUsage() + { + testQueryThrows( + "SELECT COALESCE(dim3, 'other') FROM druid.numfoo " + + "WHERE MV_OVERLAP(COALESCE(dim3, ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5", + e -> { + e.expect(SqlPlanningException.class); + e.expectMessage("Illegal mixing of types in CASE or COALESCE statement"); + } + + ); + } + + @Test + public void testMultiValueStringOverlapFilterInconsistentUsage2() + { + testQueryThrows( + "SELECT COALESCE(dim3, 'other') FROM druid.numfoo " + + "WHERE MV_OVERLAP(COALESCE(dim3, 'other'), ARRAY['a', 'b', 'other']) LIMIT 5", + e -> { + e.expect(RuntimeException.class); + e.expectMessage("Invalid expression: (case_searched [(notnull [dim3]), (array_overlap [dim3, [a, b, other]]), 1]); [dim3] used as both scalar and array variables"); + } + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java index 91ab2a839f2c..4d7cad0882c3 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java @@ -1751,8 +1751,7 @@ public void testTimeMinusDayTimeInterval() (args) -> "(" + args.get(0).getExpression() + " - " + args.get(1).getExpression() + ")", ImmutableList.of( DruidExpression.ofColumn(ColumnType.LONG, "t"), - // RexNode type of "interval day to minute" is not converted to druid long... yet - DruidExpression.ofLiteral(null, "90060000") + DruidExpression.ofLiteral(ColumnType.LONG, "90060000") ) ), DateTimes.of("2000-02-03T04:05:06").minus(period).getMillis() @@ -1779,8 +1778,7 @@ public void testTimeMinusYearMonthInterval() DruidExpression.functionCall("timestamp_shift"), ImmutableList.of( DruidExpression.ofColumn(ColumnType.LONG, "t"), - // RexNode type "interval year to month" is not reported as ColumnType.STRING - DruidExpression.ofLiteral(null, DruidExpression.stringLiteral("P13M")), + DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.stringLiteral("P13M")), DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.longLiteral(-1)), DruidExpression.ofStringLiteral("UTC") ) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java index 87ce28ef3f6c..8418edf994e8 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java @@ -246,10 +246,8 @@ public void testTimestamp() } @Test - public void testInvalidType() + public void testIntervalYearMonth() { - expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH"); - testExpression( Collections.singletonList( testHelper.makeLiteral( @@ -257,8 +255,8 @@ public void testInvalidType() new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO) ) ), - null, - null + buildExpectedExpression(13), + 13L ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java index 047f6936d307..6702769e927a 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java @@ -247,10 +247,8 @@ public void testTimestamp() } @Test - public void testInvalidType() + public void testIntervalYearMonth() { - expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH"); - testExpression( Collections.singletonList( testHelper.makeLiteral( @@ -258,8 +256,8 @@ public void testInvalidType() new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO) ) ), - null, - null + buildExpectedExpression(13), + 13L ); } From a338dce70d2f6f47c8b0fdeb14e33a92b2aa64ed Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 7 Mar 2023 01:03:09 -0800 Subject: [PATCH 2/3] fix comment --- .../sql/calcite/expression/builtin/CastOperatorConversion.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java index e1c0f5964d1b..cd68c94951a9 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java @@ -99,7 +99,6 @@ public DruidExpression toDruidExpression( final DruidExpression typeCastExpression; if (fromExpressionType.equals(toExpressionType)) { - // Ignore casts for simple extractions since it is ok in many cases. typeCastExpression = operandExpression; } else { // Ignore casts for simple extractions (use Function.identity) since it is ok in many cases. From bd63a89ecfbab9e275c135948f01ea57e61c152b Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 7 Mar 2023 06:16:17 -0800 Subject: [PATCH 3/3] intervals are strings but also longs --- .../expression/builtin/CastOperatorConversion.java | 4 ++++ .../builtin/ReductionOperatorConversionHelper.java | 13 ++++++++++--- .../apache/druid/sql/calcite/planner/Calcites.java | 12 ++++++++---- .../sql/calcite/expression/ExpressionsTest.java | 4 ++-- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java index cd68c94951a9..7f8c1ddee887 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java @@ -28,6 +28,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.PeriodGranularity; +import org.apache.druid.math.expr.ExprType; import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; @@ -100,6 +101,9 @@ public DruidExpression toDruidExpression( if (fromExpressionType.equals(toExpressionType)) { typeCastExpression = operandExpression; + } else if (SqlTypeName.INTERVAL_TYPES.contains(fromType) && toExpressionType.is(ExprType.LONG)) { + // intervals can be longs without an explicit cast + typeCastExpression = operandExpression; } else { // Ignore casts for simple extractions (use Function.identity) since it is ok in many cases. typeCastExpression = operandExpression.map( diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java index a747e9d27dab..427c93a28782 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java @@ -56,9 +56,16 @@ private ReductionOperatorConversionHelper() boolean hasDouble = false; boolean isString = false; for (int i = 0; i < n; i++) { - RelDataType type = opBinding.getOperandType(i); - SqlTypeName sqlTypeName = type.getSqlTypeName(); - ColumnType valueType = Calcites.getColumnTypeForRelDataType(type); + final RelDataType type = opBinding.getOperandType(i); + final SqlTypeName sqlTypeName = type.getSqlTypeName(); + final ColumnType valueType; + + if (SqlTypeName.INTERVAL_TYPES.contains(type.getSqlTypeName())) { + // handle intervals as a LONG type even though it is a string + valueType = ColumnType.LONG; + } else { + valueType = Calcites.getColumnTypeForRelDataType(type); + } // Return types are listed in order of preference: if (valueType != null) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index 157c51b8059f..331a61a1f50b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -160,7 +160,7 @@ public static ColumnType getValueTypeForRelDataTypeFull(final RelDataType type) return ColumnType.DOUBLE; } else if (isLongType(sqlTypeName)) { return ColumnType.LONG; - } else if (SqlTypeName.CHAR_TYPES.contains(sqlTypeName)) { + } else if (isStringType(sqlTypeName)) { return ColumnType.STRING; } else if (SqlTypeName.OTHER == sqlTypeName) { if (type instanceof RowSignatures.ComplexSqlType) { @@ -178,6 +178,12 @@ public static ColumnType getValueTypeForRelDataTypeFull(final RelDataType type) } } + public static boolean isStringType(SqlTypeName sqlTypeName) + { + return SqlTypeName.CHAR_TYPES.contains(sqlTypeName) || + SqlTypeName.INTERVAL_TYPES.contains(sqlTypeName); + } + public static boolean isDoubleType(SqlTypeName sqlTypeName) { return SqlTypeName.FRACTIONAL_TYPES.contains(sqlTypeName) || SqlTypeName.APPROX_TYPES.contains(sqlTypeName); @@ -188,9 +194,7 @@ public static boolean isLongType(SqlTypeName sqlTypeName) return SqlTypeName.TIMESTAMP == sqlTypeName || SqlTypeName.DATE == sqlTypeName || SqlTypeName.BOOLEAN == sqlTypeName || - SqlTypeName.INT_TYPES.contains(sqlTypeName) || - SqlTypeName.DAY_INTERVAL_TYPES.contains(sqlTypeName) || - SqlTypeName.YEAR_INTERVAL_TYPES.contains(sqlTypeName); + SqlTypeName.INT_TYPES.contains(sqlTypeName); } public static StringComparator getStringComparatorForRelDataType(RelDataType dataType) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java index 4d7cad0882c3..f7fec59032f1 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java @@ -1751,7 +1751,7 @@ public void testTimeMinusDayTimeInterval() (args) -> "(" + args.get(0).getExpression() + " - " + args.get(1).getExpression() + ")", ImmutableList.of( DruidExpression.ofColumn(ColumnType.LONG, "t"), - DruidExpression.ofLiteral(ColumnType.LONG, "90060000") + DruidExpression.ofLiteral(ColumnType.STRING, "90060000") ) ), DateTimes.of("2000-02-03T04:05:06").minus(period).getMillis() @@ -1778,7 +1778,7 @@ public void testTimeMinusYearMonthInterval() DruidExpression.functionCall("timestamp_shift"), ImmutableList.of( DruidExpression.ofColumn(ColumnType.LONG, "t"), - DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.stringLiteral("P13M")), + DruidExpression.ofLiteral(ColumnType.STRING, DruidExpression.stringLiteral("P13M")), DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.longLiteral(-1)), DruidExpression.ofStringLiteral("UTC") )