From 57d03cd7df994721ce988228a23cd38cd83fd5c3 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Thu, 18 Apr 2024 22:01:09 -0700 Subject: [PATCH 1/3] Four changes to scalar_in_array as follow-ups to #16306: 1) Align behavior for `null` scalars to the behavior of the native `in` and `inType` filters: return `true` if the array itself contains null, else return `null`. 2) Rename the class to more closely match the function name. 3) Add a specialization for constant arrays, where we build a `HashSet`. 4) Use `castForEqualityComparison` to properly handle cross-type comparisons. Additional tests verify comparisons between LONG and DOUBLE are now handled properly. --- docs/querying/math-expr.md | 2 +- docs/querying/sql-array-functions.md | 6 +- docs/querying/sql-functions.md | 9 +- .../org/apache/druid/math/expr/Function.java | 88 +++++++++++++++++-- .../apache/druid/math/expr/FunctionTest.java | 12 ++- 5 files changed, 103 insertions(+), 14 deletions(-) diff --git a/docs/querying/math-expr.md b/docs/querying/math-expr.md index d5255544a03e..d8de55dc00ef 100644 --- a/docs/querying/math-expr.md +++ b/docs/querying/math-expr.md @@ -184,7 +184,7 @@ See javadoc of java.lang.Math for detailed explanation for each function. | array_ordinal(arr,long) | returns the array element at the 1 based index supplied, or null for an out of range index | | array_contains(arr,expr) | returns 1 if the array contains the element specified by expr, or contains all elements specified by expr if expr is an array, else 0 | | array_overlap(arr1,arr2) | returns 1 if arr1 and arr2 have any elements in common, else 0 | -| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 | +| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 if the expr is nonnull or null if the expr is null | | array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. | | array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. | | array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array | diff --git a/docs/querying/sql-array-functions.md b/docs/querying/sql-array-functions.md index ab84c664dee7..b29e8a1bfc0e 100644 --- a/docs/querying/sql-array-functions.md +++ b/docs/querying/sql-array-functions.md @@ -52,9 +52,9 @@ The following table describes array functions. To learn more about array aggrega |`ARRAY_LENGTH(arr)`|Returns length of the array expression.| |`ARRAY_OFFSET(arr, long)`|Returns the array element at the 0-based index supplied, or null for an out of range index.| |`ARRAY_ORDINAL(arr, long)`|Returns the array element at the 1-based index supplied, or null for an out of range index.| -|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 0.| -|`ARRAY_OVERLAP(arr1, arr2)`|Returns 1 if `arr1` and `arr2` have any elements in common, else 0.| -| `SCALAR_IN_ARRAY(expr, arr)`|Returns 1 if the scalar `expr` is present in `arr`. else 0.| +|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` is an array, returns true if `arr` contains all elements of `expr`. Otherwise returns false.| +|`ARRAY_OVERLAP(arr1, arr2)`|Returns true if `arr1` and `arr2` have any elements in common, else false.| +|`SCALAR_IN_ARRAY(expr, arr)`|Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is nonnull or `UNKNOWN` if the scalar `expr` is `NULL`.| |`ARRAY_OFFSET_OF(arr, expr)`|Returns the 0-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).| |`ARRAY_ORDINAL_OF(arr, expr)`|Returns the 1-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).| |`ARRAY_PREPEND(expr, arr)`|Adds `expr` to the beginning of `arr`, the resulting array type determined by the type of `arr`.| diff --git a/docs/querying/sql-functions.md b/docs/querying/sql-functions.md index 093e7ce60fde..8a82a00c76cf 100644 --- a/docs/querying/sql-functions.md +++ b/docs/querying/sql-functions.md @@ -156,7 +156,7 @@ Concatenates array inputs into a single array. **Function type:** [Array](./sql-array-functions.md) -If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 0. +If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns false. ## ARRAY_LENGTH @@ -204,7 +204,7 @@ Returns the 1-based index of the first occurrence of `expr` in the array. If no **Function type:** [Array](./sql-array-functions.md) -Returns 1 if `arr1` and `arr2` have any elements in common, else 0.| +Returns true if `arr1` and `arr2` have any elements in common, else false. ## SCALAR_IN_ARRAY @@ -212,7 +212,10 @@ Returns 1 if `arr1` and `arr2` have any elements in common, else 0.| **Function type:** [Array](./sql-array-functions.md) -Returns 1 if the scalar `expr` is present in `arr`, else 0.| +Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is nonnull or +`UNKNOWN` if the scalar `expr` is `NULL`. + +Returns `UNKNOWN` if `arr` is `NULL`. ## ARRAY_PREPEND diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index aa54409e132e..b42f46394780 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -45,6 +45,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -3724,8 +3725,11 @@ ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) } } - class ArrayScalarInFunction extends ArrayScalarFunction + class ScalarInArrayFunction extends ArrayScalarFunction { + private static final int SCALAR_ARG = 0; + private static final int ARRAY_ARG = 1; + @Override public String name() { @@ -3742,23 +3746,95 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) { - return args.get(0); + return args.get(SCALAR_ARG); } @Override Expr getArrayArgument(List args) { - return args.get(1); + return args.get(ARRAY_ARG); } @Override - ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) + ExprEval doApply(ExprEval arrayEval, ExprEval scalarEval) { - final Object[] array = arrayExpr.castTo(scalarExpr.asArrayType()).asArray(); + final Object[] array = arrayEval.asArray(); if (array == null) { return ExprEval.ofLong(null); } - return ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarExpr.value())); + + if (scalarEval.value() == null) { + return Arrays.asList(array).contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null); + } + + final ExpressionType matchType = arrayEval.elementType(); + final ExprEval scalarEvalForComparison = ExprEval.castForEqualityComparison(scalarEval, matchType); + + if (scalarEvalForComparison == null) { + return ExprEval.ofLongBoolean(false); + } else { + return ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarEvalForComparison.value())); + } + } + + @Override + public Function asSingleThreaded(List args, Expr.InputBindingInspector inspector) + { + if (args.get(ARRAY_ARG).isLiteral()) { + final ExpressionType lhsType = args.get(SCALAR_ARG).getOutputType(inspector); + if (lhsType == null) { + return this; + } + + final ExprEval rhsEval = args.get(ARRAY_ARG).eval(InputBindings.nilBindings()); + return new WithConstantArray(rhsEval); + } + return this; + } + + /** + * Specialization of {@link ScalarInArrayFunction} for constant {@link #ARRAY_ARG}. + */ + private static final class WithConstantArray extends ScalarInArrayFunction + { + private final Set matchValues; + + @Nullable + private final ExpressionType matchType; + + public WithConstantArray(final ExprEval arrayEval) + { + final Object[] arrayValues = arrayEval.asArray(); + + if (arrayValues == null) { + matchValues = Collections.emptySet(); + matchType = null; + } else { + matchValues = new HashSet<>(); + Collections.addAll(matchValues, arrayValues); + matchType = arrayEval.elementType(); + } + } + + @Override + ExprEval doApply(final ExprEval arrayExpr, final ExprEval scalarEval) + { + if (matchType == null) { + return ExprEval.ofLong(null); + } + + if (scalarEval.value() == null) { + return matchValues.contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null); + } + + final ExprEval scalarEvalForComparison = ExprEval.castForEqualityComparison(scalarEval, matchType); + + if (scalarEvalForComparison == null) { + return ExprEval.ofLongBoolean(false); + } else { + return ExprEval.ofLongBoolean(matchValues.contains(scalarEvalForComparison.value())); + } + } } } diff --git a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java index da81a556b0b7..d6143fd1fa15 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -373,12 +373,15 @@ public void testArrayOrdinalOf() public void testScalarInArray() { assertExpr("scalar_in_array(2, [1, 2, 3])", 1L); + assertExpr("scalar_in_array(2.1, [1, 2, 3])", 0L); + assertExpr("scalar_in_array(2, [1.1, 2.1, 3.1])", 0L); + assertExpr("scalar_in_array(2, [1.1, 2.0, 3.1])", 1L); assertExpr("scalar_in_array(4, [1, 2, 3])", 0L); assertExpr("scalar_in_array(b, [3, 4])", 0L); assertExpr("scalar_in_array(1, null)", null); assertExpr("scalar_in_array(null, null)", null); assertExpr("scalar_in_array(null, [1, null, 2])", 1L); - assertExpr("scalar_in_array(null, [1, 2])", 0L); + assertExpr("scalar_in_array(null, [1, 2])", null); } @Test @@ -1290,6 +1293,13 @@ private void assertExpr( final Expr singleThreaded = Expr.singleThreaded(expr, bindings); Assert.assertEquals(singleThreaded.stringify(), expectedResult, singleThreaded.eval(bindings).value()); + final Expr singleThreadedNoFlatten = Expr.singleThreaded(exprNoFlatten, bindings); + Assert.assertEquals( + singleThreadedNoFlatten.stringify(), + expectedResult, + singleThreadedNoFlatten.eval(bindings).value() + ); + Assert.assertEquals(expr.stringify(), roundTrip.stringify()); Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify()); Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey()); From a135b75c29cf3d14ec6289febfb96f03d427b27c Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Sat, 20 Apr 2024 12:14:50 -0700 Subject: [PATCH 2/3] Fix spelling. --- docs/querying/math-expr.md | 2 +- docs/querying/sql-array-functions.md | 2 +- docs/querying/sql-functions.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/querying/math-expr.md b/docs/querying/math-expr.md index d8de55dc00ef..38ced649c06c 100644 --- a/docs/querying/math-expr.md +++ b/docs/querying/math-expr.md @@ -184,7 +184,7 @@ See javadoc of java.lang.Math for detailed explanation for each function. | array_ordinal(arr,long) | returns the array element at the 1 based index supplied, or null for an out of range index | | array_contains(arr,expr) | returns 1 if the array contains the element specified by expr, or contains all elements specified by expr if expr is an array, else 0 | | array_overlap(arr1,arr2) | returns 1 if arr1 and arr2 have any elements in common, else 0 | -| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 if the expr is nonnull or null if the expr is null | +| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 if the expr is non-null, or null if the expr is null | | array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. | | array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. | | array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array | diff --git a/docs/querying/sql-array-functions.md b/docs/querying/sql-array-functions.md index b29e8a1bfc0e..7b0f2112b6f7 100644 --- a/docs/querying/sql-array-functions.md +++ b/docs/querying/sql-array-functions.md @@ -54,7 +54,7 @@ The following table describes array functions. To learn more about array aggrega |`ARRAY_ORDINAL(arr, long)`|Returns the array element at the 1-based index supplied, or null for an out of range index.| |`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` is an array, returns true if `arr` contains all elements of `expr`. Otherwise returns false.| |`ARRAY_OVERLAP(arr1, arr2)`|Returns true if `arr1` and `arr2` have any elements in common, else false.| -|`SCALAR_IN_ARRAY(expr, arr)`|Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is nonnull or `UNKNOWN` if the scalar `expr` is `NULL`.| +|`SCALAR_IN_ARRAY(expr, arr)`|Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is non-null or `UNKNOWN` if the scalar `expr` is `NULL`.| |`ARRAY_OFFSET_OF(arr, expr)`|Returns the 0-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).| |`ARRAY_ORDINAL_OF(arr, expr)`|Returns the 1-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).| |`ARRAY_PREPEND(expr, arr)`|Adds `expr` to the beginning of `arr`, the resulting array type determined by the type of `arr`.| diff --git a/docs/querying/sql-functions.md b/docs/querying/sql-functions.md index 8a82a00c76cf..883f3b209ace 100644 --- a/docs/querying/sql-functions.md +++ b/docs/querying/sql-functions.md @@ -212,7 +212,7 @@ Returns true if `arr1` and `arr2` have any elements in common, else false. **Function type:** [Array](./sql-array-functions.md) -Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is nonnull or +Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is non-null or `UNKNOWN` if the scalar `expr` is `NULL`. Returns `UNKNOWN` if `arr` is `NULL`. From 623d1653bdbf37160123ddccd0a36faf82ca3a08 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Fri, 26 Apr 2024 08:58:36 -0700 Subject: [PATCH 3/3] Adjustments from review. --- .../org/apache/druid/math/expr/Function.java | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index b42f46394780..48bc0570aaa3 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -3786,42 +3786,52 @@ public Function asSingleThreaded(List args, Expr.InputBindingInspector ins return this; } - final ExprEval rhsEval = args.get(ARRAY_ARG).eval(InputBindings.nilBindings()); - return new WithConstantArray(rhsEval); + final ExprEval arrayEval = args.get(ARRAY_ARG).eval(InputBindings.nilBindings()); + final Object[] arrayValues = arrayEval.asArray(); + + if (arrayValues == null) { + return WithNullArray.INSTANCE; + } else { + final Set matchValues = new HashSet<>(Arrays.asList(arrayValues)); + final ExpressionType matchType = arrayEval.elementType(); + return new WithConstantArray(matchValues, matchType); + } } return this; } /** - * Specialization of {@link ScalarInArrayFunction} for constant {@link #ARRAY_ARG}. + * Specialization of {@link ScalarInArrayFunction} for null {@link #ARRAY_ARG}. + */ + private static final class WithNullArray extends ScalarInArrayFunction + { + private static final WithNullArray INSTANCE = new WithNullArray(); + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + return ExprEval.of(null); + } + } + + /** + * Specialization of {@link ScalarInArrayFunction} for constant, non-null {@link #ARRAY_ARG}. */ private static final class WithConstantArray extends ScalarInArrayFunction { private final Set matchValues; - - @Nullable private final ExpressionType matchType; - public WithConstantArray(final ExprEval arrayEval) + public WithConstantArray(Set matchValues, ExpressionType matchType) { - final Object[] arrayValues = arrayEval.asArray(); - - if (arrayValues == null) { - matchValues = Collections.emptySet(); - matchType = null; - } else { - matchValues = new HashSet<>(); - Collections.addAll(matchValues, arrayValues); - matchType = arrayEval.elementType(); - } + this.matchValues = Preconditions.checkNotNull(matchValues, "matchValues"); + this.matchType = Preconditions.checkNotNull(matchType, "matchType"); } @Override - ExprEval doApply(final ExprEval arrayExpr, final ExprEval scalarEval) + public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (matchType == null) { - return ExprEval.ofLong(null); - } + final ExprEval scalarEval = args.get(SCALAR_ARG).eval(bindings); if (scalarEval.value() == null) { return matchValues.contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null);