From 37b088d767c4cc8513ad45c30062951fae3cea1c Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Sun, 5 Mar 2023 22:25:21 -0800 Subject: [PATCH] relax native expression multi-value string usage validation for conditional and null coalescing functions --- .../java/org/apache/druid/math/expr/Expr.java | 10 +-- .../org/apache/druid/math/expr/Function.java | 63 ++++++++++---- .../segment/virtual/ExpressionPlanner.java | 1 + .../apache/druid/math/expr/ParserTest.java | 7 +- .../builtin/CastOperatorConversion.java | 69 +++++---------- .../druid/sql/calcite/planner/Calcites.java | 4 +- .../CalciteMultiValueStringQueryTest.java | 85 +++++++++++++++++++ .../calcite/expression/ExpressionsTest.java | 6 +- .../expression/GreatestExpressionTest.java | 8 +- .../expression/LeastExpressionTest.java | 8 +- 10 files changed, 175 insertions(+), 86 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/math/expr/Expr.java b/processing/src/main/java/org/apache/druid/math/expr/Expr.java index 9844a0526bf1..53654c65953c 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Expr.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Expr.java @@ -224,7 +224,7 @@ default boolean areNumeric(List args) if (argType == null) { continue; } - numeric &= argType.isNumeric(); + numeric = numeric && argType.isNumeric(); } return numeric; } @@ -265,7 +265,7 @@ default boolean areSameTypes(List args) if (currentType == null) { currentType = argType; } - allSame &= Objects.equals(argType, currentType); + allSame = allSame && Objects.equals(argType, currentType); } return allSame; } @@ -302,7 +302,7 @@ default boolean areScalar(List args) if (argType == null) { continue; } - scalar &= argType.isPrimitive(); + scalar = scalar && argType.isPrimitive(); } return scalar; } @@ -330,7 +330,7 @@ default boolean canVectorize(List args) { boolean canVectorize = true; for (Expr arg : args) { - canVectorize &= arg.canVectorize(this); + canVectorize = canVectorize && arg.canVectorize(this); } return canVectorize; } @@ -498,7 +498,7 @@ public Set getRequiredBindings() /** * Set of {@link IdentifierExpr#binding} which are used as scalar inputs to operators and functions. */ - Set getScalarBindings() + public Set getScalarBindings() { return map(scalarVariables, IdentifierExpr::getBindingIfIdentifier); } diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index 70cd0f8f278a..0801f248067c 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -1962,14 +1962,11 @@ public Set getScalarInputs(List args) ExpressionType castTo = ExpressionType.fromString( StringUtils.toUpperCase(args.get(1).getLiteralValue().toString()) ); - switch (castTo.getType()) { - case ARRAY: - return Collections.emptySet(); - default: - return ImmutableSet.of(args.get(0)); + if (!castTo.getType().isArray()) { + return ImmutableSet.of(args.get(0)); } } - // unknown cast, can't safely assume either way + // either has array inputs or unknown inputs return Collections.emptySet(); } @@ -1980,16 +1977,11 @@ public Set getArrayInputs(List args) ExpressionType castTo = ExpressionType.fromString( StringUtils.toUpperCase(args.get(1).getLiteralValue().toString()) ); - switch (castTo.getType()) { - case LONG: - case DOUBLE: - case STRING: - return Collections.emptySet(); - default: - return ImmutableSet.of(args.get(0)); + if (castTo.getType().isArray()) { + return ImmutableSet.of(args.get(0)); } } - // unknown cast, can't safely assume either way + // not an array, or unknown input types return Collections.emptySet(); } @@ -2087,6 +2079,13 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List getScalarInputs(List args) + { + // could potentially look for constants in the return positions and examine type... + return Collections.emptySet(); + } } /** @@ -2134,6 +2133,13 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List getScalarInputs(List args) + { + // could potentially look for constants in the return positions and examine type... + return Collections.emptySet(); + } } /** @@ -2181,6 +2187,13 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List getScalarInputs(List args) + { + // could potentially look for constants in the return positions and examine type... + return Collections.emptySet(); + } } class NvlFunc implements Function @@ -2222,6 +2235,13 @@ public ExprVectorProcessor asVectorProcessor(Expr.VectorInputBindingInspe { return VectorProcessors.nvl(inspector, args.get(0), args.get(1)); } + + @Override + public Set getScalarInputs(List args) + { + // output is same as input, doesn't matter the type + return Collections.emptySet(); + } } class IsNullFunc implements Function @@ -2263,6 +2283,13 @@ public ExprVectorProcessor asVectorProcessor(Expr.VectorInputBindingInspe { return VectorProcessors.isNull(inspector, args.get(0)); } + + @Override + public Set getScalarInputs(List args) + { + // null or not, doesnt matter if the inputs are arrays or scalars + return Collections.emptySet(); + } } class IsNotNullFunc implements Function @@ -2293,7 +2320,6 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) { @@ -2305,6 +2331,13 @@ public ExprVectorProcessor asVectorProcessor(Expr.VectorInputBindingInspe { return VectorProcessors.isNotNull(inspector, args.get(0)); } + + @Override + public Set getScalarInputs(List args) + { + // null or not, doesnt matter if the inputs are arrays or scalars + return Collections.emptySet(); + } } class ConcatFunc implements Function diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlanner.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlanner.java index bd77c6ca32e8..bfddcec6c3a2 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlanner.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionPlanner.java @@ -142,6 +142,7 @@ public static ExpressionPlan plan(ColumnInspector inspector, Expr expression) c -> !definitelyArray.contains(c) && definitelyMultiValued.contains(c) && !analysis.getArrayBindings().contains(c) + && analysis.getScalarBindings().contains(c) ) .collect(Collectors.toList()); diff --git a/processing/src/test/java/org/apache/druid/math/expr/ParserTest.java b/processing/src/test/java/org/apache/druid/math/expr/ParserTest.java index b5344aa129b1..aabe1b75f7a8 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/ParserTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/ParserTest.java @@ -470,7 +470,12 @@ public void testLiteralArraysExplicitDoubleParseException() public void testFunctions() { validateParser("sqrt(x)", "(sqrt [x])", ImmutableList.of("x")); - validateParser("if(cond,then,else)", "(if [cond, then, else])", ImmutableList.of("cond", "else", "then")); + validateParser("if(cond,then,else)", "(if [cond, then, else])", ImmutableList.of("cond", "else", "then"), Collections.emptySet(), Collections.emptySet()); + validateParser("case_simple(cond,then,else)", "(case_simple [cond, then, else])", ImmutableList.of("cond", "else", "then"), Collections.emptySet(), Collections.emptySet()); + validateParser("case_searched(cond,then,else)", "(case_searched [cond, then, else])", ImmutableList.of("cond", "else", "then"), Collections.emptySet(), Collections.emptySet()); + validateParser("nvl(x, fallback)", "(nvl [x, fallback])", ImmutableList.of("x", "fallback"), Collections.emptySet(), Collections.emptySet()); + validateParser("nvl(x, 1)", "(nvl [x, 1])", ImmutableList.of("x"), ImmutableSet.of(), Collections.emptySet()); + validateParser("nvl(x, [1,2,3])", "(nvl [x, [1, 2, 3]])", ImmutableList.of("x"), Collections.emptySet(), ImmutableSet.of()); validateParser("cast(x, 'STRING')", "(cast [x, STRING])", ImmutableList.of("x")); validateParser("cast(x, 'LONG')", "(cast [x, LONG])", ImmutableList.of("x")); validateParser("cast(x, 'DOUBLE')", "(cast [x, DOUBLE])", ImmutableList.of("x")); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java index 9062a32d0baf..c22c216b1312 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/CastOperatorConversion.java @@ -20,7 +20,6 @@ package org.apache.druid.sql.calcite.expression.builtin; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlOperator; @@ -29,7 +28,7 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.PeriodGranularity; -import org.apache.druid.math.expr.ExprType; +import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; @@ -39,45 +38,10 @@ import org.apache.druid.sql.calcite.planner.PlannerContext; import org.joda.time.Period; -import java.util.Map; import java.util.function.Function; public class CastOperatorConversion implements SqlOperatorConversion { - private static final Map EXPRESSION_TYPES; - - static { - final ImmutableMap.Builder builder = ImmutableMap.builder(); - - for (SqlTypeName type : SqlTypeName.FRACTIONAL_TYPES) { - builder.put(type, ExprType.DOUBLE); - } - - for (SqlTypeName type : SqlTypeName.INT_TYPES) { - builder.put(type, ExprType.LONG); - } - - for (SqlTypeName type : SqlTypeName.STRING_TYPES) { - builder.put(type, ExprType.STRING); - } - - // Booleans are treated as longs in Druid expressions, using two-value logic (positive = true, nonpositive = false). - builder.put(SqlTypeName.BOOLEAN, ExprType.LONG); - - // Timestamps are treated as longs (millis since the epoch) in Druid expressions. - builder.put(SqlTypeName.TIMESTAMP, ExprType.LONG); - builder.put(SqlTypeName.DATE, ExprType.LONG); - - for (SqlTypeName type : SqlTypeName.DAY_INTERVAL_TYPES) { - builder.put(type, ExprType.LONG); - } - - for (SqlTypeName type : SqlTypeName.YEAR_INTERVAL_TYPES) { - builder.put(type, ExprType.LONG); - } - - EXPRESSION_TYPES = builder.build(); - } @Override public SqlOperator calciteOperator() @@ -103,6 +67,7 @@ public DruidExpression toDruidExpression( return null; } + final SqlTypeName fromType = operand.getType().getSqlTypeName(); final SqlTypeName toType = rexNode.getType().getSqlTypeName(); @@ -118,28 +83,32 @@ public DruidExpression toDruidExpression( } else { // Handle other casts. If either type is ANY, use the other type instead. If both are ANY, this means nulls // downstream, Druid will try its best - final ExprType fromExprType = SqlTypeName.ANY.equals(fromType) - ? EXPRESSION_TYPES.get(toType) - : EXPRESSION_TYPES.get(fromType); - final ExprType toExprType = SqlTypeName.ANY.equals(toType) - ? EXPRESSION_TYPES.get(fromType) - : EXPRESSION_TYPES.get(toType); - - if (fromExprType == null || toExprType == null) { + + final ColumnType fromDruidType = Calcites.getColumnTypeForRelDataType(operand.getType()); + final ColumnType toDruidType = Calcites.getColumnTypeForRelDataType(rexNode.getType()); + + final ExpressionType fromExpressionType = SqlTypeName.ANY.equals(fromType) + ? ExpressionType.fromColumnType(toDruidType) + : ExpressionType.fromColumnType(fromDruidType); + final ExpressionType toExpressionType = SqlTypeName.ANY.equals(toType) + ? ExpressionType.fromColumnType(fromDruidType) + : ExpressionType.fromColumnType(toDruidType); + + if (fromExpressionType == null || toExpressionType == null) { // We have no runtime type for these SQL types. return null; } final DruidExpression typeCastExpression; - if (fromExprType != toExprType) { - // Ignore casts for simple extractions (use Function.identity) since it is ok in many cases. + if (fromExpressionType.equals(toExpressionType)) { + // Ignore casts for simple extractions since it is ok in many cases. + typeCastExpression = operandExpression; + } else { typeCastExpression = operandExpression.map( Function.identity(), - expression -> StringUtils.format("CAST(%s, '%s')", expression, toExprType.toString()) + expression -> StringUtils.format("CAST(%s, '%s')", expression, toExpressionType.asTypeString()) ); - } else { - typeCastExpression = operandExpression; } if (toType == SqlTypeName.DATE) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index aa95fee28716..157c51b8059f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -188,7 +188,9 @@ public static boolean isLongType(SqlTypeName sqlTypeName) return SqlTypeName.TIMESTAMP == sqlTypeName || SqlTypeName.DATE == sqlTypeName || SqlTypeName.BOOLEAN == sqlTypeName || - SqlTypeName.INT_TYPES.contains(sqlTypeName); + SqlTypeName.INT_TYPES.contains(sqlTypeName) || + SqlTypeName.DAY_INTERVAL_TYPES.contains(sqlTypeName) || + SqlTypeName.YEAR_INTERVAL_TYPES.contains(sqlTypeName); } public static StringComparator getStringComparatorForRelDataType(RelDataType dataType) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java index 534f334b8345..b5f5b561b86f 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java @@ -29,8 +29,10 @@ import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.filter.AndDimFilter; +import org.apache.druid.query.filter.ExpressionDimFilter; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; +import org.apache.druid.query.filter.OrDimFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQueryConfig; @@ -39,6 +41,7 @@ import org.apache.druid.query.ordering.StringComparators; import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.segment.virtual.ListFilteredVirtualColumn; import org.apache.druid.sql.SqlPlanningException; import org.apache.druid.sql.calcite.filtration.Filtration; @@ -279,6 +282,88 @@ public void testMultiValueStringOverlapFilter() ); } + @Test + public void testMultiValueStringOverlapFilterCoalesceNvl() + { + testQuery( + "SELECT COALESCE(dim3, 'other') FROM druid.numfoo " + + "WHERE MV_OVERLAP(COALESCE(dim3, 'other'), ARRAY['a', 'b', 'other']) OR " + + "MV_OVERLAP(NVL(dim3, 'other'), ARRAY['a', 'b', 'other']) OR " + + "MV_OVERLAP(COALESCE(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) OR " + + "MV_OVERLAP(NVL(MV_TO_ARRAY(dim3), ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .eternityInterval() + .virtualColumns( + new ExpressionVirtualColumn( + "v0", + "case_searched(notnull(\"dim3\"),\"dim3\",'other')", + ColumnType.STRING, + queryFramework().macroTable() + ) + ) + .filters( + new OrDimFilter( + new ExpressionDimFilter( + "case_searched(notnull(\"dim3\"),array_overlap(\"dim3\",array('a','b','other')),1)", + null, + queryFramework().macroTable() + ), + new ExpressionDimFilter( + "case_searched(notnull(\"dim3\"),array_overlap(\"dim3\",array('a','b','other')),1)", + null, + queryFramework().macroTable() + ), + new ExpressionDimFilter( + "case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)", + null, + queryFramework().macroTable() + ), + new ExpressionDimFilter( + "case_searched(notnull(mv_to_array(\"dim3\")),array_overlap(mv_to_array(\"dim3\"),array('a','b','other')),1)", + null, + queryFramework().macroTable() + ) + ) + ) + .columns("v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.replaceWithDefault() + ? ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"b\",\"c\"]"}, + new Object[]{"other"}, + new Object[]{"other"}, + new Object[]{"other"} + ) + : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"b\",\"c\"]"}, + new Object[]{"other"}, + new Object[]{"other"} + ) + ); + } + + @Test + public void testMultiValueStringOverlapFilterInconsistentUsage() + { + testQueryThrows( + "SELECT COALESCE(dim3, 'other') FROM druid.numfoo " + + "WHERE MV_OVERLAP(COALESCE(dim3, ARRAY['other']), ARRAY['a', 'b', 'other']) LIMIT 5", + e -> { + e.expect(SqlPlanningException.class); + e.expectMessage("Illegal mixing of types in CASE or COALESCE statement"); + } + + ); + } + @Test public void testMultiValueStringOverlapFilterNonLiteral() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java index 91ab2a839f2c..4d7cad0882c3 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionsTest.java @@ -1751,8 +1751,7 @@ public void testTimeMinusDayTimeInterval() (args) -> "(" + args.get(0).getExpression() + " - " + args.get(1).getExpression() + ")", ImmutableList.of( DruidExpression.ofColumn(ColumnType.LONG, "t"), - // RexNode type of "interval day to minute" is not converted to druid long... yet - DruidExpression.ofLiteral(null, "90060000") + DruidExpression.ofLiteral(ColumnType.LONG, "90060000") ) ), DateTimes.of("2000-02-03T04:05:06").minus(period).getMillis() @@ -1779,8 +1778,7 @@ public void testTimeMinusYearMonthInterval() DruidExpression.functionCall("timestamp_shift"), ImmutableList.of( DruidExpression.ofColumn(ColumnType.LONG, "t"), - // RexNode type "interval year to month" is not reported as ColumnType.STRING - DruidExpression.ofLiteral(null, DruidExpression.stringLiteral("P13M")), + DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.stringLiteral("P13M")), DruidExpression.ofLiteral(ColumnType.LONG, DruidExpression.longLiteral(-1)), DruidExpression.ofStringLiteral("UTC") ) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java index 87ce28ef3f6c..aefa534b6489 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/GreatestExpressionTest.java @@ -246,10 +246,8 @@ public void testTimestamp() } @Test - public void testInvalidType() + public void testBigDecimal() { - expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH"); - testExpression( Collections.singletonList( testHelper.makeLiteral( @@ -257,8 +255,8 @@ public void testInvalidType() new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO) ) ), - null, - null + buildExpectedExpression(13), + 13L ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java index 047f6936d307..dc5412bdf6a7 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/LeastExpressionTest.java @@ -247,10 +247,8 @@ public void testTimestamp() } @Test - public void testInvalidType() + public void testBigDecimal() { - expectException(IllegalArgumentException.class, "Argument 0 has invalid type: INTERVAL_YEAR_MONTH"); - testExpression( Collections.singletonList( testHelper.makeLiteral( @@ -258,8 +256,8 @@ public void testInvalidType() new SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.MONTH, SqlParserPos.ZERO) ) ), - null, - null + buildExpectedExpression(13), + 13L ); }