diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 1a9a96287055..a2ef91f0cfa4 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -374,21 +374,11 @@ private static Class convertType(@Nullable Class existing, Class next) throw new UOE("Invalid array expression type: %s", next); } - public static ExprEval ofLong(@Nullable Number longValue) - { - return new LongExprEval(longValue); - } - public static ExprEval of(long longValue) { return new LongExprEval(longValue); } - public static ExprEval ofDouble(@Nullable Number doubleValue) - { - return new DoubleExprEval(doubleValue); - } - public static ExprEval of(double doubleValue) { return new DoubleExprEval(doubleValue); @@ -402,22 +392,50 @@ public static ExprEval of(@Nullable String stringValue) return new StringExprEval(stringValue); } + public static ExprEval ofLong(@Nullable Number longValue) + { + if (longValue == null) { + return LongExprEval.OF_NULL; + } + return new LongExprEval(longValue); + } + + public static ExprEval ofDouble(@Nullable Number doubleValue) + { + if (doubleValue == null) { + return DoubleExprEval.OF_NULL; + } + return new DoubleExprEval(doubleValue); + } + public static ExprEval ofLongArray(@Nullable Long[] longValue) { + if (longValue == null) { + return LongArrayExprEval.OF_NULL; + } return new LongArrayExprEval(longValue); } public static ExprEval ofDoubleArray(@Nullable Double[] doubleValue) { + if (doubleValue == null) { + return DoubleArrayExprEval.OF_NULL; + } return new DoubleArrayExprEval(doubleValue); } public static ExprEval ofStringArray(@Nullable String[] stringValue) { + if (stringValue == null) { + return StringArrayExprEval.OF_NULL; + } return new StringArrayExprEval(stringValue); } - public static ExprEval of(boolean value, ExprType type) + /** + * Convert a boolean back into native expression type + */ + public static ExprEval ofBoolean(boolean value, ExprType type) { switch (type) { case DOUBLE: @@ -431,11 +449,17 @@ public static ExprEval of(boolean value, ExprType type) } } + /** + * Convert a boolean into a long expression type + */ public static ExprEval ofLongBoolean(boolean value) { return ExprEval.of(Evals.asLong(value)); } + /** + * Examine java type to find most appropriate expression type + */ public static ExprEval bestEffortOf(@Nullable Object val) { if (val instanceof ExprEval) { @@ -631,6 +655,8 @@ public boolean isNumericNull() private static class DoubleExprEval extends NumericExprEval { + private static final DoubleExprEval OF_NULL = new DoubleExprEval(null); + private DoubleExprEval(@Nullable Number value) { super(value == null ? NullHandling.defaultDoubleValue() : (Double) value.doubleValue()); @@ -691,6 +717,8 @@ public Expr toExpr() private static class LongExprEval extends NumericExprEval { + private static final LongExprEval OF_NULL = new LongExprEval(null); + private LongExprEval(@Nullable Number value) { super(value == null ? NullHandling.defaultLongValue() : (Long) value.longValue()); @@ -758,6 +786,8 @@ public Expr toExpr() private static class StringExprEval extends ExprEval { + private static final StringExprEval OF_NULL = new StringExprEval(null); + // Cached primitive values. private boolean intValueValid = false; private boolean longValueValid = false; @@ -768,8 +798,6 @@ private static class StringExprEval extends ExprEval private double doubleValue; private boolean booleanValue; - private static final StringExprEval OF_NULL = new StringExprEval(null); - @Nullable private Number numericVal; @@ -1014,6 +1042,8 @@ public T getIndex(int index) private static class LongArrayExprEval extends ArrayExprEval { + private static final LongArrayExprEval OF_NULL = new LongArrayExprEval(null); + private LongArrayExprEval(@Nullable Long[] value) { super(value); @@ -1073,6 +1103,8 @@ public Expr toExpr() private static class DoubleArrayExprEval extends ArrayExprEval { + private static final DoubleArrayExprEval OF_NULL = new DoubleArrayExprEval(null); + private DoubleArrayExprEval(@Nullable Double[] value) { super(value); @@ -1132,6 +1164,8 @@ public Expr toExpr() private static class StringArrayExprEval extends ArrayExprEval { + private static final StringArrayExprEval OF_NULL = new StringArrayExprEval(null); + private boolean longValueValid = false; private boolean doubleValueValid = false; private Long[] longValues; diff --git a/core/src/main/java/org/apache/druid/math/expr/InputBindings.java b/core/src/main/java/org/apache/druid/math/expr/InputBindings.java new file mode 100644 index 000000000000..9862bcaaf872 --- /dev/null +++ b/core/src/main/java/org/apache/druid/math/expr/InputBindings.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.math.expr; + +import com.google.common.base.Supplier; + +import javax.annotation.Nullable; +import java.util.Map; + +public class InputBindings +{ + /** + * Create an {@link Expr.InputBindingInspector} backed by a map of binding identifiers to their {@link ExprType} + */ + public static Expr.InputBindingInspector inspectorFromTypeMap(final Map types) + { + return new Expr.InputBindingInspector() + { + @Nullable + @Override + public ExprType getType(String name) + { + return types.get(name); + } + }; + } + + /** + * Create {@link Expr.ObjectBinding} backed by {@link Map} to provide values for identifiers to evaluate {@link Expr} + */ + public static Expr.ObjectBinding withMap(final Map bindings) + { + return bindings::get; + } + + /** + * Create {@link Expr.ObjectBinding} backed by map of {@link Supplier} to provide values for identifiers to evaluate + * {@link Expr} + */ + public static Expr.ObjectBinding withSuppliers(final Map> bindings) + { + return (String name) -> { + Supplier supplier = bindings.get(name); + return supplier == null ? null : supplier.get(); + }; + } +} diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java index b0c923c025fe..0a42bfa5d487 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Parser.java +++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java @@ -601,23 +601,4 @@ public static void validateExpr(Expr expression, Expr.BindingAnalysis bindingAna } } - /** - * Create {@link Expr.ObjectBinding} backed by {@link Map} to provide values for identifiers to evaluate {@link Expr} - */ - public static Expr.ObjectBinding withMap(final Map bindings) - { - return bindings::get; - } - - /** - * Create {@link Expr.ObjectBinding} backed by map of {@link Supplier} to provide values for identifiers to evaluate - * {@link Expr} - */ - public static Expr.ObjectBinding withSuppliers(final Map> bindings) - { - return (String name) -> { - Supplier supplier = bindings.get(name); - return supplier == null ? null : supplier.get(); - }; - } } diff --git a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java index e05cc08f6559..6993ed75c127 100644 --- a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java @@ -166,7 +166,7 @@ public ExprEval eval(ObjectBinding bindings) } // conforming to other boolean-returning binary operators ExprType retType = ret.type() == ExprType.DOUBLE ? ExprType.DOUBLE : ExprType.LONG; - return ExprEval.of(!ret.asBoolean(), retType); + return ExprEval.ofBoolean(!ret.asBoolean(), retType); } @Nullable diff --git a/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java index 8dea860b33cb..d352bfd67725 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java @@ -48,7 +48,7 @@ public void setup() builder.put("d", new String[] {null}); builder.put("e", new String[] {null, "foo", "bar"}); builder.put("f", new String[0]); - bindings = Parser.withMap(builder.build()); + bindings = InputBindings.withMap(builder.build()); } @Test diff --git a/core/src/test/java/org/apache/druid/math/expr/EvalTest.java b/core/src/test/java/org/apache/druid/math/expr/EvalTest.java index 732e744fd8fe..a80edba8f2db 100644 --- a/core/src/test/java/org/apache/druid/math/expr/EvalTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/EvalTest.java @@ -51,7 +51,7 @@ private ExprEval eval(String x, Expr.ObjectBinding bindings) @Test public void testDoubleEval() { - Expr.ObjectBinding bindings = Parser.withMap(ImmutableMap.of("x", 2.0d)); + Expr.ObjectBinding bindings = InputBindings.withMap(ImmutableMap.of("x", 2.0d)); Assert.assertEquals(2.0, evalDouble("x", bindings), 0.0001); Assert.assertEquals(2.0, evalDouble("\"x\"", bindings), 0.0001); Assert.assertEquals(304.0, evalDouble("300 + \"x\" * 2", bindings), 0.0001); @@ -89,7 +89,7 @@ public void testDoubleEval() @Test public void testLongEval() { - Expr.ObjectBinding bindings = Parser.withMap(ImmutableMap.of("x", 9223372036854775807L)); + Expr.ObjectBinding bindings = InputBindings.withMap(ImmutableMap.of("x", 9223372036854775807L)); Assert.assertEquals(9223372036854775807L, evalLong("x", bindings)); Assert.assertEquals(9223372036854775807L, evalLong("\"x\"", bindings)); @@ -147,7 +147,7 @@ public void testLongEval() @Test public void testBooleanReturn() { - Expr.ObjectBinding bindings = Parser.withMap( + Expr.ObjectBinding bindings = InputBindings.withMap( ImmutableMap.of("x", 100L, "y", 100L, "z", 100D, "w", 100D) ); ExprEval eval = Parser.parse("x==y", ExprMacroTable.nil()).eval(bindings); diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java index 1bd423f57fb0..bc283d31844c 100644 --- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -59,7 +59,7 @@ public void setup() .put("a", new String[] {"foo", "bar", "baz", "foobar"}) .put("b", new Long[] {1L, 2L, 3L, 4L, 5L}) .put("c", new Double[] {3.1, 4.2, 5.3}); - bindings = Parser.withMap(builder.build()); + bindings = InputBindings.withMap(builder.build()); } @Test diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java index 51f991f3f8f0..12dcfee0b894 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java @@ -741,7 +741,7 @@ private void validateConstantExpression(String expression, Object expected) Assert.assertEquals( expression, expected, - parsed.eval(Parser.withMap(ImmutableMap.of())).value() + parsed.eval(InputBindings.withMap(ImmutableMap.of())).value() ); final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false); @@ -749,7 +749,7 @@ private void validateConstantExpression(String expression, Object expected) Assert.assertEquals( expression, expected, - parsedRoundTrip.eval(Parser.withMap(ImmutableMap.of())).value() + parsedRoundTrip.eval(InputBindings.withMap(ImmutableMap.of())).value() ); Assert.assertEquals(parsed.stringify(), parsedRoundTrip.stringify()); } @@ -757,7 +757,7 @@ private void validateConstantExpression(String expression, Object expected) private void validateConstantExpression(String expression, Object[] expected) { Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); - Object evaluated = parsed.eval(Parser.withMap(ImmutableMap.of())).value(); + Object evaluated = parsed.eval(InputBindings.withMap(ImmutableMap.of())).value(); Assert.assertArrayEquals( expression, expected, @@ -770,7 +770,7 @@ private void validateConstantExpression(String expression, Object[] expected) Assert.assertArrayEquals( expression, expected, - (Object[]) roundTrip.eval(Parser.withMap(ImmutableMap.of())).value() + (Object[]) roundTrip.eval(InputBindings.withMap(ImmutableMap.of())).value() ); Assert.assertEquals(parsed.stringify(), roundTrip.stringify()); } diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 7dece55b74c2..da89ed4f8969 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -311,7 +311,7 @@ Aggregation functions can appear in the SELECT clause of any query. Any aggregat `AGG(expr) FILTER(WHERE whereExpr)`. Filtered aggregators will only aggregate rows that match their filter. It's possible for two aggregators in the same SQL query to have different filters. -Only the COUNT aggregation can accept DISTINCT. +Only the COUNT and ARRAY_AGG aggregations can accept DISTINCT. > The order of aggregation operations across segments is not deterministic. This means that non-commutative aggregation > functions can produce inconsistent results across the same query. @@ -353,6 +353,8 @@ Only the COUNT aggregation can accept DISTINCT. |`ANY_VALUE(expr)`|Returns any value of `expr` including null. `expr` must be numeric. This aggregator can simplify and optimize the performance by returning the first encountered value (including null)| |`ANY_VALUE(expr, maxBytesPerString)`|Like `ANY_VALUE(expr)`, but for strings. The `maxBytesPerString` parameter determines how much aggregation space to allocate per string. Strings longer than this limit will be truncated. This parameter should be set as low as possible, since high values will lead to wasted memory.| |`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.| +|`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.| +|`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.| For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.md#approx). diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java index 2da1abde8c0d..be8b1004e55b 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java @@ -25,6 +25,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.StringUtils; @@ -33,6 +34,7 @@ import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExprType; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.apache.druid.math.expr.SettableObjectBinding; import org.apache.druid.query.cache.CacheKeyBuilder; @@ -94,6 +96,7 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory Suppliers.memoize(() -> new SettableObjectBinding(2)); private final Supplier finalizeBindings = Suppliers.memoize(() -> new SettableObjectBinding(1)); + private final Supplier finalizeInspector; @JsonCreator public ExpressionLambdaAggregatorFactory( @@ -148,9 +151,15 @@ public ExpressionLambdaAggregatorFactory( this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable); this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable); this.compareExpression = Parser.lazyParse(compareExpressionString, macroTable); + this.finalizeInspector = Suppliers.memoize( + () -> InputBindings.inspectorFromTypeMap( + ImmutableMap.of(FINALIZE_IDENTIFIER, this.initialCombineValue.get().type()) + ) + ); this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable); this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES; Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES); + } @JsonProperty @@ -363,9 +372,11 @@ public ValueType getFinalizedType() Expr finalizeExpr = finalizeExpression.get(); ExprEval initialVal = initialCombineValue.get(); if (finalizeExpr != null) { - return ExprType.toValueType( - finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, initialVal)).type() - ); + ExprType type = finalizeExpr.getOutputType(finalizeInspector.get()); + if (type == null) { + type = initialVal.type(); + } + return ExprType.toValueType(type); } return ExprType.toValueType(initialVal.type()); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java index 34cbb16d1636..b05ecbe35089 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java @@ -31,6 +31,7 @@ import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.PostAggregator; @@ -170,7 +171,7 @@ public Object compute(Map values) } ); - return parsed.get().eval(Parser.withMap(finalizedValues)).value(); + return parsed.get().eval(InputBindings.withMap(finalizedValues)).value(); } @Override diff --git a/processing/src/main/java/org/apache/druid/segment/QueryableIndexColumnSelectorFactory.java b/processing/src/main/java/org/apache/druid/segment/QueryableIndexColumnSelectorFactory.java index 1f9b1dac6b92..2f05a2fa2dc8 100644 --- a/processing/src/main/java/org/apache/druid/segment/QueryableIndexColumnSelectorFactory.java +++ b/processing/src/main/java/org/apache/druid/segment/QueryableIndexColumnSelectorFactory.java @@ -190,7 +190,7 @@ public ColumnCapabilities getColumnCapabilities(String columnName) { if (virtualColumns.exists(columnName)) { return virtualColumns.getColumnCapabilities( - baseColumnName -> QueryableIndexStorageAdapter.getColumnCapabilities(index, baseColumnName), + QueryableIndexStorageAdapter.getColumnInspectorForIndex(index), columnName ); } diff --git a/processing/src/main/java/org/apache/druid/segment/QueryableIndexStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/QueryableIndexStorageAdapter.java index cbfe79b84367..84cd638d09a6 100644 --- a/processing/src/main/java/org/apache/druid/segment/QueryableIndexStorageAdapter.java +++ b/processing/src/main/java/org/apache/druid/segment/QueryableIndexStorageAdapter.java @@ -315,15 +315,7 @@ public static ColumnCapabilities getColumnCapabilities(ColumnSelector index, Str public static ColumnInspector getColumnInspectorForIndex(ColumnSelector index) { - return new ColumnInspector() - { - @Nullable - @Override - public ColumnCapabilities getColumnCapabilities(String column) - { - return QueryableIndexStorageAdapter.getColumnCapabilities(index, column); - } - }; + return column -> getColumnCapabilities(index, column); } @Override diff --git a/processing/src/main/java/org/apache/druid/segment/vector/QueryableIndexVectorColumnSelectorFactory.java b/processing/src/main/java/org/apache/druid/segment/vector/QueryableIndexVectorColumnSelectorFactory.java index 6d1b9a89489f..40d2f950dcff 100644 --- a/processing/src/main/java/org/apache/druid/segment/vector/QueryableIndexVectorColumnSelectorFactory.java +++ b/processing/src/main/java/org/apache/druid/segment/vector/QueryableIndexVectorColumnSelectorFactory.java @@ -267,7 +267,7 @@ public ColumnCapabilities getColumnCapabilities(final String columnName) { if (virtualColumns.exists(columnName)) { return virtualColumns.getColumnCapabilities( - baseColumnName -> QueryableIndexStorageAdapter.getColumnCapabilities(index, baseColumnName), + QueryableIndexStorageAdapter.getColumnInspectorForIndex(index), columnName ); } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index d910d7b54f79..fc6ddfff6064 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -26,7 +26,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.expression.ExprUtils; import org.apache.druid.query.extraction.ExtractionFn; @@ -308,7 +308,7 @@ public static Expr.ObjectBinding createBindings( return supplier.get(); }; } else { - return Parser.withSuppliers(suppliers); + return InputBindings.withSuppliers(suppliers); } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java index 5b5c2960af1b..a4143794791a 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java @@ -84,7 +84,8 @@ public void testEqualsAndHashCode() "finalizeExpression", "compareBindings", "combineBindings", - "finalizeBindings" + "finalizeBindings", + "finalizeInspector" ) .verify(); } diff --git a/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java index 1722ad32fe08..54a15393a7c3 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java @@ -23,7 +23,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprType; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.junit.Assert; import org.junit.Test; @@ -38,22 +38,22 @@ public CaseInsensitiveExprMacroTest() public void testErrorZeroArguments() { expectException(IllegalArgumentException.class, "Function[icontains_string] must have 2 arguments"); - eval("icontains_string()", Parser.withMap(ImmutableMap.of())); + eval("icontains_string()", InputBindings.withMap(ImmutableMap.of())); } @Test public void testErrorThreeArguments() { expectException(IllegalArgumentException.class, "Function[icontains_string] must have 2 arguments"); - eval("icontains_string('a', 'b', 'c')", Parser.withMap(ImmutableMap.of())); + eval("icontains_string('a', 'b', 'c')", InputBindings.withMap(ImmutableMap.of())); } @Test public void testMatchSearchLowerCase() { - final ExprEval result = eval("icontains_string(a, 'OBA')", Parser.withMap(ImmutableMap.of("a", "foobar"))); + final ExprEval result = eval("icontains_string(a, 'OBA')", InputBindings.withMap(ImmutableMap.of("a", "foobar"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -61,9 +61,9 @@ public void testMatchSearchLowerCase() @Test public void testMatchSearchUpperCase() { - final ExprEval result = eval("icontains_string(a, 'oba')", Parser.withMap(ImmutableMap.of("a", "FOOBAR"))); + final ExprEval result = eval("icontains_string(a, 'oba')", InputBindings.withMap(ImmutableMap.of("a", "FOOBAR"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -71,9 +71,9 @@ public void testMatchSearchUpperCase() @Test public void testNoMatch() { - final ExprEval result = eval("icontains_string(a, 'bar')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("icontains_string(a, 'bar')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(false, ExprType.LONG).value(), + ExprEval.ofBoolean(false, ExprType.LONG).value(), result.value() ); } @@ -85,9 +85,9 @@ public void testNullSearch() expectException(IllegalArgumentException.class, "Function[icontains_string] substring must be a string literal"); } - final ExprEval result = eval("icontains_string(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("icontains_string(a, null)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -95,9 +95,9 @@ public void testNullSearch() @Test public void testEmptyStringSearch() { - final ExprEval result = eval("icontains_string(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("icontains_string(a, '')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -109,9 +109,9 @@ public void testNullSearchOnEmptyString() expectException(IllegalArgumentException.class, "Function[icontains_string] substring must be a string literal"); } - final ExprEval result = eval("icontains_string(a, null)", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("icontains_string(a, null)", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -119,9 +119,9 @@ public void testNullSearchOnEmptyString() @Test public void testEmptyStringSearchOnEmptyString() { - final ExprEval result = eval("icontains_string(a, '')", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("icontains_string(a, '')", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -135,10 +135,10 @@ public void testNullSearchOnNull() final ExprEval result = eval( "icontains_string(a, null)", - Parser.withSuppliers(ImmutableMap.of("a", () -> null)) + InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)) ); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -146,9 +146,9 @@ public void testNullSearchOnNull() @Test public void testEmptyStringSearchOnNull() { - final ExprEval result = eval("icontains_string(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("icontains_string(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(!NullHandling.sqlCompatible(), ExprType.LONG).value(), + ExprEval.ofBoolean(!NullHandling.sqlCompatible(), ExprType.LONG).value(), result.value() ); } diff --git a/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java index bfbff7d0ab56..decd899a727b 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java @@ -23,7 +23,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprType; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.junit.Assert; import org.junit.Test; @@ -38,22 +38,22 @@ public ContainsExprMacroTest() public void testErrorZeroArguments() { expectException(IllegalArgumentException.class, "Function[contains_string] must have 2 arguments"); - eval("contains_string()", Parser.withMap(ImmutableMap.of())); + eval("contains_string()", InputBindings.withMap(ImmutableMap.of())); } @Test public void testErrorThreeArguments() { expectException(IllegalArgumentException.class, "Function[contains_string] must have 2 arguments"); - eval("contains_string('a', 'b', 'c')", Parser.withMap(ImmutableMap.of())); + eval("contains_string('a', 'b', 'c')", InputBindings.withMap(ImmutableMap.of())); } @Test public void testMatch() { - final ExprEval result = eval("contains_string(a, 'oba')", Parser.withMap(ImmutableMap.of("a", "foobar"))); + final ExprEval result = eval("contains_string(a, 'oba')", InputBindings.withMap(ImmutableMap.of("a", "foobar"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -61,9 +61,9 @@ public void testMatch() @Test public void testNoMatch() { - final ExprEval result = eval("contains_string(a, 'bar')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("contains_string(a, 'bar')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(false, ExprType.LONG).value(), + ExprEval.ofBoolean(false, ExprType.LONG).value(), result.value() ); } @@ -75,9 +75,9 @@ public void testNullSearch() expectException(IllegalArgumentException.class, "Function[contains_string] substring must be a string literal"); } - final ExprEval result = eval("contains_string(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("contains_string(a, null)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -85,9 +85,9 @@ public void testNullSearch() @Test public void testEmptyStringSearch() { - final ExprEval result = eval("contains_string(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("contains_string(a, '')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -99,9 +99,9 @@ public void testNullSearchOnEmptyString() expectException(IllegalArgumentException.class, "Function[contains_string] substring must be a string literal"); } - final ExprEval result = eval("contains_string(a, null)", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("contains_string(a, null)", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -109,9 +109,9 @@ public void testNullSearchOnEmptyString() @Test public void testEmptyStringSearchOnEmptyString() { - final ExprEval result = eval("contains_string(a, '')", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("contains_string(a, '')", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -123,9 +123,9 @@ public void testNullSearchOnNull() expectException(IllegalArgumentException.class, "Function[contains_string] substring must be a string literal"); } - final ExprEval result = eval("contains_string(a, null)", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("contains_string(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -133,9 +133,9 @@ public void testNullSearchOnNull() @Test public void testEmptyStringSearchOnNull() { - final ExprEval result = eval("contains_string(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("contains_string(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(!NullHandling.sqlCompatible(), ExprType.LONG).value(), + ExprEval.ofBoolean(!NullHandling.sqlCompatible(), ExprType.LONG).value(), result.value() ); } diff --git a/processing/src/test/java/org/apache/druid/query/expression/ExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/ExprMacroTest.java index 5f6b59214820..77e17493388a 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/ExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/ExprMacroTest.java @@ -23,6 +23,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.junit.Assert; import org.junit.BeforeClass; @@ -34,7 +35,7 @@ public class ExprMacroTest { private static final String IPV4_STRING = "192.168.0.1"; private static final long IPV4_LONG = 3232235521L; - private static final Expr.ObjectBinding BINDINGS = Parser.withMap( + private static final Expr.ObjectBinding BINDINGS = InputBindings.withMap( ImmutableMap.builder() .put("t", DateTimes.of("2000-02-03T04:05:06").getMillis()) .put("t1", DateTimes.of("2000-02-03T00:00:00").getMillis()) diff --git a/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java index 2f811d788b4f..8d7a322efc05 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java @@ -22,7 +22,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.junit.Assert; import org.junit.Test; @@ -37,34 +37,34 @@ public RegexpExtractExprMacroTest() public void testErrorZeroArguments() { expectException(IllegalArgumentException.class, "Function[regexp_extract] must have 2 to 3 arguments"); - eval("regexp_extract()", Parser.withMap(ImmutableMap.of())); + eval("regexp_extract()", InputBindings.withMap(ImmutableMap.of())); } @Test public void testErrorFourArguments() { expectException(IllegalArgumentException.class, "Function[regexp_extract] must have 2 to 3 arguments"); - eval("regexp_extract('a', 'b', 'c', 'd')", Parser.withMap(ImmutableMap.of())); + eval("regexp_extract('a', 'b', 'c', 'd')", InputBindings.withMap(ImmutableMap.of())); } @Test public void testMatch() { - final ExprEval result = eval("regexp_extract(a, 'f(.o)')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.o)')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals("foo", result.value()); } @Test public void testMatchGroup0() { - final ExprEval result = eval("regexp_extract(a, 'f(.o)', 0)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.o)', 0)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals("foo", result.value()); } @Test public void testMatchGroup1() { - final ExprEval result = eval("regexp_extract(a, 'f(.o)', 1)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.o)', 1)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals("oo", result.value()); } @@ -72,20 +72,20 @@ public void testMatchGroup1() public void testMatchGroup2() { expectedException.expectMessage("No group 2"); - final ExprEval result = eval("regexp_extract(a, 'f(.o)', 2)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.o)', 2)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); } @Test public void testNoMatch() { - final ExprEval result = eval("regexp_extract(a, 'f(.x)')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.x)')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertNull(result.value()); } @Test public void testMatchInMiddle() { - final ExprEval result = eval("regexp_extract(a, '.o$')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, '.o$')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals("oo", result.value()); } @@ -96,14 +96,14 @@ public void testNullPattern() expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal"); } - final ExprEval result = eval("regexp_extract(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, null)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertNull(result.value()); } @Test public void testEmptyStringPattern() { - final ExprEval result = eval("regexp_extract(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, '')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals(NullHandling.emptyToNullIfNeeded(""), result.value()); } @@ -111,14 +111,14 @@ public void testEmptyStringPattern() public void testNumericPattern() { expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal"); - eval("regexp_extract(a, 1)", Parser.withMap(ImmutableMap.of("a", "foo"))); + eval("regexp_extract(a, 1)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); } @Test public void testNonLiteralPattern() { expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal"); - eval("regexp_extract(a, a)", Parser.withMap(ImmutableMap.of("a", "foo"))); + eval("regexp_extract(a, a)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); } @Test @@ -128,14 +128,14 @@ public void testNullPatternOnNull() expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal"); } - final ExprEval result = eval("regexp_extract(a, null)", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("regexp_extract(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertNull(result.value()); } @Test public void testEmptyStringPatternOnNull() { - final ExprEval result = eval("regexp_extract(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("regexp_extract(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertNull(result.value()); } } diff --git a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java index b57db64b2327..fb6d99f7c72b 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java @@ -22,7 +22,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.junit.Assert; import org.junit.Test; @@ -37,20 +37,20 @@ public RegexpLikeExprMacroTest() public void testErrorZeroArguments() { expectException(IllegalArgumentException.class, "Function[regexp_like] must have 2 arguments"); - eval("regexp_like()", Parser.withMap(ImmutableMap.of())); + eval("regexp_like()", InputBindings.withMap(ImmutableMap.of())); } @Test public void testErrorThreeArguments() { expectException(IllegalArgumentException.class, "Function[regexp_like] must have 2 arguments"); - eval("regexp_like('a', 'b', 'c')", Parser.withMap(ImmutableMap.of())); + eval("regexp_like('a', 'b', 'c')", InputBindings.withMap(ImmutableMap.of())); } @Test public void testMatch() { - final ExprEval result = eval("regexp_like(a, 'f.o')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_like(a, 'f.o')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -60,7 +60,7 @@ public void testMatch() @Test public void testNoMatch() { - final ExprEval result = eval("regexp_like(a, 'f.x')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_like(a, 'f.x')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( ExprEval.ofLongBoolean(false).value(), result.value() @@ -74,7 +74,7 @@ public void testNullPattern() expectException(IllegalArgumentException.class, "Function[regexp_like] pattern must be a string literal"); } - final ExprEval result = eval("regexp_like(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_like(a, null)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -84,7 +84,7 @@ public void testNullPattern() @Test public void testEmptyStringPattern() { - final ExprEval result = eval("regexp_like(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_like(a, '')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -98,7 +98,7 @@ public void testNullPatternOnEmptyString() expectException(IllegalArgumentException.class, "Function[regexp_like] pattern must be a string literal"); } - final ExprEval result = eval("regexp_like(a, null)", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("regexp_like(a, null)", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -108,7 +108,7 @@ public void testNullPatternOnEmptyString() @Test public void testEmptyStringPatternOnEmptyString() { - final ExprEval result = eval("regexp_like(a, '')", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("regexp_like(a, '')", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -122,7 +122,7 @@ public void testNullPatternOnNull() expectException(IllegalArgumentException.class, "Function[regexp_like] pattern must be a string literal"); } - final ExprEval result = eval("regexp_like(a, null)", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("regexp_like(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -132,7 +132,7 @@ public void testNullPatternOnNull() @Test public void testEmptyStringPatternOnNull() { - final ExprEval result = eval("regexp_like(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("regexp_like(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( ExprEval.ofLongBoolean(NullHandling.replaceWithDefault()).value(), result.value() diff --git a/server/src/test/java/org/apache/druid/query/expression/LookupExprMacroTest.java b/server/src/test/java/org/apache/druid/query/expression/LookupExprMacroTest.java index f23afaa99bd3..ce8c9e3b41a6 100644 --- a/server/src/test/java/org/apache/druid/query/expression/LookupExprMacroTest.java +++ b/server/src/test/java/org/apache/druid/query/expression/LookupExprMacroTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; @@ -30,7 +31,7 @@ public class LookupExprMacroTest extends InitializedNullHandlingTest { - private static final Expr.ObjectBinding BINDINGS = Parser.withMap( + private static final Expr.ObjectBinding BINDINGS = InputBindings.withMap( ImmutableMap.builder() .put("x", "foo") .build() diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java new file mode 100644 index 000000000000..96e51bb6db45 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.aggregation.builtin; + +import com.google.common.collect.ImmutableSet; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.Optionality; +import org.apache.druid.java.util.common.HumanReadableBytes; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; +import org.apache.druid.segment.VirtualColumn; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.sql.calcite.aggregation.Aggregation; +import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class ArraySqlAggregator implements SqlAggregator +{ + private static final String NAME = "ARRAY_AGG"; + private static final SqlAggFunction FUNCTION = new ArrayAggFunction(); + + @Override + public SqlAggFunction calciteFunction() + { + return FUNCTION; + } + + @Nullable + @Override + public Aggregation toDruidAggregation( + PlannerContext plannerContext, + RowSignature rowSignature, + VirtualColumnRegistry virtualColumnRegistry, + RexBuilder rexBuilder, + String name, + AggregateCall aggregateCall, + Project project, + List existingAggregations, + boolean finalizeAggregations + ) + { + + final List arguments = aggregateCall + .getArgList() + .stream() + .map(i -> Expressions.fromFieldAccess(rowSignature, project, i)) + .map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode)) + .collect(Collectors.toList()); + + if (arguments.stream().anyMatch(Objects::isNull)) { + return null; + } + + Integer maxSizeBytes = null; + if (arguments.size() > 1) { + RexNode maxBytes = Expressions.fromFieldAccess( + rowSignature, + project, + aggregateCall.getArgList().get(1) + ); + if (!maxBytes.isA(SqlKind.LITERAL)) { + // maxBytes must be a literal + return null; + } + maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); + } + final DruidExpression arg = arguments.get(0); + final ExprMacroTable macroTable = plannerContext.getExprMacroTable(); + + final String fieldName; + final String initialvalue; + final ValueType elementType; + final ValueType druidType = Calcites.getValueTypeForRelDataTypeFull(aggregateCall.getType()); + if (druidType == null) { + initialvalue = "[]"; + elementType = ValueType.STRING; + } else { + switch (druidType) { + case LONG_ARRAY: + initialvalue = "[]"; + elementType = ValueType.LONG; + break; + case DOUBLE_ARRAY: + initialvalue = "[]"; + elementType = ValueType.DOUBLE; + break; + default: + initialvalue = "[]"; + elementType = ValueType.STRING; + break; + } + } + List virtualColumns = new ArrayList<>(); + + if (arg.isDirectColumnAccess()) { + fieldName = arg.getDirectColumn(); + } else { + VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, elementType); + virtualColumns.add(vc); + fieldName = vc.getOutputName(); + } + + if (aggregateCall.isDistinct()) { + return Aggregation.create( + virtualColumns, + new ExpressionLambdaAggregatorFactory( + name, + ImmutableSet.of(fieldName), + null, + initialvalue, + null, + StringUtils.format("array_set_add(\"__acc\", \"%s\")", fieldName), + StringUtils.format("array_set_add_all(\"__acc\", \"%s\")", name), + null, + "if(array_length(o) == 0, null, o)", + maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, + macroTable + ) + ); + } else { + return Aggregation.create( + virtualColumns, + new ExpressionLambdaAggregatorFactory( + name, + ImmutableSet.of(fieldName), + null, + initialvalue, + null, + StringUtils.format("array_append(\"__acc\", \"%s\")", fieldName), + StringUtils.format("array_concat(\"__acc\", \"%s\")", name), + null, + "if(array_length(o) == 0, null, o)", + maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, + macroTable + ) + ); + } + } + + static class ArrayAggReturnTypeInference implements SqlReturnTypeInference + { + @Override + public RelDataType inferReturnType(SqlOperatorBinding sqlOperatorBinding) + { + RelDataType type = sqlOperatorBinding.getOperandType(0); + if (SqlTypeUtil.isArray(type)) { + throw new ISE("Cannot ARRAY_AGG on array inputs %s", type); + } + return Calcites.createSqlArrayTypeWithNullability( + sqlOperatorBinding.getTypeFactory(), + type.getSqlTypeName(), + true + ); + } + } + + private static class ArrayAggFunction extends SqlAggFunction + { + private static final ArrayAggReturnTypeInference RETURN_TYPE_INFERENCE = new ArrayAggReturnTypeInference(); + + ArrayAggFunction() + { + super( + NAME, + null, + SqlKind.OTHER_FUNCTION, + RETURN_TYPE_INFERENCE, + InferTypes.ANY_NULLABLE, + OperandTypes.or( + OperandTypes.ANY, + OperandTypes.and( + OperandTypes.sequence(StringUtils.format("'%s'(expr, maxSizeBytes)", NAME), OperandTypes.ANY, OperandTypes.LITERAL), + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + ) + ), + SqlFunctionCategory.USER_DEFINED_FUNCTION, + false, + false, + Optionality.IGNORED + ); + } + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index 0f38dd72210b..87b731713031 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -125,8 +125,28 @@ public static String escapeStringLiteral(final String s) } + /** + * Convert {@link RelDataType} to the most appropriate {@link ValueType}, coercing all ARRAY types to STRING (until + * the time is right and we are more comfortable handling Druid ARRAY types in all parts of the engine). + * + * Callers who are not scared of ARRAY types should isntead call {@link #getValueTypeForRelDataTypeFull(RelDataType)}, + * which returns the most accurate conversion of {@link RelDataType} to {@link ValueType}. + */ @Nullable public static ValueType getValueTypeForRelDataType(final RelDataType type) + { + ValueType valueType = getValueTypeForRelDataTypeFull(type); + if (ValueType.isArray(valueType)) { + return ValueType.STRING; + } + return valueType; + } + + /** + * Convert {@link RelDataType} to the most appropriate {@link ValueType} + */ + @Nullable + public static ValueType getValueTypeForRelDataTypeFull(final RelDataType type) { final SqlTypeName sqlTypeName = type.getSqlTypeName(); if (SqlTypeName.FLOAT == sqlTypeName) { @@ -142,15 +162,12 @@ public static ValueType getValueTypeForRelDataType(final RelDataType type) } else if (sqlTypeName == SqlTypeName.ARRAY) { SqlTypeName componentType = type.getComponentType().getSqlTypeName(); if (isDoubleType(componentType)) { - // in the future return ValueType.DOUBLE_ARRAY; - return ValueType.STRING; + return ValueType.DOUBLE_ARRAY; } if (isLongType(componentType)) { - // in the future we will return ValueType.LONG_ARRAY; - return ValueType.STRING; + return ValueType.LONG_ARRAY; } - // in the future we will return ValueType.STRING_ARRAY; - return ValueType.STRING; + return ValueType.STRING_ARRAY; } else { return null; } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index 7f5ee4a3c51e..f0135b7c0a5a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -34,6 +34,7 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.ApproxCountDistinctSqlAggregator; +import org.apache.druid.sql.calcite.aggregation.builtin.ArraySqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.AvgSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.CountSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.EarliestLatestAnySqlAggregator; @@ -133,6 +134,7 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new SumSqlAggregator()) .add(new SumZeroSqlAggregator()) .add(new GroupingSqlAggregator()) + .add(new ArraySqlAggregator()) .build(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 0ce2d2bd901f..5df9570af423 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -22,11 +22,13 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import junitparams.JUnitParamsRunner; import junitparams.Parameters; import org.apache.calcite.plan.RelOptPlanner; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.JodaUtils; @@ -52,6 +54,7 @@ import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory; import org.apache.druid.query.aggregation.FloatMinAggregatorFactory; @@ -80,11 +83,13 @@ import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec; +import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.extraction.RegexDimExtractionFn; import org.apache.druid.query.extraction.SubstringDimExtractionFn; import org.apache.druid.query.filter.AndDimFilter; import org.apache.druid.query.filter.BoundDimFilter; import org.apache.druid.query.filter.DimFilter; +import org.apache.druid.query.filter.ExpressionDimFilter; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; import org.apache.druid.query.filter.NotDimFilter; @@ -95,6 +100,7 @@ import org.apache.druid.query.groupby.GroupByQueryConfig; import org.apache.druid.query.groupby.ResultRow; import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; +import org.apache.druid.query.groupby.orderby.NoopLimitSpec; import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; import org.apache.druid.query.groupby.orderby.OrderByColumnSpec.Direction; import org.apache.druid.query.lookup.RegisteredLookupExtractionFn; @@ -17903,4 +17909,625 @@ public void testRoundFuc() throws Exception ) ); } + + @Test + public void testArrayAgg() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_AGG(dim1), ARRAY_AGG(DISTINCT dim1), ARRAY_AGG(DISTINCT dim1) FILTER(WHERE dim1 = 'shazbot') FROM foo WHERE dim1 is not null", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(not(selector("dim1", null, null))) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"dim1\")", + "array_concat(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a2", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a2\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + selector("dim1", "shazbot", null) + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + ? new Object[]{"[\"10.1\",\"2\",\"1\",\"def\",\"abc\"]", "[\"1\",\"2\",\"abc\",\"def\",\"10.1\"]", null} + : new Object[]{"[\"\",\"10.1\",\"2\",\"1\",\"def\",\"abc\"]", "[\"\",\"1\",\"2\",\"abc\",\"def\",\"10.1\"]", null} + ) + ); + } + + @Test + public void testArrayAggMultiValue() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_AGG(dim3), ARRAY_AGG(DISTINCT dim3) FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim3"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"dim3\")", + "array_concat(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("dim3"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim3\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + ? new Object[]{"[\"a\",\"b\",\"b\",\"c\",\"d\",null,null,null]", "[null,\"a\",\"b\",\"c\",\"d\"]"} + : new Object[]{"[\"a\",\"b\",\"b\",\"c\",\"d\",\"\",null,null]", "[\"\",null,\"a\",\"b\",\"c\",\"d\"]"} + ) + ); + } + + @Test + public void testArrayAggNumeric() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_AGG(l1), ARRAY_AGG(DISTINCT l1), ARRAY_AGG(d1), ARRAY_AGG(DISTINCT d1), ARRAY_AGG(f1), ARRAY_AGG(DISTINCT f1) FROM numfoo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("l1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"l1\")", + "array_concat(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("l1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"l1\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a2", + ImmutableSet.of("d1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"d1\")", + "array_concat(\"__acc\", \"a2\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a3", + ImmutableSet.of("d1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"d1\")", + "array_set_add_all(\"__acc\", \"a3\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a4", + ImmutableSet.of("f1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"f1\")", + "array_concat(\"__acc\", \"a4\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a5", + ImmutableSet.of("f1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"f1\")", + "array_set_add_all(\"__acc\", \"a5\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + ? new Object[]{ + "[7,325323,0,0,0,0]", + "[0,7,325323]", + "[1.0,1.7,0.0,0.0,0.0,0.0]", + "[1.0,0.0,1.7]", + "[1.0,0.10000000149011612,0.0,0.0,0.0,0.0]", + "[1.0,0.10000000149011612,0.0]" + } + : new Object[]{ + "[7,325323,0,null,null,null]", + "[0,null,7,325323]", + "[1.0,1.7,0.0,null,null,null]", + "[1.0,0.0,null,1.7]", + "[1.0,0.10000000149011612,0.0,null,null,null]", + "[1.0,0.10000000149011612,0.0,null]" + } + ) + ); + } + + @Test + public void testArrayAggToString() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_TO_STRING(ARRAY_AGG(DISTINCT dim1), ',') FROM foo WHERE dim1 is not null", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(not(selector("dim1", null, null))) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .postAggregators(expressionPostAgg("p0", "array_to_string(\"a0\",',')")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault ? new Object[]{"1,2,abc,def,10.1"} : new Object[]{",1,2,abc,def,10.1"} + ) + ); + } + + @Test + public void testArrayAggExpression() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_TO_STRING(ARRAY_AGG(DISTINCT CONCAT(dim1, dim2)), ',') FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .virtualColumns( + expressionVirtualColumn("v0", "concat(\"dim1\",\"dim2\")", ValueType.STRING) + ) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("v0"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"v0\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .postAggregators(expressionPostAgg("p0", "array_to_string(\"a0\",',')")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault ? new Object[]{"1a,a,2,abc,10.1,defabc"} : new Object[]{"null,1a,a,2,defabc"} + ) + ); + } + + @Test + public void testArrayAggMaxBytes() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_AGG(l1, 128), ARRAY_AGG(DISTINCT l1, 128) FROM numfoo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("l1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"l1\")", + "array_concat(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(128), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("l1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"l1\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(128), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + ? new Object[]{"[7,325323,0,0,0,0]", "[0,7,325323]"} + : new Object[]{"[7,325323,0,null,null,null]", "[0,null,7,325323]"} + ) + ); + } + + @Test + public void testArrayAggAsArrayFromJoin() throws Exception + { + cannotVectorize(); + List expectedResults; + if (useDefault) { + expectedResults = ImmutableList.of( + new Object[]{"a", "[\"2\",\"10.1\"]", "2,10.1"}, + new Object[]{"a", "[\"2\",\"10.1\"]", "2,10.1"}, + new Object[]{"a", "[\"2\",\"10.1\"]", "2,10.1"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"} + ); + } else { + expectedResults = ImmutableList.of( + new Object[]{"a", "[\"\",\"2\",\"10.1\"]", ",2,10.1"}, + new Object[]{"a", "[\"\",\"2\",\"10.1\"]", ",2,10.1"}, + new Object[]{"a", "[\"\",\"2\",\"10.1\"]", ",2,10.1"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"} + ); + } + testQuery( + "SELECT numfoo.dim4, j.arr, ARRAY_TO_STRING(j.arr, ',') FROM numfoo INNER JOIN (SELECT dim4, ARRAY_AGG(DISTINCT dim1) as arr FROM numfoo WHERE dim1 is not null GROUP BY 1) as j ON numfoo.dim4 = j.dim4", + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE3), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(not(selector("dim1", null, null))) + .setDimensions(new DefaultDimensionSpec("dim4", "_d0")) + .setAggregatorSpecs( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + "j0.", + "(\"dim4\" == \"j0._d0\")", + JoinType.INNER, + null + ) + ) + .virtualColumns( + expressionVirtualColumn("v0", "array_to_string(\"j0.a0\",',')", ValueType.STRING) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("dim4", "j0.a0", "v0") + .context(QUERY_CONTEXT_DEFAULT) + .resultFormat(ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .build() + + ), + expectedResults + ); + } + + @Test + public void testArrayAggGroupByArrayAggFromSubquery() throws Exception + { + cannotVectorize(); + // yo, can't group on array types right now so expect failure + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("Cannot create query type helper from invalid type [STRING_ARRAY]"); + testQuery( + "SELECT dim2, arr, COUNT(*) FROM (SELECT dim2, ARRAY_AGG(DISTINCT dim1) as arr FROM foo WHERE dim1 is not null GROUP BY 1 LIMIT 5) GROUP BY 1,2", + ImmutableList.of(), + ImmutableList.of() + ); + } + + @Test + public void testArrayAggArrayContainsSubquery() throws Exception + { + cannotVectorize(); + List expectedResults; + if (useDefault) { + expectedResults = ImmutableList.of( + new Object[]{"10.1", ""}, + new Object[]{"2", ""}, + new Object[]{"1", "a"}, + new Object[]{"def", "abc"}, + new Object[]{"abc", ""} + ); + } else { + expectedResults = ImmutableList.of( + new Object[]{"", "a"}, + new Object[]{"10.1", null}, + new Object[]{"2", ""}, + new Object[]{"1", "a"}, + new Object[]{"def", "abc"}, + new Object[]{"abc", null} + ); + } + testQuery( + "SELECT dim1,dim2 FROM foo WHERE ARRAY_CONTAINS((SELECT ARRAY_AGG(DISTINCT dim1) FROM foo WHERE dim1 is not null), dim1)", + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(not(selector("dim1", null, null))) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "j0.", + "1", + JoinType.LEFT, + null + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters( + new ExpressionDimFilter( + "array_contains(\"j0.a0\",\"dim1\")", + TestExprMacroTable.INSTANCE + ) + ) + .columns("dim1", "dim2") + .context(QUERY_CONTEXT_DEFAULT) + .resultFormat(ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .build() + + ), + expectedResults + ); + } + + @Test + public void testArrayAggGroupByArrayContainsSubquery() throws Exception + { + cannotVectorize(); + List expectedResults; + if (useDefault) { + expectedResults = ImmutableList.of( + new Object[]{"", 3L}, + new Object[]{"a", 1L}, + new Object[]{"abc", 1L} + ); + } else { + expectedResults = ImmutableList.of( + new Object[]{null, 2L}, + new Object[]{"", 1L}, + new Object[]{"a", 2L}, + new Object[]{"abc", 1L} + ); + } + testQuery( + "SELECT dim2, COUNT(*) FROM foo WHERE ARRAY_CONTAINS((SELECT ARRAY_AGG(DISTINCT dim1) FROM foo WHERE dim1 is not null), dim1) GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(not(selector("dim1", null, null))) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "j0.", + "1", + JoinType.LEFT, + null + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimFilter( + new ExpressionDimFilter( + "array_contains(\"j0.a0\",\"dim1\")", + TestExprMacroTable.INSTANCE + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0"))) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) + .setGranularity(Granularities.ALL) + .setLimitSpec(NoopLimitSpec.instance()) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + + ), + expectedResults + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java index ebd3ecdf2799..dc56fe7bd711 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java @@ -31,6 +31,7 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.data.input.MapBasedRow; import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.ValueMatcher; @@ -246,7 +247,7 @@ void testExpression( Assert.assertEquals("Expression for: " + rexNode, expectedExpression, expression); ExprEval result = Parser.parse(expression.getExpression(), PLANNER_CONTEXT.getExprMacroTable()) - .eval(Parser.withMap(bindings)); + .eval(InputBindings.withMap(bindings)); Assert.assertEquals("Result for: " + rexNode, expectedResult, result.value()); } diff --git a/website/.spelling b/website/.spelling index d717fbf9ecf7..5bebca07f0d8 100644 --- a/website/.spelling +++ b/website/.spelling @@ -1497,6 +1497,7 @@ file2 - ../docs/querying/sql.md APPROX_COUNT_DISTINCT APPROX_QUANTILE +ARRAY_AGG BIGINT CATALOG_NAME CHARACTER_MAXIMUM_LENGTH