From 834cb13aeb531a60685ce9eb270894e74a5201c3 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 00:46:52 -0700 Subject: [PATCH 01/15] push down ValueType to ExprType conversion, tidy up --- .../apache/druid/math/expr/ApplyFunction.java | 8 ++-- .../math/expr/BinaryLogicalOperatorExpr.java | 12 +++--- .../org/apache/druid/math/expr/ExprEval.java | 5 +++ .../org/apache/druid/math/expr/ExprType.java | 39 ++++++++++++++++++- .../org/apache/druid/math/expr/Function.java | 8 ++-- .../expressions/BloomFilterExprMacro.java | 3 +- .../expression/IPv4AddressMatchExprMacro.java | 3 +- .../druid/query/expression/LikeExprMacro.java | 3 +- .../query/expression/RegexpLikeExprMacro.java | 5 +-- .../segment/virtual/ExpressionSelectors.java | 9 ++--- .../expression/RegexpLikeExprMacroTest.java | 17 ++++---- .../druid/sql/calcite/rel/Projection.java | 4 +- 12 files changed, 75 insertions(+), 41 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java index 4bf2fa5e934b..df9b7fb30b5d 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java +++ b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java @@ -274,7 +274,7 @@ ExprEval applyFold(LambdaExpr lambdaExpr, Object accumulator, IndexableFoldLambd accumulator = evaluated.value(); } if (accumulator instanceof Boolean) { - return ExprEval.of((boolean) accumulator, ExprType.LONG); + return ExprEval.ofLongBoolean((boolean) accumulator); } return ExprEval.bestEffortOf(accumulator); } @@ -501,7 +501,7 @@ public ExprEval apply(LambdaExpr lambdaExpr, List argsExpr, Expr.ObjectBin final Object[] array = arrayEval.asArray(); if (array == null) { - return ExprEval.of(false, ExprType.LONG); + return ExprEval.ofLongBoolean(false); } SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(lambdaExpr, bindings); @@ -550,7 +550,7 @@ public ExprEval match(Object[] values, LambdaExpr expr, SettableLambdaBinding bi { boolean anyMatch = Arrays.stream(values) .anyMatch(o -> expr.eval(bindings.withBinding(expr.getIdentifier(), o)).asBoolean()); - return ExprEval.of(anyMatch, ExprType.LONG); + return ExprEval.ofLongBoolean(anyMatch); } } @@ -573,7 +573,7 @@ public ExprEval match(Object[] values, LambdaExpr expr, SettableLambdaBinding bi { boolean allMatch = Arrays.stream(values) .allMatch(o -> expr.eval(bindings.withBinding(expr.getIdentifier(), o)).asBoolean()); - return ExprEval.of(allMatch, ExprType.LONG); + return ExprEval.ofLongBoolean(allMatch); } } diff --git a/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java index dad35f30560a..3e4c9b8218f7 100644 --- a/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java @@ -42,7 +42,7 @@ protected BinaryOpExprBase copy(Expr left, Expr right) @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Comparators.naturalNullsFirst().compare(left, right) < 0, ExprType.LONG); + return ExprEval.ofLongBoolean(Comparators.naturalNullsFirst().compare(left, right) < 0); } @Override @@ -75,7 +75,7 @@ protected BinaryOpExprBase copy(Expr left, Expr right) @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Comparators.naturalNullsFirst().compare(left, right) <= 0, ExprType.LONG); + return ExprEval.ofLongBoolean(Comparators.naturalNullsFirst().compare(left, right) <= 0); } @Override @@ -108,7 +108,7 @@ protected BinaryOpExprBase copy(Expr left, Expr right) @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Comparators.naturalNullsFirst().compare(left, right) > 0, ExprType.LONG); + return ExprEval.ofLongBoolean(Comparators.naturalNullsFirst().compare(left, right) > 0); } @Override @@ -141,7 +141,7 @@ protected BinaryOpExprBase copy(Expr left, Expr right) @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Comparators.naturalNullsFirst().compare(left, right) >= 0, ExprType.LONG); + return ExprEval.ofLongBoolean(Comparators.naturalNullsFirst().compare(left, right) >= 0); } @Override @@ -174,7 +174,7 @@ protected BinaryOpExprBase copy(Expr left, Expr right) @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(Objects.equals(left, right), ExprType.LONG); + return ExprEval.ofLongBoolean(Objects.equals(left, right)); } @Override @@ -206,7 +206,7 @@ protected BinaryOpExprBase copy(Expr left, Expr right) @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { - return ExprEval.of(!Objects.equals(left, right), ExprType.LONG); + return ExprEval.ofLongBoolean(!Objects.equals(left, right)); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 61cdc26f6dd1..1c02186296be 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -89,6 +89,11 @@ public static ExprEval of(boolean value, ExprType type) } } + public static ExprEval ofLongBoolean(boolean value) + { + return ExprEval.of(Evals.asLong(value)); + } + public static ExprEval bestEffortOf(@Nullable Object val) { if (val instanceof ExprEval) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index 0bc1573bef56..41f1b60c61df 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -19,6 +19,11 @@ package org.apache.druid.math.expr; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.segment.column.ValueType; + +import javax.annotation.Nullable; + /** * Base 'value' types of Druid expression language, all {@link Expr} must evaluate to one of these types. */ @@ -29,5 +34,37 @@ public enum ExprType STRING, DOUBLE_ARRAY, LONG_ARRAY, - STRING_ARRAY + STRING_ARRAY; + + /** + * The expression system does not distinguish between {@link ValueType#FLOAT} and {@link ValueType#DOUBLE}, and + * cannot currently handle {@link ValueType#COMPLEX} inputs. This method will convert {@link ValueType#FLOAT} to + * {@link #DOUBLE}, or throw an exception if a {@link ValueType#COMPLEX} is encountered. + * + * @throws IllegalStateException + */ + public static ExprType fromValueType(@Nullable ValueType valueType) + { + if (valueType == null) { + throw new IllegalStateException("Unsupported unknown value type"); + } + switch (valueType) { + case LONG: + return LONG; + case LONG_ARRAY: + return LONG_ARRAY; + case FLOAT: + case DOUBLE: + return DOUBLE; + case DOUBLE_ARRAY: + return DOUBLE_ARRAY; + case STRING: + return STRING; + case STRING_ARRAY: + return STRING_ARRAY; + case COMPLEX: + default: + throw new ISE("Unsupported value type[%s]", valueType); + } + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index a20863929875..7fc6901d3739 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -1828,7 +1828,7 @@ public String name() public ExprEval apply(List args, Expr.ObjectBinding bindings) { final ExprEval expr = args.get(0).eval(bindings); - return ExprEval.of(expr.value() == null, ExprType.LONG); + return ExprEval.ofLongBoolean(expr.value() == null); } @Override @@ -1852,7 +1852,7 @@ public String name() public ExprEval apply(List args, Expr.ObjectBinding bindings) { final ExprEval expr = args.get(0).eval(bindings); - return ExprEval.of(expr.value() != null, ExprType.LONG); + return ExprEval.ofLongBoolean(expr.value() != null); } @Override @@ -2414,7 +2414,7 @@ ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) { final Object[] array1 = lhsExpr.asArray(); final Object[] array2 = rhsExpr.asArray(); - return ExprEval.of(Arrays.asList(array1).containsAll(Arrays.asList(array2)), ExprType.LONG); + return ExprEval.ofLongBoolean(Arrays.asList(array1).containsAll(Arrays.asList(array2))); } } @@ -2435,7 +2435,7 @@ ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) for (Object check : array1) { any |= array2.contains(check); } - return ExprEval.of(any, ExprType.LONG); + return ExprEval.ofLongBoolean(any); } } diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java index dcc3a16ccad6..6cbcfd16bc66 100644 --- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java +++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java @@ -25,7 +25,6 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.math.expr.ExprType; import org.apache.druid.query.filter.BloomKFilter; import javax.annotation.Nonnull; @@ -108,7 +107,7 @@ public ExprEval eval(final ObjectBinding bindings) break; } - return ExprEval.of(matches, ExprType.LONG); + return ExprEval.ofLongBoolean(matches); } private boolean nullMatch() diff --git a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java index 5e9cc85fe545..05d510e2d3f6 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java @@ -25,7 +25,6 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; import java.util.List; @@ -98,7 +97,7 @@ public ExprEval eval(final ObjectBinding bindings) default: match = false; } - return ExprEval.of(match, ExprType.LONG); + return ExprEval.ofLongBoolean(match); } private boolean isStringMatch(String stringValue) diff --git a/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java index 2332b2858eaf..a124722e8f8b 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java @@ -25,7 +25,6 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.math.expr.ExprType; import org.apache.druid.query.filter.LikeDimFilter; import javax.annotation.Nonnull; @@ -81,7 +80,7 @@ private LikeExtractExpr(Expr arg) @Override public ExprEval eval(final ObjectBinding bindings) { - return ExprEval.of(likeMatcher.matches(arg.eval(bindings).asString()), ExprType.LONG); + return ExprEval.ofLongBoolean(likeMatcher.matches(arg.eval(bindings).asString())); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java index 83735e863494..a4909194d484 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java @@ -25,7 +25,6 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; import java.util.List; @@ -76,10 +75,10 @@ public ExprEval eval(final ObjectBinding bindings) if (s == null) { // True nulls do not match anything. Note: this branch only executes in SQL-compatible null handling mode. - return ExprEval.of(false, ExprType.LONG); + return ExprEval.ofLongBoolean(false); } else { final Matcher matcher = pattern.matcher(s); - return ExprEval.of(matcher.find(), ExprType.LONG); + return ExprEval.ofLongBoolean(matcher.find()); } } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index 5dd6a9970432..860e47b0f6c0 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -376,16 +376,13 @@ private static Expr.ObjectBinding createBindings( final Supplier supplier; if (nativeType == ValueType.FLOAT) { - ColumnValueSelector selector = columnSelectorFactory - .makeColumnValueSelector(columnName); + ColumnValueSelector selector = columnSelectorFactory.makeColumnValueSelector(columnName); supplier = makeNullableNumericSupplier(selector, selector::getFloat); } else if (nativeType == ValueType.LONG) { - ColumnValueSelector selector = columnSelectorFactory - .makeColumnValueSelector(columnName); + ColumnValueSelector selector = columnSelectorFactory.makeColumnValueSelector(columnName); supplier = makeNullableNumericSupplier(selector, selector::getLong); } else if (nativeType == ValueType.DOUBLE) { - ColumnValueSelector selector = columnSelectorFactory - .makeColumnValueSelector(columnName); + ColumnValueSelector selector = columnSelectorFactory.makeColumnValueSelector(columnName); supplier = makeNullableNumericSupplier(selector, selector::getDouble); } else if (nativeType == ValueType.STRING) { supplier = supplierFromDimensionSelector( diff --git a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java index a6bdfb36a03a..b57db64b2327 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java @@ -22,7 +22,6 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.ExprType; import org.apache.druid.math.expr.Parser; import org.junit.Assert; import org.junit.Test; @@ -53,7 +52,7 @@ public void testMatch() { final ExprEval result = eval("regexp_like(a, 'f.o')", Parser.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -63,7 +62,7 @@ public void testNoMatch() { final ExprEval result = eval("regexp_like(a, 'f.x')", Parser.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(false, ExprType.LONG).value(), + ExprEval.ofLongBoolean(false).value(), result.value() ); } @@ -77,7 +76,7 @@ public void testNullPattern() final ExprEval result = eval("regexp_like(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -87,7 +86,7 @@ public void testEmptyStringPattern() { final ExprEval result = eval("regexp_like(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -101,7 +100,7 @@ public void testNullPatternOnEmptyString() final ExprEval result = eval("regexp_like(a, null)", Parser.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -111,7 +110,7 @@ public void testEmptyStringPatternOnEmptyString() { final ExprEval result = eval("regexp_like(a, '')", Parser.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -125,7 +124,7 @@ public void testNullPatternOnNull() final ExprEval result = eval("regexp_like(a, null)", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofLongBoolean(true).value(), result.value() ); } @@ -135,7 +134,7 @@ public void testEmptyStringPatternOnNull() { final ExprEval result = eval("regexp_like(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(NullHandling.replaceWithDefault(), ExprType.LONG).value(), + ExprEval.ofLongBoolean(NullHandling.replaceWithDefault()).value(), result.value() ); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Projection.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Projection.java index d353a4ae116c..02cc3ce2aae3 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Projection.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Projection.java @@ -316,8 +316,8 @@ private static boolean postAggregatorDirectColumnIsOk( } // Check if a cast is necessary. - final ExprType toExprType = Expressions.exprTypeForValueType(columnValueType); - final ExprType fromExprType = Expressions.exprTypeForValueType( + final ExprType toExprType = ExprType.fromValueType(columnValueType); + final ExprType fromExprType = ExprType.fromValueType( Calcites.getValueTypeForRelDataType(rexNode.getType()) ); From 472918e282c5787dbaa801d85ea6c052e2e29c8a Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 17:57:36 -0700 Subject: [PATCH 02/15] determine expr output type for given input types --- .../apache/druid/math/expr/ApplyFunction.java | 60 +- .../math/expr/BinaryLogicalOperatorExpr.java | 67 +- .../druid/math/expr/BinaryOperatorExpr.java | 9 +- .../apache/druid/math/expr/ConstantExpr.java | 40 +- .../java/org/apache/druid/math/expr/Expr.java | 56 +- .../druid/math/expr/ExprListenerImpl.java | 2 +- .../druid/math/expr/ExprMacroTable.java | 14 +- .../org/apache/druid/math/expr/ExprType.java | 72 ++ .../org/apache/druid/math/expr/Function.java | 918 ++++++++++++------ .../druid/math/expr/FunctionalExpr.java | 55 +- .../druid/math/expr/IdentifierExpr.java | 10 +- .../org/apache/druid/math/expr/Parser.java | 28 +- .../druid/math/expr/UnaryOperatorExpr.java | 21 +- .../org/apache/druid/math/expr/ExprTest.java | 40 +- .../druid/math/expr/OutputTypeTest.java | 367 +++++++ .../apache/druid/math/expr/ParserTest.java | 8 +- .../expressions/BloomFilterExprMacro.java | 9 + .../expression/IPv4AddressMatchExprMacro.java | 9 + .../expression/IPv4AddressParseExprMacro.java | 9 + .../IPv4AddressStringifyExprMacro.java | 9 + .../druid/query/expression/LikeExprMacro.java | 9 + .../query/expression/LookupExprMacro.java | 9 + .../expression/RegexpExtractExprMacro.java | 9 + .../query/expression/RegexpLikeExprMacro.java | 9 + .../expression/TimestampCeilExprMacro.java | 16 + .../expression/TimestampExtractExprMacro.java | 15 + .../expression/TimestampFloorExprMacro.java | 16 + .../expression/TimestampFormatExprMacro.java | 9 + .../expression/TimestampParseExprMacro.java | 9 + .../expression/TimestampShiftExprMacro.java | 16 + .../druid/query/expression/TrimExprMacro.java | 18 +- .../segment/filter/ExpressionFilter.java | 4 +- .../join/filter/JoinFilterCorrelations.java | 4 +- .../segment/virtual/ExpressionSelectors.java | 14 +- ...RowBasedExpressionColumnValueSelector.java | 10 +- .../IPv4AddressMatchExprMacroTest.java | 9 + .../builtin/GreatestOperatorConversion.java | 4 +- .../builtin/LeastOperatorConversion.java | 4 +- .../ReductionOperatorConversionHelper.java | 4 +- 39 files changed, 1577 insertions(+), 414 deletions(-) create mode 100644 core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java diff --git a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java index df9b7fb30b5d..8a0f68c24a78 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java +++ b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableSet; import it.unimi.dsi.fastutil.objects.Object2IntArrayMap; import it.unimi.dsi.fastutil.objects.Object2IntMap; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.StringUtils; @@ -74,6 +75,8 @@ default boolean hasArrayOutput(LambdaExpr lambdaExpr) */ void validateArguments(LambdaExpr lambdaExpr, List args); + ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args); + /** * Base class for "map" functions, which are a class of {@link ApplyFunction} which take a lambda function that is * mapped to the values of an {@link IndexableMapLambdaObjectBinding} which is created from the outer @@ -87,6 +90,12 @@ public boolean hasArrayOutput(LambdaExpr lambdaExpr) return true; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + return ExprType.asArrayType(expr.getOutputType(new LambdaInputBindingTypes(inputTypes, expr, args))); + } + /** * Evaluate {@link LambdaExpr} against every index position of an {@link IndexableMapLambdaObjectBinding} */ @@ -282,8 +291,15 @@ ExprEval applyFold(LambdaExpr lambdaExpr, Object accumulator, IndexableFoldLambd @Override public boolean hasArrayOutput(LambdaExpr lambdaExpr) { - Expr.BindingDetails lambdaBindingDetails = lambdaExpr.analyzeInputs(); - return lambdaBindingDetails.isOutputArray(); + Expr.ExprInputBindingAnalysis lambdaExprInputBindingAnalysis = lambdaExpr.analyzeInputs(); + return lambdaExprInputBindingAnalysis.isOutputArray(); + } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + // output type is accumulator type, which is last argument + return args.get(args.size() - 1).getOutputType(inputTypes); } } @@ -481,6 +497,13 @@ public void validateArguments(LambdaExpr lambdaExpr, List args) ); } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + // output type is input array type + return args.get(0).getOutputType(inputTypes); + } + private Stream filter(T[] array, LambdaExpr expr, SettableLambdaBinding binding) { return Arrays.stream(array).filter(s -> expr.eval(binding.withBinding(expr.getIdentifier(), s)).asBoolean()); @@ -528,6 +551,12 @@ public void validateArguments(LambdaExpr lambdaExpr, List args) ); } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + return ExprType.LONG; + } + public abstract ExprEval match(Object[] values, LambdaExpr expr, SettableLambdaBinding bindings); } @@ -848,4 +877,31 @@ public CartesianFoldLambdaBinding accumulateWithIndex(int index, Object acc) return this; } } + + class LambdaInputBindingTypes implements Expr.InputBindingTypes + { + private final Object2IntMap lambdaIdentifiers; + private final Expr.InputBindingTypes inputTypes; + private final List args; + + public LambdaInputBindingTypes(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) + { + this.inputTypes = inputTypes; + this.args = args; + List identifiers = expr.getIdentifiers(); + this.lambdaIdentifiers = new Object2IntOpenHashMap<>(args.size()); + for (int i = 0; i < args.size(); i++) { + lambdaIdentifiers.put(identifiers.get(i), i); + } + } + + @Override + public ExprType getType(String name) + { + if (lambdaIdentifiers.containsKey(name)) { + return ExprType.elementType(args.get(lambdaIdentifiers.getInt(name)).getOutputType(inputTypes)); + } + return inputTypes.getType(name); + } + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java index 3e4c9b8218f7..58cb5a08de87 100644 --- a/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/BinaryLogicalOperatorExpr.java @@ -57,6 +57,17 @@ protected final double evalDouble(double left, double right) // Use Double.compare for more consistent NaN handling. return Evals.asDouble(Double.compare(left, right) < 0); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinLeqExpr extends BinaryEvalOpExprBase @@ -90,6 +101,17 @@ protected final double evalDouble(double left, double right) // Use Double.compare for more consistent NaN handling. return Evals.asDouble(Double.compare(left, right) <= 0); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinGtExpr extends BinaryEvalOpExprBase @@ -123,6 +145,17 @@ protected final double evalDouble(double left, double right) // Use Double.compare for more consistent NaN handling. return Evals.asDouble(Double.compare(left, right) > 0); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinGeqExpr extends BinaryEvalOpExprBase @@ -156,6 +189,17 @@ protected final double evalDouble(double left, double right) // Use Double.compare for more consistent NaN handling. return Evals.asDouble(Double.compare(left, right) >= 0); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinEqExpr extends BinaryEvalOpExprBase @@ -188,6 +232,17 @@ protected final double evalDouble(double left, double right) { return Evals.asDouble(left == right); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinNeqExpr extends BinaryEvalOpExprBase @@ -220,6 +275,17 @@ protected final double evalDouble(double left, double right) { return Evals.asDouble(left != right); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } class BinAndExpr extends BinaryOpExprBase @@ -262,5 +328,4 @@ public ExprEval eval(ObjectBinding bindings) ExprEval leftVal = left.eval(bindings); return leftVal.asBoolean() ? leftVal : right.eval(bindings); } - } diff --git a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java index 9c390587bd4e..7b7423d29733 100644 --- a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java @@ -81,12 +81,19 @@ public String stringify() protected abstract BinaryOpExprBase copy(Expr left, Expr right); @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() { // currently all binary operators operate on scalar inputs return left.analyzeInputs().with(right).withScalarArguments(ImmutableSet.of(left, right)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.implicitCast(left.getOutputType(inputTypes), right.getOutputType(inputTypes)); + } + @Override public boolean equals(Object o) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java index 4f6099cf3665..c0daed09da2b 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java @@ -35,6 +35,19 @@ */ abstract class ConstantExpr implements Expr { + final ExprType outputType; + + protected ConstantExpr(ExprType outputType) + { + this.outputType = outputType; + } + + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return outputType; + } + @Override public boolean isLiteral() { @@ -54,9 +67,9 @@ public Expr visit(Shuttle shuttle) } @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() { - return new BindingDetails(); + return new ExprInputBindingAnalysis(); } @Override @@ -71,6 +84,11 @@ public String stringify() */ abstract class NullNumericConstantExpr extends ConstantExpr { + protected NullNumericConstantExpr(ExprType outputType) + { + super(outputType); + } + @Override public Object getLiteralValue() { @@ -82,6 +100,8 @@ public String toString() { return NULL_LITERAL; } + + } class LongExpr extends ConstantExpr @@ -90,6 +110,7 @@ class LongExpr extends ConstantExpr LongExpr(Long value) { + super(ExprType.LONG); this.value = Preconditions.checkNotNull(value, "value"); } @@ -133,6 +154,11 @@ public int hashCode() class NullLongExpr extends NullNumericConstantExpr { + NullLongExpr() + { + super(ExprType.LONG); + } + @Override public ExprEval eval(ObjectBinding bindings) { @@ -158,6 +184,7 @@ class LongArrayExpr extends ConstantExpr LongArrayExpr(Long[] value) { + super(ExprType.LONG_ARRAY); this.value = Preconditions.checkNotNull(value, "value"); } @@ -215,6 +242,7 @@ class StringExpr extends ConstantExpr StringExpr(@Nullable String value) { + super(ExprType.STRING); this.value = NullHandling.emptyToNullIfNeeded(value); } @@ -270,6 +298,7 @@ class StringArrayExpr extends ConstantExpr StringArrayExpr(String[] value) { + super(ExprType.STRING_ARRAY); this.value = Preconditions.checkNotNull(value, "value"); } @@ -338,6 +367,7 @@ class DoubleExpr extends ConstantExpr DoubleExpr(Double value) { + super(ExprType.DOUBLE); this.value = Preconditions.checkNotNull(value, "value"); } @@ -381,6 +411,11 @@ public int hashCode() class NullDoubleExpr extends NullNumericConstantExpr { + NullDoubleExpr() + { + super(ExprType.DOUBLE); + } + @Override public ExprEval eval(ObjectBinding bindings) { @@ -406,6 +441,7 @@ class DoubleArrayExpr extends ConstantExpr DoubleArrayExpr(Double[] value) { + super(ExprType.DOUBLE_ARRAY); this.value = Preconditions.checkNotNull(value, "value"); } diff --git a/core/src/main/java/org/apache/druid/math/expr/Expr.java b/core/src/main/java/org/apache/druid/math/expr/Expr.java index e0a1525c7df5..ef4d8022d4a0 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Expr.java +++ b/core/src/main/java/org/apache/druid/math/expr/Expr.java @@ -123,9 +123,17 @@ default String getBindingIfIdentifier() Expr visit(Shuttle shuttle); /** - * Examine the usage of {@link IdentifierExpr} children of an {@link Expr}, constructing a {@link BindingDetails} + * Examine the usage of {@link IdentifierExpr} children of an {@link Expr}, constructing a {@link ExprInputBindingAnalysis} */ - BindingDetails analyzeInputs(); + ExprInputBindingAnalysis analyzeInputs(); + + @Nullable + ExprType getOutputType(InputBindingTypes inputTypes); + + interface InputBindingTypes + { + ExprType getType(String name); + } /** * Mechanism to supply values to back {@link IdentifierExpr} during expression evaluation @@ -180,7 +188,7 @@ interface Shuttle * * This means in rare cases and mostly for "questionable" expressions which we still allow to function 'correctly', * these lists might not be fully reliable without a complete type inference system in place. Due to this shortcoming, - * boolean values {@link BindingDetails#hasInputArrays()} and {@link BindingDetails#isOutputArray()} are provided to + * boolean values {@link ExprInputBindingAnalysis#hasInputArrays()} and {@link ExprInputBindingAnalysis#isOutputArray()} are provided to * allow functions to explicitly declare that they utilize array typed values, used when determining if some types of * optimizations can be applied when constructing the expression column value selector. * @@ -194,7 +202,7 @@ interface Shuttle * @see org.apache.druid.segment.virtual.ExpressionSelectors#makeColumnValueSelector */ @SuppressWarnings("JavadocReference") - class BindingDetails + class ExprInputBindingAnalysis { private final ImmutableSet freeVariables; private final ImmutableSet scalarVariables; @@ -202,17 +210,17 @@ class BindingDetails private final boolean hasInputArrays; private final boolean isOutputArray; - BindingDetails() + ExprInputBindingAnalysis() { this(ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of(), false, false); } - BindingDetails(IdentifierExpr expr) + ExprInputBindingAnalysis(IdentifierExpr expr) { this(ImmutableSet.of(expr), ImmutableSet.of(), ImmutableSet.of(), false, false); } - private BindingDetails( + private ExprInputBindingAnalysis( ImmutableSet freeVariables, ImmutableSet scalarVariables, ImmutableSet arrayVariables, @@ -310,19 +318,19 @@ public boolean isOutputArray() } /** - * Combine with {@link BindingDetails} from {@link Expr#analyzeInputs()} + * Combine with {@link ExprInputBindingAnalysis} from {@link Expr#analyzeInputs()} */ - public BindingDetails with(Expr other) + public ExprInputBindingAnalysis with(Expr other) { return with(other.analyzeInputs()); } /** - * Combine (union) another {@link BindingDetails} + * Combine (union) another {@link ExprInputBindingAnalysis} */ - public BindingDetails with(BindingDetails other) + public ExprInputBindingAnalysis with(ExprInputBindingAnalysis other) { - return new BindingDetails( + return new ExprInputBindingAnalysis( ImmutableSet.copyOf(Sets.union(freeVariables, other.freeVariables)), ImmutableSet.copyOf(Sets.union(scalarVariables, other.scalarVariables)), ImmutableSet.copyOf(Sets.union(arrayVariables, other.arrayVariables)), @@ -332,10 +340,10 @@ public BindingDetails with(BindingDetails other) } /** - * Add set of arguments as {@link BindingDetails#scalarVariables} that are *directly* {@link IdentifierExpr}, + * Add set of arguments as {@link ExprInputBindingAnalysis#scalarVariables} that are *directly* {@link IdentifierExpr}, * else they are ignored. */ - public BindingDetails withScalarArguments(Set scalarArguments) + public ExprInputBindingAnalysis withScalarArguments(Set scalarArguments) { Set moreScalars = new HashSet<>(); for (Expr expr : scalarArguments) { @@ -344,7 +352,7 @@ public BindingDetails withScalarArguments(Set scalarArguments) moreScalars.add((IdentifierExpr) expr); } } - return new BindingDetails( + return new ExprInputBindingAnalysis( ImmutableSet.copyOf(Sets.union(freeVariables, moreScalars)), ImmutableSet.copyOf(Sets.union(scalarVariables, moreScalars)), arrayVariables, @@ -354,10 +362,10 @@ public BindingDetails withScalarArguments(Set scalarArguments) } /** - * Add set of arguments as {@link BindingDetails#arrayVariables} that are *directly* {@link IdentifierExpr}, + * Add set of arguments as {@link ExprInputBindingAnalysis#arrayVariables} that are *directly* {@link IdentifierExpr}, * else they are ignored. */ - BindingDetails withArrayArguments(Set arrayArguments) + ExprInputBindingAnalysis withArrayArguments(Set arrayArguments) { Set arrayIdentifiers = new HashSet<>(); for (Expr expr : arrayArguments) { @@ -366,7 +374,7 @@ BindingDetails withArrayArguments(Set arrayArguments) arrayIdentifiers.add((IdentifierExpr) expr); } } - return new BindingDetails( + return new ExprInputBindingAnalysis( ImmutableSet.copyOf(Sets.union(freeVariables, arrayIdentifiers)), scalarVariables, ImmutableSet.copyOf(Sets.union(arrayVariables, arrayIdentifiers)), @@ -378,9 +386,9 @@ BindingDetails withArrayArguments(Set arrayArguments) /** * Copy, setting if an expression has array inputs */ - BindingDetails withArrayInputs(boolean hasArrays) + ExprInputBindingAnalysis withArrayInputs(boolean hasArrays) { - return new BindingDetails( + return new ExprInputBindingAnalysis( freeVariables, scalarVariables, arrayVariables, @@ -392,9 +400,9 @@ BindingDetails withArrayInputs(boolean hasArrays) /** * Copy, setting if an expression produces an array output */ - BindingDetails withArrayOutput(boolean isOutputArray) + ExprInputBindingAnalysis withArrayOutput(boolean isOutputArray) { - return new BindingDetails( + return new ExprInputBindingAnalysis( freeVariables, scalarVariables, arrayVariables, @@ -407,9 +415,9 @@ BindingDetails withArrayOutput(boolean isOutputArray) * Remove any {@link IdentifierExpr} that are from a {@link LambdaExpr}, since the {@link ApplyFunction} will * provide bindings for these variables. */ - BindingDetails removeLambdaArguments(Set lambda) + ExprInputBindingAnalysis removeLambdaArguments(Set lambda) { - return new BindingDetails( + return new ExprInputBindingAnalysis( ImmutableSet.copyOf(freeVariables.stream().filter(x -> !lambda.contains(x.getIdentifier())).iterator()), ImmutableSet.copyOf(scalarVariables.stream().filter(x -> !lambda.contains(x.getIdentifier())).iterator()), ImmutableSet.copyOf(arrayVariables.stream().filter(x -> !lambda.contains(x.getIdentifier())).iterator()), diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java index ae41653950f9..69a64e13d2b3 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java @@ -482,7 +482,7 @@ public void exitExplicitStringArray(ExprParser.ExplicitStringArrayContext ctx) * {@link IdentifierExpr#identifier} be the same as {@link IdentifierExpr#binding} because they have * synthetic bindings set at evaluation time. This is done to aid in analysis needed for the automatic expression * translation which maps scalar expressions to multi-value inputs. See - * {@link Parser#applyUnappliedBindings(Expr, Expr.BindingDetails, List)}} for additional details. + * {@link Parser#applyUnappliedBindings(Expr, Expr.ExprInputBindingAnalysis, List)}} for additional details. */ private IdentifierExpr createIdentifierExpr(String binding) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java index f7cf1d0f6489..4396969693b6 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java @@ -102,7 +102,7 @@ public abstract static class BaseScalarUnivariateMacroFunctionExpr implements Ex protected final Expr arg; // Use Supplier to memoize values as ExpressionSelectors#makeExprEvalSelector() can make repeated calls for them - private final Supplier analyzeInputsSupplier; + private final Supplier analyzeInputsSupplier; public BaseScalarUnivariateMacroFunctionExpr(String name, Expr arg) { @@ -119,7 +119,7 @@ public void visit(final Visitor visitor) } @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() { return analyzeInputsSupplier.get(); } @@ -150,7 +150,7 @@ public int hashCode() return Objects.hash(name, arg); } - private BindingDetails supplyAnalyzeInputs() + private ExprInputBindingAnalysis supplyAnalyzeInputs() { return arg.analyzeInputs().withScalarArguments(ImmutableSet.of(arg)); } @@ -165,7 +165,7 @@ public abstract static class BaseScalarMacroFunctionExpr implements Expr protected final List args; // Use Supplier to memoize values as ExpressionSelectors#makeExprEvalSelector() can make repeated calls for them - private final Supplier analyzeInputsSupplier; + private final Supplier analyzeInputsSupplier; public BaseScalarMacroFunctionExpr(String name, final List args) { @@ -194,7 +194,7 @@ public void visit(final Visitor visitor) } @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() { return analyzeInputsSupplier.get(); } @@ -219,10 +219,10 @@ public int hashCode() return Objects.hash(name, args); } - private BindingDetails supplyAnalyzeInputs() + private ExprInputBindingAnalysis supplyAnalyzeInputs() { final Set argSet = Sets.newHashSetWithExpectedSize(args.size()); - BindingDetails accumulator = new BindingDetails(); + ExprInputBindingAnalysis accumulator = new ExprInputBindingAnalysis(); for (Expr arg : args) { accumulator = accumulator.with(arg); argSet.add(arg); diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index 41f1b60c61df..89700545f161 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -19,6 +19,7 @@ package org.apache.druid.math.expr; +import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.segment.column.ValueType; @@ -36,6 +37,12 @@ public enum ExprType LONG_ARRAY, STRING_ARRAY; + + public boolean isNumeric() + { + return isNumeric(this); + } + /** * The expression system does not distinguish between {@link ValueType#FLOAT} and {@link ValueType#DOUBLE}, and * cannot currently handle {@link ValueType#COMPLEX} inputs. This method will convert {@link ValueType#FLOAT} to @@ -67,4 +74,69 @@ public static ExprType fromValueType(@Nullable ValueType valueType) throw new ISE("Unsupported value type[%s]", valueType); } } + + public static boolean isNumeric(ExprType type) + { + return LONG.equals(type) || DOUBLE.equals(type); + } + + public static boolean isArray(@Nullable ExprType type) + { + return LONG_ARRAY.equals(type) || DOUBLE_ARRAY.equals(type) || STRING_ARRAY.equals(type); + } + + @Nullable + public static ExprType elementType(ExprType type) + { + if (isArray(type)) { + switch (type) { + case STRING_ARRAY: + return STRING; + case LONG_ARRAY: + return LONG; + case DOUBLE_ARRAY: + return DOUBLE; + } + } + return type; + } + + @Nullable + public static ExprType asArrayType(ExprType elementType) + { + if (!isArray(elementType)) { + switch (elementType) { + case STRING: + return STRING_ARRAY; + case LONG: + return LONG_ARRAY; + case DOUBLE: + return DOUBLE_ARRAY; + } + } + return null; + } + + public static ExprType implicitCast(@Nullable ExprType type, @Nullable ExprType other) + { + if (type == null || other == null) { + throw new IAE("Cannot implicitly cast unknown types"); + } + // arrays cannot be implicitly cast + if (isArray(type)) { + if (!type.equals(other)) { + throw new IAE("Cannot implicitly cast %s to %s", type, other); + } + return type; + } + // if either argument is a string, type becomes a string + if (STRING.equals(type) || STRING.equals(other)) { + return STRING; + } + // all numbers win over Integer + if (LONG.equals(type)) { + return other; + } + return type; + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 7fc6901d3739..31f452482972 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -39,6 +39,7 @@ import java.util.EnumSet; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.function.BinaryOperator; import java.util.function.DoubleBinaryOperator; @@ -104,6 +105,8 @@ default boolean hasArrayOutput() */ void validateArguments(List args); + ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args); + /** * Base class for a single variable input {@link Function} implementation */ @@ -180,6 +183,24 @@ protected ExprEval eval(double param) { return eval((long) param); } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return args.get(0).getOutputType(inputTypes); + } + } + + /** + * Many math functions always output a {@link Double} primitive, regardless of input type. + */ + abstract class DoubleUnivariateMathFunction extends UnivariateMathFunction + { + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.DOUBLE; + } } /** @@ -210,6 +231,24 @@ protected ExprEval eval(double x, double y) { return eval((long) x, (long) y); } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.implicitCast(args.get(0).getOutputType(inputTypes), args.get(1).getOutputType(inputTypes)); + } + } + + /** + * Many math functions always output a {@link Double} primitive, regardless of input type. + */ + abstract class DoubleBivariateMathFunction extends BivariateMathFunction + { + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.DOUBLE; + } } /** @@ -324,6 +363,141 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) abstract ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr); } + abstract class ReduceFunction implements Function + { + private final DoubleBinaryOperator doubleReducer; + private final LongBinaryOperator longReducer; + private final BinaryOperator stringReducer; + + ReduceFunction( + DoubleBinaryOperator doubleReducer, + LongBinaryOperator longReducer, + BinaryOperator stringReducer + ) + { + this.doubleReducer = doubleReducer; + this.longReducer = longReducer; + this.stringReducer = stringReducer; + } + + @Override + public void validateArguments(List args) + { + // anything goes + } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType outputType = ExprType.LONG; + for (Expr expr : args) { + outputType = ExprType.implicitCast(outputType, expr.getOutputType(inputTypes)); + } + return outputType; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + if (args.isEmpty()) { + return ExprEval.of(null); + } + + ExprAnalysis exprAnalysis = analyzeExprs(args, bindings); + if (exprAnalysis.exprEvals.isEmpty()) { + // The GREATEST/LEAST functions are not in the SQL standard. Emulate the behavior of postgres (return null if + // all expressions are null, otherwise skip null values) since it is used as a base for a wide number of + // databases. This also matches the behavior the the long/double greatest/least post aggregators. Some other + // databases (e.g., MySQL) return null if any expression is null. + // https://www.postgresql.org/docs/9.5/functions-conditional.html + // https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least + return ExprEval.of(null); + } + + Stream> exprEvalStream = exprAnalysis.exprEvals.stream(); + switch (exprAnalysis.comparisonType) { + case DOUBLE: + //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) + return ExprEval.of(exprEvalStream.mapToDouble(ExprEval::asDouble).reduce(doubleReducer).getAsDouble()); + case LONG: + //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) + return ExprEval.of(exprEvalStream.mapToLong(ExprEval::asLong).reduce(longReducer).getAsLong()); + default: + //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) + return ExprEval.of(exprEvalStream.map(ExprEval::asString).reduce(stringReducer).get()); + } + } + + /** + * Determines which {@link ExprType} to use to compare non-null evaluated expressions. + * + * @param exprs Expressions to analyze + * @param bindings Bindings for expressions + * + * @return Comparison type and non-null evaluated expressions. + */ + private ExprAnalysis analyzeExprs(List exprs, Expr.ObjectBinding bindings) + { + Set presentTypes = EnumSet.noneOf(ExprType.class); + List> exprEvals = new ArrayList<>(); + + for (Expr expr : exprs) { + ExprEval exprEval = expr.eval(bindings); + ExprType exprType = exprEval.type(); + + if (isValidType(exprType)) { + presentTypes.add(exprType); + } + + if (exprEval.value() != null) { + exprEvals.add(exprEval); + } + } + + ExprType comparisonType = getComparisionType(presentTypes); + return new ExprAnalysis(comparisonType, exprEvals); + } + + private boolean isValidType(ExprType exprType) + { + switch (exprType) { + case DOUBLE: + case LONG: + case STRING: + return true; + default: + throw new IAE("Function[%s] does not accept %s types", name(), exprType); + } + } + + /** + * Implements rules similar to: https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least + * + * @see org.apache.druid.sql.calcite.expression.builtin.ReductionOperatorConversionHelper#TYPE_INFERENCE + */ + private static ExprType getComparisionType(Set exprTypes) + { + if (exprTypes.contains(ExprType.STRING)) { + return ExprType.STRING; + } else if (exprTypes.contains(ExprType.DOUBLE)) { + return ExprType.DOUBLE; + } else { + return ExprType.LONG; + } + } + + private static class ExprAnalysis + { + final ExprType comparisonType; + final List> exprEvals; + + ExprAnalysis(ExprType comparisonType, List> exprEvals) + { + this.comparisonType = comparisonType; + this.exprEvals = exprEvals; + } + } + } // ------------------------------ implementations ------------------------------ @@ -343,6 +517,12 @@ public void validateArguments(List args) } } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { @@ -393,6 +573,12 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 0 argument", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.DOUBLE; + } } class Abs extends UnivariateMathFunction @@ -416,7 +602,7 @@ protected ExprEval eval(double param) } } - class Acos extends UnivariateMathFunction + class Acos extends DoubleUnivariateMathFunction { @Override public String name() @@ -431,7 +617,7 @@ protected ExprEval eval(double param) } } - class Asin extends UnivariateMathFunction + class Asin extends DoubleUnivariateMathFunction { @Override public String name() @@ -446,7 +632,7 @@ protected ExprEval eval(double param) } } - class Atan extends UnivariateMathFunction + class Atan extends DoubleUnivariateMathFunction { @Override public String name() @@ -461,7 +647,7 @@ protected ExprEval eval(double param) } } - class Cbrt extends UnivariateMathFunction + class Cbrt extends DoubleUnivariateMathFunction { @Override public String name() @@ -476,7 +662,7 @@ protected ExprEval eval(double param) } } - class Ceil extends UnivariateMathFunction + class Ceil extends DoubleUnivariateMathFunction { @Override public String name() @@ -491,7 +677,7 @@ protected ExprEval eval(double param) } } - class Cos extends UnivariateMathFunction + class Cos extends DoubleUnivariateMathFunction { @Override public String name() @@ -506,7 +692,7 @@ protected ExprEval eval(double param) } } - class Cosh extends UnivariateMathFunction + class Cosh extends DoubleUnivariateMathFunction { @Override public String name() @@ -521,7 +707,7 @@ protected ExprEval eval(double param) } } - class Cot extends UnivariateMathFunction + class Cot extends DoubleUnivariateMathFunction { @Override public String name() @@ -557,7 +743,7 @@ protected ExprEval eval(final double x, final double y) } } - class Exp extends UnivariateMathFunction + class Exp extends DoubleUnivariateMathFunction { @Override public String name() @@ -572,7 +758,7 @@ protected ExprEval eval(double param) } } - class Expm1 extends UnivariateMathFunction + class Expm1 extends DoubleUnivariateMathFunction { @Override public String name() @@ -587,7 +773,7 @@ protected ExprEval eval(double param) } } - class Floor extends UnivariateMathFunction + class Floor extends DoubleUnivariateMathFunction { @Override public String name() @@ -617,7 +803,7 @@ protected ExprEval eval(double param) } } - class Log extends UnivariateMathFunction + class Log extends DoubleUnivariateMathFunction { @Override public String name() @@ -632,7 +818,7 @@ protected ExprEval eval(double param) } } - class Log10 extends UnivariateMathFunction + class Log10 extends DoubleUnivariateMathFunction { @Override public String name() @@ -647,7 +833,7 @@ protected ExprEval eval(double param) } } - class Log1p extends UnivariateMathFunction + class Log1p extends DoubleUnivariateMathFunction { @Override public String name() @@ -662,7 +848,7 @@ protected ExprEval eval(double param) } } - class NextUp extends UnivariateMathFunction + class NextUp extends DoubleUnivariateMathFunction { @Override public String name() @@ -677,7 +863,7 @@ protected ExprEval eval(double param) } } - class Rint extends UnivariateMathFunction + class Rint extends DoubleUnivariateMathFunction { @Override public String name() @@ -740,6 +926,12 @@ public void validateArguments(List args) } } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return args.get(0).getOutputType(inputTypes); + } + private ExprEval eval(ExprEval param) { return eval(param, 0); @@ -773,7 +965,7 @@ private static BigDecimal safeGetFromDouble(double val) } } - class Signum extends UnivariateMathFunction + class Signum extends DoubleUnivariateMathFunction { @Override public String name() @@ -788,7 +980,7 @@ protected ExprEval eval(double param) } } - class Sin extends UnivariateMathFunction + class Sin extends DoubleUnivariateMathFunction { @Override public String name() @@ -803,7 +995,7 @@ protected ExprEval eval(double param) } } - class Sinh extends UnivariateMathFunction + class Sinh extends DoubleUnivariateMathFunction { @Override public String name() @@ -818,7 +1010,7 @@ protected ExprEval eval(double param) } } - class Sqrt extends UnivariateMathFunction + class Sqrt extends DoubleUnivariateMathFunction { @Override public String name() @@ -833,7 +1025,7 @@ protected ExprEval eval(double param) } } - class Tan extends UnivariateMathFunction + class Tan extends DoubleUnivariateMathFunction { @Override public String name() @@ -848,7 +1040,7 @@ protected ExprEval eval(double param) } } - class Tanh extends UnivariateMathFunction + class Tanh extends DoubleUnivariateMathFunction { @Override public String name() @@ -863,7 +1055,7 @@ protected ExprEval eval(double param) } } - class ToDegrees extends UnivariateMathFunction + class ToDegrees extends DoubleUnivariateMathFunction { @Override public String name() @@ -878,7 +1070,7 @@ protected ExprEval eval(double param) } } - class ToRadians extends UnivariateMathFunction + class ToRadians extends DoubleUnivariateMathFunction { @Override public String name() @@ -893,7 +1085,7 @@ protected ExprEval eval(double param) } } - class Ulp extends UnivariateMathFunction + class Ulp extends DoubleUnivariateMathFunction { @Override public String name() @@ -908,7 +1100,7 @@ protected ExprEval eval(double param) } } - class Atan2 extends BivariateMathFunction + class Atan2 extends DoubleBivariateMathFunction { @Override public String name() @@ -923,7 +1115,7 @@ protected ExprEval eval(double y, double x) } } - class CopySign extends BivariateMathFunction + class CopySign extends DoubleBivariateMathFunction { @Override public String name() @@ -938,7 +1130,7 @@ protected ExprEval eval(double x, double y) } } - class Hypot extends BivariateMathFunction + class Hypot extends DoubleBivariateMathFunction { @Override public String name() @@ -953,7 +1145,7 @@ protected ExprEval eval(double x, double y) } } - class Remainder extends BivariateMathFunction + class Remainder extends DoubleBivariateMathFunction { @Override public String name() @@ -1010,214 +1202,165 @@ protected ExprEval eval(double x, double y) } } - class GreatestFunc extends ReduceFunc + class NextAfter extends DoubleBivariateMathFunction { - public static final String NAME = "greatest"; - - public GreatestFunc() + @Override + public String name() { - super( - Math::max, - Math::max, - BinaryOperator.maxBy(Comparator.naturalOrder()) - ); + return "nextAfter"; } @Override - public String name() + protected ExprEval eval(double x, double y) { - return NAME; + return ExprEval.of(Math.nextAfter(x, y)); } } - class LeastFunc extends ReduceFunc + class Pow extends DoubleBivariateMathFunction { - public static final String NAME = "least"; - - public LeastFunc() + @Override + public String name() { - super( - Math::min, - Math::min, - BinaryOperator.minBy(Comparator.naturalOrder()) - ); + return "pow"; } @Override - public String name() + protected ExprEval eval(double x, double y) { - return NAME; + return ExprEval.of(Math.pow(x, y)); } } - abstract class ReduceFunc implements Function + class Scalb extends BivariateFunction { - private final DoubleBinaryOperator doubleReducer; - private final LongBinaryOperator longReducer; - private final BinaryOperator stringReducer; - - ReduceFunc( - DoubleBinaryOperator doubleReducer, - LongBinaryOperator longReducer, - BinaryOperator stringReducer - ) + @Override + public String name() { - this.doubleReducer = doubleReducer; - this.longReducer = longReducer; - this.stringReducer = stringReducer; + return "scalb"; } @Override - public void validateArguments(List args) + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { - // anything goes + return ExprType.DOUBLE; } @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) + protected ExprEval eval(ExprEval x, ExprEval y) { - if (args.isEmpty()) { - return ExprEval.of(null); - } - - ExprAnalysis exprAnalysis = analyzeExprs(args, bindings); - if (exprAnalysis.exprEvals.isEmpty()) { - // The GREATEST/LEAST functions are not in the SQL standard. Emulate the behavior of postgres (return null if - // all expressions are null, otherwise skip null values) since it is used as a base for a wide number of - // databases. This also matches the behavior the the long/double greatest/least post aggregators. Some other - // databases (e.g., MySQL) return null if any expression is null. - // https://www.postgresql.org/docs/9.5/functions-conditional.html - // https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least - return ExprEval.of(null); - } - - Stream> exprEvalStream = exprAnalysis.exprEvals.stream(); - switch (exprAnalysis.comparisonType) { - case DOUBLE: - //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) - return ExprEval.of(exprEvalStream.mapToDouble(ExprEval::asDouble).reduce(doubleReducer).getAsDouble()); - case LONG: - //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) - return ExprEval.of(exprEvalStream.mapToLong(ExprEval::asLong).reduce(longReducer).getAsLong()); - default: - //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) - return ExprEval.of(exprEvalStream.map(ExprEval::asString).reduce(stringReducer).get()); - } + return ExprEval.of(Math.scalb(x.asDouble(), y.asInt())); } + } - /** - * Determines which {@link ExprType} to use to compare non-null evaluated expressions. - * - * @param exprs Expressions to analyze - * @param bindings Bindings for expressions - * - * @return Comparison type and non-null evaluated expressions. - */ - private ExprAnalysis analyzeExprs(List exprs, Expr.ObjectBinding bindings) + class CastFunc extends BivariateFunction + { + @Override + public String name() { - Set presentTypes = EnumSet.noneOf(ExprType.class); - List> exprEvals = new ArrayList<>(); - - for (Expr expr : exprs) { - ExprEval exprEval = expr.eval(bindings); - ExprType exprType = exprEval.type(); - - if (isValidType(exprType)) { - presentTypes.add(exprType); - } - - if (exprEval.value() != null) { - exprEvals.add(exprEval); - } - } - - ExprType comparisonType = getComparisionType(presentTypes); - return new ExprAnalysis(comparisonType, exprEvals); + return "cast"; } - private boolean isValidType(ExprType exprType) + @Override + protected ExprEval eval(ExprEval x, ExprEval y) { - switch (exprType) { - case DOUBLE: - case LONG: - case STRING: - return true; - default: - throw new IAE("Function[%s] does not accept %s types", name(), exprType); + if (NullHandling.sqlCompatible() && x.value() == null) { + return ExprEval.of(null); } - } - - /** - * Implements rules similar to: https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least - * - * @see org.apache.druid.sql.calcite.expression.builtin.ReductionOperatorConversionHelper#TYPE_INFERENCE - */ - private static ExprType getComparisionType(Set exprTypes) - { - if (exprTypes.contains(ExprType.STRING)) { - return ExprType.STRING; - } else if (exprTypes.contains(ExprType.DOUBLE)) { - return ExprType.DOUBLE; - } else { - return ExprType.LONG; + ExprType castTo; + try { + castTo = ExprType.valueOf(StringUtils.toUpperCase(y.asString())); + } + catch (IllegalArgumentException e) { + throw new IAE("invalid type '%s'", y.asString()); } + return x.castTo(castTo); } - private static class ExprAnalysis + @Override + public Set getScalarInputs(List args) { - final ExprType comparisonType; - final List> exprEvals; - - ExprAnalysis(ExprType comparisonType, List> exprEvals) - { - this.comparisonType = comparisonType; - this.exprEvals = exprEvals; + if (args.get(1).isLiteral()) { + ExprType castTo = ExprType.valueOf(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())); + switch (castTo) { + case LONG_ARRAY: + case DOUBLE_ARRAY: + case STRING_ARRAY: + return Collections.emptySet(); + default: + return ImmutableSet.of(args.get(0)); + } } + // unknown cast, can't safely assume either way + return Collections.emptySet(); } - } - class NextAfter extends BivariateMathFunction - { @Override - public String name() + public Set getArrayInputs(List args) { - return "nextAfter"; + if (args.get(1).isLiteral()) { + ExprType castTo = ExprType.valueOf(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())); + switch (castTo) { + case LONG: + case DOUBLE: + case STRING: + return Collections.emptySet(); + default: + return ImmutableSet.of(args.get(0)); + } + } + // unknown cast, can't safely assume either way + return Collections.emptySet(); } @Override - protected ExprEval eval(double x, double y) + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { - return ExprEval.of(Math.nextAfter(x, y)); + // can only know cast output type if cast to argument is constant + if (args.get(1).isLiteral()) { + return ExprType.valueOf(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())); + } + return null; } } - class Pow extends BivariateMathFunction + class GreatestFunction extends ReduceFunction { - @Override - public String name() + public static final String NAME = "greatest"; + + public GreatestFunction() { - return "pow"; + super( + Math::max, + Math::max, + BinaryOperator.maxBy(Comparator.naturalOrder()) + ); } @Override - protected ExprEval eval(double x, double y) + public String name() { - return ExprEval.of(Math.pow(x, y)); + return NAME; } } - class Scalb extends BivariateFunction + class LeastFunction extends ReduceFunction { - @Override - public String name() + public static final String NAME = "least"; + + public LeastFunction() { - return "scalb"; + super( + Math::min, + Math::min, + BinaryOperator.minBy(Comparator.naturalOrder()) + ); } @Override - protected ExprEval eval(ExprEval x, ExprEval y) + public String name() { - return ExprEval.of(Math.scalb(x.asDouble(), y.asInt())); + return NAME; } } @@ -1243,6 +1386,13 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + // output type is defined by else + return args.get(2).getOutputType(inputTypes); + } } /** @@ -1279,6 +1429,13 @@ public void validateArguments(List args) throw new IAE("Function[%s] must have at least 2 arguments", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + // output type is defined by else + return args.get(args.size() - 1).getOutputType(inputTypes); + } } /** @@ -1315,154 +1472,103 @@ public void validateArguments(List args) throw new IAE("Function[%s] must have at least 3 arguments", name()); } } - } - class CastFunc extends BivariateFunction - { @Override - public String name() + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { - return "cast"; - } - - @Override - protected ExprEval eval(ExprEval x, ExprEval y) - { - if (NullHandling.sqlCompatible() && x.value() == null) { - return ExprEval.of(null); - } - ExprType castTo; - try { - castTo = ExprType.valueOf(StringUtils.toUpperCase(y.asString())); - } - catch (IllegalArgumentException e) { - throw new IAE("invalid type '%s'", y.asString()); - } - return x.castTo(castTo); - } - - @Override - public Set getScalarInputs(List args) - { - if (args.get(1).isLiteral()) { - ExprType castTo = ExprType.valueOf(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())); - switch (castTo) { - case LONG_ARRAY: - case DOUBLE_ARRAY: - case STRING_ARRAY: - return Collections.emptySet(); - default: - return ImmutableSet.of(args.get(0)); - } - } - // unknown cast, can't safely assume either way - return Collections.emptySet(); - } - - @Override - public Set getArrayInputs(List args) - { - if (args.get(1).isLiteral()) { - ExprType castTo = ExprType.valueOf(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())); - switch (castTo) { - case LONG: - case DOUBLE: - case STRING: - return Collections.emptySet(); - default: - return ImmutableSet.of(args.get(0)); - } - } - // unknown cast, can't safely assume either way - return Collections.emptySet(); + // output type is defined by else + return args.get(args.size() - 1).getOutputType(inputTypes); } } - class TimestampFromEpochFunc implements Function + class NvlFunc implements Function { @Override public String name() { - return "timestamp"; + return "nvl"; } @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - ExprEval value = args.get(0).eval(bindings); - if (value.type() != ExprType.STRING) { - throw new IAE("first argument should be string type but got %s type", value.type()); - } - - DateTimes.UtcFormatter formatter = DateTimes.ISO_DATE_OPTIONAL_TIME; - if (args.size() > 1) { - ExprEval format = args.get(1).eval(bindings); - if (format.type() != ExprType.STRING) { - throw new IAE("second argument should be string type but got %s type", format.type()); - } - formatter = DateTimes.wrapFormatter(DateTimeFormat.forPattern(format.asString())); - } - DateTime date; - try { - date = formatter.parse(value.asString()); - } - catch (IllegalArgumentException e) { - throw new IAE(e, "invalid value %s", value.asString()); - } - return toValue(date); + final ExprEval eval = args.get(0).eval(bindings); + return eval.value() == null ? args.get(1).eval(bindings) : eval; } @Override public void validateArguments(List args) { - if (args.size() != 1 && args.size() != 2) { - throw new IAE("Function[%s] needs 1 or 2 arguments", name()); + if (args.size() != 2) { + throw new IAE("Function[%s] needs 2 arguments", name()); } } - protected ExprEval toValue(DateTime date) + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { - return ExprEval.of(date.getMillis()); + return args.get(0).getOutputType(inputTypes); } } - class UnixTimestampFunc extends TimestampFromEpochFunc + class IsNullFunc implements Function { @Override public String name() { - return "unix_timestamp"; + return "isnull"; } @Override - protected final ExprEval toValue(DateTime date) + public ExprEval apply(List args, Expr.ObjectBinding bindings) { - return ExprEval.of(date.getMillis() / 1000); + final ExprEval expr = args.get(0).eval(bindings); + return ExprEval.ofLongBoolean(expr.value() == null); + } + + @Override + public void validateArguments(List args) + { + if (args.size() != 1) { + throw new IAE("Function[%s] needs 1 argument", name()); + } + } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; } } - class NvlFunc implements Function + class IsNotNullFunc implements Function { @Override public String name() { - return "nvl"; + return "notnull"; } @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - final ExprEval eval = args.get(0).eval(bindings); - return eval.value() == null ? args.get(1).eval(bindings) : eval; + final ExprEval expr = args.get(0).eval(bindings); + return ExprEval.ofLongBoolean(expr.value() != null); } @Override public void validateArguments(List args) { - if (args.size() != 2) { - throw new IAE("Function[%s] needs 2 arguments", name()); + if (args.size() != 1) { + throw new IAE("Function[%s] needs 1 argument", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } } class ConcatFunc implements Function @@ -1506,6 +1612,12 @@ public void validateArguments(List args) { // anything goes } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class StrlenFunc implements Function @@ -1530,6 +1642,12 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 1 argument", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } } class StringFormatFunc implements Function @@ -1564,6 +1682,12 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 1 or more arguments", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class StrposFunc implements Function @@ -1602,6 +1726,12 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 2 or 3 arguments", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } } class SubstringFunc implements Function @@ -1645,6 +1775,12 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class RightFunc extends StringLongFunction @@ -1655,6 +1791,12 @@ public String name() return "right"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override protected ExprEval eval(@Nullable String x, int y) { @@ -1680,6 +1822,12 @@ public String name() return "left"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override protected ExprEval eval(@Nullable String x, int y) { @@ -1723,6 +1871,12 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class LowerFunc implements Function @@ -1750,6 +1904,12 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 1 argument", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class UpperFunc implements Function @@ -1777,6 +1937,12 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 1 argument", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } class ReverseFunc extends UnivariateFunction @@ -1787,6 +1953,12 @@ public String name() return "reverse"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override protected ExprEval eval(ExprEval param) { @@ -1809,6 +1981,12 @@ public String name() return "repeat"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override protected ExprEval eval(String x, int y) { @@ -1816,60 +1994,50 @@ protected ExprEval eval(String x, int y) } } - class IsNullFunc implements Function + class LpadFunc implements Function { @Override public String name() { - return "isnull"; + return "lpad"; } @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - final ExprEval expr = args.get(0).eval(bindings); - return ExprEval.ofLongBoolean(expr.value() == null); - } + String base = args.get(0).eval(bindings).asString(); + int len = args.get(1).eval(bindings).asInt(); + String pad = args.get(2).eval(bindings).asString(); - @Override - public void validateArguments(List args) - { - if (args.size() != 1) { - throw new IAE("Function[%s] needs 1 argument", name()); + if (base == null || pad == null) { + return ExprEval.of(null); + } else { + return ExprEval.of(len == 0 ? NullHandling.defaultStringValue() : StringUtils.lpad(base, len, pad)); } - } - } - class IsNotNullFunc implements Function - { - @Override - public String name() - { - return "notnull"; } @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) + public void validateArguments(List args) { - final ExprEval expr = args.get(0).eval(bindings); - return ExprEval.ofLongBoolean(expr.value() != null); + if (args.size() != 3) { + throw new IAE("Function[%s] needs 3 arguments", name()); + } } @Override - public void validateArguments(List args) + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { - if (args.size() != 1) { - throw new IAE("Function[%s] needs 1 argument", name()); - } + return ExprType.STRING; } } - class LpadFunc implements Function + class RpadFunc implements Function { @Override public String name() { - return "lpad"; + return "rpad"; } @Override @@ -1882,7 +2050,7 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) if (base == null || pad == null) { return ExprEval.of(null); } else { - return ExprEval.of(len == 0 ? NullHandling.defaultStringValue() : StringUtils.lpad(base, len, pad)); + return ExprEval.of(len == 0 ? NullHandling.defaultStringValue() : StringUtils.rpad(base, len, pad)); } } @@ -1894,38 +2062,81 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } } - class RpadFunc implements Function + class TimestampFromEpochFunc implements Function { @Override public String name() { - return "rpad"; + return "timestamp"; } @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - String base = args.get(0).eval(bindings).asString(); - int len = args.get(1).eval(bindings).asInt(); - String pad = args.get(2).eval(bindings).asString(); - - if (base == null || pad == null) { - return ExprEval.of(null); - } else { - return ExprEval.of(len == 0 ? NullHandling.defaultStringValue() : StringUtils.rpad(base, len, pad)); + ExprEval value = args.get(0).eval(bindings); + if (value.type() != ExprType.STRING) { + throw new IAE("first argument should be string type but got %s type", value.type()); } + DateTimes.UtcFormatter formatter = DateTimes.ISO_DATE_OPTIONAL_TIME; + if (args.size() > 1) { + ExprEval format = args.get(1).eval(bindings); + if (format.type() != ExprType.STRING) { + throw new IAE("second argument should be string type but got %s type", format.type()); + } + formatter = DateTimes.wrapFormatter(DateTimeFormat.forPattern(format.asString())); + } + DateTime date; + try { + date = formatter.parse(value.asString()); + } + catch (IllegalArgumentException e) { + throw new IAE(e, "invalid value %s", value.asString()); + } + return toValue(date); } @Override public void validateArguments(List args) { - if (args.size() != 3) { - throw new IAE("Function[%s] needs 3 arguments", name()); + if (args.size() != 1 && args.size() != 2) { + throw new IAE("Function[%s] needs 1 or 2 arguments", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + + protected ExprEval toValue(DateTime date) + { + return ExprEval.of(date.getMillis()); + } + } + + class UnixTimestampFunc extends TimestampFromEpochFunc + { + @Override + public String name() + { + return "unix_timestamp"; + } + + @Override + protected final ExprEval toValue(DateTime date) + { + return ExprEval.of(date.getMillis() / 1000); + } } class SubMonthFunc implements Function @@ -1958,6 +2169,12 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs 3 arguments", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } } class ArrayConstructorFunction implements Function @@ -2064,6 +2281,16 @@ public void validateArguments(List args) throw new IAE("Function[%s] needs at least 1 argument", name()); } } + + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType type = ExprType.LONG; + for (Expr arg : args) { + type = ExprType.implicitCast(type, arg.getOutputType(inputTypes)); + } + return ExprType.asArrayType(type); + } } class ArrayLengthFunction implements Function @@ -2110,6 +2337,12 @@ public void validateArguments(List args) } } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override public Set getScalarInputs(List args) { @@ -2133,6 +2366,12 @@ public void validateArguments(List args) } } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING_ARRAY; + } + @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { @@ -2167,6 +2406,12 @@ public String name() return "array_to_string"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.STRING; + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2189,6 +2434,12 @@ public String name() return "array_offset"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.elementType(args.get(0).getOutputType(inputTypes)); + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2210,6 +2461,12 @@ public String name() return "array_ordinal"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.elementType(args.get(0).getOutputType(inputTypes)); + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2231,6 +2488,12 @@ public String name() return "array_offset_of"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2262,6 +2525,12 @@ public String name() return "array_ordinal_of"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2298,6 +2567,13 @@ public boolean hasArrayOutput() return true; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType arrayType = args.get(0).getOutputType(inputTypes); + return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); + } + @Override ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) { @@ -2354,6 +2630,13 @@ public boolean hasArrayOutput() return true; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType arrayType = args.get(0).getOutputType(inputTypes); + return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); + } + @Override ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) { @@ -2409,6 +2692,12 @@ public boolean hasArrayOutput() return true; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) { @@ -2426,6 +2715,12 @@ public String name() return "array_overlap"; } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return ExprType.LONG; + } + @Override ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) { @@ -2455,6 +2750,12 @@ public void validateArguments(List args) } } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + return args.get(0).getOutputType(inputTypes); + } + @Override public Set getScalarInputs(List args) { @@ -2534,6 +2835,13 @@ public void validateArguments(List args) } } + @Override + public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) + { + ExprType arrayType = args.get(1).getOutputType(inputTypes); + return Optional.ofNullable(ExprType.asArrayType(arrayType)).orElse(arrayType); + } + @Override public Set getScalarInputs(List args) { diff --git a/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java b/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java index 2b3474a6c821..bcad13b4ccb4 100644 --- a/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java @@ -105,13 +105,19 @@ public Expr visit(Shuttle shuttle) } @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() { final Set lambdaArgs = args.stream().map(IdentifierExpr::toString).collect(Collectors.toSet()); - BindingDetails bodyDetails = expr.analyzeInputs(); + ExprInputBindingAnalysis bodyDetails = expr.analyzeInputs(); return bodyDetails.removeLambdaArguments(lambdaArgs); } + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return expr.getOutputType(inputTypes); + } + @Override public boolean equals(Object o) { @@ -187,9 +193,9 @@ public Expr visit(Shuttle shuttle) } @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() { - BindingDetails accumulator = new BindingDetails(); + ExprInputBindingAnalysis accumulator = new ExprInputBindingAnalysis(); for (Expr arg : args) { accumulator = accumulator.with(arg); @@ -200,6 +206,12 @@ public BindingDetails analyzeInputs() .withArrayOutput(function.hasArrayOutput()); } + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return function.getOutputType(inputTypes, args); + } + @Override public boolean equals(Object o) { @@ -232,9 +244,9 @@ class ApplyFunctionExpr implements Expr final String name; final LambdaExpr lambdaExpr; final ImmutableList argsExpr; - final BindingDetails bindingDetails; - final BindingDetails lambdaBindingDetails; - final ImmutableList argsBindingDetails; + final ExprInputBindingAnalysis exprInputBindingAnalysis; + final ExprInputBindingAnalysis lambdaExprInputBindingAnalysis; + final ImmutableList argsExprInputBindingAnalysis; ApplyFunctionExpr(ApplyFunction function, String name, LambdaExpr expr, List args) { @@ -247,21 +259,21 @@ class ApplyFunctionExpr implements Expr // apply function expressions are examined during expression selector creation, so precompute and cache the // binding details of children - ImmutableList.Builder argBindingDetailsBuilder = ImmutableList.builder(); - BindingDetails accumulator = new BindingDetails(); + ImmutableList.Builder argBindingDetailsBuilder = ImmutableList.builder(); + ExprInputBindingAnalysis accumulator = new ExprInputBindingAnalysis(); for (Expr arg : argsExpr) { - BindingDetails argDetails = arg.analyzeInputs(); + ExprInputBindingAnalysis argDetails = arg.analyzeInputs(); argBindingDetailsBuilder.add(argDetails); accumulator = accumulator.with(argDetails); } - lambdaBindingDetails = lambdaExpr.analyzeInputs(); + lambdaExprInputBindingAnalysis = lambdaExpr.analyzeInputs(); - bindingDetails = accumulator.with(lambdaBindingDetails) - .withArrayArguments(function.getArrayInputs(argsExpr)) - .withArrayInputs(true) - .withArrayOutput(function.hasArrayOutput(lambdaExpr)); - argsBindingDetails = argBindingDetailsBuilder.build(); + exprInputBindingAnalysis = accumulator.with(lambdaExprInputBindingAnalysis) + .withArrayArguments(function.getArrayInputs(argsExpr)) + .withArrayInputs(true) + .withArrayOutput(function.hasArrayOutput(lambdaExpr)); + argsExprInputBindingAnalysis = argBindingDetailsBuilder.build(); } @Override @@ -306,9 +318,16 @@ public Expr visit(Shuttle shuttle) } @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() + { + return exprInputBindingAnalysis; + } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) { - return bindingDetails; + return function.getOutputType(inputTypes, lambdaExpr, argsExpr); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java b/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java index d23657a3bd90..9b4f888f2694 100644 --- a/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java @@ -102,9 +102,15 @@ public IdentifierExpr getIdentifierExprIfIdentifierExpr() } @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() { - return new BindingDetails(this); + return new ExprInputBindingAnalysis(this); + } + + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return inputTypes.getType(binding); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java index d8fb564c4f2d..ec2686fbd406 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Parser.java +++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java @@ -169,7 +169,7 @@ public static Expr flatten(Expr expr) * @param bindingsToApply * @return */ - public static Expr applyUnappliedBindings(Expr expr, Expr.BindingDetails bindingDetails, List bindingsToApply) + public static Expr applyUnappliedBindings(Expr expr, Expr.ExprInputBindingAnalysis exprInputBindingAnalysis, List bindingsToApply) { if (bindingsToApply.isEmpty()) { // nothing to do, expression is fine as is @@ -177,7 +177,7 @@ public static Expr applyUnappliedBindings(Expr expr, Expr.BindingDetails binding } // filter the list of bindings to those which are used in this expression List unappliedBindingsInExpression = bindingsToApply.stream() - .filter(x -> bindingDetails.getRequiredBindings().contains(x)) + .filter(x -> exprInputBindingAnalysis.getRequiredBindings().contains(x)) .collect(Collectors.toList()); // any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten @@ -193,7 +193,7 @@ public static Expr applyUnappliedBindings(Expr expr, Expr.BindingDetails binding List newArgs = new ArrayList<>(); for (Expr arg : fnExpr.args) { if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) { - Expr newArg = applyUnappliedBindings(arg, bindingDetails, unappliedBindingsInExpression); + Expr newArg = applyUnappliedBindings(arg, exprInputBindingAnalysis, unappliedBindingsInExpression); newArgs.add(newArg); } else { newArgs.add(arg); @@ -207,7 +207,7 @@ public static Expr applyUnappliedBindings(Expr expr, Expr.BindingDetails binding } ); - Expr.BindingDetails newExprBindings = newExpr.analyzeInputs(); + Expr.ExprInputBindingAnalysis newExprBindings = newExpr.analyzeInputs(); final Set expectedArrays = newExprBindings.getArrayVariables(); List remainingUnappliedBindings = @@ -288,11 +288,11 @@ private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List unappliedInThisApply = unappliedArgs.stream() - .filter(u -> !expr.bindingDetails.getArrayBindings().contains(u)) + .filter(u -> !expr.exprInputBindingAnalysis.getArrayBindings().contains(u)) .collect(Collectors.toSet()); List unappliedIdentifiers = - expr.bindingDetails + expr.exprInputBindingAnalysis .getFreeVariables() .stream() .filter(x -> unappliedInThisApply.contains(x.getBindingIfIdentifier())) @@ -304,7 +304,7 @@ private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List unappliedLambdaBindings = - expr.lambdaBindingDetails.getFreeVariables() - .stream() - .filter(x -> unappliedArgs.contains(x.getBindingIfIdentifier())) - .map(x -> new IdentifierExpr(x.getIdentifier(), x.getBinding())) - .collect(Collectors.toList()); + expr.lambdaExprInputBindingAnalysis.getFreeVariables() + .stream() + .filter(x -> unappliedArgs.contains(x.getBindingIfIdentifier())) + .map(x -> new IdentifierExpr(x.getIdentifier(), x.getBinding())) + .collect(Collectors.toList()); if (unappliedLambdaBindings.isEmpty()) { return new ApplyFunctionExpr(expr.function, expr.name, expr.lambdaExpr, newArgs); @@ -397,10 +397,10 @@ private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List conflicted = - Sets.intersection(bindingDetails.getScalarBindings(), bindingDetails.getArrayBindings()); + Sets.intersection(exprInputBindingAnalysis.getScalarBindings(), exprInputBindingAnalysis.getArrayBindings()); if (!conflicted.isEmpty()) { throw new RE("Invalid expression: %s; %s used as both scalar and array variables", expression, conflicted); } diff --git a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java index 5a41e9042509..f3304971a0ee 100644 --- a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java @@ -24,6 +24,7 @@ import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; +import javax.annotation.Nullable; import java.util.Objects; /** @@ -59,12 +60,19 @@ public Expr visit(Shuttle shuttle) } @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() { // currently all unary operators only operate on scalar inputs return expr.analyzeInputs().withScalarArguments(ImmutableSet.of(expr)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return expr.getOutputType(inputTypes); + } + @Override public boolean equals(Object o) { @@ -163,4 +171,15 @@ public String toString() { return StringUtils.format("!%s", expr); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + ExprType implicitCast = super.getOutputType(inputTypes); + if (ExprType.STRING.equals(implicitCast)) { + return ExprType.LONG; + } + return implicitCast; + } } diff --git a/core/src/test/java/org/apache/druid/math/expr/ExprTest.java b/core/src/test/java/org/apache/druid/math/expr/ExprTest.java index ff12669bbc6c..02e373834179 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ExprTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ExprTest.java @@ -113,7 +113,7 @@ public void testEqualsContractForApplyFunctionExpr() { EqualsVerifier.forClass(ApplyFunctionExpr.class) .usingGetClass() - .withIgnoredFields("function", "bindingDetails", "lambdaBindingDetails", "argsBindingDetails") + .withIgnoredFields("function", "exprInputBindingAnalysis", "lambdaExprInputBindingAnalysis", "argsExprInputBindingAnalysis") .verify(); } @@ -132,37 +132,55 @@ public void testEqualsContractForUnaryMinusExpr() @Test public void testEqualsContractForStringExpr() { - EqualsVerifier.forClass(StringExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(StringExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForDoubleExpr() { - EqualsVerifier.forClass(DoubleExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(DoubleExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForLongExpr() { - EqualsVerifier.forClass(LongExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(LongExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForStringArrayExpr() { - EqualsVerifier.forClass(StringArrayExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(StringArrayExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForLongArrayExpr() { - EqualsVerifier.forClass(LongArrayExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(LongArrayExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test public void testEqualsContractForDoubleArrayExpr() { - EqualsVerifier.forClass(DoubleArrayExpr.class).usingGetClass().verify(); + EqualsVerifier.forClass(DoubleArrayExpr.class) + .withIgnoredFields("outputType") + .usingGetClass() + .verify(); } @Test @@ -179,12 +197,16 @@ public void testEqualsContractForLambdaExpr() @Test public void testEqualsContractForNullLongExpr() { - EqualsVerifier.forClass(NullLongExpr.class).verify(); + EqualsVerifier.forClass(NullLongExpr.class) + .withIgnoredFields("outputType") + .verify(); } @Test public void testEqualsContractForNullDoubleExpr() { - EqualsVerifier.forClass(NullDoubleExpr.class).verify(); + EqualsVerifier.forClass(NullDoubleExpr.class) + .withIgnoredFields("outputType") + .verify(); } } diff --git a/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java b/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java new file mode 100644 index 000000000000..d7b5d32b4f86 --- /dev/null +++ b/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.math.expr; + +import com.google.common.collect.ImmutableMap; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Map; + +public class OutputTypeTest extends InitializedNullHandlingTest +{ + private final Expr.InputBindingTypes inputTypes = inputTypesFromMap( + ImmutableMap.builder().put("x", ExprType.STRING) + .put("x_", ExprType.STRING) + .put("y", ExprType.LONG) + .put("y_", ExprType.LONG) + .put("z", ExprType.DOUBLE) + .put("z_", ExprType.DOUBLE) + .put("a", ExprType.STRING_ARRAY) + .put("a_", ExprType.STRING_ARRAY) + .put("b", ExprType.LONG_ARRAY) + .put("b_", ExprType.LONG_ARRAY) + .put("c", ExprType.DOUBLE_ARRAY) + .put("c_", ExprType.DOUBLE_ARRAY) + .build() + ); + + @Test + public void testConstantsAndIdentifiers() + { + assertOutputType("'hello'", inputTypes, ExprType.STRING); + assertOutputType("23", inputTypes, ExprType.LONG); + assertOutputType("3.2", inputTypes, ExprType.DOUBLE); + assertOutputType("['a', 'b']", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("[1,2,3]", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("[1.0]", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("x", inputTypes, ExprType.STRING); + assertOutputType("y", inputTypes, ExprType.LONG); + assertOutputType("z", inputTypes, ExprType.DOUBLE); + assertOutputType("a", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("b", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("c", inputTypes, ExprType.DOUBLE_ARRAY); + } + + @Test + public void testUnaryOperators() + { + assertOutputType("-1", inputTypes, ExprType.LONG); + assertOutputType("-1.1", inputTypes, ExprType.DOUBLE); + assertOutputType("-y", inputTypes, ExprType.LONG); + assertOutputType("-z", inputTypes, ExprType.DOUBLE); + + assertOutputType("!'true'", inputTypes, ExprType.LONG); + assertOutputType("!1", inputTypes, ExprType.LONG); + assertOutputType("!1.1", inputTypes, ExprType.DOUBLE); + assertOutputType("!x", inputTypes, ExprType.LONG); + assertOutputType("!y", inputTypes, ExprType.LONG); + assertOutputType("!z", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testBinaryMathOperators() + { + assertOutputType("1+1", inputTypes, ExprType.LONG); + assertOutputType("1-1", inputTypes, ExprType.LONG); + assertOutputType("1*1", inputTypes, ExprType.LONG); + assertOutputType("1/1", inputTypes, ExprType.LONG); + assertOutputType("1^1", inputTypes, ExprType.LONG); + assertOutputType("1%1", inputTypes, ExprType.LONG); + + assertOutputType("y+y_", inputTypes, ExprType.LONG); + assertOutputType("y-y_", inputTypes, ExprType.LONG); + assertOutputType("y*y_", inputTypes, ExprType.LONG); + assertOutputType("y/y_", inputTypes, ExprType.LONG); + assertOutputType("y^y_", inputTypes, ExprType.LONG); + assertOutputType("y%y_", inputTypes, ExprType.LONG); + + assertOutputType("y+z", inputTypes, ExprType.DOUBLE); + assertOutputType("y-z", inputTypes, ExprType.DOUBLE); + assertOutputType("y*z", inputTypes, ExprType.DOUBLE); + assertOutputType("y/z", inputTypes, ExprType.DOUBLE); + assertOutputType("y^z", inputTypes, ExprType.DOUBLE); + assertOutputType("y%z", inputTypes, ExprType.DOUBLE); + + assertOutputType("z+z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z-z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z*z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z/z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z^z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z%z_", inputTypes, ExprType.DOUBLE); + + assertOutputType("y>y_", inputTypes, ExprType.LONG); + assertOutputType("y_=y", inputTypes, ExprType.LONG); + assertOutputType("y_==y", inputTypes, ExprType.LONG); + assertOutputType("y_!=y", inputTypes, ExprType.LONG); + assertOutputType("y_ && y", inputTypes, ExprType.LONG); + assertOutputType("y_ || y", inputTypes, ExprType.LONG); + + assertOutputType("z>y_", inputTypes, ExprType.DOUBLE); + assertOutputType("z=z", inputTypes, ExprType.DOUBLE); + assertOutputType("z==y", inputTypes, ExprType.DOUBLE); + assertOutputType("z!=y", inputTypes, ExprType.DOUBLE); + assertOutputType("z && y", inputTypes, ExprType.DOUBLE); + assertOutputType("y || z", inputTypes, ExprType.DOUBLE); + + assertOutputType("z>z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z=z", inputTypes, ExprType.DOUBLE); + assertOutputType("z==z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z!=z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z && z_", inputTypes, ExprType.DOUBLE); + assertOutputType("z_ || z", inputTypes, ExprType.DOUBLE); + + assertOutputType("1*(2 + 3.0)", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testUnivariateMathFunctions() + { + assertOutputType("pi()", inputTypes, ExprType.DOUBLE); + assertOutputType("abs(x)", inputTypes, ExprType.STRING); + assertOutputType("abs(y)", inputTypes, ExprType.LONG); + assertOutputType("abs(z)", inputTypes, ExprType.DOUBLE); + assertOutputType("cos(y)", inputTypes, ExprType.DOUBLE); + assertOutputType("cos(z)", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testBivariateMathFunctions() + { + assertOutputType("div(y,y_)", inputTypes, ExprType.LONG); + assertOutputType("div(y,z_)", inputTypes, ExprType.DOUBLE); + assertOutputType("div(z,z_)", inputTypes, ExprType.DOUBLE); + + assertOutputType("max(y,y_)", inputTypes, ExprType.LONG); + assertOutputType("max(y,z_)", inputTypes, ExprType.DOUBLE); + assertOutputType("max(z,z_)", inputTypes, ExprType.DOUBLE); + + assertOutputType("hypot(y,y_)", inputTypes, ExprType.DOUBLE); + assertOutputType("hypot(y,z_)", inputTypes, ExprType.DOUBLE); + assertOutputType("hypot(z,z_)", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testConditionalFunctions() + { + assertOutputType("if(y, 'foo', 'bar')", inputTypes, ExprType.STRING); + assertOutputType("if(y,2,3)", inputTypes, ExprType.LONG); + assertOutputType("if(y,2,3.0)", inputTypes, ExprType.DOUBLE); + + assertOutputType( + "case_simple(x,'baz','is baz','foo','is foo','is other')", + inputTypes, + ExprType.STRING + ); + assertOutputType( + "case_simple(y,2,2,3,3,4)", + inputTypes, + ExprType.LONG + ); + assertOutputType( + "case_simple(z,2.0,2.0,3.0,3.0,4.0)", + inputTypes, + ExprType.DOUBLE + ); + + assertOutputType( + "case_searched(x=='baz','is baz',x=='foo','is foo','is other')", + inputTypes, + ExprType.STRING + ); + assertOutputType( + "case_searched(y==1,1,y==2,2,0)", + inputTypes, + ExprType.LONG + ); + assertOutputType( + "case_searched(z==1.0,1.0,z==2.0,2.0,0.0)", + inputTypes, + ExprType.DOUBLE + ); + + assertOutputType("nvl(x, 'foo')", inputTypes, ExprType.STRING); + assertOutputType("nvl(y, 1)", inputTypes, ExprType.LONG); + assertOutputType("nvl(z, 2.0)", inputTypes, ExprType.DOUBLE); + assertOutputType("isnull(x)", inputTypes, ExprType.LONG); + assertOutputType("isnull(y)", inputTypes, ExprType.LONG); + assertOutputType("isnull(z)", inputTypes, ExprType.LONG); + assertOutputType("notnull(x)", inputTypes, ExprType.LONG); + assertOutputType("notnull(y)", inputTypes, ExprType.LONG); + assertOutputType("notnull(z)", inputTypes, ExprType.LONG); + } + + @Test + public void testStringFunctions() + { + assertOutputType("concat(x, 'foo')", inputTypes, ExprType.STRING); + assertOutputType("concat(y, 'foo')", inputTypes, ExprType.STRING); + assertOutputType("concat(z, 'foo')", inputTypes, ExprType.STRING); + + assertOutputType("strlen(x)", inputTypes, ExprType.LONG); + assertOutputType("format('%s', x)", inputTypes, ExprType.STRING); + assertOutputType("format('%s', y)", inputTypes, ExprType.STRING); + assertOutputType("format('%s', z)", inputTypes, ExprType.STRING); + assertOutputType("strpos(x, x_)", inputTypes, ExprType.LONG); + assertOutputType("strpos(x, y)", inputTypes, ExprType.LONG); + assertOutputType("strpos(x, z)", inputTypes, ExprType.LONG); + assertOutputType("substring(x, 1, 2)", inputTypes, ExprType.STRING); + assertOutputType("left(x, 1)", inputTypes, ExprType.STRING); + assertOutputType("right(x, 1)", inputTypes, ExprType.STRING); + assertOutputType("replace(x, 'foo', '')", inputTypes, ExprType.STRING); + assertOutputType("lower(x)", inputTypes, ExprType.STRING); + assertOutputType("upper(x)", inputTypes, ExprType.STRING); + assertOutputType("reverse(x)", inputTypes, ExprType.STRING); + assertOutputType("repeat(x, 4)", inputTypes, ExprType.STRING); + } + + @Test + public void testArrayFunctions() + { + assertOutputType("array(1, 2, 3)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array(1, 2, 3.0)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("array_length(a)", inputTypes, ExprType.LONG); + assertOutputType("array_length(b)", inputTypes, ExprType.LONG); + assertOutputType("array_length(c)", inputTypes, ExprType.LONG); + + assertOutputType("string_to_array(x, ',')", inputTypes, ExprType.STRING_ARRAY); + + assertOutputType("array_to_string(a, ',')", inputTypes, ExprType.STRING); + assertOutputType("array_to_string(b, ',')", inputTypes, ExprType.STRING); + assertOutputType("array_to_string(c, ',')", inputTypes, ExprType.STRING); + + assertOutputType("array_offset(a, 1)", inputTypes, ExprType.STRING); + assertOutputType("array_offset(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_offset(c, 1)", inputTypes, ExprType.DOUBLE); + + assertOutputType("array_ordinal(a, 1)", inputTypes, ExprType.STRING); + assertOutputType("array_ordinal(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_ordinal(c, 1)", inputTypes, ExprType.DOUBLE); + + assertOutputType("array_offset_of(a, 'a')", inputTypes, ExprType.LONG); + assertOutputType("array_offset_of(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_offset_of(c, 1.0)", inputTypes, ExprType.LONG); + + assertOutputType("array_ordinal_of(a, 'a')", inputTypes, ExprType.LONG); + assertOutputType("array_ordinal_of(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_ordinal_of(c, 1.0)", inputTypes, ExprType.LONG); + + assertOutputType("array_append(x, x_)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_append(a, x_)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_append(y, y_)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_append(b, y_)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_append(z, z_)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("array_append(c, z_)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("array_concat(x, a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_concat(a, a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_concat(y, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_concat(b, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_concat(z, c)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("array_concat(c, c)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("array_contains(a, 'a')", inputTypes, ExprType.LONG); + assertOutputType("array_contains(b, 1)", inputTypes, ExprType.LONG); + assertOutputType("array_contains(c, 2.0)", inputTypes, ExprType.LONG); + + assertOutputType("array_overlap(a, a)", inputTypes, ExprType.LONG); + assertOutputType("array_overlap(b, b)", inputTypes, ExprType.LONG); + assertOutputType("array_overlap(c, c)", inputTypes, ExprType.LONG); + + assertOutputType("array_slice(a, 1, 2)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_slice(b, 1, 2)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_slice(c, 1, 2)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("array_prepend(x, a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_prepend(x, x_)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("array_prepend(y, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_prepend(y, y_)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("array_prepend(z, c)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("array_prepend(z, z_)", inputTypes, ExprType.DOUBLE_ARRAY); + } + + @Test + public void testReduceFunctions() + { + assertOutputType("greatest('B', x, 'A')", inputTypes, ExprType.STRING); + assertOutputType("greatest(y, 0)", inputTypes, ExprType.LONG); + assertOutputType("greatest(34.0, z, 5.0, 767.0)", inputTypes, ExprType.DOUBLE); + + assertOutputType("least('B', x, 'A')", inputTypes, ExprType.STRING); + assertOutputType("least(y, 0)", inputTypes, ExprType.LONG); + assertOutputType("least(34.0, z, 5.0, 767.0)", inputTypes, ExprType.DOUBLE); + } + + @Test + public void testApplyFunctions() + { + assertOutputType("map((x) -> concat(x, 'foo'), x)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("map((x) -> x + x, y)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("map((x) -> x + x, z)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType("map((x) -> concat(x, 'foo'), a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("map((x) -> x + x, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("map((x) -> x + x, c)", inputTypes, ExprType.DOUBLE_ARRAY); + assertOutputType( + "cartesian_map((x, y) -> concat(x, y), ['foo', 'bar', 'baz', 'foobar'], ['bar', 'baz'])", + inputTypes, + ExprType.STRING_ARRAY + ); + assertOutputType("fold((x, acc) -> x + acc, y, 0)", inputTypes, ExprType.LONG); + assertOutputType("fold((x, acc) -> x + acc, y, y)", inputTypes, ExprType.LONG); + assertOutputType("fold((x, acc) -> x + acc, y, 1.0)", inputTypes, ExprType.DOUBLE); + assertOutputType("fold((x, acc) -> x + acc, y, z)", inputTypes, ExprType.DOUBLE); + + assertOutputType("cartesian_fold((x, y, acc) -> x + y + acc, y, z, 0)", inputTypes, ExprType.LONG); + assertOutputType("cartesian_fold((x, y, acc) -> x + y + acc, y, z, y)", inputTypes, ExprType.LONG); + assertOutputType("cartesian_fold((x, y, acc) -> x + y + acc, y, z, 1.0)", inputTypes, ExprType.DOUBLE); + assertOutputType("cartesian_fold((x, y, acc) -> x + y + acc, y, z, z)", inputTypes, ExprType.DOUBLE); + + assertOutputType("filter((x) -> x == 'foo', a)", inputTypes, ExprType.STRING_ARRAY); + assertOutputType("filter((x) -> x > 1, b)", inputTypes, ExprType.LONG_ARRAY); + assertOutputType("filter((x) -> x > 1, c)", inputTypes, ExprType.DOUBLE_ARRAY); + + assertOutputType("any((x) -> x == 'foo', a)", inputTypes, ExprType.LONG); + assertOutputType("any((x) -> x > 1, b)", inputTypes, ExprType.LONG); + assertOutputType("any((x) -> x > 1.2, c)", inputTypes, ExprType.LONG); + + assertOutputType("all((x) -> x == 'foo', a)", inputTypes, ExprType.LONG); + assertOutputType("all((x) -> x > 1, b)", inputTypes, ExprType.LONG); + assertOutputType("all((x) -> x > 1.2, c)", inputTypes, ExprType.LONG); + } + + private void assertOutputType(String expression, Expr.InputBindingTypes inputTypes, ExprType outputType) + { + final Expr expr = Parser.parse(expression, ExprMacroTable.nil(), false); + Assert.assertEquals(outputType, expr.getOutputType(inputTypes)); + } + + Expr.InputBindingTypes inputTypesFromMap(Map types) + { + return types::get; + } +} diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java index b1ef6736ed5a..01bda4182c76 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java @@ -577,7 +577,7 @@ private void validateParser( ) { final Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); - final Expr.BindingDetails deets = parsed.analyzeInputs(); + final Expr.ExprInputBindingAnalysis deets = parsed.analyzeInputs(); Assert.assertEquals(expression, expected, parsed.toString()); Assert.assertEquals(expression, identifiers, deets.getRequiredBindingsList()); Assert.assertEquals(expression, scalars, deets.getScalarVariables()); @@ -586,7 +586,7 @@ private void validateParser( final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false); final Expr roundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil()); Assert.assertEquals(parsed.stringify(), roundTrip.stringify()); - final Expr.BindingDetails roundTripDeets = roundTrip.analyzeInputs(); + final Expr.ExprInputBindingAnalysis roundTripDeets = roundTrip.analyzeInputs(); Assert.assertEquals(expression, identifiers, roundTripDeets.getRequiredBindingsList()); Assert.assertEquals(expression, scalars, roundTripDeets.getScalarVariables()); Assert.assertEquals(expression, arrays, roundTripDeets.getArrayVariables()); @@ -600,7 +600,7 @@ private void validateApplyUnapplied( ) { final Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); - Expr.BindingDetails deets = parsed.analyzeInputs(); + Expr.ExprInputBindingAnalysis deets = parsed.analyzeInputs(); Parser.validateExpr(parsed, deets); final Expr transformed = Parser.applyUnappliedBindings(parsed, deets, identifiers); Assert.assertEquals(expression, unapplied, parsed.toString()); @@ -608,7 +608,7 @@ private void validateApplyUnapplied( final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false); final Expr parsedRoundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil()); - Expr.BindingDetails roundTripDeets = parsedRoundTrip.analyzeInputs(); + Expr.ExprInputBindingAnalysis roundTripDeets = parsedRoundTrip.analyzeInputs(); Parser.validateExpr(parsedRoundTrip, roundTripDeets); final Expr transformedRoundTrip = Parser.applyUnappliedBindings(parsedRoundTrip, roundTripDeets, identifiers); Assert.assertEquals(expression, unapplied, parsedRoundTrip.toString()); diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java index 6cbcfd16bc66..8e7d04ccb7d1 100644 --- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java +++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java @@ -25,9 +25,11 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.apache.druid.query.filter.BloomKFilter; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.io.IOException; import java.util.List; @@ -122,6 +124,13 @@ public Expr visit(Shuttle shuttle) Expr newArg = arg.visit(shuttle); return shuttle.visit(new BloomExpr(newArg)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } return new BloomExpr(arg); diff --git a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java index 05d510e2d3f6..1aff62d199f6 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacro.java @@ -25,8 +25,10 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; /** @@ -117,6 +119,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new IPv4AddressMatchExpr(newArg, subnetInfo)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressParseExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressParseExprMacro.java index fdf67b4cc5b1..a75fa323fdb3 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressParseExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressParseExprMacro.java @@ -23,8 +23,10 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.net.Inet4Address; import java.util.List; @@ -92,6 +94,13 @@ public Expr visit(Shuttle shuttle) Expr newArg = arg.visit(shuttle); return shuttle.visit(new IPv4AddressParseExpr(newArg)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } return new IPv4AddressParseExpr(arg); diff --git a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacro.java index 4aea0aa37182..17431a0e5923 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/IPv4AddressStringifyExprMacro.java @@ -23,8 +23,10 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.net.Inet4Address; import java.util.List; @@ -91,6 +93,13 @@ public Expr visit(Shuttle shuttle) Expr newArg = arg.visit(shuttle); return shuttle.visit(new IPv4AddressStringifyExpr(newArg)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } } return new IPv4AddressStringifyExpr(arg); diff --git a/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java index a124722e8f8b..d5bbf02dad02 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java @@ -25,9 +25,11 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.apache.druid.query.filter.LikeDimFilter; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class LikeExprMacro implements ExprMacroTable.ExprMacro @@ -90,6 +92,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new LikeExtractExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java index a827aea2601d..6ff028778a40 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java @@ -26,10 +26,12 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; import org.apache.druid.query.lookup.RegisteredLookupExtractionFn; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class LookupExprMacro implements ExprMacroTable.ExprMacro @@ -94,6 +96,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new LookupExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java index 9bef704a663e..3964c1793d7e 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java @@ -25,8 +25,10 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -97,6 +99,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new RegexpExtractExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java index a4909194d484..9279c84774dd 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/RegexpLikeExprMacro.java @@ -25,8 +25,10 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -89,6 +91,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new RegexpLikeExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java index 8d6a628d97a3..6779bf6ddf74 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java @@ -27,9 +27,11 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.DateTime; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -93,6 +95,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new TimestampCeilExpr(newArgs)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public boolean equals(Object o) { @@ -153,5 +162,12 @@ public Expr visit(Shuttle shuttle) List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampCeilDynamicExpr(newArgs)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java index d3184dd5ee38..278076901872 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java @@ -25,11 +25,13 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.joda.time.chrono.ISOChronology; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class TimestampExtractExprMacro implements ExprMacroTable.ExprMacro @@ -162,6 +164,19 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new TimestampExtractExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + switch (unit) { + case CENTURY: + case MILLENNIUM: + return ExprType.DOUBLE; + default: + return ExprType.LONG; + } + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java index aef159ae7cea..a3a95306c0ae 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java @@ -25,8 +25,10 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -111,6 +113,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new TimestampFloorExpr(newArgs)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public boolean equals(Object o) { @@ -155,5 +164,12 @@ public Expr visit(Shuttle shuttle) List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampFloorDynamicExpr(newArgs)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java index e7f469666854..455d445fe980 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java @@ -25,12 +25,14 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.DateTimeZone; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.ISODateTimeFormat; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class TimestampFormatExprMacro implements ExprMacroTable.ExprMacro @@ -97,6 +99,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new TimestampFormatExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java index 535c7332554b..935a2b7cbae7 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java @@ -25,6 +25,7 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.DateTimeZone; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; @@ -33,6 +34,7 @@ import org.joda.time.format.ISODateTimeFormat; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; public class TimestampParseExprMacro implements ExprMacroTable.ExprMacro @@ -100,6 +102,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new TimestampParseExpr(newArg)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public String stringify() { diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java index b3f8d1e767f2..259d054e411a 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java @@ -24,11 +24,13 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.joda.time.Chronology; import org.joda.time.Period; import org.joda.time.chrono.ISOChronology; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.List; import java.util.stream.Collectors; @@ -101,6 +103,13 @@ public Expr visit(Shuttle shuttle) List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampShiftExpr(newArgs)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } private static class TimestampShiftDynamicExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr @@ -127,5 +136,12 @@ public Expr visit(Shuttle shuttle) List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampShiftDynamicExpr(newArgs)); } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java index c7ce44f8fe91..71c7e6db7bda 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java @@ -26,8 +26,10 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -168,6 +170,13 @@ public Expr visit(Shuttle shuttle) return shuttle.visit(new TrimStaticCharsExpr(mode, newStringExpr, chars, charsExpr)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public String stringify() { @@ -290,13 +299,20 @@ public Expr visit(Shuttle shuttle) } @Override - public BindingDetails analyzeInputs() + public ExprInputBindingAnalysis analyzeInputs() { return stringExpr.analyzeInputs() .with(charsExpr) .withScalarArguments(ImmutableSet.of(stringExpr, charsExpr)); } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.STRING; + } + @Override public boolean equals(Object o) { diff --git a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java index 880e90ced2e8..76072c5df955 100644 --- a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java +++ b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java @@ -45,7 +45,7 @@ public class ExpressionFilter implements Filter { private final Supplier expr; - private final Supplier bindingDetails; + private final Supplier bindingDetails; private final FilterTuning filterTuning; public ExpressionFilter(final Supplier expr, final FilterTuning filterTuning) @@ -107,7 +107,7 @@ public void inspectRuntimeShape(final RuntimeShapeInspector inspector) @Override public boolean supportsBitmapIndex(final BitmapIndexSelector selector) { - final Expr.BindingDetails details = this.bindingDetails.get(); + final Expr.ExprInputBindingAnalysis details = this.bindingDetails.get(); if (details.getRequiredBindings().isEmpty()) { // Constant expression. diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java index 84fbccd8a572..57e0808ce62a 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java @@ -380,8 +380,8 @@ private static void getCorrelationForRHSColumn( String identifier = lhsExpr.getBindingIfIdentifier(); if (identifier == null) { // We push down if the function only requires base table columns - Expr.BindingDetails bindingDetails = lhsExpr.analyzeInputs(); - Set requiredBindings = bindingDetails.getRequiredBindings(); + Expr.ExprInputBindingAnalysis exprInputBindingAnalysis = lhsExpr.analyzeInputs(); + Set requiredBindings = exprInputBindingAnalysis.getRequiredBindings(); if (joinableClauses.areSomeColumnsFromJoin(requiredBindings)) { break; diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index 860e47b0f6c0..6e2e71cdca8b 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -136,7 +136,7 @@ public static ColumnValueSelector makeExprEvalSelector( Expr expression ) { - final Expr.BindingDetails exprDetails = expression.analyzeInputs(); + final Expr.ExprInputBindingAnalysis exprDetails = expression.analyzeInputs(); Parser.validateExpr(expression, exprDetails); final List columns = exprDetails.getRequiredBindingsList(); @@ -212,7 +212,7 @@ public static DimensionSelector makeDimensionSelector( @Nullable final ExtractionFn extractionFn ) { - final Expr.BindingDetails exprDetails = expression.analyzeInputs(); + final Expr.ExprInputBindingAnalysis exprDetails = expression.analyzeInputs(); Parser.validateExpr(expression, exprDetails); final List columns = exprDetails.getRequiredBindingsList(); @@ -348,7 +348,7 @@ public void inspectRuntimeShape(RuntimeShapeInspector inspector) * @param hasMultipleValues result of calling {@link ColumnCapabilities#hasMultipleValues()} */ public static boolean canMapOverDictionary( - final Expr.BindingDetails exprDetails, + final Expr.ExprInputBindingAnalysis exprDetails, final ColumnCapabilities.Capable hasMultipleValues ) { @@ -357,17 +357,17 @@ public static boolean canMapOverDictionary( } /** - * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingDetails} which + * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.ExprInputBindingAnalysis} which * provides the set of identifiers which need a binding (list of required columns), and context of whether or not they * are used as array or scalar inputs */ private static Expr.ObjectBinding createBindings( - Expr.BindingDetails bindingDetails, + Expr.ExprInputBindingAnalysis exprInputBindingAnalysis, ColumnSelectorFactory columnSelectorFactory ) { final Map> suppliers = new HashMap<>(); - final List columns = bindingDetails.getRequiredBindingsList(); + final List columns = exprInputBindingAnalysis.getRequiredBindingsList(); for (String columnName : columns) { final ColumnCapabilities columnCapabilities = columnSelectorFactory .getColumnCapabilities(columnName); @@ -601,7 +601,7 @@ public static Object coerceEvalToSelectorObject(ExprEval eval) */ private static Pair, Set> examineColumnSelectorFactoryArrays( ColumnSelectorFactory columnSelectorFactory, - Expr.BindingDetails exprDetails, + Expr.ExprInputBindingAnalysis exprDetails, List columns ) { diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java index 727f1e4a2432..37be181a0d13 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java @@ -40,22 +40,22 @@ public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValueSelector { private final List unknownColumns; - private final Expr.BindingDetails baseExprBindingDetails; + private final Expr.ExprInputBindingAnalysis baseExprExprInputBindingAnalysis; private final Set ignoredColumns; private final Int2ObjectMap transformedCache; public RowBasedExpressionColumnValueSelector( Expr expression, - Expr.BindingDetails baseExprBindingDetails, + Expr.ExprInputBindingAnalysis baseExprExprInputBindingAnalysis, Expr.ObjectBinding bindings, Set unknownColumnsSet ) { super(expression, bindings); this.unknownColumns = unknownColumnsSet.stream() - .filter(x -> !baseExprBindingDetails.getArrayBindings().contains(x)) + .filter(x -> !baseExprExprInputBindingAnalysis.getArrayBindings().contains(x)) .collect(Collectors.toList()); - this.baseExprBindingDetails = baseExprBindingDetails; + this.baseExprExprInputBindingAnalysis = baseExprExprInputBindingAnalysis; this.ignoredColumns = new HashSet<>(); this.transformedCache = new Int2ObjectArrayMap<>(unknownColumns.size()); } @@ -79,7 +79,7 @@ public ExprEval getObject() if (transformedCache.containsKey(key)) { return transformedCache.get(key).eval(bindings); } - Expr transformed = Parser.applyUnappliedBindings(expression, baseExprBindingDetails, arrayBindings); + Expr transformed = Parser.applyUnappliedBindings(expression, baseExprExprInputBindingAnalysis, arrayBindings); transformedCache.put(key, transformed); return transformed.eval(bindings); } diff --git a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java index aa5bd917bf13..e2c420994e3b 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java @@ -22,9 +22,11 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.ExprType; import org.junit.Assert; import org.junit.Test; +import javax.annotation.Nullable; import java.util.Arrays; import java.util.Collections; @@ -203,5 +205,12 @@ public Expr visit(Shuttle shuttle) { return null; } + + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return null; + } } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java index ebdb99931ad2..d1fc8d0cb056 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java @@ -30,14 +30,14 @@ public class GreatestOperatorConversion extends DirectOperatorConversion { private static final SqlFunction SQL_FUNCTION = OperatorConversions - .operatorBuilder(StringUtils.toUpperCase(Function.GreatestFunc.NAME)) + .operatorBuilder(StringUtils.toUpperCase(Function.GreatestFunction.NAME)) .operandTypeChecker(OperandTypes.VARIADIC) .returnTypeInference(ReductionOperatorConversionHelper.TYPE_INFERENCE) .build(); public GreatestOperatorConversion() { - super(SQL_FUNCTION, Function.GreatestFunc.NAME); + super(SQL_FUNCTION, Function.GreatestFunction.NAME); } @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java index 09578c6eaa9e..93217ee9cd68 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java @@ -30,14 +30,14 @@ public class LeastOperatorConversion extends DirectOperatorConversion { private static final SqlFunction SQL_FUNCTION = OperatorConversions - .operatorBuilder(StringUtils.toUpperCase(Function.LeastFunc.NAME)) + .operatorBuilder(StringUtils.toUpperCase(Function.LeastFunction.NAME)) .operandTypeChecker(OperandTypes.VARIADIC) .returnTypeInference(ReductionOperatorConversionHelper.TYPE_INFERENCE) .build(); public LeastOperatorConversion() { - super(SQL_FUNCTION, Function.LeastFunc.NAME); + super(SQL_FUNCTION, Function.LeastFunction.NAME); } @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java index f76d8352bbb5..b5b4c21c7069 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java @@ -37,8 +37,8 @@ private ReductionOperatorConversionHelper() * Implements type precedence rules similar to: * https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least * - * @see org.apache.druid.math.expr.Function.ReduceFunc#apply - * @see org.apache.druid.math.expr.Function.ReduceFunc#getComparisionType + * @see org.apache.druid.math.expr.Function.ReduceFunction#apply + * @see org.apache.druid.math.expr.Function.ReduceFunction#getComparisionType */ static final SqlReturnTypeInference TYPE_INFERENCE = opBinding -> { From 28c058816d50d40c90222d93a937257fafc72f1b Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 18:01:08 -0700 Subject: [PATCH 03/15] revert unintended name change --- .../main/java/org/apache/druid/math/expr/Function.java | 8 ++++---- .../expression/builtin/GreatestOperatorConversion.java | 4 ++-- .../expression/builtin/LeastOperatorConversion.java | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 31f452482972..c6cb99220fb8 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -1324,11 +1324,11 @@ public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args } } - class GreatestFunction extends ReduceFunction + class GreatestFunc extends ReduceFunction { public static final String NAME = "greatest"; - public GreatestFunction() + public GreatestFunc() { super( Math::max, @@ -1344,11 +1344,11 @@ public String name() } } - class LeastFunction extends ReduceFunction + class LeastFunc extends ReduceFunction { public static final String NAME = "least"; - public LeastFunction() + public LeastFunc() { super( Math::min, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java index d1fc8d0cb056..ebdb99931ad2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/GreatestOperatorConversion.java @@ -30,14 +30,14 @@ public class GreatestOperatorConversion extends DirectOperatorConversion { private static final SqlFunction SQL_FUNCTION = OperatorConversions - .operatorBuilder(StringUtils.toUpperCase(Function.GreatestFunction.NAME)) + .operatorBuilder(StringUtils.toUpperCase(Function.GreatestFunc.NAME)) .operandTypeChecker(OperandTypes.VARIADIC) .returnTypeInference(ReductionOperatorConversionHelper.TYPE_INFERENCE) .build(); public GreatestOperatorConversion() { - super(SQL_FUNCTION, Function.GreatestFunction.NAME); + super(SQL_FUNCTION, Function.GreatestFunc.NAME); } @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java index 93217ee9cd68..09578c6eaa9e 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/LeastOperatorConversion.java @@ -30,14 +30,14 @@ public class LeastOperatorConversion extends DirectOperatorConversion { private static final SqlFunction SQL_FUNCTION = OperatorConversions - .operatorBuilder(StringUtils.toUpperCase(Function.LeastFunction.NAME)) + .operatorBuilder(StringUtils.toUpperCase(Function.LeastFunc.NAME)) .operandTypeChecker(OperandTypes.VARIADIC) .returnTypeInference(ReductionOperatorConversionHelper.TYPE_INFERENCE) .build(); public LeastOperatorConversion() { - super(SQL_FUNCTION, Function.LeastFunction.NAME); + super(SQL_FUNCTION, Function.LeastFunc.NAME); } @Override From 57282df134152becec26e1ae964f28b6159bdedb Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 18:26:21 -0700 Subject: [PATCH 04/15] add nullable --- core/src/main/java/org/apache/druid/math/expr/Expr.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/java/org/apache/druid/math/expr/Expr.java b/core/src/main/java/org/apache/druid/math/expr/Expr.java index ef4d8022d4a0..1724083ca5c0 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Expr.java +++ b/core/src/main/java/org/apache/druid/math/expr/Expr.java @@ -132,6 +132,7 @@ default String getBindingIfIdentifier() interface InputBindingTypes { + @Nullable ExprType getType(String name); } From 3b5a1fe7deb737304ad242605193c6ea3d884ae6 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 18:32:08 -0700 Subject: [PATCH 05/15] tidy up --- .../segment/virtual/ExpressionSelectors.java | 42 +++++++++---------- ...RowBasedExpressionColumnValueSelector.java | 10 ++--- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index 6e2e71cdca8b..78c12f3617bc 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -136,9 +136,9 @@ public static ColumnValueSelector makeExprEvalSelector( Expr expression ) { - final Expr.ExprInputBindingAnalysis exprDetails = expression.analyzeInputs(); - Parser.validateExpr(expression, exprDetails); - final List columns = exprDetails.getRequiredBindingsList(); + final Expr.ExprInputBindingAnalysis inputBindingAnalysis = expression.analyzeInputs(); + Parser.validateExpr(expression, inputBindingAnalysis); + final List columns = inputBindingAnalysis.getRequiredBindingsList(); if (columns.size() == 1) { final String column = Iterables.getOnlyElement(columns); @@ -155,7 +155,7 @@ public static ColumnValueSelector makeExprEvalSelector( && capabilities.getType() == ValueType.STRING && capabilities.isDictionaryEncoded().isTrue() && capabilities.hasMultipleValues().isFalse() - && exprDetails.getArrayBindings().isEmpty()) { + && inputBindingAnalysis.getArrayBindings().isEmpty()) { // Optimization for expressions that hit one scalar string column and nothing else. return new SingleStringInputCachingExpressionColumnValueSelector( columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(column, column, ValueType.STRING)), @@ -165,22 +165,22 @@ public static ColumnValueSelector makeExprEvalSelector( } final Pair, Set> arrayUsage = - examineColumnSelectorFactoryArrays(columnSelectorFactory, exprDetails, columns); + examineColumnSelectorFactoryArrays(columnSelectorFactory, inputBindingAnalysis, columns); final Set actualArrays = arrayUsage.lhs; final Set unknownIfArrays = arrayUsage.rhs; final List needsApplied = columns.stream() - .filter(c -> actualArrays.contains(c) && !exprDetails.getArrayBindings().contains(c)) + .filter(c -> actualArrays.contains(c) && !inputBindingAnalysis.getArrayBindings().contains(c)) .collect(Collectors.toList()); final Expr finalExpr; if (needsApplied.size() > 0) { - finalExpr = Parser.applyUnappliedBindings(expression, exprDetails, needsApplied); + finalExpr = Parser.applyUnappliedBindings(expression, inputBindingAnalysis, needsApplied); } else { finalExpr = expression; } - final Expr.ObjectBinding bindings = createBindings(exprDetails, columnSelectorFactory); + final Expr.ObjectBinding bindings = createBindings(inputBindingAnalysis, columnSelectorFactory); if (bindings.equals(ExprUtils.nilBindings())) { // Optimization for constant expressions. @@ -192,7 +192,7 @@ public static ColumnValueSelector makeExprEvalSelector( if (unknownIfArrays.size() > 0) { return new RowBasedExpressionColumnValueSelector( finalExpr, - exprDetails, + inputBindingAnalysis, bindings, unknownIfArrays ); @@ -212,9 +212,9 @@ public static DimensionSelector makeDimensionSelector( @Nullable final ExtractionFn extractionFn ) { - final Expr.ExprInputBindingAnalysis exprDetails = expression.analyzeInputs(); - Parser.validateExpr(expression, exprDetails); - final List columns = exprDetails.getRequiredBindingsList(); + final Expr.ExprInputBindingAnalysis inputBindingAnalysis = expression.analyzeInputs(); + Parser.validateExpr(expression, inputBindingAnalysis); + final List columns = inputBindingAnalysis.getRequiredBindingsList(); if (columns.size() == 1) { final String column = Iterables.getOnlyElement(columns); @@ -226,7 +226,7 @@ public static DimensionSelector makeDimensionSelector( if (capabilities != null && capabilities.getType() == ValueType.STRING && capabilities.isDictionaryEncoded().isTrue() - && canMapOverDictionary(exprDetails, capabilities.hasMultipleValues()) + && canMapOverDictionary(inputBindingAnalysis, capabilities.hasMultipleValues()) ) { return new SingleStringInputDimensionSelector( columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(column, column, ValueType.STRING)), @@ -236,14 +236,14 @@ && canMapOverDictionary(exprDetails, capabilities.hasMultipleValues()) } final Pair, Set> arrayUsage = - examineColumnSelectorFactoryArrays(columnSelectorFactory, exprDetails, columns); + examineColumnSelectorFactoryArrays(columnSelectorFactory, inputBindingAnalysis, columns); final Set actualArrays = arrayUsage.lhs; final Set unknownIfArrays = arrayUsage.rhs; final ColumnValueSelector baseSelector = makeExprEvalSelector(columnSelectorFactory, expression); final boolean multiVal = actualArrays.size() > 0 || - exprDetails.getArrayBindings().size() > 0 || + inputBindingAnalysis.getArrayBindings().size() > 0 || unknownIfArrays.size() > 0; if (baseSelector instanceof ConstantExprEvalSelector) { @@ -344,16 +344,16 @@ public void inspectRuntimeShape(RuntimeShapeInspector inspector) * This function should only be called if you have already determined that an expression is over a single column, * and that single column has a dictionary. * - * @param exprDetails result of calling {@link Expr#analyzeInputs()} on an expression + * @param inputBindingAnalysis result of calling {@link Expr#analyzeInputs()} on an expression * @param hasMultipleValues result of calling {@link ColumnCapabilities#hasMultipleValues()} */ public static boolean canMapOverDictionary( - final Expr.ExprInputBindingAnalysis exprDetails, + final Expr.ExprInputBindingAnalysis inputBindingAnalysis, final ColumnCapabilities.Capable hasMultipleValues ) { - Preconditions.checkState(exprDetails.getRequiredBindings().size() == 1, "requiredBindings.size == 1"); - return !hasMultipleValues.isUnknown() && !exprDetails.hasInputArrays() && !exprDetails.isOutputArray(); + Preconditions.checkState(inputBindingAnalysis.getRequiredBindings().size() == 1, "requiredBindings.size == 1"); + return !hasMultipleValues.isUnknown() && !inputBindingAnalysis.hasInputArrays() && !inputBindingAnalysis.isOutputArray(); } /** @@ -601,7 +601,7 @@ public static Object coerceEvalToSelectorObject(ExprEval eval) */ private static Pair, Set> examineColumnSelectorFactoryArrays( ColumnSelectorFactory columnSelectorFactory, - Expr.ExprInputBindingAnalysis exprDetails, + Expr.ExprInputBindingAnalysis inputBindingAnalysis, List columns ) { @@ -615,7 +615,7 @@ private static Pair, Set> examineColumnSelectorFactoryArrays } else if ( capabilities.getType().equals(ValueType.STRING) && capabilities.hasMultipleValues().isMaybeTrue() && - !exprDetails.getArrayBindings().contains(column) + !inputBindingAnalysis.getArrayBindings().contains(column) ) { unknownIfArrays.add(column); } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java index 37be181a0d13..51bae6ccf2ed 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java @@ -40,22 +40,22 @@ public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValueSelector { private final List unknownColumns; - private final Expr.ExprInputBindingAnalysis baseExprExprInputBindingAnalysis; + private final Expr.ExprInputBindingAnalysis baseExprInputBindingAnalysis; private final Set ignoredColumns; private final Int2ObjectMap transformedCache; public RowBasedExpressionColumnValueSelector( Expr expression, - Expr.ExprInputBindingAnalysis baseExprExprInputBindingAnalysis, + Expr.ExprInputBindingAnalysis baseExprInputBindingAnalysis, Expr.ObjectBinding bindings, Set unknownColumnsSet ) { super(expression, bindings); this.unknownColumns = unknownColumnsSet.stream() - .filter(x -> !baseExprExprInputBindingAnalysis.getArrayBindings().contains(x)) + .filter(x -> !baseExprInputBindingAnalysis.getArrayBindings().contains(x)) .collect(Collectors.toList()); - this.baseExprExprInputBindingAnalysis = baseExprExprInputBindingAnalysis; + this.baseExprInputBindingAnalysis = baseExprInputBindingAnalysis; this.ignoredColumns = new HashSet<>(); this.transformedCache = new Int2ObjectArrayMap<>(unknownColumns.size()); } @@ -79,7 +79,7 @@ public ExprEval getObject() if (transformedCache.containsKey(key)) { return transformedCache.get(key).eval(bindings); } - Expr transformed = Parser.applyUnappliedBindings(expression, baseExprExprInputBindingAnalysis, arrayBindings); + Expr transformed = Parser.applyUnappliedBindings(expression, baseExprInputBindingAnalysis, arrayBindings); transformedCache.put(key, transformed); return transformed.eval(bindings); } From 8fc125358d0f894b9d124414b464c4e212cef92d Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 18:36:21 -0700 Subject: [PATCH 06/15] fixup --- .../main/java/org/apache/druid/math/expr/ApplyFunction.java | 1 + core/src/main/java/org/apache/druid/math/expr/ExprType.java | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java index 8a0f68c24a78..b2b8918bd8ef 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java +++ b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java @@ -895,6 +895,7 @@ public LambdaInputBindingTypes(Expr.InputBindingTypes inputTypes, LambdaExpr exp } } + @Nullable @Override public ExprType getType(String name) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index 89700545f161..1f143ced1b8d 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -86,9 +86,9 @@ public static boolean isArray(@Nullable ExprType type) } @Nullable - public static ExprType elementType(ExprType type) + public static ExprType elementType(@Nullable ExprType type) { - if (isArray(type)) { + if (type != null && isArray(type)) { switch (type) { case STRING_ARRAY: return STRING; From 2c6fc87c5a0361363914a8f1ea1ef46403fd4e6e Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 19:27:50 -0700 Subject: [PATCH 07/15] more better --- .../java/org/apache/druid/math/expr/ExprType.java | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index 1f143ced1b8d..6fa09b2ae90c 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -37,7 +37,6 @@ public enum ExprType LONG_ARRAY, STRING_ARRAY; - public boolean isNumeric() { return isNumeric(this); @@ -102,9 +101,9 @@ public static ExprType elementType(@Nullable ExprType type) } @Nullable - public static ExprType asArrayType(ExprType elementType) + public static ExprType asArrayType(@Nullable ExprType elementType) { - if (!isArray(elementType)) { + if (elementType != null && !isArray(elementType)) { switch (elementType) { case STRING: return STRING_ARRAY; @@ -114,13 +113,15 @@ public static ExprType asArrayType(ExprType elementType) return DOUBLE_ARRAY; } } - return null; + return elementType; } + @Nullable public static ExprType implicitCast(@Nullable ExprType type, @Nullable ExprType other) { if (type == null || other == null) { - throw new IAE("Cannot implicitly cast unknown types"); + // cannot implicitly cast unknown types + return null; } // arrays cannot be implicitly cast if (isArray(type)) { From c36afdfea33cdffc6182ecf07f7317ad7c21b870 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 19:32:41 -0700 Subject: [PATCH 08/15] fix signatures --- .../apache/druid/math/expr/ApplyFunction.java | 5 ++ .../org/apache/druid/math/expr/Function.java | 47 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java index b2b8918bd8ef..52b3fa0bdbd3 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java +++ b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java @@ -75,6 +75,7 @@ default boolean hasArrayOutput(LambdaExpr lambdaExpr) */ void validateArguments(LambdaExpr lambdaExpr, List args); + @Nullable ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args); /** @@ -90,6 +91,7 @@ public boolean hasArrayOutput(LambdaExpr lambdaExpr) return true; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) { @@ -295,6 +297,7 @@ public boolean hasArrayOutput(LambdaExpr lambdaExpr) return lambdaExprInputBindingAnalysis.isOutputArray(); } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) { @@ -497,6 +500,7 @@ public void validateArguments(LambdaExpr lambdaExpr, List args) ); } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) { @@ -551,6 +555,7 @@ public void validateArguments(LambdaExpr lambdaExpr, List args) ); } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args) { diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index c6cb99220fb8..0a50a8d79e05 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -105,6 +105,7 @@ default boolean hasArrayOutput() */ void validateArguments(List args); + @Nullable ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args); /** @@ -184,6 +185,7 @@ protected ExprEval eval(double param) return eval((long) param); } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -196,6 +198,7 @@ public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args */ abstract class DoubleUnivariateMathFunction extends UnivariateMathFunction { + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -232,6 +235,7 @@ protected ExprEval eval(double x, double y) return eval((long) x, (long) y); } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -244,6 +248,7 @@ public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args */ abstract class DoubleBivariateMathFunction extends BivariateMathFunction { + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -386,6 +391,7 @@ public void validateArguments(List args) // anything goes } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -517,6 +523,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -574,6 +581,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -926,6 +934,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1240,6 +1249,7 @@ public String name() return "scalb"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1313,6 +1323,7 @@ public Set getArrayInputs(List args) return Collections.emptySet(); } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1387,6 +1398,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1430,6 +1442,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1473,6 +1486,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1504,6 +1518,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1534,6 +1549,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1564,6 +1580,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1613,6 +1630,7 @@ public void validateArguments(List args) // anything goes } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1643,6 +1661,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1683,6 +1702,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1727,6 +1747,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1776,6 +1797,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1791,6 +1813,7 @@ public String name() return "right"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1822,6 +1845,7 @@ public String name() return "left"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1872,6 +1896,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1905,6 +1930,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1938,6 +1964,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1953,6 +1980,7 @@ public String name() return "reverse"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -1981,6 +2009,7 @@ public String name() return "repeat"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2025,6 +2054,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2063,6 +2093,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2112,6 +2143,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2170,6 +2202,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2282,6 +2315,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2337,6 +2371,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2366,6 +2401,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2406,6 +2442,7 @@ public String name() return "array_to_string"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2434,6 +2471,7 @@ public String name() return "array_offset"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2461,6 +2499,7 @@ public String name() return "array_ordinal"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2488,6 +2527,7 @@ public String name() return "array_offset_of"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2525,6 +2565,7 @@ public String name() return "array_ordinal_of"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2567,6 +2608,7 @@ public boolean hasArrayOutput() return true; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2630,6 +2672,7 @@ public boolean hasArrayOutput() return true; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2692,6 +2735,7 @@ public boolean hasArrayOutput() return true; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2715,6 +2759,7 @@ public String name() return "array_overlap"; } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2750,6 +2795,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { @@ -2835,6 +2881,7 @@ public void validateArguments(List args) } } + @Nullable @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { From d2d7f9ad249d1189c121958726fa9b38d87ca886 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 20:52:39 -0700 Subject: [PATCH 09/15] naming things is hard --- .../apache/druid/math/expr/ApplyFunction.java | 4 +- .../druid/math/expr/BinaryOperatorExpr.java | 2 +- .../apache/druid/math/expr/ConstantExpr.java | 4 +- .../java/org/apache/druid/math/expr/Expr.java | 50 +++++------ .../druid/math/expr/ExprListenerImpl.java | 2 +- .../druid/math/expr/ExprMacroTable.java | 14 +-- .../org/apache/druid/math/expr/Function.java | 89 +++++-------------- .../druid/math/expr/FunctionalExpr.java | 36 ++++---- .../druid/math/expr/IdentifierExpr.java | 4 +- .../org/apache/druid/math/expr/Parser.java | 28 +++--- .../druid/math/expr/UnaryOperatorExpr.java | 2 +- .../org/apache/druid/math/expr/ExprTest.java | 2 +- .../apache/druid/math/expr/ParserTest.java | 8 +- .../druid/query/expression/TrimExprMacro.java | 2 +- .../segment/filter/ExpressionFilter.java | 4 +- .../join/filter/JoinFilterCorrelations.java | 4 +- .../segment/virtual/ExpressionSelectors.java | 48 +++++----- ...RowBasedExpressionColumnValueSelector.java | 10 +-- 18 files changed, 135 insertions(+), 178 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java index 52b3fa0bdbd3..e04e1a388735 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java +++ b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java @@ -293,8 +293,8 @@ ExprEval applyFold(LambdaExpr lambdaExpr, Object accumulator, IndexableFoldLambd @Override public boolean hasArrayOutput(LambdaExpr lambdaExpr) { - Expr.ExprInputBindingAnalysis lambdaExprInputBindingAnalysis = lambdaExpr.analyzeInputs(); - return lambdaExprInputBindingAnalysis.isOutputArray(); + Expr.BindingAnalysis lambdaBindingAnalysis = lambdaExpr.analyzeInputs(); + return lambdaBindingAnalysis.isOutputArray(); } @Nullable diff --git a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java index 7b7423d29733..91811a5610fb 100644 --- a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java @@ -81,7 +81,7 @@ public String stringify() protected abstract BinaryOpExprBase copy(Expr left, Expr right); @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { // currently all binary operators operate on scalar inputs return left.analyzeInputs().with(right).withScalarArguments(ImmutableSet.of(left, right)); diff --git a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java index c0daed09da2b..fe74491e5e07 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java @@ -67,9 +67,9 @@ public Expr visit(Shuttle shuttle) } @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { - return new ExprInputBindingAnalysis(); + return new BindingAnalysis(); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/Expr.java b/core/src/main/java/org/apache/druid/math/expr/Expr.java index 1724083ca5c0..06ed806b071c 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Expr.java +++ b/core/src/main/java/org/apache/druid/math/expr/Expr.java @@ -116,16 +116,16 @@ default String getBindingIfIdentifier() void visit(Visitor visitor); /** - * Programatically rewrite the {@link Expr} tree with a {@link Shuttle}.Each {@link Expr} is responsible for + * Programatically rewrite the {@link Expr} tree with a {@link Shuttle}. Each {@link Expr} is responsible for * ensuring the {@link Shuttle} can visit all of its {@link Expr} children, as well as updating its children * {@link Expr} with the results from the {@link Shuttle}, before finally visiting an updated form of itself. */ Expr visit(Shuttle shuttle); /** - * Examine the usage of {@link IdentifierExpr} children of an {@link Expr}, constructing a {@link ExprInputBindingAnalysis} + * Examine the usage of {@link IdentifierExpr} children of an {@link Expr}, constructing a {@link BindingAnalysis} */ - ExprInputBindingAnalysis analyzeInputs(); + BindingAnalysis analyzeInputs(); @Nullable ExprType getOutputType(InputBindingTypes inputTypes); @@ -189,7 +189,7 @@ interface Shuttle * * This means in rare cases and mostly for "questionable" expressions which we still allow to function 'correctly', * these lists might not be fully reliable without a complete type inference system in place. Due to this shortcoming, - * boolean values {@link ExprInputBindingAnalysis#hasInputArrays()} and {@link ExprInputBindingAnalysis#isOutputArray()} are provided to + * boolean values {@link BindingAnalysis#hasInputArrays()} and {@link BindingAnalysis#isOutputArray()} are provided to * allow functions to explicitly declare that they utilize array typed values, used when determining if some types of * optimizations can be applied when constructing the expression column value selector. * @@ -203,7 +203,7 @@ interface Shuttle * @see org.apache.druid.segment.virtual.ExpressionSelectors#makeColumnValueSelector */ @SuppressWarnings("JavadocReference") - class ExprInputBindingAnalysis + class BindingAnalysis { private final ImmutableSet freeVariables; private final ImmutableSet scalarVariables; @@ -211,17 +211,17 @@ class ExprInputBindingAnalysis private final boolean hasInputArrays; private final boolean isOutputArray; - ExprInputBindingAnalysis() + BindingAnalysis() { this(ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of(), false, false); } - ExprInputBindingAnalysis(IdentifierExpr expr) + BindingAnalysis(IdentifierExpr expr) { this(ImmutableSet.of(expr), ImmutableSet.of(), ImmutableSet.of(), false, false); } - private ExprInputBindingAnalysis( + private BindingAnalysis( ImmutableSet freeVariables, ImmutableSet scalarVariables, ImmutableSet arrayVariables, @@ -319,19 +319,19 @@ public boolean isOutputArray() } /** - * Combine with {@link ExprInputBindingAnalysis} from {@link Expr#analyzeInputs()} + * Combine with {@link BindingAnalysis} from {@link Expr#analyzeInputs()} */ - public ExprInputBindingAnalysis with(Expr other) + public BindingAnalysis with(Expr other) { return with(other.analyzeInputs()); } /** - * Combine (union) another {@link ExprInputBindingAnalysis} + * Combine (union) another {@link BindingAnalysis} */ - public ExprInputBindingAnalysis with(ExprInputBindingAnalysis other) + public BindingAnalysis with(BindingAnalysis other) { - return new ExprInputBindingAnalysis( + return new BindingAnalysis( ImmutableSet.copyOf(Sets.union(freeVariables, other.freeVariables)), ImmutableSet.copyOf(Sets.union(scalarVariables, other.scalarVariables)), ImmutableSet.copyOf(Sets.union(arrayVariables, other.arrayVariables)), @@ -341,10 +341,10 @@ public ExprInputBindingAnalysis with(ExprInputBindingAnalysis other) } /** - * Add set of arguments as {@link ExprInputBindingAnalysis#scalarVariables} that are *directly* {@link IdentifierExpr}, + * Add set of arguments as {@link BindingAnalysis#scalarVariables} that are *directly* {@link IdentifierExpr}, * else they are ignored. */ - public ExprInputBindingAnalysis withScalarArguments(Set scalarArguments) + public BindingAnalysis withScalarArguments(Set scalarArguments) { Set moreScalars = new HashSet<>(); for (Expr expr : scalarArguments) { @@ -353,7 +353,7 @@ public ExprInputBindingAnalysis withScalarArguments(Set scalarArguments) moreScalars.add((IdentifierExpr) expr); } } - return new ExprInputBindingAnalysis( + return new BindingAnalysis( ImmutableSet.copyOf(Sets.union(freeVariables, moreScalars)), ImmutableSet.copyOf(Sets.union(scalarVariables, moreScalars)), arrayVariables, @@ -363,10 +363,10 @@ public ExprInputBindingAnalysis withScalarArguments(Set scalarArguments) } /** - * Add set of arguments as {@link ExprInputBindingAnalysis#arrayVariables} that are *directly* {@link IdentifierExpr}, + * Add set of arguments as {@link BindingAnalysis#arrayVariables} that are *directly* {@link IdentifierExpr}, * else they are ignored. */ - ExprInputBindingAnalysis withArrayArguments(Set arrayArguments) + BindingAnalysis withArrayArguments(Set arrayArguments) { Set arrayIdentifiers = new HashSet<>(); for (Expr expr : arrayArguments) { @@ -375,7 +375,7 @@ ExprInputBindingAnalysis withArrayArguments(Set arrayArguments) arrayIdentifiers.add((IdentifierExpr) expr); } } - return new ExprInputBindingAnalysis( + return new BindingAnalysis( ImmutableSet.copyOf(Sets.union(freeVariables, arrayIdentifiers)), scalarVariables, ImmutableSet.copyOf(Sets.union(arrayVariables, arrayIdentifiers)), @@ -387,9 +387,9 @@ ExprInputBindingAnalysis withArrayArguments(Set arrayArguments) /** * Copy, setting if an expression has array inputs */ - ExprInputBindingAnalysis withArrayInputs(boolean hasArrays) + BindingAnalysis withArrayInputs(boolean hasArrays) { - return new ExprInputBindingAnalysis( + return new BindingAnalysis( freeVariables, scalarVariables, arrayVariables, @@ -401,9 +401,9 @@ ExprInputBindingAnalysis withArrayInputs(boolean hasArrays) /** * Copy, setting if an expression produces an array output */ - ExprInputBindingAnalysis withArrayOutput(boolean isOutputArray) + BindingAnalysis withArrayOutput(boolean isOutputArray) { - return new ExprInputBindingAnalysis( + return new BindingAnalysis( freeVariables, scalarVariables, arrayVariables, @@ -416,9 +416,9 @@ ExprInputBindingAnalysis withArrayOutput(boolean isOutputArray) * Remove any {@link IdentifierExpr} that are from a {@link LambdaExpr}, since the {@link ApplyFunction} will * provide bindings for these variables. */ - ExprInputBindingAnalysis removeLambdaArguments(Set lambda) + BindingAnalysis removeLambdaArguments(Set lambda) { - return new ExprInputBindingAnalysis( + return new BindingAnalysis( ImmutableSet.copyOf(freeVariables.stream().filter(x -> !lambda.contains(x.getIdentifier())).iterator()), ImmutableSet.copyOf(scalarVariables.stream().filter(x -> !lambda.contains(x.getIdentifier())).iterator()), ImmutableSet.copyOf(arrayVariables.stream().filter(x -> !lambda.contains(x.getIdentifier())).iterator()), diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java index 69a64e13d2b3..3f69f6e0b7e4 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java @@ -482,7 +482,7 @@ public void exitExplicitStringArray(ExprParser.ExplicitStringArrayContext ctx) * {@link IdentifierExpr#identifier} be the same as {@link IdentifierExpr#binding} because they have * synthetic bindings set at evaluation time. This is done to aid in analysis needed for the automatic expression * translation which maps scalar expressions to multi-value inputs. See - * {@link Parser#applyUnappliedBindings(Expr, Expr.ExprInputBindingAnalysis, List)}} for additional details. + * {@link Parser#applyUnappliedBindings(Expr, Expr.BindingAnalysis, List)}} for additional details. */ private IdentifierExpr createIdentifierExpr(String binding) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java index 4396969693b6..616297a57c3b 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java @@ -102,7 +102,7 @@ public abstract static class BaseScalarUnivariateMacroFunctionExpr implements Ex protected final Expr arg; // Use Supplier to memoize values as ExpressionSelectors#makeExprEvalSelector() can make repeated calls for them - private final Supplier analyzeInputsSupplier; + private final Supplier analyzeInputsSupplier; public BaseScalarUnivariateMacroFunctionExpr(String name, Expr arg) { @@ -119,7 +119,7 @@ public void visit(final Visitor visitor) } @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { return analyzeInputsSupplier.get(); } @@ -150,7 +150,7 @@ public int hashCode() return Objects.hash(name, arg); } - private ExprInputBindingAnalysis supplyAnalyzeInputs() + private BindingAnalysis supplyAnalyzeInputs() { return arg.analyzeInputs().withScalarArguments(ImmutableSet.of(arg)); } @@ -165,7 +165,7 @@ public abstract static class BaseScalarMacroFunctionExpr implements Expr protected final List args; // Use Supplier to memoize values as ExpressionSelectors#makeExprEvalSelector() can make repeated calls for them - private final Supplier analyzeInputsSupplier; + private final Supplier analyzeInputsSupplier; public BaseScalarMacroFunctionExpr(String name, final List args) { @@ -194,7 +194,7 @@ public void visit(final Visitor visitor) } @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { return analyzeInputsSupplier.get(); } @@ -219,10 +219,10 @@ public int hashCode() return Objects.hash(name, args); } - private ExprInputBindingAnalysis supplyAnalyzeInputs() + private BindingAnalysis supplyAnalyzeInputs() { final Set argSet = Sets.newHashSetWithExpectedSize(args.size()); - ExprInputBindingAnalysis accumulator = new ExprInputBindingAnalysis(); + BindingAnalysis accumulator = new BindingAnalysis(); for (Expr arg : args) { accumulator = accumulator.with(arg); argSet.add(arg); diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 0a50a8d79e05..5059996337c7 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -36,7 +36,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; -import java.util.EnumSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -368,6 +367,7 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) abstract ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr); } + abstract class ReduceFunction implements Function { private final DoubleBinaryOperator doubleReducer; @@ -409,8 +409,24 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) return ExprEval.of(null); } - ExprAnalysis exprAnalysis = analyzeExprs(args, bindings); - if (exprAnalysis.exprEvals.isEmpty()) { + // evaluate arguments and collect output type + List> evals = new ArrayList<>(); + ExprType outputType = ExprType.LONG; + + for (Expr expr : args) { + ExprEval exprEval = expr.eval(bindings); + ExprType exprType = exprEval.type(); + + if (isValidType(exprType)) { + outputType = ExprType.implicitCast(outputType, exprType); + } + + if (exprEval.value() != null) { + evals.add(exprEval); + } + } + + if (evals.isEmpty()) { // The GREATEST/LEAST functions are not in the SQL standard. Emulate the behavior of postgres (return null if // all expressions are null, otherwise skip null values) since it is used as a base for a wide number of // databases. This also matches the behavior the the long/double greatest/least post aggregators. Some other @@ -420,48 +436,17 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) return ExprEval.of(null); } - Stream> exprEvalStream = exprAnalysis.exprEvals.stream(); - switch (exprAnalysis.comparisonType) { + switch (outputType) { case DOUBLE: //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) - return ExprEval.of(exprEvalStream.mapToDouble(ExprEval::asDouble).reduce(doubleReducer).getAsDouble()); + return ExprEval.of(evals.stream().mapToDouble(ExprEval::asDouble).reduce(doubleReducer).getAsDouble()); case LONG: //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) - return ExprEval.of(exprEvalStream.mapToLong(ExprEval::asLong).reduce(longReducer).getAsLong()); + return ExprEval.of(evals.stream().mapToLong(ExprEval::asLong).reduce(longReducer).getAsLong()); default: //noinspection OptionalGetWithoutIsPresent (empty list handled earlier) - return ExprEval.of(exprEvalStream.map(ExprEval::asString).reduce(stringReducer).get()); - } - } - - /** - * Determines which {@link ExprType} to use to compare non-null evaluated expressions. - * - * @param exprs Expressions to analyze - * @param bindings Bindings for expressions - * - * @return Comparison type and non-null evaluated expressions. - */ - private ExprAnalysis analyzeExprs(List exprs, Expr.ObjectBinding bindings) - { - Set presentTypes = EnumSet.noneOf(ExprType.class); - List> exprEvals = new ArrayList<>(); - - for (Expr expr : exprs) { - ExprEval exprEval = expr.eval(bindings); - ExprType exprType = exprEval.type(); - - if (isValidType(exprType)) { - presentTypes.add(exprType); - } - - if (exprEval.value() != null) { - exprEvals.add(exprEval); - } + return ExprEval.of(evals.stream().map(ExprEval::asString).reduce(stringReducer).get()); } - - ExprType comparisonType = getComparisionType(presentTypes); - return new ExprAnalysis(comparisonType, exprEvals); } private boolean isValidType(ExprType exprType) @@ -475,34 +460,6 @@ private boolean isValidType(ExprType exprType) throw new IAE("Function[%s] does not accept %s types", name(), exprType); } } - - /** - * Implements rules similar to: https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least - * - * @see org.apache.druid.sql.calcite.expression.builtin.ReductionOperatorConversionHelper#TYPE_INFERENCE - */ - private static ExprType getComparisionType(Set exprTypes) - { - if (exprTypes.contains(ExprType.STRING)) { - return ExprType.STRING; - } else if (exprTypes.contains(ExprType.DOUBLE)) { - return ExprType.DOUBLE; - } else { - return ExprType.LONG; - } - } - - private static class ExprAnalysis - { - final ExprType comparisonType; - final List> exprEvals; - - ExprAnalysis(ExprType comparisonType, List> exprEvals) - { - this.comparisonType = comparisonType; - this.exprEvals = exprEvals; - } - } } // ------------------------------ implementations ------------------------------ diff --git a/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java b/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java index bcad13b4ccb4..e81d5bafd2cc 100644 --- a/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java @@ -105,10 +105,10 @@ public Expr visit(Shuttle shuttle) } @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { final Set lambdaArgs = args.stream().map(IdentifierExpr::toString).collect(Collectors.toSet()); - ExprInputBindingAnalysis bodyDetails = expr.analyzeInputs(); + BindingAnalysis bodyDetails = expr.analyzeInputs(); return bodyDetails.removeLambdaArguments(lambdaArgs); } @@ -193,9 +193,9 @@ public Expr visit(Shuttle shuttle) } @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { - ExprInputBindingAnalysis accumulator = new ExprInputBindingAnalysis(); + BindingAnalysis accumulator = new BindingAnalysis(); for (Expr arg : args) { accumulator = accumulator.with(arg); @@ -244,9 +244,9 @@ class ApplyFunctionExpr implements Expr final String name; final LambdaExpr lambdaExpr; final ImmutableList argsExpr; - final ExprInputBindingAnalysis exprInputBindingAnalysis; - final ExprInputBindingAnalysis lambdaExprInputBindingAnalysis; - final ImmutableList argsExprInputBindingAnalysis; + final BindingAnalysis bindingAnalysis; + final BindingAnalysis lambdaBindingAnalysis; + final ImmutableList argsBindingAnalyses; ApplyFunctionExpr(ApplyFunction function, String name, LambdaExpr expr, List args) { @@ -259,21 +259,21 @@ class ApplyFunctionExpr implements Expr // apply function expressions are examined during expression selector creation, so precompute and cache the // binding details of children - ImmutableList.Builder argBindingDetailsBuilder = ImmutableList.builder(); - ExprInputBindingAnalysis accumulator = new ExprInputBindingAnalysis(); + ImmutableList.Builder argBindingDetailsBuilder = ImmutableList.builder(); + BindingAnalysis accumulator = new BindingAnalysis(); for (Expr arg : argsExpr) { - ExprInputBindingAnalysis argDetails = arg.analyzeInputs(); + BindingAnalysis argDetails = arg.analyzeInputs(); argBindingDetailsBuilder.add(argDetails); accumulator = accumulator.with(argDetails); } - lambdaExprInputBindingAnalysis = lambdaExpr.analyzeInputs(); + lambdaBindingAnalysis = lambdaExpr.analyzeInputs(); - exprInputBindingAnalysis = accumulator.with(lambdaExprInputBindingAnalysis) - .withArrayArguments(function.getArrayInputs(argsExpr)) - .withArrayInputs(true) - .withArrayOutput(function.hasArrayOutput(lambdaExpr)); - argsExprInputBindingAnalysis = argBindingDetailsBuilder.build(); + bindingAnalysis = accumulator.with(lambdaBindingAnalysis) + .withArrayArguments(function.getArrayInputs(argsExpr)) + .withArrayInputs(true) + .withArrayOutput(function.hasArrayOutput(lambdaExpr)); + argsBindingAnalyses = argBindingDetailsBuilder.build(); } @Override @@ -318,9 +318,9 @@ public Expr visit(Shuttle shuttle) } @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { - return exprInputBindingAnalysis; + return bindingAnalysis; } @Nullable diff --git a/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java b/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java index 9b4f888f2694..437370641c61 100644 --- a/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/IdentifierExpr.java @@ -102,9 +102,9 @@ public IdentifierExpr getIdentifierExprIfIdentifierExpr() } @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { - return new ExprInputBindingAnalysis(this); + return new BindingAnalysis(this); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java index ec2686fbd406..c9388bff17f1 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Parser.java +++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java @@ -169,7 +169,7 @@ public static Expr flatten(Expr expr) * @param bindingsToApply * @return */ - public static Expr applyUnappliedBindings(Expr expr, Expr.ExprInputBindingAnalysis exprInputBindingAnalysis, List bindingsToApply) + public static Expr applyUnappliedBindings(Expr expr, Expr.BindingAnalysis bindingAnalysis, List bindingsToApply) { if (bindingsToApply.isEmpty()) { // nothing to do, expression is fine as is @@ -177,7 +177,7 @@ public static Expr applyUnappliedBindings(Expr expr, Expr.ExprInputBindingAnalys } // filter the list of bindings to those which are used in this expression List unappliedBindingsInExpression = bindingsToApply.stream() - .filter(x -> exprInputBindingAnalysis.getRequiredBindings().contains(x)) + .filter(x -> bindingAnalysis.getRequiredBindings().contains(x)) .collect(Collectors.toList()); // any unapplied bindings that are inside a lambda expression need that lambda expression to be rewritten @@ -193,7 +193,7 @@ public static Expr applyUnappliedBindings(Expr expr, Expr.ExprInputBindingAnalys List newArgs = new ArrayList<>(); for (Expr arg : fnExpr.args) { if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) { - Expr newArg = applyUnappliedBindings(arg, exprInputBindingAnalysis, unappliedBindingsInExpression); + Expr newArg = applyUnappliedBindings(arg, bindingAnalysis, unappliedBindingsInExpression); newArgs.add(newArg); } else { newArgs.add(arg); @@ -207,7 +207,7 @@ public static Expr applyUnappliedBindings(Expr expr, Expr.ExprInputBindingAnalys } ); - Expr.ExprInputBindingAnalysis newExprBindings = newExpr.analyzeInputs(); + Expr.BindingAnalysis newExprBindings = newExpr.analyzeInputs(); final Set expectedArrays = newExprBindings.getArrayVariables(); List remainingUnappliedBindings = @@ -288,11 +288,11 @@ private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List unappliedInThisApply = unappliedArgs.stream() - .filter(u -> !expr.exprInputBindingAnalysis.getArrayBindings().contains(u)) + .filter(u -> !expr.bindingAnalysis.getArrayBindings().contains(u)) .collect(Collectors.toSet()); List unappliedIdentifiers = - expr.exprInputBindingAnalysis + expr.bindingAnalysis .getFreeVariables() .stream() .filter(x -> unappliedInThisApply.contains(x.getBindingIfIdentifier())) @@ -304,7 +304,7 @@ private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List unappliedLambdaBindings = - expr.lambdaExprInputBindingAnalysis.getFreeVariables() - .stream() - .filter(x -> unappliedArgs.contains(x.getBindingIfIdentifier())) - .map(x -> new IdentifierExpr(x.getIdentifier(), x.getBinding())) - .collect(Collectors.toList()); + expr.lambdaBindingAnalysis.getFreeVariables() + .stream() + .filter(x -> unappliedArgs.contains(x.getBindingIfIdentifier())) + .map(x -> new IdentifierExpr(x.getIdentifier(), x.getBinding())) + .collect(Collectors.toList()); if (unappliedLambdaBindings.isEmpty()) { return new ApplyFunctionExpr(expr.function, expr.name, expr.lambdaExpr, newArgs); @@ -397,10 +397,10 @@ private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List conflicted = - Sets.intersection(exprInputBindingAnalysis.getScalarBindings(), exprInputBindingAnalysis.getArrayBindings()); + Sets.intersection(bindingAnalysis.getScalarBindings(), bindingAnalysis.getArrayBindings()); if (!conflicted.isEmpty()) { throw new RE("Invalid expression: %s; %s used as both scalar and array variables", expression, conflicted); } diff --git a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java index f3304971a0ee..3d68430ea65b 100644 --- a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java @@ -60,7 +60,7 @@ public Expr visit(Shuttle shuttle) } @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { // currently all unary operators only operate on scalar inputs return expr.analyzeInputs().withScalarArguments(ImmutableSet.of(expr)); diff --git a/core/src/test/java/org/apache/druid/math/expr/ExprTest.java b/core/src/test/java/org/apache/druid/math/expr/ExprTest.java index 02e373834179..6dfa61d6d186 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ExprTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ExprTest.java @@ -113,7 +113,7 @@ public void testEqualsContractForApplyFunctionExpr() { EqualsVerifier.forClass(ApplyFunctionExpr.class) .usingGetClass() - .withIgnoredFields("function", "exprInputBindingAnalysis", "lambdaExprInputBindingAnalysis", "argsExprInputBindingAnalysis") + .withIgnoredFields("function", "bindingAnalysis", "lambdaBindingAnalysis", "argsBindingAnalyses") .verify(); } diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java index 01bda4182c76..1ebd71f1cc43 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java @@ -577,7 +577,7 @@ private void validateParser( ) { final Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); - final Expr.ExprInputBindingAnalysis deets = parsed.analyzeInputs(); + final Expr.BindingAnalysis deets = parsed.analyzeInputs(); Assert.assertEquals(expression, expected, parsed.toString()); Assert.assertEquals(expression, identifiers, deets.getRequiredBindingsList()); Assert.assertEquals(expression, scalars, deets.getScalarVariables()); @@ -586,7 +586,7 @@ private void validateParser( final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false); final Expr roundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil()); Assert.assertEquals(parsed.stringify(), roundTrip.stringify()); - final Expr.ExprInputBindingAnalysis roundTripDeets = roundTrip.analyzeInputs(); + final Expr.BindingAnalysis roundTripDeets = roundTrip.analyzeInputs(); Assert.assertEquals(expression, identifiers, roundTripDeets.getRequiredBindingsList()); Assert.assertEquals(expression, scalars, roundTripDeets.getScalarVariables()); Assert.assertEquals(expression, arrays, roundTripDeets.getArrayVariables()); @@ -600,7 +600,7 @@ private void validateApplyUnapplied( ) { final Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); - Expr.ExprInputBindingAnalysis deets = parsed.analyzeInputs(); + Expr.BindingAnalysis deets = parsed.analyzeInputs(); Parser.validateExpr(parsed, deets); final Expr transformed = Parser.applyUnappliedBindings(parsed, deets, identifiers); Assert.assertEquals(expression, unapplied, parsed.toString()); @@ -608,7 +608,7 @@ private void validateApplyUnapplied( final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false); final Expr parsedRoundTrip = Parser.parse(parsedNoFlatten.stringify(), ExprMacroTable.nil()); - Expr.ExprInputBindingAnalysis roundTripDeets = parsedRoundTrip.analyzeInputs(); + Expr.BindingAnalysis roundTripDeets = parsedRoundTrip.analyzeInputs(); Parser.validateExpr(parsedRoundTrip, roundTripDeets); final Expr transformedRoundTrip = Parser.applyUnappliedBindings(parsedRoundTrip, roundTripDeets, identifiers); Assert.assertEquals(expression, unapplied, parsedRoundTrip.toString()); diff --git a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java index 71c7e6db7bda..f019edc93e4a 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java @@ -299,7 +299,7 @@ public Expr visit(Shuttle shuttle) } @Override - public ExprInputBindingAnalysis analyzeInputs() + public BindingAnalysis analyzeInputs() { return stringExpr.analyzeInputs() .with(charsExpr) diff --git a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java index 76072c5df955..acf0dbeaf04e 100644 --- a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java +++ b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java @@ -45,7 +45,7 @@ public class ExpressionFilter implements Filter { private final Supplier expr; - private final Supplier bindingDetails; + private final Supplier bindingDetails; private final FilterTuning filterTuning; public ExpressionFilter(final Supplier expr, final FilterTuning filterTuning) @@ -107,7 +107,7 @@ public void inspectRuntimeShape(final RuntimeShapeInspector inspector) @Override public boolean supportsBitmapIndex(final BitmapIndexSelector selector) { - final Expr.ExprInputBindingAnalysis details = this.bindingDetails.get(); + final Expr.BindingAnalysis details = this.bindingDetails.get(); if (details.getRequiredBindings().isEmpty()) { // Constant expression. diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java index 57e0808ce62a..ed9fe0756251 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java @@ -380,8 +380,8 @@ private static void getCorrelationForRHSColumn( String identifier = lhsExpr.getBindingIfIdentifier(); if (identifier == null) { // We push down if the function only requires base table columns - Expr.ExprInputBindingAnalysis exprInputBindingAnalysis = lhsExpr.analyzeInputs(); - Set requiredBindings = exprInputBindingAnalysis.getRequiredBindings(); + Expr.BindingAnalysis bindingAnalysis = lhsExpr.analyzeInputs(); + Set requiredBindings = bindingAnalysis.getRequiredBindings(); if (joinableClauses.areSomeColumnsFromJoin(requiredBindings)) { break; diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index 78c12f3617bc..f2a0571d610a 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -136,9 +136,9 @@ public static ColumnValueSelector makeExprEvalSelector( Expr expression ) { - final Expr.ExprInputBindingAnalysis inputBindingAnalysis = expression.analyzeInputs(); - Parser.validateExpr(expression, inputBindingAnalysis); - final List columns = inputBindingAnalysis.getRequiredBindingsList(); + final Expr.BindingAnalysis bindingAnalysis = expression.analyzeInputs(); + Parser.validateExpr(expression, bindingAnalysis); + final List columns = bindingAnalysis.getRequiredBindingsList(); if (columns.size() == 1) { final String column = Iterables.getOnlyElement(columns); @@ -155,7 +155,7 @@ public static ColumnValueSelector makeExprEvalSelector( && capabilities.getType() == ValueType.STRING && capabilities.isDictionaryEncoded().isTrue() && capabilities.hasMultipleValues().isFalse() - && inputBindingAnalysis.getArrayBindings().isEmpty()) { + && bindingAnalysis.getArrayBindings().isEmpty()) { // Optimization for expressions that hit one scalar string column and nothing else. return new SingleStringInputCachingExpressionColumnValueSelector( columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(column, column, ValueType.STRING)), @@ -165,22 +165,22 @@ public static ColumnValueSelector makeExprEvalSelector( } final Pair, Set> arrayUsage = - examineColumnSelectorFactoryArrays(columnSelectorFactory, inputBindingAnalysis, columns); + examineColumnSelectorFactoryArrays(columnSelectorFactory, bindingAnalysis, columns); final Set actualArrays = arrayUsage.lhs; final Set unknownIfArrays = arrayUsage.rhs; final List needsApplied = columns.stream() - .filter(c -> actualArrays.contains(c) && !inputBindingAnalysis.getArrayBindings().contains(c)) + .filter(c -> actualArrays.contains(c) && !bindingAnalysis.getArrayBindings().contains(c)) .collect(Collectors.toList()); final Expr finalExpr; if (needsApplied.size() > 0) { - finalExpr = Parser.applyUnappliedBindings(expression, inputBindingAnalysis, needsApplied); + finalExpr = Parser.applyUnappliedBindings(expression, bindingAnalysis, needsApplied); } else { finalExpr = expression; } - final Expr.ObjectBinding bindings = createBindings(inputBindingAnalysis, columnSelectorFactory); + final Expr.ObjectBinding bindings = createBindings(bindingAnalysis, columnSelectorFactory); if (bindings.equals(ExprUtils.nilBindings())) { // Optimization for constant expressions. @@ -192,7 +192,7 @@ public static ColumnValueSelector makeExprEvalSelector( if (unknownIfArrays.size() > 0) { return new RowBasedExpressionColumnValueSelector( finalExpr, - inputBindingAnalysis, + bindingAnalysis, bindings, unknownIfArrays ); @@ -212,9 +212,9 @@ public static DimensionSelector makeDimensionSelector( @Nullable final ExtractionFn extractionFn ) { - final Expr.ExprInputBindingAnalysis inputBindingAnalysis = expression.analyzeInputs(); - Parser.validateExpr(expression, inputBindingAnalysis); - final List columns = inputBindingAnalysis.getRequiredBindingsList(); + final Expr.BindingAnalysis bindingAnalysis = expression.analyzeInputs(); + Parser.validateExpr(expression, bindingAnalysis); + final List columns = bindingAnalysis.getRequiredBindingsList(); if (columns.size() == 1) { final String column = Iterables.getOnlyElement(columns); @@ -226,7 +226,7 @@ public static DimensionSelector makeDimensionSelector( if (capabilities != null && capabilities.getType() == ValueType.STRING && capabilities.isDictionaryEncoded().isTrue() - && canMapOverDictionary(inputBindingAnalysis, capabilities.hasMultipleValues()) + && canMapOverDictionary(bindingAnalysis, capabilities.hasMultipleValues()) ) { return new SingleStringInputDimensionSelector( columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(column, column, ValueType.STRING)), @@ -236,14 +236,14 @@ && canMapOverDictionary(inputBindingAnalysis, capabilities.hasMultipleValues()) } final Pair, Set> arrayUsage = - examineColumnSelectorFactoryArrays(columnSelectorFactory, inputBindingAnalysis, columns); + examineColumnSelectorFactoryArrays(columnSelectorFactory, bindingAnalysis, columns); final Set actualArrays = arrayUsage.lhs; final Set unknownIfArrays = arrayUsage.rhs; final ColumnValueSelector baseSelector = makeExprEvalSelector(columnSelectorFactory, expression); final boolean multiVal = actualArrays.size() > 0 || - inputBindingAnalysis.getArrayBindings().size() > 0 || + bindingAnalysis.getArrayBindings().size() > 0 || unknownIfArrays.size() > 0; if (baseSelector instanceof ConstantExprEvalSelector) { @@ -344,30 +344,30 @@ public void inspectRuntimeShape(RuntimeShapeInspector inspector) * This function should only be called if you have already determined that an expression is over a single column, * and that single column has a dictionary. * - * @param inputBindingAnalysis result of calling {@link Expr#analyzeInputs()} on an expression + * @param bindingAnalysis result of calling {@link Expr#analyzeInputs()} on an expression * @param hasMultipleValues result of calling {@link ColumnCapabilities#hasMultipleValues()} */ public static boolean canMapOverDictionary( - final Expr.ExprInputBindingAnalysis inputBindingAnalysis, + final Expr.BindingAnalysis bindingAnalysis, final ColumnCapabilities.Capable hasMultipleValues ) { - Preconditions.checkState(inputBindingAnalysis.getRequiredBindings().size() == 1, "requiredBindings.size == 1"); - return !hasMultipleValues.isUnknown() && !inputBindingAnalysis.hasInputArrays() && !inputBindingAnalysis.isOutputArray(); + Preconditions.checkState(bindingAnalysis.getRequiredBindings().size() == 1, "requiredBindings.size == 1"); + return !hasMultipleValues.isUnknown() && !bindingAnalysis.hasInputArrays() && !bindingAnalysis.isOutputArray(); } /** - * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.ExprInputBindingAnalysis} which + * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingAnalysis} which * provides the set of identifiers which need a binding (list of required columns), and context of whether or not they * are used as array or scalar inputs */ private static Expr.ObjectBinding createBindings( - Expr.ExprInputBindingAnalysis exprInputBindingAnalysis, + Expr.BindingAnalysis bindingAnalysis, ColumnSelectorFactory columnSelectorFactory ) { final Map> suppliers = new HashMap<>(); - final List columns = exprInputBindingAnalysis.getRequiredBindingsList(); + final List columns = bindingAnalysis.getRequiredBindingsList(); for (String columnName : columns) { final ColumnCapabilities columnCapabilities = columnSelectorFactory .getColumnCapabilities(columnName); @@ -601,7 +601,7 @@ public static Object coerceEvalToSelectorObject(ExprEval eval) */ private static Pair, Set> examineColumnSelectorFactoryArrays( ColumnSelectorFactory columnSelectorFactory, - Expr.ExprInputBindingAnalysis inputBindingAnalysis, + Expr.BindingAnalysis bindingAnalysis, List columns ) { @@ -615,7 +615,7 @@ private static Pair, Set> examineColumnSelectorFactoryArrays } else if ( capabilities.getType().equals(ValueType.STRING) && capabilities.hasMultipleValues().isMaybeTrue() && - !inputBindingAnalysis.getArrayBindings().contains(column) + !bindingAnalysis.getArrayBindings().contains(column) ) { unknownIfArrays.add(column); } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java index 51bae6ccf2ed..5a33bc771b4d 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java @@ -40,22 +40,22 @@ public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValueSelector { private final List unknownColumns; - private final Expr.ExprInputBindingAnalysis baseExprInputBindingAnalysis; + private final Expr.BindingAnalysis baseBindingAnalysis; private final Set ignoredColumns; private final Int2ObjectMap transformedCache; public RowBasedExpressionColumnValueSelector( Expr expression, - Expr.ExprInputBindingAnalysis baseExprInputBindingAnalysis, + Expr.BindingAnalysis baseBindingAnalysis, Expr.ObjectBinding bindings, Set unknownColumnsSet ) { super(expression, bindings); this.unknownColumns = unknownColumnsSet.stream() - .filter(x -> !baseExprInputBindingAnalysis.getArrayBindings().contains(x)) + .filter(x -> !baseBindingAnalysis.getArrayBindings().contains(x)) .collect(Collectors.toList()); - this.baseExprInputBindingAnalysis = baseExprInputBindingAnalysis; + this.baseBindingAnalysis = baseBindingAnalysis; this.ignoredColumns = new HashSet<>(); this.transformedCache = new Int2ObjectArrayMap<>(unknownColumns.size()); } @@ -79,7 +79,7 @@ public ExprEval getObject() if (transformedCache.containsKey(key)) { return transformedCache.get(key).eval(bindings); } - Expr transformed = Parser.applyUnappliedBindings(expression, baseExprInputBindingAnalysis, arrayBindings); + Expr transformed = Parser.applyUnappliedBindings(expression, baseBindingAnalysis, arrayBindings); transformedCache.put(key, transformed); return transformed.eval(bindings); } From 8d64ecce0e8cc4a683c81cfbaf78b4aedcf2cbdb Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Tue, 8 Sep 2020 23:57:25 -0700 Subject: [PATCH 10/15] fix inspection --- .../java/org/apache/druid/math/expr/ExprType.java | 15 +++++++++++---- .../ReductionOperatorConversionHelper.java | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index 6fa09b2ae90c..663b34e2c616 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -134,10 +134,17 @@ public static ExprType implicitCast(@Nullable ExprType type, @Nullable ExprType if (STRING.equals(type) || STRING.equals(other)) { return STRING; } - // all numbers win over Integer - if (LONG.equals(type)) { - return other; + + if (isNumeric(type) && isNumeric(other)) { + // all numbers win over longs + if (LONG.equals(type)) { + return other; + } + // floats vs longs would be handled here, but we currently only support doubles... + return type; } - return type; + + // unhandled is unknown + return null; } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java index b5b4c21c7069..430cb2c58410 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java @@ -38,7 +38,7 @@ private ReductionOperatorConversionHelper() * https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least * * @see org.apache.druid.math.expr.Function.ReduceFunction#apply - * @see org.apache.druid.math.expr.Function.ReduceFunction#getComparisionType + * @see org.apache.druid.math.expr.ExprType#implicitCast */ static final SqlReturnTypeInference TYPE_INFERENCE = opBinding -> { From ee4306cfecf643fe2a7d3fce4d8386b41b1af785 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Thu, 10 Sep 2020 14:50:12 -0700 Subject: [PATCH 11/15] javadoc --- .../org/apache/druid/math/expr/ApplyFunction.java | 12 ++++++++++++ .../main/java/org/apache/druid/math/expr/Expr.java | 11 +++++++++++ .../java/org/apache/druid/math/expr/Function.java | 5 +++++ 3 files changed, 28 insertions(+) diff --git a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java index e04e1a388735..d6f4ed2bd876 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java +++ b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java @@ -75,6 +75,12 @@ default boolean hasArrayOutput(LambdaExpr lambdaExpr) */ void validateArguments(LambdaExpr lambdaExpr, List args); + /** + * Compute the output type of this function for a given lambda and the argument expressions which will be applied as + * its inputs. + * + * @see Expr#getOutputType + */ @Nullable ExprType getOutputType(Expr.InputBindingTypes inputTypes, LambdaExpr expr, List args); @@ -883,6 +889,12 @@ public CartesianFoldLambdaBinding accumulateWithIndex(int index, Object acc) } } + /** + * Helper that can wrap another {@link Expr.InputBindingTypes} to use to supply the type information of a + * {@link LambdaExpr} when evaluating {@link ApplyFunctionExpr#getOutputType}. Lambda identifiers do not exist + * in the underlying {@link Expr.InputBindingTypes}, but can be created by mapping the lambda identifiers to the + * arguments that will be applied to them, to map the type information. + */ class LambdaInputBindingTypes implements Expr.InputBindingTypes { private final Object2IntMap lambdaIdentifiers; diff --git a/core/src/main/java/org/apache/druid/math/expr/Expr.java b/core/src/main/java/org/apache/druid/math/expr/Expr.java index 06ed806b071c..4d0cdbfb5c9b 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Expr.java +++ b/core/src/main/java/org/apache/druid/math/expr/Expr.java @@ -127,9 +127,20 @@ default String getBindingIfIdentifier() */ BindingAnalysis analyzeInputs(); + /** + * Given an {@link InputBindingTypes}, compute what the output {@link ExprType} will be for this expression. A return + * value of null indicates that the given type information was not enough to resolve the output type, so the + * expression must be evaluated using default {@link #eval} handling where types are only known after evaluation, + * through {@link ExprEval#type}. + */ @Nullable ExprType getOutputType(InputBindingTypes inputTypes); + /** + * Mechanism to supply input types for the bindings which will back {@link IdentifierExpr}, to use in the aid of + * inferring the output type of an expression with {@link #getOutputType}. A null value means that either the binding + * doesn't exist, or, that the type information is unavailable. + */ interface InputBindingTypes { @Nullable diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 5059996337c7..8ed23e83af4e 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -104,6 +104,11 @@ default boolean hasArrayOutput() */ void validateArguments(List args); + /** + * Compute the output type of this function for a given set of argument expression inputs. + * + * @see Expr#getOutputType + */ @Nullable ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args); From b09b24d9ff6addcb0bb6dab22c83a72cd85c8969 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Thu, 10 Sep 2020 15:18:22 -0700 Subject: [PATCH 12/15] make default implementation of Expr.getOutputType that returns null --- core/src/main/java/org/apache/druid/math/expr/Expr.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/Expr.java b/core/src/main/java/org/apache/druid/math/expr/Expr.java index 4d0cdbfb5c9b..2a13be3f8459 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Expr.java +++ b/core/src/main/java/org/apache/druid/math/expr/Expr.java @@ -134,7 +134,10 @@ default String getBindingIfIdentifier() * through {@link ExprEval#type}. */ @Nullable - ExprType getOutputType(InputBindingTypes inputTypes); + default ExprType getOutputType(InputBindingTypes inputTypes) + { + return null; + } /** * Mechanism to supply input types for the bindings which will back {@link IdentifierExpr}, to use in the aid of From 9ff744b360de695535cd76010120d3091bb572ad Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Thu, 10 Sep 2020 16:14:43 -0700 Subject: [PATCH 13/15] rename method --- .../org/apache/druid/math/expr/BinaryOperatorExpr.java | 2 +- .../main/java/org/apache/druid/math/expr/ExprType.java | 9 ++++++--- .../main/java/org/apache/druid/math/expr/Function.java | 8 ++++---- .../server/lookup/cache/LookupCoordinatorManager.java | 2 +- .../builtin/ReductionOperatorConversionHelper.java | 2 +- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java index 91811a5610fb..ed6b8835868c 100644 --- a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java @@ -91,7 +91,7 @@ public BindingAnalysis analyzeInputs() @Override public ExprType getOutputType(InputBindingTypes inputTypes) { - return ExprType.implicitCast(left.getOutputType(inputTypes), right.getOutputType(inputTypes)); + return ExprType.autoTypeConversion(left.getOutputType(inputTypes), right.getOutputType(inputTypes)); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index 663b34e2c616..cfabf157cee5 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -116,14 +116,17 @@ public static ExprType asArrayType(@Nullable ExprType elementType) return elementType; } + /** + * Given 2 'input' types, choose the most appropriate combined type, if possible + */ @Nullable - public static ExprType implicitCast(@Nullable ExprType type, @Nullable ExprType other) + public static ExprType autoTypeConversion(@Nullable ExprType type, @Nullable ExprType other) { if (type == null || other == null) { - // cannot implicitly cast unknown types + // cannot auto conversion unknown types return null; } - // arrays cannot be implicitly cast + // arrays cannot be auto converted if (isArray(type)) { if (!type.equals(other)) { throw new IAE("Cannot implicitly cast %s to %s", type, other); diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 8ed23e83af4e..0367014b415f 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -243,7 +243,7 @@ protected ExprEval eval(double x, double y) @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { - return ExprType.implicitCast(args.get(0).getOutputType(inputTypes), args.get(1).getOutputType(inputTypes)); + return ExprType.autoTypeConversion(args.get(0).getOutputType(inputTypes), args.get(1).getOutputType(inputTypes)); } } @@ -402,7 +402,7 @@ public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args { ExprType outputType = ExprType.LONG; for (Expr expr : args) { - outputType = ExprType.implicitCast(outputType, expr.getOutputType(inputTypes)); + outputType = ExprType.autoTypeConversion(outputType, expr.getOutputType(inputTypes)); } return outputType; } @@ -423,7 +423,7 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) ExprType exprType = exprEval.type(); if (isValidType(exprType)) { - outputType = ExprType.implicitCast(outputType, exprType); + outputType = ExprType.autoTypeConversion(outputType, exprType); } if (exprEval.value() != null) { @@ -2283,7 +2283,7 @@ public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args { ExprType type = ExprType.LONG; for (Expr arg : args) { - type = ExprType.implicitCast(type, arg.getOutputType(inputTypes)); + type = ExprType.autoTypeConversion(type, arg.getOutputType(inputTypes)); } return ExprType.asArrayType(type); } diff --git a/server/src/main/java/org/apache/druid/server/lookup/cache/LookupCoordinatorManager.java b/server/src/main/java/org/apache/druid/server/lookup/cache/LookupCoordinatorManager.java index 7526ecbcc10f..981469d12205 100644 --- a/server/src/main/java/org/apache/druid/server/lookup/cache/LookupCoordinatorManager.java +++ b/server/src/main/java/org/apache/druid/server/lookup/cache/LookupCoordinatorManager.java @@ -518,7 +518,7 @@ private void initializeLookupsConfigWatcher() configManager.set( LOOKUP_CONFIG_KEY, converted, - new AuditInfo("autoConversion", "autoConversion", "127.0.0.1") + new AuditInfo("autoTypeConversion", "autoTypeConversion", "127.0.0.1") ); } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java index 430cb2c58410..c80335638474 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java @@ -38,7 +38,7 @@ private ReductionOperatorConversionHelper() * https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least * * @see org.apache.druid.math.expr.Function.ReduceFunction#apply - * @see org.apache.druid.math.expr.ExprType#implicitCast + * @see org.apache.druid.math.expr.ExprType#autoTypeConversion */ static final SqlReturnTypeInference TYPE_INFERENCE = opBinding -> { From 40d0913bc0c44b35bc2cd1dee8bf5bbd660aeecb Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Thu, 10 Sep 2020 22:05:01 -0700 Subject: [PATCH 14/15] more test --- .../apache/druid/math/expr/ConstantExpr.java | 1 + .../org/apache/druid/math/expr/ExprType.java | 6 +- .../druid/math/expr/OutputTypeTest.java | 61 +++++++++++++++++++ .../IPv4AddressMatchExprMacroTest.java | 9 --- .../cache/LookupCoordinatorManager.java | 2 +- 5 files changed, 66 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java index fe74491e5e07..ef090b5edd3a 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/ConstantExpr.java @@ -42,6 +42,7 @@ protected ConstantExpr(ExprType outputType) this.outputType = outputType; } + @Nullable @Override public ExprType getOutputType(InputBindingTypes inputTypes) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index cfabf157cee5..28af28a3d7c5 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -87,7 +87,7 @@ public static boolean isArray(@Nullable ExprType type) @Nullable public static ExprType elementType(@Nullable ExprType type) { - if (type != null && isArray(type)) { + if (type != null) { switch (type) { case STRING_ARRAY: return STRING; @@ -103,7 +103,7 @@ public static ExprType elementType(@Nullable ExprType type) @Nullable public static ExprType asArrayType(@Nullable ExprType elementType) { - if (elementType != null && !isArray(elementType)) { + if (elementType != null) { switch (elementType) { case STRING: return STRING_ARRAY; @@ -127,7 +127,7 @@ public static ExprType autoTypeConversion(@Nullable ExprType type, @Nullable Exp return null; } // arrays cannot be auto converted - if (isArray(type)) { + if (isArray(type) || isArray(other)) { if (!type.equals(other)) { throw new IAE("Cannot implicitly cast %s to %s", type, other); } diff --git a/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java b/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java index d7b5d32b4f86..61b0b6889d10 100644 --- a/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java @@ -20,9 +20,12 @@ package org.apache.druid.math.expr; import com.google.common.collect.ImmutableMap; +import org.apache.druid.java.util.common.IAE; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import java.util.Map; @@ -44,6 +47,9 @@ public class OutputTypeTest extends InitializedNullHandlingTest .build() ); + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @Test public void testConstantsAndIdentifiers() { @@ -354,6 +360,61 @@ public void testApplyFunctions() assertOutputType("all((x) -> x > 1.2, c)", inputTypes, ExprType.LONG); } + @Test + public void testAutoConversion() + { + // nulls output nulls + Assert.assertNull(ExprType.autoTypeConversion(ExprType.LONG, null)); + Assert.assertNull(ExprType.autoTypeConversion(null, ExprType.LONG)); + Assert.assertNull(ExprType.autoTypeConversion(ExprType.DOUBLE, null)); + Assert.assertNull(ExprType.autoTypeConversion(null, ExprType.DOUBLE)); + Assert.assertNull(ExprType.autoTypeConversion(ExprType.STRING, null)); + Assert.assertNull(ExprType.autoTypeConversion(null, ExprType.STRING)); + // only long stays long + Assert.assertEquals(ExprType.LONG, ExprType.autoTypeConversion(ExprType.LONG, ExprType.LONG)); + // any double makes all doubles + Assert.assertEquals(ExprType.DOUBLE, ExprType.autoTypeConversion(ExprType.LONG, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.autoTypeConversion(ExprType.DOUBLE, ExprType.LONG)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.autoTypeConversion(ExprType.DOUBLE, ExprType.DOUBLE)); + // any string makes become string + Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.LONG, ExprType.STRING)); + Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.STRING, ExprType.LONG)); + Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.DOUBLE, ExprType.STRING)); + Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.STRING, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.STRING, ExprType.STRING)); + // unless it is an array, and those have to be the same + Assert.assertEquals(ExprType.LONG_ARRAY, ExprType.autoTypeConversion(ExprType.LONG_ARRAY, ExprType.LONG_ARRAY)); + Assert.assertEquals( + ExprType.DOUBLE_ARRAY, + ExprType.autoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.DOUBLE_ARRAY) + ); + Assert.assertEquals( + ExprType.STRING_ARRAY, + ExprType.autoTypeConversion(ExprType.STRING_ARRAY, ExprType.STRING_ARRAY) + ); + } + + @Test + public void testAutoConversionArrayMismatchArrays() + { + expectedException.expect(IAE.class); + ExprType.autoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.LONG_ARRAY); + } + + @Test + public void testAutoConversionArrayMismatchArrayScalar() + { + expectedException.expect(IAE.class); + ExprType.autoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.LONG); + } + + @Test + public void testAutoConversionArrayMismatchScalarArray() + { + expectedException.expect(IAE.class); + ExprType.autoTypeConversion(ExprType.STRING, ExprType.LONG_ARRAY); + } + private void assertOutputType(String expression, Expr.InputBindingTypes inputTypes, ExprType outputType) { final Expr expr = Parser.parse(expression, ExprMacroTable.nil(), false); diff --git a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java index e2c420994e3b..aa5bd917bf13 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/IPv4AddressMatchExprMacroTest.java @@ -22,11 +22,9 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.math.expr.ExprType; import org.junit.Assert; import org.junit.Test; -import javax.annotation.Nullable; import java.util.Arrays; import java.util.Collections; @@ -205,12 +203,5 @@ public Expr visit(Shuttle shuttle) { return null; } - - @Nullable - @Override - public ExprType getOutputType(InputBindingTypes inputTypes) - { - return null; - } } } diff --git a/server/src/main/java/org/apache/druid/server/lookup/cache/LookupCoordinatorManager.java b/server/src/main/java/org/apache/druid/server/lookup/cache/LookupCoordinatorManager.java index 981469d12205..7526ecbcc10f 100644 --- a/server/src/main/java/org/apache/druid/server/lookup/cache/LookupCoordinatorManager.java +++ b/server/src/main/java/org/apache/druid/server/lookup/cache/LookupCoordinatorManager.java @@ -518,7 +518,7 @@ private void initializeLookupsConfigWatcher() configManager.set( LOOKUP_CONFIG_KEY, converted, - new AuditInfo("autoTypeConversion", "autoTypeConversion", "127.0.0.1") + new AuditInfo("autoConversion", "autoConversion", "127.0.0.1") ); } } From b0b76aacee80065622ed04dd857d406a9d45d655 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Mon, 14 Sep 2020 12:31:33 -0700 Subject: [PATCH 15/15] add output for contains expr macro, split operation and function auto conversion --- .../druid/math/expr/BinaryOperatorExpr.java | 2 +- .../org/apache/druid/math/expr/ExprType.java | 46 ++++++++--- .../org/apache/druid/math/expr/Function.java | 8 +- .../druid/math/expr/OutputTypeTest.java | 79 +++++++++++++------ .../druid/query/expression/ContainsExpr.java | 12 ++- .../ReductionOperatorConversionHelper.java | 2 +- 6 files changed, 109 insertions(+), 40 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java index ed6b8835868c..9db527bf5b43 100644 --- a/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/BinaryOperatorExpr.java @@ -91,7 +91,7 @@ public BindingAnalysis analyzeInputs() @Override public ExprType getOutputType(InputBindingTypes inputTypes) { - return ExprType.autoTypeConversion(left.getOutputType(inputTypes), right.getOutputType(inputTypes)); + return ExprType.operatorAutoTypeConversion(left.getOutputType(inputTypes), right.getOutputType(inputTypes)); } @Override diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index 28af28a3d7c5..3b9108de9216 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -120,7 +120,7 @@ public static ExprType asArrayType(@Nullable ExprType elementType) * Given 2 'input' types, choose the most appropriate combined type, if possible */ @Nullable - public static ExprType autoTypeConversion(@Nullable ExprType type, @Nullable ExprType other) + public static ExprType operatorAutoTypeConversion(@Nullable ExprType type, @Nullable ExprType other) { if (type == null || other == null) { // cannot auto conversion unknown types @@ -133,21 +133,47 @@ public static ExprType autoTypeConversion(@Nullable ExprType type, @Nullable Exp } return type; } - // if either argument is a string, type becomes a string - if (STRING.equals(type) || STRING.equals(other)) { + // if both arguments are a string, type becomes a string + if (STRING.equals(type) && STRING.equals(other)) { return STRING; } - if (isNumeric(type) && isNumeric(other)) { - // all numbers win over longs - if (LONG.equals(type)) { - return other; + return numericAutoTypeConversion(type, other); + } + + /** + * Given 2 'input' types, choose the most appropriate combined type, if possible + */ + @Nullable + public static ExprType functionAutoTypeConversion(@Nullable ExprType type, @Nullable ExprType other) + { + if (type == null || other == null) { + // cannot auto conversion unknown types + return null; + } + // arrays cannot be auto converted + if (isArray(type) || isArray(other)) { + if (!type.equals(other)) { + throw new IAE("Cannot implicitly cast %s to %s", type, other); } - // floats vs longs would be handled here, but we currently only support doubles... return type; } + // if either argument is a string, type becomes a string + if (STRING.equals(type) || STRING.equals(other)) { + return STRING; + } + + return numericAutoTypeConversion(type, other); + } - // unhandled is unknown - return null; + @Nullable + public static ExprType numericAutoTypeConversion(ExprType type, ExprType other) + { + // all numbers win over longs + if (LONG.equals(type) && LONG.equals(other)) { + return LONG; + } + // floats vs doubles would be handled here, but we currently only support doubles... + return DOUBLE; } } diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 0367014b415f..2e27aab84ae6 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -243,7 +243,7 @@ protected ExprEval eval(double x, double y) @Override public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args) { - return ExprType.autoTypeConversion(args.get(0).getOutputType(inputTypes), args.get(1).getOutputType(inputTypes)); + return ExprType.functionAutoTypeConversion(args.get(0).getOutputType(inputTypes), args.get(1).getOutputType(inputTypes)); } } @@ -402,7 +402,7 @@ public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args { ExprType outputType = ExprType.LONG; for (Expr expr : args) { - outputType = ExprType.autoTypeConversion(outputType, expr.getOutputType(inputTypes)); + outputType = ExprType.functionAutoTypeConversion(outputType, expr.getOutputType(inputTypes)); } return outputType; } @@ -423,7 +423,7 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) ExprType exprType = exprEval.type(); if (isValidType(exprType)) { - outputType = ExprType.autoTypeConversion(outputType, exprType); + outputType = ExprType.functionAutoTypeConversion(outputType, exprType); } if (exprEval.value() != null) { @@ -2283,7 +2283,7 @@ public ExprType getOutputType(Expr.InputBindingTypes inputTypes, List args { ExprType type = ExprType.LONG; for (Expr arg : args) { - type = ExprType.autoTypeConversion(type, arg.getOutputType(inputTypes)); + type = ExprType.functionAutoTypeConversion(type, arg.getOutputType(inputTypes)); } return ExprType.asArrayType(type); } diff --git a/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java b/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java index 61b0b6889d10..7b977d3b9ac2 100644 --- a/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/OutputTypeTest.java @@ -360,37 +360,72 @@ public void testApplyFunctions() assertOutputType("all((x) -> x > 1.2, c)", inputTypes, ExprType.LONG); } + + @Test + public void testOperatorAutoConversion() + { + // nulls output nulls + Assert.assertNull(ExprType.operatorAutoTypeConversion(ExprType.LONG, null)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(null, ExprType.LONG)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(ExprType.DOUBLE, null)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(null, ExprType.DOUBLE)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(ExprType.STRING, null)); + Assert.assertNull(ExprType.operatorAutoTypeConversion(null, ExprType.STRING)); + // only long stays long + Assert.assertEquals(ExprType.LONG, ExprType.operatorAutoTypeConversion(ExprType.LONG, ExprType.LONG)); + // only string stays string + Assert.assertEquals(ExprType.STRING, ExprType.operatorAutoTypeConversion(ExprType.STRING, ExprType.STRING)); + // for operators, doubles is the catch all + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.LONG, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.DOUBLE, ExprType.LONG)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.DOUBLE, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.DOUBLE, ExprType.STRING)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.STRING, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.STRING, ExprType.LONG)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.operatorAutoTypeConversion(ExprType.LONG, ExprType.STRING)); + // unless it is an array, and those have to be the same + Assert.assertEquals(ExprType.LONG_ARRAY, ExprType.operatorAutoTypeConversion(ExprType.LONG_ARRAY, ExprType.LONG_ARRAY)); + Assert.assertEquals( + ExprType.DOUBLE_ARRAY, + ExprType.operatorAutoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.DOUBLE_ARRAY) + ); + Assert.assertEquals( + ExprType.STRING_ARRAY, + ExprType.operatorAutoTypeConversion(ExprType.STRING_ARRAY, ExprType.STRING_ARRAY) + ); + } + @Test - public void testAutoConversion() + public void testFunctionAutoConversion() { // nulls output nulls - Assert.assertNull(ExprType.autoTypeConversion(ExprType.LONG, null)); - Assert.assertNull(ExprType.autoTypeConversion(null, ExprType.LONG)); - Assert.assertNull(ExprType.autoTypeConversion(ExprType.DOUBLE, null)); - Assert.assertNull(ExprType.autoTypeConversion(null, ExprType.DOUBLE)); - Assert.assertNull(ExprType.autoTypeConversion(ExprType.STRING, null)); - Assert.assertNull(ExprType.autoTypeConversion(null, ExprType.STRING)); + Assert.assertNull(ExprType.functionAutoTypeConversion(ExprType.LONG, null)); + Assert.assertNull(ExprType.functionAutoTypeConversion(null, ExprType.LONG)); + Assert.assertNull(ExprType.functionAutoTypeConversion(ExprType.DOUBLE, null)); + Assert.assertNull(ExprType.functionAutoTypeConversion(null, ExprType.DOUBLE)); + Assert.assertNull(ExprType.functionAutoTypeConversion(ExprType.STRING, null)); + Assert.assertNull(ExprType.functionAutoTypeConversion(null, ExprType.STRING)); // only long stays long - Assert.assertEquals(ExprType.LONG, ExprType.autoTypeConversion(ExprType.LONG, ExprType.LONG)); + Assert.assertEquals(ExprType.LONG, ExprType.functionAutoTypeConversion(ExprType.LONG, ExprType.LONG)); // any double makes all doubles - Assert.assertEquals(ExprType.DOUBLE, ExprType.autoTypeConversion(ExprType.LONG, ExprType.DOUBLE)); - Assert.assertEquals(ExprType.DOUBLE, ExprType.autoTypeConversion(ExprType.DOUBLE, ExprType.LONG)); - Assert.assertEquals(ExprType.DOUBLE, ExprType.autoTypeConversion(ExprType.DOUBLE, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.functionAutoTypeConversion(ExprType.LONG, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.functionAutoTypeConversion(ExprType.DOUBLE, ExprType.LONG)); + Assert.assertEquals(ExprType.DOUBLE, ExprType.functionAutoTypeConversion(ExprType.DOUBLE, ExprType.DOUBLE)); // any string makes become string - Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.LONG, ExprType.STRING)); - Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.STRING, ExprType.LONG)); - Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.DOUBLE, ExprType.STRING)); - Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.STRING, ExprType.DOUBLE)); - Assert.assertEquals(ExprType.STRING, ExprType.autoTypeConversion(ExprType.STRING, ExprType.STRING)); + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.LONG, ExprType.STRING)); + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.STRING, ExprType.LONG)); + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.DOUBLE, ExprType.STRING)); + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.STRING, ExprType.DOUBLE)); + Assert.assertEquals(ExprType.STRING, ExprType.functionAutoTypeConversion(ExprType.STRING, ExprType.STRING)); // unless it is an array, and those have to be the same - Assert.assertEquals(ExprType.LONG_ARRAY, ExprType.autoTypeConversion(ExprType.LONG_ARRAY, ExprType.LONG_ARRAY)); + Assert.assertEquals(ExprType.LONG_ARRAY, ExprType.functionAutoTypeConversion(ExprType.LONG_ARRAY, ExprType.LONG_ARRAY)); Assert.assertEquals( ExprType.DOUBLE_ARRAY, - ExprType.autoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.DOUBLE_ARRAY) + ExprType.functionAutoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.DOUBLE_ARRAY) ); Assert.assertEquals( ExprType.STRING_ARRAY, - ExprType.autoTypeConversion(ExprType.STRING_ARRAY, ExprType.STRING_ARRAY) + ExprType.functionAutoTypeConversion(ExprType.STRING_ARRAY, ExprType.STRING_ARRAY) ); } @@ -398,21 +433,21 @@ public void testAutoConversion() public void testAutoConversionArrayMismatchArrays() { expectedException.expect(IAE.class); - ExprType.autoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.LONG_ARRAY); + ExprType.functionAutoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.LONG_ARRAY); } @Test public void testAutoConversionArrayMismatchArrayScalar() { expectedException.expect(IAE.class); - ExprType.autoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.LONG); + ExprType.functionAutoTypeConversion(ExprType.DOUBLE_ARRAY, ExprType.LONG); } @Test public void testAutoConversionArrayMismatchScalarArray() { expectedException.expect(IAE.class); - ExprType.autoTypeConversion(ExprType.STRING, ExprType.LONG_ARRAY); + ExprType.functionAutoTypeConversion(ExprType.STRING, ExprType.LONG_ARRAY); } private void assertOutputType(String expression, Expr.InputBindingTypes inputTypes, ExprType outputType) diff --git a/processing/src/main/java/org/apache/druid/query/expression/ContainsExpr.java b/processing/src/main/java/org/apache/druid/query/expression/ContainsExpr.java index f9550f32429b..f36311229e79 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/ContainsExpr.java +++ b/processing/src/main/java/org/apache/druid/query/expression/ContainsExpr.java @@ -28,6 +28,7 @@ import org.apache.druid.math.expr.ExprType; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.util.function.Function; /** @@ -62,13 +63,20 @@ public ExprEval eval(final Expr.ObjectBinding bindings) if (s == null) { // same behavior as regexp_like. - return ExprEval.of(false, ExprType.LONG); + return ExprEval.ofLongBoolean(false); } else { final boolean doesContain = searchFunction.apply(s); - return ExprEval.of(doesContain, ExprType.LONG); + return ExprEval.ofLongBoolean(doesContain); } } + @Nullable + @Override + public ExprType getOutputType(InputBindingTypes inputTypes) + { + return ExprType.LONG; + } + @Override public Expr visit(Expr.Shuttle shuttle) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java index c80335638474..5aa9a9e3645f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ReductionOperatorConversionHelper.java @@ -38,7 +38,7 @@ private ReductionOperatorConversionHelper() * https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#function_least * * @see org.apache.druid.math.expr.Function.ReduceFunction#apply - * @see org.apache.druid.math.expr.ExprType#autoTypeConversion + * @see org.apache.druid.math.expr.ExprType#functionAutoTypeConversion */ static final SqlReturnTypeInference TYPE_INFERENCE = opBinding -> {