Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/querying/math-expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ See javadoc of java.lang.Math for detailed explanation for each function.
| array_ordinal(arr,long) | returns the array element at the 1 based index supplied, or null for an out of range index |
| array_contains(arr,expr) | returns 1 if the array contains the element specified by expr, or contains all elements specified by expr if expr is an array, else 0 |
| array_overlap(arr1,arr2) | returns 1 if arr1 and arr2 have any elements in common, else 0 |
| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 |
| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 if the expr is non-null, or null if the expr is null |
| array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. |
| array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. |
| array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array |
Expand Down
6 changes: 3 additions & 3 deletions docs/querying/sql-array-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ The following table describes array functions. To learn more about array aggrega
|`ARRAY_LENGTH(arr)`|Returns length of the array expression.|
|`ARRAY_OFFSET(arr, long)`|Returns the array element at the 0-based index supplied, or null for an out of range index.|
|`ARRAY_ORDINAL(arr, long)`|Returns the array element at the 1-based index supplied, or null for an out of range index.|
|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 0.|
|`ARRAY_OVERLAP(arr1, arr2)`|Returns 1 if `arr1` and `arr2` have any elements in common, else 0.|
| `SCALAR_IN_ARRAY(expr, arr)`|Returns 1 if the scalar `expr` is present in `arr`. else 0.|
|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` is an array, returns true if `arr` contains all elements of `expr`. Otherwise returns false.|
|`ARRAY_OVERLAP(arr1, arr2)`|Returns true if `arr1` and `arr2` have any elements in common, else false.|
|`SCALAR_IN_ARRAY(expr, arr)`|Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is non-null or `UNKNOWN` if the scalar `expr` is `NULL`.|
|`ARRAY_OFFSET_OF(arr, expr)`|Returns the 0-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).|
|`ARRAY_ORDINAL_OF(arr, expr)`|Returns the 1-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).|
|`ARRAY_PREPEND(expr, arr)`|Adds `expr` to the beginning of `arr`, the resulting array type determined by the type of `arr`.|
Expand Down
9 changes: 6 additions & 3 deletions docs/querying/sql-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Concatenates array inputs into a single array.

**Function type:** [Array](./sql-array-functions.md)

If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 0.
If `expr` is a scalar type, returns true if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns false.


## ARRAY_LENGTH
Expand Down Expand Up @@ -204,15 +204,18 @@ Returns the 1-based index of the first occurrence of `expr` in the array. If no

**Function type:** [Array](./sql-array-functions.md)

Returns 1 if `arr1` and `arr2` have any elements in common, else 0.|
Returns true if `arr1` and `arr2` have any elements in common, else false.

## SCALAR_IN_ARRAY

`SCALAR_IN_ARRAY(expr, arr)`

**Function type:** [Array](./sql-array-functions.md)

Returns 1 if the scalar `expr` is present in `arr`, else 0.|
Returns true if the scalar `expr` is present in `arr`. Otherwise, returns false if the scalar `expr` is non-null or
`UNKNOWN` if the scalar `expr` is `NULL`.

Returns `UNKNOWN` if `arr` is `NULL`.

## ARRAY_PREPEND

Expand Down
98 changes: 92 additions & 6 deletions processing/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -3724,8 +3725,11 @@ ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
}
}

class ArrayScalarInFunction extends ArrayScalarFunction
class ScalarInArrayFunction extends ArrayScalarFunction
{
private static final int SCALAR_ARG = 0;
private static final int ARRAY_ARG = 1;

@Override
public String name()
{
Expand All @@ -3742,23 +3746,105 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<E
@Override
Expr getScalarArgument(List<Expr> args)
{
return args.get(0);
return args.get(SCALAR_ARG);
}

@Override
Expr getArrayArgument(List<Expr> args)
{
return args.get(1);
return args.get(ARRAY_ARG);
}

@Override
ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
ExprEval doApply(ExprEval arrayEval, ExprEval scalarEval)
{
final Object[] array = arrayExpr.castTo(scalarExpr.asArrayType()).asArray();
final Object[] array = arrayEval.asArray();
if (array == null) {
return ExprEval.ofLong(null);
}
return ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarExpr.value()));

if (scalarEval.value() == null) {
return Arrays.asList(array).contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the array contains null - then doesn't all places when false would have been returned should return null? for example:

c IN (2,null)
c = 2 OR c = null
c = 2 OR null

this last rewrite essentially mean that false is no more
....other way around is to think about = as IS NOT DISTINCT FROM ; however in that case the return value will never be null as <=> is a 2-valued

this is not a blocking comment; I was just wondering...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that if the array contains null then the comparison acts like IS NOT DISTINCT FROM (always returns true or false), whereas if the array does not contain null, then the comparison acts like = (returns null if the lhs is null). It's the same way the native in and inType filters behave, so this is designed to align with those.

A future patch would convert from SQL IN to this SCALAR_IN_ARRAY function. We'll need to make sure to handle NULL appropriately in that patch-- it would not be ok to rewrite c IN (2, NULL) to SCALAR_IN_ARRAY(c, ARRAY[2, NULL]).

}

final ExpressionType matchType = arrayEval.elementType();
final ExprEval<?> scalarEvalForComparison = ExprEval.castForEqualityComparison(scalarEval, matchType);

if (scalarEvalForComparison == null) {
return ExprEval.ofLongBoolean(false);
} else {
return ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarEvalForComparison.value()));
}
}

@Override
public Function asSingleThreaded(List<Expr> args, Expr.InputBindingInspector inspector)
{
if (args.get(ARRAY_ARG).isLiteral()) {
final ExpressionType lhsType = args.get(SCALAR_ARG).getOutputType(inspector);
if (lhsType == null) {
return this;
}

final ExprEval<?> arrayEval = args.get(ARRAY_ARG).eval(InputBindings.nilBindings());
final Object[] arrayValues = arrayEval.asArray();

if (arrayValues == null) {
return WithNullArray.INSTANCE;
} else {
final Set<Object> matchValues = new HashSet<>(Arrays.asList(arrayValues));
final ExpressionType matchType = arrayEval.elementType();
return new WithConstantArray(matchValues, matchType);
}
}
return this;
}

/**
* Specialization of {@link ScalarInArrayFunction} for null {@link #ARRAY_ARG}.
*/
private static final class WithNullArray extends ScalarInArrayFunction
{
private static final WithNullArray INSTANCE = new WithNullArray();

@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
return ExprEval.of(null);
}
}

/**
* Specialization of {@link ScalarInArrayFunction} for constant, non-null {@link #ARRAY_ARG}.
*/
private static final class WithConstantArray extends ScalarInArrayFunction
{
private final Set<Object> matchValues;
private final ExpressionType matchType;

public WithConstantArray(Set<Object> matchValues, ExpressionType matchType)
{
this.matchValues = Preconditions.checkNotNull(matchValues, "matchValues");
this.matchType = Preconditions.checkNotNull(matchType, "matchType");
}

@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
final ExprEval scalarEval = args.get(SCALAR_ARG).eval(bindings);

if (scalarEval.value() == null) {
return matchValues.contains(null) ? ExprEval.ofLongBoolean(true) : ExprEval.ofLong(null);
}

final ExprEval<?> scalarEvalForComparison = ExprEval.castForEqualityComparison(scalarEval, matchType);

if (scalarEvalForComparison == null) {
return ExprEval.ofLongBoolean(false);
} else {
return ExprEval.ofLongBoolean(matchValues.contains(scalarEvalForComparison.value()));
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,15 @@ public void testArrayOrdinalOf()
public void testScalarInArray()
{
assertExpr("scalar_in_array(2, [1, 2, 3])", 1L);
assertExpr("scalar_in_array(2.1, [1, 2, 3])", 0L);
assertExpr("scalar_in_array(2, [1.1, 2.1, 3.1])", 0L);
assertExpr("scalar_in_array(2, [1.1, 2.0, 3.1])", 1L);
assertExpr("scalar_in_array(4, [1, 2, 3])", 0L);
assertExpr("scalar_in_array(b, [3, 4])", 0L);
assertExpr("scalar_in_array(1, null)", null);
assertExpr("scalar_in_array(null, null)", null);
assertExpr("scalar_in_array(null, [1, null, 2])", 1L);
assertExpr("scalar_in_array(null, [1, 2])", 0L);
assertExpr("scalar_in_array(null, [1, 2])", null);
}

@Test
Expand Down Expand Up @@ -1290,6 +1293,13 @@ private void assertExpr(
final Expr singleThreaded = Expr.singleThreaded(expr, bindings);
Assert.assertEquals(singleThreaded.stringify(), expectedResult, singleThreaded.eval(bindings).value());

final Expr singleThreadedNoFlatten = Expr.singleThreaded(exprNoFlatten, bindings);
Assert.assertEquals(
singleThreadedNoFlatten.stringify(),
expectedResult,
singleThreadedNoFlatten.eval(bindings).value()
);

Assert.assertEquals(expr.stringify(), roundTrip.stringify());
Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());
Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey());
Expand Down