diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 711acadf08be..a0b8e59a17ff 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -2721,6 +2721,9 @@ public ExprType getOutputType(Expr.InputBindingInspector inspector, List a @Override protected ExprEval eval(String x, int y) { + if (x == null) { + return ExprEval.of(null); + } return ExprEval.of(y < 1 ? NullHandling.defaultStringValue() : StringUtils.repeat(x, y)); } } diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java index 155774958b43..cb211798e7a7 100644 --- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -596,6 +596,16 @@ public void testBitwise() assertExpr("bitwiseConvertDoubleToLongBits(null)", null); } + @Test + public void testRepeat() + { + assertExpr("repeat('hello', 2)", "hellohello"); + assertExpr("repeat('hello', -1)", null); + assertExpr("repeat(null, 10)", null); + assertExpr("repeat(nonexistent, 10)", null); + } + + private void assertExpr(final String expression, @Nullable final Object expectedResult) { final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 6f837de8ec12..12d0da989d06 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -303,7 +303,7 @@ columns in this mode are not nullable; any null or missing values will be treate In SQL compatible mode (`false`), NULLs are treated more closely to the SQL standard. The property affects both storage and querying, so for correct behavior, it should be set on all Druid service types to be available at both ingestion time and query time. There is some overhead associated with the ability to handle NULLs; see -the [segment internals](../design/segments.md#sql-compatible-null-handling)documentation for more details. +the [segment internals](../design/segments.md#sql-compatible-null-handling) documentation for more details. ## Aggregation functions diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java index e7c39d00a85d..de1293aef005 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java @@ -99,7 +99,11 @@ private static class TimestampShiftExpr extends ExprMacroTable.BaseScalarMacroFu @Override public ExprEval eval(final ObjectBinding bindings) { - return ExprEval.of(chronology.add(period, args.get(0).eval(bindings).asLong(), step)); + ExprEval timestamp = args.get(0).eval(bindings); + if (timestamp.isNumericNull()) { + return ExprEval.of(null); + } + return ExprEval.of(chronology.add(period, timestamp.asLong(), step)); } @Override @@ -128,10 +132,14 @@ private static class TimestampShiftDynamicExpr extends ExprMacroTable.BaseScalar @Override public ExprEval eval(final ObjectBinding bindings) { + ExprEval timestamp = args.get(0).eval(bindings); + if (timestamp.isNumericNull()) { + return ExprEval.of(null); + } final Period period = getPeriod(args, bindings); final Chronology chronology = getTimeZone(args, bindings); final int step = getStep(args, bindings); - return ExprEval.of(chronology.add(period, args.get(0).eval(bindings).asLong(), step)); + return ExprEval.of(chronology.add(period, timestamp.asLong(), step)); } @Override diff --git a/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java index c4710f9c3603..05945b1cc70f 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/TimestampShiftMacroTest.java @@ -20,6 +20,7 @@ package org.apache.druid.query.expression; import com.google.common.collect.ImmutableList; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.IAE; import org.apache.druid.math.expr.Expr; @@ -219,6 +220,24 @@ public Object get(String name) ); } + @Test + public void testNull() + { + Expr expr = apply( + ImmutableList.of( + ExprEval.ofLong(null).toExpr(), + ExprEval.of("P1M").toExpr(), + ExprEval.of(1L).toExpr() + ) + ); + + if (NullHandling.replaceWithDefault()) { + Assert.assertEquals(2678400000L, expr.eval(ExprUtils.nilBindings()).value()); + } else { + Assert.assertNull(expr.eval(ExprUtils.nilBindings()).value()); + } + } + private static class NotLiteralExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { NotLiteralExpr(String name) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java index 6f060f19c923..36607bfcb435 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java @@ -48,6 +48,7 @@ import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeTransforms; import org.apache.calcite.util.Static; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; @@ -255,11 +256,12 @@ private OperatorBuilder(final String name) } /** - * Sets the return type of the operator to "typeName", marked as non-nullable. + * Sets the return type of the operator to "typeName", marked as non-nullable. If this method is used it implies the + * operator should never, ever, return null. * - * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or - * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods - * cannot be mixed; you must call exactly one. + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. */ public OperatorBuilder returnTypeNonNull(final SqlTypeName typeName) { @@ -274,9 +276,9 @@ public OperatorBuilder returnTypeNonNull(final SqlTypeName typeName) /** * Sets the return type of the operator to "typeName", marked as nullable. * - * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or - * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods - * cannot be mixed; you must call exactly one. + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. */ public OperatorBuilder returnTypeNullable(final SqlTypeName typeName) { @@ -287,12 +289,27 @@ public OperatorBuilder returnTypeNullable(final SqlTypeName typeName) ); return this; } + + /** + * Sets the return type of the operator to "typeName", marked as nullable if any of its operands are nullable. + * + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. + */ + public OperatorBuilder returnTypeCascadeNullable(final SqlTypeName typeName) + { + Preconditions.checkState(this.returnTypeInference == null, "Cannot set return type multiple times"); + this.returnTypeInference = ReturnTypes.cascade(ReturnTypes.explicit(typeName), SqlTypeTransforms.TO_NULLABLE); + return this; + } + /** * Sets the return type of the operator to an array type with elements of "typeName", marked as nullable. * - * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or - * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods - * cannot be mixed; you must call exactly one. + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. */ public OperatorBuilder returnTypeNullableArray(final SqlTypeName elementTypeName) { @@ -308,9 +325,9 @@ public OperatorBuilder returnTypeNullableArray(final SqlTypeName elementTypeName /** * Provides customized return type inference logic. * - * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeNullableArray}, or - * {@link #returnTypeInference(SqlReturnTypeInference)} must be used before calling {@link #build()}. These methods - * cannot be mixed; you must call exactly one. + * One of {@link #returnTypeNonNull}, {@link #returnTypeNullable}, {@link #returnTypeCascadeNullable(SqlTypeName)} + * {@link #returnTypeNullableArray}, or {@link #returnTypeInference(SqlReturnTypeInference)} must be used before + * calling {@link #build()}. These methods cannot be mixed; you must call exactly one. */ public OperatorBuilder returnTypeInference(final SqlReturnTypeInference returnTypeInference) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java index 073d93556d82..9e67cc3ea31c 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayLengthOperatorConversion.java @@ -43,7 +43,7 @@ public class ArrayLengthOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.INTEGER) + .returnTypeCascadeNullable(SqlTypeName.INTEGER) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java index 51cad2feda4a..ca026c5e10be 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOffsetOfOperatorConversion.java @@ -47,7 +47,7 @@ public class ArrayOffsetOfOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.INTEGER) + .returnTypeNullable(SqlTypeName.INTEGER) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java index 12edb572743b..dfc1501d52bc 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOfOperatorConversion.java @@ -47,7 +47,7 @@ public class ArrayOrdinalOfOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.INTEGER) + .returnTypeCascadeNullable(SqlTypeName.INTEGER) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java index 5d316a591163..285993b399c5 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayToStringOperatorConversion.java @@ -47,7 +47,7 @@ public class ArrayToStringOperatorConversion implements SqlOperatorConversion ) ) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java index d77c20b5d787..648d54b9380a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BTrimOperatorConversion.java @@ -37,7 +37,7 @@ public class BTrimOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("BTRIM") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(1) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java index e7dbf5046141..7ffc47dd17e9 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java @@ -22,29 +22,22 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.OperatorConversions; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; -import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; public class ConcatOperatorConversion implements SqlOperatorConversion { - private static final SqlFunction SQL_FUNCTION = new SqlFunction( - "CONCAT", - SqlKind.OTHER_FUNCTION, - ReturnTypes.explicit( - factory -> Calcites.createSqlType(factory, SqlTypeName.VARCHAR) - ), - null, - OperandTypes.SAME_VARIADIC, - SqlFunctionCategory.STRING - ); + private static final SqlFunction SQL_FUNCTION = OperatorConversions + .operatorBuilder("CONCAT") + .operandTypeChecker(OperandTypes.SAME_VARIADIC) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) + .functionCategory(SqlFunctionCategory.STRING) + .build(); @Override public SqlFunction calciteOperator() diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java index f496e0ab9671..574fa2fd46b0 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/DateTruncOperatorConversion.java @@ -67,7 +67,7 @@ public class DateTruncOperatorConversion implements SqlOperatorConversion .operatorBuilder("DATE_TRUNC") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java index 2d13b02fcbbd..3d98d3e9f052 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LPadOperatorConversion.java @@ -37,7 +37,7 @@ public class LPadOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("LPAD") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(2) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java index 70ec0c97e621..233ded0acb6b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LTrimOperatorConversion.java @@ -37,7 +37,7 @@ public class LTrimOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("LTRIM") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(1) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java index 252343cddba1..deeffa50760a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeftOperatorConversion.java @@ -39,7 +39,7 @@ public class LeftOperatorConversion implements SqlOperatorConversion .operatorBuilder("LEFT") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java index e8b8e748a640..2456f059f50c 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MillisToTimestampOperatorConversion.java @@ -39,7 +39,7 @@ public class MillisToTimestampOperatorConversion implements SqlOperatorConversio private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("MILLIS_TO_TIMESTAMP") .operandTypes(SqlTypeFamily.EXACT_NUMERIC) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java index 9fd710fb1cb4..4de200002d0e 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ParseLongOperatorConversion.java @@ -38,7 +38,7 @@ public class ParseLongOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder(NAME) .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) - .returnTypeNonNull(SqlTypeName.BIGINT) + .returnTypeCascadeNullable(SqlTypeName.BIGINT) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(1) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java index 47c8eadc2f85..5ab8454643c3 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RPadOperatorConversion.java @@ -37,7 +37,7 @@ public class RPadOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("RPAD") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(2) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java index 6aa8f1b28a6f..bc96610d126d 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RTrimOperatorConversion.java @@ -37,7 +37,7 @@ public class RTrimOperatorConversion implements SqlOperatorConversion private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("RTRIM") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .requiredOperands(1) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java index 9521a0443bcc..55b01be9c5b9 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RepeatOperatorConversion.java @@ -39,7 +39,7 @@ public class RepeatOperatorConversion implements SqlOperatorConversion .operatorBuilder("REPEAT") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java index 70280abf2f98..6014231ab549 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReverseOperatorConversion.java @@ -37,7 +37,7 @@ public class ReverseOperatorConversion implements SqlOperatorConversion .operatorBuilder("REVERSE") .operandTypes(SqlTypeFamily.CHARACTER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java index 863bbccd5578..5f454a5f9801 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/RightOperatorConversion.java @@ -39,7 +39,7 @@ public class RightOperatorConversion implements SqlOperatorConversion .operatorBuilder("RIGHT") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java index b2aabbb2d11c..133d6226dc8e 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StringFormatOperatorConversion.java @@ -42,7 +42,7 @@ public class StringFormatOperatorConversion implements SqlOperatorConversion .operatorBuilder("STRING_FORMAT") .operandTypeChecker(new StringFormatOperandTypeChecker()) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java index e18c0896a5d4..c36405f06629 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/StrposOperatorConversion.java @@ -38,7 +38,7 @@ public class StrposOperatorConversion implements SqlOperatorConversion .operatorBuilder("STRPOS") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .functionCategory(SqlFunctionCategory.STRING) - .returnTypeNonNull(SqlTypeName.INTEGER) + .returnTypeCascadeNullable(SqlTypeName.INTEGER) .build(); @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java index ee160d6b3ef6..c44375c131d9 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java @@ -36,7 +36,7 @@ public class TextcatOperatorConversion implements SqlOperatorConversion .operatorBuilder("textcat") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.STRING) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java index 81b2dfa12ae2..359612c08523 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeCeilOperatorConversion.java @@ -41,7 +41,7 @@ public class TimeCeilOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_CEIL") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java index 35accd1f9b3f..000923c4fd64 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeExtractOperatorConversion.java @@ -44,7 +44,7 @@ public class TimeExtractOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_EXTRACT") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.BIGINT) + .returnTypeCascadeNullable(SqlTypeName.BIGINT) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java index 87c07f25b7e5..20377a03aacb 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFloorOperatorConversion.java @@ -56,7 +56,7 @@ public class TimeFloorOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_FLOOR") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER) .requiredOperands(2) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java index 1f7b6f95d32e..e44734f84bab 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeFormatOperatorConversion.java @@ -47,7 +47,7 @@ public class TimeFormatOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_FORMAT") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .requiredOperands(1) - .returnTypeNonNull(SqlTypeName.VARCHAR) + .returnTypeCascadeNullable(SqlTypeName.VARCHAR) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java index 25b05c40f1d1..a4fd210aa4b4 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimeShiftOperatorConversion.java @@ -45,7 +45,7 @@ public class TimeShiftOperatorConversion implements SqlOperatorConversion .operatorBuilder("TIME_SHIFT") .operandTypes(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER) .requiredOperands(3) - .returnTypeNonNull(SqlTypeName.TIMESTAMP) + .returnTypeCascadeNullable(SqlTypeName.TIMESTAMP) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java index ae4565579fb7..ece14e2dd63f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TimestampToMillisOperatorConversion.java @@ -39,7 +39,7 @@ public class TimestampToMillisOperatorConversion implements SqlOperatorConversio private static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("TIMESTAMP_TO_MILLIS") .operandTypes(SqlTypeFamily.TIMESTAMP) - .returnTypeNonNull(SqlTypeName.BIGINT) + .returnTypeCascadeNullable(SqlTypeName.BIGINT) .functionCategory(SqlFunctionCategory.TIMEDATE) .build(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 1e32e765ab56..d64eb8de66e5 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -17523,4 +17523,53 @@ public void testJoinWithTimeDimension() throws Exception .build()), ImmutableList.of(new Object[]{6L})); } + + @Test + public void testExpressionCounts() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT\n" + + " COUNT(reverse(dim2)),\n" + + " COUNT(left(dim2, 5)),\n" + + " COUNT(strpos(dim2, 'a'))\n" + + "FROM druid.numfoo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .virtualColumns( + expressionVirtualColumn("v0", "reverse(\"dim2\")", ValueType.STRING), + expressionVirtualColumn("v1", "left(\"dim2\",5)", ValueType.STRING), + expressionVirtualColumn("v2", "(strpos(\"dim2\",'a') + 1)", ValueType.LONG) + ) + .aggregators( + aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + not(selector("v0", null, null)) + ), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a1"), + not(selector("v1", null, null)) + ), + new FilteredAggregatorFactory( + new CountAggregatorFactory("a2"), + not(selector("v2", null, null)) + ) + ) + ) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + // in default mode strpos is 6 because the '+ 1' of the expression (no null numbers in + // default mode so is 0 + 1 for null rows) + ? new Object[]{3L, 3L, 6L} + : new Object[]{4L, 4L, 4L} + ) + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java index 0268bb636d85..5f70dc5902a3 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/OperatorConversionsTest.java @@ -31,12 +31,14 @@ import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlTypeFactoryImpl; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.sql.calcite.expression.OperatorConversions.DefaultOperandTypeChecker; +import org.apache.druid.sql.calcite.planner.DruidTypeSystem; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -275,6 +277,69 @@ public void testNullLiteralForNullableOperand() ); } + @Test + public void testNullForNullableOperandNonNullOutput() + { + SqlFunction function = OperatorConversions + .operatorBuilder("testNullForNullableNonnull") + .operandTypes(SqlTypeFamily.CHARACTER) + .requiredOperands(1) + .returnTypeNonNull(SqlTypeName.CHAR) + .build(); + SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); + SqlCallBinding binding = mockCallBinding( + function, + ImmutableList.of( + new OperandSpec(SqlTypeName.CHAR, false, true) + ) + ); + Assert.assertTrue(typeChecker.checkOperandTypes(binding, true)); + RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding); + Assert.assertFalse(returnType.isNullable()); + } + + @Test + public void testNullForNullableOperandCascadeNullOutput() + { + SqlFunction function = OperatorConversions + .operatorBuilder("testNullForNullableCascade") + .operandTypes(SqlTypeFamily.CHARACTER) + .requiredOperands(1) + .returnTypeCascadeNullable(SqlTypeName.CHAR) + .build(); + SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); + SqlCallBinding binding = mockCallBinding( + function, + ImmutableList.of( + new OperandSpec(SqlTypeName.CHAR, false, true) + ) + ); + Assert.assertTrue(typeChecker.checkOperandTypes(binding, true)); + RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding); + Assert.assertTrue(returnType.isNullable()); + } + + @Test + public void testNullForNullableOperandAlwaysNullableOutput() + { + SqlFunction function = OperatorConversions + .operatorBuilder("testNullForNullableNonnull") + .operandTypes(SqlTypeFamily.CHARACTER) + .requiredOperands(1) + .returnTypeNullable(SqlTypeName.CHAR) + .build(); + SqlOperandTypeChecker typeChecker = function.getOperandTypeChecker(); + SqlCallBinding binding = mockCallBinding( + function, + ImmutableList.of( + new OperandSpec(SqlTypeName.CHAR, false, false) + ) + ); + Assert.assertTrue(typeChecker.checkOperandTypes(binding, true)); + RelDataType returnType = function.getReturnTypeInference().inferReturnType(binding); + Assert.assertTrue(returnType.isNullable()); + } + @Test public void testNullForNonNullableOperand() { @@ -359,6 +424,7 @@ private static SqlCallBinding mockCallBinding( ) { SqlValidator validator = Mockito.mock(SqlValidator.class); + Mockito.when(validator.getTypeFactory()).thenReturn(new SqlTypeFactoryImpl(DruidTypeSystem.INSTANCE)); List operands = new ArrayList<>(actualOperands.size()); for (OperandSpec operand : actualOperands) { final SqlNode node; @@ -368,6 +434,12 @@ private static SqlCallBinding mockCallBinding( node = Mockito.mock(SqlNode.class); } RelDataType relDataType = Mockito.mock(RelDataType.class); + + if (operand.isNullable) { + Mockito.when(relDataType.isNullable()).thenReturn(true); + } else { + Mockito.when(relDataType.isNullable()).thenReturn(false); + } Mockito.when(validator.deriveType(ArgumentMatchers.any(), ArgumentMatchers.eq(node))) .thenReturn(relDataType); Mockito.when(relDataType.getSqlTypeName()).thenReturn(operand.type); @@ -394,11 +466,18 @@ private static class OperandSpec { private final SqlTypeName type; private final boolean isLiteral; + private final boolean isNullable; private OperandSpec(SqlTypeName type, boolean isLiteral) + { + this(type, isLiteral, type == SqlTypeName.NULL); + } + + private OperandSpec(SqlTypeName type, boolean isLiteral, boolean isNullable) { this.type = type; this.isLiteral = isLiteral; + this.isNullable = isNullable; } } }