diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlExpressionBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlExpressionBenchmark.java index 7b3e71aecec8..c04d4e31f959 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlExpressionBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlExpressionBenchmark.java @@ -213,7 +213,12 @@ public String getFormatString() // 40: regex filtering "SELECT string4, COUNT(*) FROM foo WHERE REGEXP_EXTRACT(string1, '^1') IS NOT NULL OR REGEXP_EXTRACT('Z' || string2, '^Z2') IS NOT NULL GROUP BY 1", // 41: complicated filtering - "SELECT string2, SUM(long1) FROM foo WHERE string1 = '1000' AND string5 LIKE '%1%' AND (string3 in ('1', '10', '20', '22', '32') AND long2 IN (1, 19, 21, 23, 25, 26, 46) AND double3 < 1010.0 AND double3 > 1000.0 AND (string4 = '1' OR REGEXP_EXTRACT(string1, '^1') IS NOT NULL OR REGEXP_EXTRACT('Z' || string2, '^Z2') IS NOT NULL)) GROUP BY 1 ORDER BY 2" + "SELECT string2, SUM(long1) FROM foo WHERE string1 = '1000' AND string5 LIKE '%1%' AND (string3 in ('1', '10', '20', '22', '32') AND long2 IN (1, 19, 21, 23, 25, 26, 46) AND double3 < 1010.0 AND double3 > 1000.0 AND (string4 = '1' OR REGEXP_EXTRACT(string1, '^1') IS NOT NULL OR REGEXP_EXTRACT('Z' || string2, '^Z2') IS NOT NULL)) GROUP BY 1 ORDER BY 2", + // 42: array_contains expr + "SELECT ARRAY_CONTAINS(\"multi-string3\", 100) FROM foo", + "SELECT ARRAY_CONTAINS(\"multi-string3\", ARRAY[1, 2, 10, 11, 20, 22, 30, 33, 40, 44, 50, 55, 100]) FROM foo", + "SELECT ARRAY_OVERLAP(\"multi-string3\", ARRAY[1, 100]) FROM foo", + "SELECT ARRAY_OVERLAP(\"multi-string3\", ARRAY[1, 2, 10, 11, 20, 22, 30, 33, 40, 44, 50, 55, 100]) FROM foo" ); @Param({"5000000"}) @@ -275,7 +280,11 @@ public String getFormatString() "38", "39", "40", - "41" + "41", + "42", + "43", + "44", + "45" }) private String query; @@ -369,8 +378,8 @@ public void setup() .writeValueAsString(jsonMapper.readValue((String) planResult[0], List.class)) ); } - catch (JsonProcessingException e) { - throw new RuntimeException(e); + catch (JsonProcessingException ignored) { + } try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, ImmutableMap.of())) { @@ -384,6 +393,9 @@ public void setup() } log.info("Total result row count:" + rowCounter); } + catch (Throwable ignored) { + + } } @TearDown(Level.Trial) diff --git a/processing/src/main/java/org/apache/druid/math/expr/ConstantExpr.java b/processing/src/main/java/org/apache/druid/math/expr/ConstantExpr.java index 8dc66b4306ef..85cebd478ee0 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/ConstantExpr.java +++ b/processing/src/main/java/org/apache/druid/math/expr/ConstantExpr.java @@ -43,7 +43,7 @@ * {@link Expr}. */ @Immutable -abstract class ConstantExpr implements Expr, Expr.SingleThreadSpecializable +abstract class ConstantExpr implements Expr { final ExpressionType outputType; @@ -122,7 +122,7 @@ public String toString() } @Override - public Expr toSingleThreaded() + public Expr asSingleThreaded(InputBindingInspector inspector) { return new ExprEvalBasedConstantExpr(realEval()); } diff --git a/processing/src/main/java/org/apache/druid/math/expr/Expr.java b/processing/src/main/java/org/apache/druid/math/expr/Expr.java index 3eb7ac467e58..8fa025cf9145 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Expr.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Expr.java @@ -182,6 +182,16 @@ default boolean canVectorize(InputBindingInspector inspector) return false; } + /** + * Possibly convert the {@link Expr} into an optimized, possibly not thread-safe {@link Expr}. Does not convert + * child {@link Expr}. Most callers should use {@link Expr#singleThreaded(Expr, InputBindingInspector)} to convert + * an entire tree, which delegates to this method to translate individual nodes. + */ + default Expr asSingleThreaded(InputBindingInspector inspector) + { + return this; + } + /** * Builds a 'vectorized' expression processor, that can operate on batches of input values for use in vectorized * query engines. @@ -769,30 +779,9 @@ private static Set map( * Returns the single-threaded version of the given expression tree. * * Nested expressions in the subtree are also optimized. - * Individual {@link Expr}-s which have a singleThreaded implementation via {@link SingleThreadSpecializable} are substituted. */ - static Expr singleThreaded(Expr expr) + static Expr singleThreaded(Expr expr, InputBindingInspector inspector) { - return expr.visit( - node -> { - if (node instanceof SingleThreadSpecializable) { - SingleThreadSpecializable canBeSingleThreaded = (SingleThreadSpecializable) node; - return canBeSingleThreaded.toSingleThreaded(); - } else { - return node; - } - } - ); - } - - /** - * Implementing this interface allows to provide a non-threadsafe {@link Expr} implementation. - */ - interface SingleThreadSpecializable - { - /** - * Non-threadsafe version of this expression. - */ - Expr toSingleThreaded(); + return expr.visit(node -> node.asSingleThreaded(inspector)); } } diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index 2c8a26759f34..e8ff45d90f13 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -21,6 +21,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; +import it.unimi.dsi.fastutil.objects.ObjectAVLTreeSet; import org.apache.druid.common.config.NullHandling; import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.DateTimes; @@ -63,6 +64,14 @@ @SuppressWarnings("unused") public interface Function extends NamedFunction { + /** + * Possibly convert a {@link Function} into an optimized, possibly not thread-safe {@link Function}. + */ + default Function asSingleThreaded(List args, Expr.InputBindingInspector inspector) + { + return this; + } + /** * Evaluate the function, given a list of arguments and a set of bindings to provide values for {@link IdentifierExpr}. */ @@ -3243,6 +3252,67 @@ public Set getArrayInputs(List args) } } + /** + * Primarily internal helper function used to coerce null, [], and [null] into [null], similar to the logic done + * by {@link org.apache.druid.segment.virtual.ExpressionSelectors#supplierFromDimensionSelector} when the 3rd + * argument is true, which is done when implicitly mapping scalar functions over mvd values. + */ + class MultiValueStringHarmonizeNullsFunction implements Function + { + @Override + public String name() + { + return "mv_harmonize_nulls"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval eval = args.get(0).eval(bindings).castTo(ExpressionType.STRING_ARRAY); + if (eval.value() == null || eval.asArray().length == 0) { + return ExprEval.ofArray(ExpressionType.STRING_ARRAY, new Object[]{null}); + } + return eval; + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 1); + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.STRING_ARRAY; + } + + @Override + public boolean hasArrayInputs() + { + return true; + } + + @Override + public boolean hasArrayOutput() + { + return true; + } + + @Override + public Set getScalarInputs(List args) + { + return Collections.emptySet(); + } + + @Override + public Set getArrayInputs(List args) + { + return ImmutableSet.copyOf(args); + } + } + class ArrayToMultiValueStringFunction implements Function { @Override @@ -3757,7 +3827,7 @@ Object[] merge(TypeSignature elementType, T[] array1, T[] array2) } } - class ArrayContainsFunction extends ArraysFunction + class ArrayContainsFunction implements Function { @Override public String name() @@ -3779,15 +3849,124 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args, Expr.ObjectBinding bindings) { + final ExprEval lhsExpr = args.get(0).eval(bindings); + final ExprEval rhsExpr = args.get(1).eval(bindings); + final Object[] array1 = lhsExpr.asArray(); - final Object[] array2 = rhsExpr.asArray(); - return ExprEval.ofLongBoolean(Arrays.asList(array1).containsAll(Arrays.asList(array2))); + if (array1 == null) { + return ExprEval.ofLong(null); + } + ExpressionType array1Type = lhsExpr.asArrayType(); + + if (rhsExpr.isArray()) { + final Object[] array2 = rhsExpr.castTo(array1Type).asArray(); + if (array2 == null) { + return ExprEval.ofLongBoolean(false); + } + return ExprEval.ofLongBoolean(Arrays.asList(array1).containsAll(Arrays.asList(array2))); + } else { + final Object elem = rhsExpr.castTo((ExpressionType) array1Type.getElementType()).value(); + return ExprEval.ofLongBoolean(Arrays.asList(array1).contains(elem)); + } + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 2); + } + + @Override + public Set getScalarInputs(List args) + { + return Collections.emptySet(); + } + + @Override + public Set getArrayInputs(List args) + { + return ImmutableSet.copyOf(args); + } + + @Override + public boolean hasArrayInputs() + { + return true; + } + + @Override + public Function asSingleThreaded(List args, Expr.InputBindingInspector inspector) + { + if (args.get(1).isLiteral()) { + final ExpressionType lhsType = args.get(0).getOutputType(inspector); + if (lhsType == null) { + return this; + } + final ExpressionType lhsArrayType = ExpressionType.asArrayType(lhsType); + final ExprEval rhsEval = args.get(1).eval(InputBindings.nilBindings()); + if (rhsEval.isArray()) { + final Object[] rhsArray = rhsEval.castTo(lhsArrayType).asArray(); + return new ContainsConstantArray(rhsArray); + } else { + final Object val = rhsEval.castTo((ExpressionType) lhsArrayType.getElementType()).value(); + return new ContainsConstantScalar(val); + } + } + return this; + } + + private static final class ContainsConstantArray extends ArrayContainsFunction + { + @Nullable + final Object[] rhsArray; + + public ContainsConstantArray(@Nullable Object[] rhsArray) + { + this.rhsArray = rhsArray; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval lhsExpr = args.get(0).eval(bindings); + final Object[] array1 = lhsExpr.asArray(); + if (array1 == null) { + return ExprEval.ofLong(null); + } + if (rhsArray == null) { + return ExprEval.ofLongBoolean(false); + } + return ExprEval.ofLongBoolean(Arrays.asList(array1).containsAll(Arrays.asList(rhsArray))); + } + } + + private static final class ContainsConstantScalar extends ArrayContainsFunction + { + @Nullable + final Object val; + + public ContainsConstantScalar(@Nullable Object val) + { + this.val = val; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval lhsExpr = args.get(0).eval(bindings); + + final Object[] array1 = lhsExpr.asArray(); + if (array1 == null) { + return ExprEval.ofLong(null); + } + return ExprEval.ofLongBoolean(Arrays.asList(array1).contains(val)); + } } } - class ArrayOverlapFunction extends ArraysFunction + class ArrayOverlapFunction implements Function { @Override public String name() @@ -3803,15 +3982,110 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args, Expr.ObjectBinding bindings) { - final Object[] array1 = lhsExpr.asArray(); - final List array2 = Arrays.asList(rhsExpr.asArray()); - boolean any = false; + final ExprEval arrayExpr1 = args.get(0).eval(bindings); + final ExprEval arrayExpr2 = args.get(1).eval(bindings); + + final Object[] array1 = arrayExpr1.asArray(); + if (array1 == null) { + return ExprEval.ofLong(null); + } + ExpressionType array1Type = arrayExpr1.asArrayType(); + final Object[] array2 = arrayExpr2.castTo(array1Type).asArray(); + if (array2 == null) { + return ExprEval.ofLongBoolean(false); + } + List asList = Arrays.asList(array2); for (Object check : array1) { - any |= array2.contains(check); + if (asList.contains(check)) { + return ExprEval.ofLongBoolean(true); + } + } + return ExprEval.ofLongBoolean(false); + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 2); + } + + @Override + public Set getScalarInputs(List args) + { + return Collections.emptySet(); + } + + @Override + public Set getArrayInputs(List args) + { + return ImmutableSet.copyOf(args); + } + + @Override + public boolean hasArrayInputs() + { + return true; + } + + @Override + public Function asSingleThreaded(List args, Expr.InputBindingInspector inspector) + { + if (args.get(1).isLiteral()) { + final ExpressionType lhsType = args.get(0).getOutputType(inspector); + if (lhsType == null) { + return this; + } + final ExpressionType lhsArrayType = ExpressionType.asArrayType(lhsType); + final ExprEval rhsEval = args.get(1).eval(InputBindings.nilBindings()); + final Object[] rhsArray = rhsEval.castTo(lhsArrayType).asArray(); + if (rhsArray == null) { + return new ArrayOverlapFunction() + { + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval arrayExpr1 = args.get(0).eval(bindings); + final Object[] array1 = arrayExpr1.asArray(); + if (array1 == null) { + return ExprEval.ofLong(null); + } + return ExprEval.ofLongBoolean(false); + } + }; + } + final Set set = new ObjectAVLTreeSet<>(lhsArrayType.getElementType().getNullableStrategy()); + set.addAll(Arrays.asList(rhsArray)); + return new OverlapConstantArray(set); + } + return this; + } + + private static final class OverlapConstantArray extends ArrayContainsFunction + { + final Set set; + + public OverlapConstantArray(Set set) + { + this.set = set; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval lhsExpr = args.get(0).eval(bindings); + final Object[] array1 = lhsExpr.asArray(); + if (array1 == null) { + return ExprEval.ofLong(null); + } + for (Object check : array1) { + if (set.contains(check)) { + return ExprEval.ofLongBoolean(true); + } + } + return ExprEval.ofLongBoolean(false); } - return ExprEval.ofLongBoolean(any); } } diff --git a/processing/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java b/processing/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java index a87d5d9cd683..6fe762a63374 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java +++ b/processing/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java @@ -179,6 +179,16 @@ class FunctionExpr implements Expr function.validateArguments(args); } + @Override + public Expr asSingleThreaded(InputBindingInspector inspector) + { + return new FunctionExpr( + function.asSingleThreaded(args, inspector), + name, + args + ); + } + @Override public String toString() { diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index a1cd7cd32d6e..ca2cda3e0e14 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -195,7 +195,10 @@ public static ColumnValueSelector makeExprEvalSelector( Expr expression ) { - ExpressionPlan plan = ExpressionPlanner.plan(columnSelectorFactory, Expr.singleThreaded(expression)); + ExpressionPlan plan = ExpressionPlanner.plan( + columnSelectorFactory, + Expr.singleThreaded(expression, columnSelectorFactory) + ); final RowIdSupplier rowIdSupplier = columnSelectorFactory.getRowIdSupplier(); if (plan.is(ExpressionPlan.Trait.SINGLE_INPUT_SCALAR)) { @@ -243,7 +246,10 @@ public static DimensionSelector makeDimensionSelector( @Nullable final ExtractionFn extractionFn ) { - final ExpressionPlan plan = ExpressionPlanner.plan(columnSelectorFactory, expression); + final ExpressionPlan plan = ExpressionPlanner.plan( + columnSelectorFactory, + Expr.singleThreaded(expression, columnSelectorFactory) + ); if (plan.any(ExpressionPlan.Trait.SINGLE_INPUT_SCALAR, ExpressionPlan.Trait.SINGLE_INPUT_MAPPABLE)) { final String column = plan.getSingleInputName(); diff --git a/processing/src/test/java/org/apache/druid/math/expr/ConstantExprTest.java b/processing/src/test/java/org/apache/druid/math/expr/ConstantExprTest.java index 9237d5e92809..a7812f2d160e 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/ConstantExprTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/ConstantExprTest.java @@ -24,14 +24,11 @@ import org.apache.druid.segment.column.TypeStrategiesTest.NullableLongPair; import org.apache.druid.segment.column.TypeStrategiesTest.NullableLongPairTypeStrategy; import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; import org.junit.Test; import java.math.BigInteger; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; - public class ConstantExprTest extends InitializedNullHandlingTest { @Test @@ -40,7 +37,6 @@ public void testLongArrayExpr() ArrayExpr arrayExpr = new ArrayExpr(ExpressionType.LONG_ARRAY, new Long[] {1L, 3L}); checkExpr( arrayExpr, - true, "[1, 3]", "ARRAY[1, 3]", arrayExpr @@ -53,7 +49,6 @@ public void testStringArrayExpr() ArrayExpr arrayExpr = new ArrayExpr(ExpressionType.STRING_ARRAY, new String[] {"foo", "bar"}); checkExpr( arrayExpr, - true, "[foo, bar]", "ARRAY['foo', 'bar']", arrayExpr @@ -65,7 +60,6 @@ public void testBigIntegerExpr() { checkExpr( new BigIntegerExpr(BigInteger.valueOf(37L)), - true, "37", "37", // after reparsing it will become a LongExpr @@ -83,7 +77,6 @@ public void testComplexExpr() ); checkExpr( complexExpr, - true, "Pair{lhs=21, rhs=37}", "complex_decode_base64('nullablePair', 'AAAAAAAAAAAVAAAAAAAAAAAl')", complexExpr @@ -95,7 +88,6 @@ public void testDoubleExpr() { checkExpr( new DoubleExpr(11.73D), - true, "11.73", "11.73", new DoubleExpr(11.73D) @@ -108,7 +100,6 @@ public void testNullDoubleExpr() TypeStrategies.registerComplex("nullablePair", new NullableLongPairTypeStrategy()); checkExpr( new NullDoubleExpr(), - true, "null", "null", // the expressions 'null' is always parsed as a StringExpr(null) @@ -121,7 +112,6 @@ public void testNullLongExpr() { checkExpr( new NullLongExpr(), - true, "null", "null", // the expressions 'null' is always parsed as a StringExpr(null) @@ -134,7 +124,6 @@ public void testLong() { checkExpr( new LongExpr(11L), - true, "11", "11", new LongExpr(11L) @@ -146,7 +135,6 @@ public void testString() { checkExpr( new StringExpr("some"), - true, "some", "'some'", new StringExpr("some") @@ -158,7 +146,6 @@ public void testStringNull() { checkExpr( new StringExpr(null), - true, null, "null", new StringExpr(null) @@ -167,26 +154,24 @@ public void testStringNull() private void checkExpr( Expr expr, - boolean supportsSingleThreaded, String expectedToString, String expectedStringify, - Expr expectedReparsedExpr) + Expr expectedReparsedExpr + ) { - ObjectBinding bindings = InputBindings.nilBindings(); + final ObjectBinding bindings = InputBindings.nilBindings(); if (expr.getLiteralValue() != null) { - assertNotSame(expr.eval(bindings), expr.eval(bindings)); - } - Expr singleExpr = Expr.singleThreaded(expr); - if (supportsSingleThreaded) { - assertSame(singleExpr.eval(bindings), singleExpr.eval(bindings)); - } else { - assertNotSame(singleExpr.eval(bindings), singleExpr.eval(bindings)); + Assert.assertNotSame(expr.eval(bindings), expr.eval(bindings)); } - assertEquals(expectedToString, expr.toString()); - assertEquals(expectedStringify, expr.stringify()); - assertEquals(expectedToString, singleExpr.toString()); - String stringify = singleExpr.stringify(); - Expr reParsedExpr = Parser.parse(stringify, ExprMacroTable.nil()); - assertEquals(expectedReparsedExpr, reParsedExpr); + final Expr singleExpr = Expr.singleThreaded(expr, bindings); + Assert.assertArrayEquals(expr.getCacheKey(), singleExpr.getCacheKey()); + Assert.assertSame(singleExpr.eval(bindings), singleExpr.eval(bindings)); + Assert.assertEquals(expectedToString, expr.toString()); + Assert.assertEquals(expectedStringify, expr.stringify()); + Assert.assertEquals(expectedToString, singleExpr.toString()); + final String stringify = singleExpr.stringify(); + final Expr reParsedExpr = Parser.parse(stringify, ExprMacroTable.nil()); + Assert.assertEquals(expectedReparsedExpr, reParsedExpr); + Assert.assertArrayEquals(expr.getCacheKey(), expectedReparsedExpr.getCacheKey()); } } diff --git a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java index 0338efa46640..fae7bca736f0 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -92,7 +92,8 @@ public void setup() .put("someComplex", ExpressionType.fromColumnType(TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE)) .put("str1", ExpressionType.STRING) .put("str2", ExpressionType.STRING) - .put("nestedArray", ExpressionType.NESTED_DATA); + .put("nestedArray", ExpressionType.NESTED_DATA) + .put("emptyArray", ExpressionType.STRING_ARRAY); final StructuredData nestedArray = StructuredData.wrap( ImmutableList.of( @@ -120,7 +121,8 @@ public void setup() .put("someComplex", new TypeStrategiesTest.NullableLongPair(1L, 2L)) .put("str1", "v1") .put("str2", "v2") - .put("nestedArray", nestedArray); + .put("nestedArray", nestedArray) + .put("emptyArray", new Object[]{}); bestEffortBindings = InputBindings.forMap(builder.build()); typedBindings = InputBindings.forMap( builder.build(), InputBindings.inspectorFromTypeMap(inputTypesBuilder.build()) @@ -373,6 +375,11 @@ public void testArrayContains() assertExpr("array_contains([1, 2, 3], [2, 3])", 1L); assertExpr("array_contains([1, 2, 3], [3, 4])", 0L); assertExpr("array_contains(b, [3, 4])", 1L); + assertExpr("array_contains(null, [3, 4])", null); + assertExpr("array_contains(null, null)", null); + assertExpr("array_contains([1, null, 2], null)", 1L); + assertExpr("array_contains([1, null, 2], [null])", 1L); + assertExpr("array_contains([1, 2], null)", 0L); } @Test @@ -380,6 +387,12 @@ public void testArrayOverlap() { assertExpr("array_overlap([1, 2, 3], [2, 4, 6])", 1L); assertExpr("array_overlap([1, 2, 3], [4, 5, 6])", 0L); + assertExpr("array_overlap(null, [4, 5, 6])", null); + assertExpr("array_overlap([4, null], [4, 5, 6])", 1L); + assertExpr("array_overlap([4, 5, 6], null)", 0L); + assertExpr("array_overlap([4, 5, 6], [null])", 0L); + assertExpr("array_overlap([4, 5, null, 6], null)", 0L); + assertExpr("array_overlap([4, 5, null, 6], [null])", 1L); } @Test @@ -1226,6 +1239,17 @@ public void testDivOnString() ); } + @Test + public void testMvHarmonizeNulls() + { + assertArrayExpr("mv_harmonize_nulls(null)", new Object[]{null}); + assertArrayExpr("mv_harmonize_nulls(emptyArray)", new Object[]{null}); + // does nothing + assertArrayExpr("mv_harmonize_nulls(array(null))", new Object[]{null}); + // does nothing + assertArrayExpr("mv_harmonize_nulls(a)", new Object[]{"foo", "bar", "baz", "foobar"}); + } + private void assertExpr(final String expression, @Nullable final Object expectedResult) { for (Expr.ObjectBinding toUse : allBindings) { @@ -1249,6 +1273,9 @@ private void assertExpr( final Expr roundTripFlatten = Parser.parse(expr.stringify(), ExprMacroTable.nil()); Assert.assertEquals(expr.stringify(), expectedResult, roundTripFlatten.eval(bindings).value()); + final Expr singleThreaded = Expr.singleThreaded(expr, bindings); + Assert.assertEquals(singleThreaded.stringify(), expectedResult, singleThreaded.eval(bindings).value()); + Assert.assertEquals(expr.stringify(), roundTrip.stringify()); Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify()); Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey()); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayContainsOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayContainsOperatorConversion.java index 8bfc3aecb3e7..3d902ac068a6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayContainsOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayContainsOperatorConversion.java @@ -67,7 +67,7 @@ public class ArrayContainsOperatorConversion extends BaseExpressionDimFilterOper ) ) ) - .returnTypeInference(ReturnTypes.BOOLEAN) + .returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE) .build(); public ArrayContainsOperatorConversion() @@ -191,6 +191,6 @@ public DimFilter toDruidFilter( } } } - return toExpressionFilter(plannerContext, getDruidFunctionName(), druidExpressions); + return toExpressionFilter(plannerContext, druidExpressions); } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java index 23cfcfaa4a45..be01d4beb3b6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java @@ -67,7 +67,7 @@ public class ArrayOverlapOperatorConversion extends BaseExpressionDimFilterOpera ) ) ) - .returnTypeInference(ReturnTypes.BOOLEAN) + .returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE) .build(); public ArrayOverlapOperatorConversion() @@ -111,7 +111,7 @@ public DimFilter toDruidFilter( complexExpr = leftExpr; } } else { - return toExpressionFilter(plannerContext, getDruidFunctionName(), druidExpressions); + return toExpressionFilter(plannerContext, druidExpressions); } final Expr expr = plannerContext.parseExpression(complexExpr.getExpression()); @@ -202,6 +202,6 @@ public DimFilter toDruidFilter( } } - return toExpressionFilter(plannerContext, getDruidFunctionName(), druidExpressions); + return toExpressionFilter(plannerContext, druidExpressions); } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BaseExpressionDimFilterOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BaseExpressionDimFilterOperatorConversion.java index d24ea8ea8911..afadf79770f8 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BaseExpressionDimFilterOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/BaseExpressionDimFilterOperatorConversion.java @@ -40,13 +40,17 @@ public BaseExpressionDimFilterOperatorConversion( super(operator, druidFunctionName); } - protected static DimFilter toExpressionFilter( + protected String getFilterExpression(List druidExpressions) + { + return DruidExpression.functionCall(getDruidFunctionName()).compile(druidExpressions); + } + + protected DimFilter toExpressionFilter( PlannerContext plannerContext, - String druidFunctionName, List druidExpressions ) { - final String filterExpr = DruidExpression.functionCall(druidFunctionName, druidExpressions); + final String filterExpr = getFilterExpression(druidExpressions); return new ExpressionDimFilter( filterExpr, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java index 5a32b06c544d..da07083774ce 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/MultiValueStringOperatorConversions.java @@ -19,6 +19,7 @@ package org.apache.druid.sql.calcite.expression.builtin; +import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; @@ -39,11 +40,13 @@ import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.expression.OperatorConversions; +import org.apache.druid.sql.calcite.expression.PostAggregatorVisitor; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; import javax.annotation.Nullable; +import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -156,7 +159,7 @@ private static class Contains extends ArrayContainsOperatorConversion ) ) ) - .returnTypeInference(ReturnTypes.BOOLEAN) + .returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE) .build(); @Override @@ -164,6 +167,53 @@ public SqlOperator calciteOperator() { return SQL_FUNCTION; } + + @Override + protected String getFilterExpression(List druidExpressions) + { + return super.getFilterExpression(harmonizeNullsMvdArg0OperandList(druidExpressions)); + } + + @Override + public DruidExpression toDruidExpression( + PlannerContext plannerContext, + RowSignature rowSignature, + RexNode rexNode + ) + { + return OperatorConversions.convertCall( + plannerContext, + rowSignature, + rexNode, + druidExpressions -> DruidExpression.ofFunctionCall( + Calcites.getColumnTypeForRelDataType(rexNode.getType()), + getDruidFunctionName(), + harmonizeNullsMvdArg0OperandList(druidExpressions) + ) + ); + } + + @Nullable + @Override + public DruidExpression toDruidExpressionWithPostAggOperands( + PlannerContext plannerContext, + RowSignature rowSignature, + RexNode rexNode, + PostAggregatorVisitor postAggregatorVisitor + ) + { + return OperatorConversions.convertCallWithPostAggOperands( + plannerContext, + rowSignature, + rexNode, + operands -> DruidExpression.ofFunctionCall( + Calcites.getColumnTypeForRelDataType(rexNode.getType()), + getDruidFunctionName(), + harmonizeNullsMvdArg0OperandList(operands) + ), + postAggregatorVisitor + ); + } } public static class Offset extends ArrayOffsetOperatorConversion @@ -309,11 +359,81 @@ public OrdinalOf() /** * Private: use singleton {@link #OVERLAP}. */ - private static class Overlap extends AliasedOperatorConversion + private static class Overlap extends ArrayOverlapOperatorConversion { - public Overlap() + private static final SqlFunction SQL_FUNCTION = OperatorConversions + .operatorBuilder("MV_OVERLAP") + .operandTypeChecker( + OperandTypes.sequence( + "'MV_OVERLAP(array, array)'", + OperandTypes.or( + OperandTypes.family(SqlTypeFamily.ARRAY), + OperandTypes.family(SqlTypeFamily.STRING) + ), + OperandTypes.or( + OperandTypes.family(SqlTypeFamily.ARRAY), + OperandTypes.family(SqlTypeFamily.STRING), + OperandTypes.family(SqlTypeFamily.NUMERIC) + ) + ) + ) + .returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE) + .build(); + + @Override + public SqlOperator calciteOperator() + { + return SQL_FUNCTION; + } + + @Override + protected String getFilterExpression(List druidExpressions) { - super(new ArrayOverlapOperatorConversion(), "MV_OVERLAP"); + return super.getFilterExpression(harmonizeNullsMvdArg0OperandList(druidExpressions)); + } + + @Override + public DruidExpression toDruidExpression( + PlannerContext plannerContext, + RowSignature rowSignature, + RexNode rexNode + ) + { + return OperatorConversions.convertCall( + plannerContext, + rowSignature, + rexNode, + druidExpressions -> { + final List newArgs = harmonizeNullsMvdArg0OperandList(druidExpressions); + return DruidExpression.ofFunctionCall( + Calcites.getColumnTypeForRelDataType(rexNode.getType()), + getDruidFunctionName(), + newArgs + ); + } + ); + } + + @Nullable + @Override + public DruidExpression toDruidExpressionWithPostAggOperands( + PlannerContext plannerContext, + RowSignature rowSignature, + RexNode rexNode, + PostAggregatorVisitor postAggregatorVisitor + ) + { + return OperatorConversions.convertCallWithPostAggOperands( + plannerContext, + rowSignature, + rexNode, + operands -> DruidExpression.ofFunctionCall( + Calcites.getColumnTypeForRelDataType(rexNode.getType()), + getDruidFunctionName(), + harmonizeNullsMvdArg0OperandList(operands) + ), + postAggregatorVisitor + ); } } @@ -458,6 +578,28 @@ boolean isAllowList() } } + + private static List harmonizeNullsMvdArg0OperandList(List druidExpressions) + { + final List newArgs; + if (druidExpressions.get(0).isDirectColumnAccess()) { + // rewrite first argument to wrap with mv_harmonize_nulls function + newArgs = Lists.newArrayListWithCapacity(2); + newArgs.add( + 0, + DruidExpression.ofFunctionCall( + druidExpressions.get(0).getDruidType(), + "mv_harmonize_nulls", + Collections.singletonList(druidExpressions.get(0)) + ) + ); + newArgs.add(1, druidExpressions.get(1)); + } else { + newArgs = druidExpressions; + } + return newArgs; + } + private MultiValueStringOperatorConversions() { // no instantiation diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java index 4bed8c83fc54..715c5caaadc1 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java @@ -84,7 +84,11 @@ public void reduce( final RexNode literal; if (sqlTypeName == SqlTypeName.BOOLEAN) { - literal = rexBuilder.makeLiteral(exprResult.asBoolean(), constExp.getType(), true); + if (exprResult.valueOrDefault() == null) { + literal = rexBuilder.makeNullLiteral(constExp.getType()); + } else { + literal = rexBuilder.makeLiteral(exprResult.asBoolean(), constExp.getType(), true); + } } else if (sqlTypeName == SqlTypeName.DATE) { // It is possible for an expression to have a non-null String value but it can return null when parsed // as a primitive long/float/double. diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index ce1f992a2ad2..adfc7a7aff40 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -1293,6 +1293,36 @@ public void testArrayContainsFilterArrayStringColumns() ); } + @Test + public void testArrayContainsArrayStringColumns() + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_CONTAINS(arrayStringNulls, ARRAY['a', 'b']), ARRAY_CONTAINS(arrayStringNulls, arrayString) FROM druid.arrays LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(DATA_SOURCE_ARRAYS) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("v0", "v1") + .virtualColumns( + expressionVirtualColumn("v0", "array_contains(\"arrayStringNulls\",array('a','b'))", ColumnType.LONG), + expressionVirtualColumn("v1", "array_contains(\"arrayStringNulls\",\"arrayString\")", ColumnType.LONG) + ) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{NullHandling.sqlCompatible() ? null : false, NullHandling.sqlCompatible() ? null : false}, + new Object[]{true, false}, + new Object[]{false, false}, + new Object[]{NullHandling.sqlCompatible() ? null : false, NullHandling.sqlCompatible() ? null : false}, + new Object[]{true, true} + ) + ); + } + @Test public void testArrayContainsFilterArrayLongColumns() { @@ -1342,6 +1372,45 @@ public void testArrayContainsFilterArrayDoubleColumns() ); } + @Test + public void testArrayContainsConstantNull() + { + testQuery( + "SELECT ARRAY_CONTAINS(null, ARRAY['a','b'])", + ImmutableList.of( + NullHandling.sqlCompatible() + ? newScanQueryBuilder() + .dataSource( + InlineDataSource.fromIterable( + ImmutableList.of(new Object[]{NullHandling.defaultLongValue()}), + RowSignature.builder().add("EXPR$0", ColumnType.LONG).build() + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("EXPR$0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + : newScanQueryBuilder() + .dataSource( + InlineDataSource.fromIterable( + ImmutableList.of(new Object[]{0L}), + RowSignature.builder().add("ZERO", ColumnType.LONG).build() + ) + ) + .virtualColumns(expressionVirtualColumn("v0", "0", ColumnType.LONG)) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{NullHandling.sqlCompatible() ? null : false} + ) + ); + } + @Test public void testArraySlice() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java index 41885eabb499..5ef6c7ab9123 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteLookupFunctionQueryTest.java @@ -772,7 +772,7 @@ public void testFilterMvContainsNullInjective() buildFilterTestSql("MV_CONTAINS(LOOKUP(dim1, 'lookyloo121'), NULL)"), QUERY_CONTEXT, NullHandling.sqlCompatible() - ? buildFilterTestExpectedQuery(expressionFilter("array_contains(\"dim1\",null)")) + ? buildFilterTestExpectedQuery(expressionFilter("array_contains(mv_harmonize_nulls(\"dim1\"),null)")) : buildFilterTestExpectedQueryAlwaysFalse(), ImmutableList.of() ); @@ -856,17 +856,12 @@ public void testFilterMvContainsIsNotTrue() testQuery( buildFilterTestSql("MV_CONTAINS(lookup(dim1, 'lookyloo'), 'xabc') IS NOT TRUE"), QUERY_CONTEXT, - NullHandling.sqlCompatible() - ? buildFilterTestExpectedQuery( - expressionVirtualColumn("v0", "lookup(\"dim1\",'lookyloo')", ColumnType.STRING), - not(equality("v0", "xabc", ColumnType.STRING)) - ) - : buildFilterTestExpectedQuery( - not(equality("dim1", "abc", ColumnType.STRING)) + buildFilterTestExpectedQuery( + NullHandling.sqlCompatible() + ? not(istrue(equality("dim1", "abc", ColumnType.STRING))) + : not(equality("dim1", "abc", ColumnType.STRING)) ), - NullHandling.sqlCompatible() - ? ImmutableList.of() - : ImmutableList.of(new Object[]{"", 5L}) + ImmutableList.of(new Object[]{NullHandling.defaultStringValue(), 5L}) ); } @@ -912,12 +907,10 @@ public void testFilterMvOverlapIsNotTrue() QUERY_CONTEXT, buildFilterTestExpectedQuery( NullHandling.sqlCompatible() - ? not(in("dim1", ImmutableList.of("x6", "xabc", "nonexistent"), EXTRACTION_FN)) + ? not(istrue(in("dim1", ImmutableList.of("6", "abc"), null))) : not(in("dim1", ImmutableList.of("6", "abc"), null)) ), - NullHandling.sqlCompatible() - ? Collections.emptyList() - : ImmutableList.of(new Object[]{NullHandling.defaultStringValue(), 5L}) + ImmutableList.of(new Object[]{NullHandling.defaultStringValue(), 5L}) ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java index a93a02a543d2..c118d83c13a3 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java @@ -327,14 +327,17 @@ public void testMultiValueStringOverlapFilterNonLiteral() newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE3) .eternityInterval() - .filters(expressionFilter("array_overlap(\"dim3\",array(\"dim2\"))")) + .filters(expressionFilter("array_overlap(mv_harmonize_nulls(\"dim3\"),array(\"dim2\"))")) .columns("dim3") .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .limit(5) .context(QUERY_CONTEXT_DEFAULT) .build() ), - ImmutableList.of(new Object[]{"[\"a\",\"b\"]"}) + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{NullHandling.defaultStringValue()} + ) ); } @@ -427,7 +430,7 @@ public void testMultiValueStringContainsArrayOfNonLiteral() newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE3) .eternityInterval() - .filters(expressionFilter("array_contains(\"dim3\",array(\"dim2\"))")) + .filters(expressionFilter("array_contains(mv_harmonize_nulls(\"dim3\"),array(\"dim2\"))")) .columns("dim3") .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .limit(5) @@ -435,7 +438,8 @@ public void testMultiValueStringContainsArrayOfNonLiteral() .build() ), ImmutableList.of( - new Object[]{"[\"a\",\"b\"]"} + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{NullHandling.defaultStringValue()} ) ); } @@ -2315,4 +2319,34 @@ public void testMvContainsFilterWithExtractionFn() ) ); } + + @Test + public void testMvContainsSelectColumns() + { + cannotVectorize(); + testQuery( + "SELECT MV_CONTAINS(dim3, ARRAY['a', 'b']), MV_OVERLAP(dim3, ARRAY['a', 'b']) FROM druid.numfoo LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("v0", "v1") + .virtualColumns( + expressionVirtualColumn("v0", "array_contains(mv_harmonize_nulls(\"dim3\"),array('a','b'))", ColumnType.LONG), + expressionVirtualColumn("v1", "array_overlap(mv_harmonize_nulls(\"dim3\"),array('a','b'))", ColumnType.LONG) + ) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{true, true}, + new Object[]{false, true}, + new Object[]{false, false}, + new Object[]{false, false}, + new Object[]{false, false} + ) + ); + } }