From 4951327f7f00f5967e3897f250c3c95d9adbe852 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Fri, 23 Apr 2021 02:12:05 -0700 Subject: [PATCH 1/7] ARRAY_AGG sql aggregator function --- docs/querying/sql.md | 6 +- .../builtin/ArraySqlAggregator.java | 221 +++++++ .../druid/sql/calcite/planner/Calcites.java | 22 +- .../calcite/planner/DruidOperatorTable.java | 2 + .../druid/sql/calcite/CalciteQueryTest.java | 611 ++++++++++++++++++ 5 files changed, 855 insertions(+), 7 deletions(-) create mode 100644 sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 7dece55b74c2..8da8e2d1410f 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -311,7 +311,7 @@ Aggregation functions can appear in the SELECT clause of any query. Any aggregat `AGG(expr) FILTER(WHERE whereExpr)`. Filtered aggregators will only aggregate rows that match their filter. It's possible for two aggregators in the same SQL query to have different filters. -Only the COUNT aggregation can accept DISTINCT. +Only the COUNT and ARRAY_AGG aggregations can accept DISTINCT. > The order of aggregation operations across segments is not deterministic. This means that non-commutative aggregation > functions can produce inconsistent results across the same query. @@ -353,6 +353,10 @@ Only the COUNT aggregation can accept DISTINCT. |`ANY_VALUE(expr)`|Returns any value of `expr` including null. `expr` must be numeric. This aggregator can simplify and optimize the performance by returning the first encountered value (including null)| |`ANY_VALUE(expr, maxBytesPerString)`|Like `ANY_VALUE(expr)`, but for strings. The `maxBytesPerString` parameter determines how much aggregation space to allocate per string. Strings longer than this limit will be truncated. This parameter should be set as low as possible, since high values will lead to wasted memory.| |`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.| +|`ARRAY_AGG(expr)`|Collects all values of `expr` into an ARRAY, including null values, with the default limit on aggregation size of 1024 bytes. `ORDER BY` on the `ARRAY_AGG` expression is not currently supported.| +|`ARRAY_AGG(DISTINCT expr)`|Collects all distinct values of `expr` into an ARRAY, including null values, with the default limit on aggregation size of 1024 bytes per aggregate. `ORDER BY` on the `ARRAY_AGG` expression is not currently supported.| +|`ARRAY_AGG(expr, maxSizeBytes)`|Collects all values of `expr` into an ARRAY, including null values, with specified maximum byte size per aggregate. `ORDER BY` on the `ARRAY_AGG` expression is not currently supported.| +|`ARRAY_AGG(DISTINCT expr, maxSizeBytes)`|Collects all distinct values of `expr` into an ARRAY, including null values, with specified maximum byte size per aggregate. `ORDER BY` on the `ARRAY_AGG` expression is not currently supported.| For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.md#approx). diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java new file mode 100644 index 000000000000..d429323a0ab6 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.aggregation.builtin; + +import com.google.common.collect.ImmutableSet; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.Optionality; +import org.apache.druid.java.util.common.HumanReadableBytes; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; +import org.apache.druid.segment.VirtualColumn; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.sql.calcite.aggregation.Aggregation; +import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class ArraySqlAggregator implements SqlAggregator +{ + private static final String NAME = "ARRAY_AGG"; + + @Override + public SqlAggFunction calciteFunction() + { + return new ArrayAggFunction(); + } + + @Nullable + @Override + public Aggregation toDruidAggregation( + PlannerContext plannerContext, + RowSignature rowSignature, + VirtualColumnRegistry virtualColumnRegistry, + RexBuilder rexBuilder, + String name, + AggregateCall aggregateCall, + Project project, + List existingAggregations, + boolean finalizeAggregations + ) + { + + final List arguments = aggregateCall + .getArgList() + .stream() + .map(i -> Expressions.fromFieldAccess(rowSignature, project, i)) + .map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode)) + .collect(Collectors.toList()); + + if (arguments.stream().anyMatch(Objects::isNull)) { + return null; + } + + Integer maxSizeBytes = null; + if (arguments.size() > 1) { + RexNode maxBytes = Expressions.fromFieldAccess( + rowSignature, + project, + aggregateCall.getArgList().get(1) + ); + if (!maxBytes.isA(SqlKind.LITERAL)) { + // maxBytes must be a literal + return null; + } + maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); + } + final DruidExpression arg = arguments.get(0); + final ExprMacroTable macroTable = plannerContext.getExprMacroTable(); + + final String fieldName; + final String initialvalue; + final ValueType elementType; + final ValueType druidType = Calcites.getValueTypeForRelDataTypeFull(aggregateCall.getType()); + switch (druidType) { + case LONG_ARRAY: + initialvalue = "[]"; + elementType = ValueType.LONG; + break; + case DOUBLE_ARRAY: + initialvalue = "[]"; + elementType = ValueType.DOUBLE; + break; + default: + initialvalue = "[]"; + elementType = ValueType.STRING; + break; + } + List virtualColumns = new ArrayList<>(); + + if (arg.isDirectColumnAccess()) { + fieldName = arg.getDirectColumn(); + } else { + VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, elementType); + virtualColumns.add(vc); + fieldName = vc.getOutputName(); + } + + if (aggregateCall.isDistinct()) { + return Aggregation.create( + virtualColumns, + new ExpressionLambdaAggregatorFactory( + name, + ImmutableSet.of(fieldName), + null, + initialvalue, + null, + StringUtils.format("array_set_add(\"__acc\", \"%s\")", fieldName), + StringUtils.format("array_set_add_all(\"__acc\", \"%s\")", name), + null, + null, + maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, + macroTable + ) + ); + } else { + return Aggregation.create( + virtualColumns, + new ExpressionLambdaAggregatorFactory( + name, + ImmutableSet.of(fieldName), + null, + initialvalue, + null, + StringUtils.format("array_append(\"__acc\", \"%s\")", fieldName), + StringUtils.format("array_concat(\"__acc\", \"%s\")", name), + null, + null, + maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, + macroTable + ) + ); + } + } + + static class ArrayAggReturnTypeInference implements SqlReturnTypeInference + { + @Override + public RelDataType inferReturnType(SqlOperatorBinding sqlOperatorBinding) + { + RelDataType type = sqlOperatorBinding.getOperandType(0); + if (SqlTypeUtil.isArray(type)) { + throw new ISE("Cannot ARRAY_AGG on array inputs %s", type); + } + return Calcites.createSqlArrayTypeWithNullability( + sqlOperatorBinding.getTypeFactory(), + type.getSqlTypeName(), + true + ); + } + } + + private static class ArrayAggFunction extends SqlAggFunction + { + private static final ArrayAggReturnTypeInference RETURN_TYPE_INFERENCE = new ArrayAggReturnTypeInference(); + + ArrayAggFunction() + { + super( + NAME, + null, + SqlKind.OTHER_FUNCTION, + RETURN_TYPE_INFERENCE, + InferTypes.ANY_NULLABLE, + OperandTypes.or( + OperandTypes.ANY, + OperandTypes.and( + OperandTypes.sequence(StringUtils.format("'%s'(expr, maxSizeBytes)", NAME), OperandTypes.ANY, OperandTypes.LITERAL), + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC) + ) + ), + SqlFunctionCategory.USER_DEFINED_FUNCTION, + false, + false, + Optionality.IGNORED + ); + } + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index 0f38dd72210b..d4bb1d26b1a3 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -127,6 +127,19 @@ public static String escapeStringLiteral(final String s) @Nullable public static ValueType getValueTypeForRelDataType(final RelDataType type) + { + ValueType valueType = getValueTypeForRelDataTypeFull(type); + if (ValueType.isArray(valueType)) { + return ValueType.STRING; + } + return valueType; + } + + /** + * Convert {@link RelDataType} to the most appropriate {@link ValueType} + */ + @Nullable + public static ValueType getValueTypeForRelDataTypeFull(final RelDataType type) { final SqlTypeName sqlTypeName = type.getSqlTypeName(); if (SqlTypeName.FLOAT == sqlTypeName) { @@ -142,15 +155,12 @@ public static ValueType getValueTypeForRelDataType(final RelDataType type) } else if (sqlTypeName == SqlTypeName.ARRAY) { SqlTypeName componentType = type.getComponentType().getSqlTypeName(); if (isDoubleType(componentType)) { - // in the future return ValueType.DOUBLE_ARRAY; - return ValueType.STRING; + return ValueType.DOUBLE_ARRAY; } if (isLongType(componentType)) { - // in the future we will return ValueType.LONG_ARRAY; - return ValueType.STRING; + return ValueType.LONG_ARRAY; } - // in the future we will return ValueType.STRING_ARRAY; - return ValueType.STRING; + return ValueType.STRING_ARRAY; } else { return null; } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index 7f5ee4a3c51e..f0135b7c0a5a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -34,6 +34,7 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.ApproxCountDistinctSqlAggregator; +import org.apache.druid.sql.calcite.aggregation.builtin.ArraySqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.AvgSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.CountSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.EarliestLatestAnySqlAggregator; @@ -133,6 +134,7 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new SumSqlAggregator()) .add(new SumZeroSqlAggregator()) .add(new GroupingSqlAggregator()) + .add(new ArraySqlAggregator()) .build(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index e47a89a5f784..d2de64c7cd3f 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -22,11 +22,13 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import junitparams.JUnitParamsRunner; import junitparams.Parameters; import org.apache.calcite.plan.RelOptPlanner; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.JodaUtils; @@ -52,6 +54,7 @@ import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FloatMaxAggregatorFactory; import org.apache.druid.query.aggregation.FloatMinAggregatorFactory; @@ -80,11 +83,13 @@ import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec; +import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.extraction.RegexDimExtractionFn; import org.apache.druid.query.extraction.SubstringDimExtractionFn; import org.apache.druid.query.filter.AndDimFilter; import org.apache.druid.query.filter.BoundDimFilter; import org.apache.druid.query.filter.DimFilter; +import org.apache.druid.query.filter.ExpressionDimFilter; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; import org.apache.druid.query.filter.NotDimFilter; @@ -95,6 +100,7 @@ import org.apache.druid.query.groupby.GroupByQueryConfig; import org.apache.druid.query.groupby.ResultRow; import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; +import org.apache.druid.query.groupby.orderby.NoopLimitSpec; import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; import org.apache.druid.query.groupby.orderby.OrderByColumnSpec.Direction; import org.apache.druid.query.lookup.RegisteredLookupExtractionFn; @@ -17847,4 +17853,609 @@ public void testRoundFuc() throws Exception ) ); } + + @Test + public void testArrayAgg() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_AGG(dim1), ARRAY_AGG(DISTINCT dim1) FROM foo WHERE dim1 is not null", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(not(selector("dim1", null, null))) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"dim1\")", + "array_concat(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + ? new Object[]{"[\"10.1\",\"2\",\"1\",\"def\",\"abc\"]", "[\"1\",\"2\",\"abc\",\"def\",\"10.1\"]"} + : new Object[]{"[\"\",\"10.1\",\"2\",\"1\",\"def\",\"abc\"]", "[\"\",\"1\",\"2\",\"abc\",\"def\",\"10.1\"]"} + ) + ); + } + + @Test + public void testArrayAggMultiValue() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_AGG(dim3), ARRAY_AGG(DISTINCT dim3) FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim3"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"dim3\")", + "array_concat(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("dim3"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim3\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + ? new Object[]{"[\"a\",\"b\",\"b\",\"c\",\"d\",null,null,null]", "[null,\"a\",\"b\",\"c\",\"d\"]"} + : new Object[]{"[\"a\",\"b\",\"b\",\"c\",\"d\",\"\",null,null]", "[\"\",null,\"a\",\"b\",\"c\",\"d\"]"} + ) + ); + } + + @Test + public void testArrayAggNumeric() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_AGG(l1), ARRAY_AGG(DISTINCT l1), ARRAY_AGG(d1), ARRAY_AGG(DISTINCT d1), ARRAY_AGG(f1), ARRAY_AGG(DISTINCT f1) FROM numfoo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("l1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"l1\")", + "array_concat(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("l1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"l1\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a2", + ImmutableSet.of("d1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"d1\")", + "array_concat(\"__acc\", \"a2\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a3", + ImmutableSet.of("d1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"d1\")", + "array_set_add_all(\"__acc\", \"a3\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a4", + ImmutableSet.of("f1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"f1\")", + "array_concat(\"__acc\", \"a4\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a5", + ImmutableSet.of("f1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"f1\")", + "array_set_add_all(\"__acc\", \"a5\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + ? new Object[]{ + "[7,325323,0,0,0,0]", + "[0,7,325323]", + "[1.0,1.7,0.0,0.0,0.0,0.0]", + "[1.0,0.0,1.7]", + "[1.0,0.10000000149011612,0.0,0.0,0.0,0.0]", + "[1.0,0.10000000149011612,0.0]" + } + : new Object[]{ + "[7,325323,0,null,null,null]", + "[0,null,7,325323]", + "[1.0,1.7,0.0,null,null,null]", + "[1.0,0.0,null,1.7]", + "[1.0,0.10000000149011612,0.0,null,null,null]", + "[1.0,0.10000000149011612,0.0,null]" + } + ) + ); + } + + @Test + public void testArrayAggToString() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_TO_STRING(ARRAY_AGG(DISTINCT dim1), ',') FROM foo WHERE dim1 is not null", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(not(selector("dim1", null, null))) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .postAggregators(expressionPostAgg("p0", "array_to_string(\"a0\",',')")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault ? new Object[]{"1,2,abc,def,10.1"} : new Object[]{",1,2,abc,def,10.1"} + ) + ); + } + + @Test + public void testArrayAggExpression() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_TO_STRING(ARRAY_AGG(DISTINCT CONCAT(dim1, dim2)), ',') FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .virtualColumns( + expressionVirtualColumn("v0", "concat(\"dim1\",\"dim2\")", ValueType.STRING) + ) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("v0"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"v0\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .postAggregators(expressionPostAgg("p0", "array_to_string(\"a0\",',')")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault ? new Object[]{"1a,a,2,abc,10.1,defabc"} : new Object[]{"null,1a,a,2,defabc"} + ) + ); + } + + @Test + public void testArrayAggMaxBytes() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT ARRAY_AGG(l1, 128), ARRAY_AGG(DISTINCT l1, 128) FROM numfoo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("l1"), + "__acc", + "[]", + "[]", + "array_append(\"__acc\", \"l1\")", + "array_concat(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(128), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("l1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"l1\")", + "array_set_add_all(\"__acc\", \"a1\")", + null, + null, + new HumanReadableBytes(128), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + ? new Object[]{"[7,325323,0,0,0,0]", "[0,7,325323]"} + : new Object[]{"[7,325323,0,null,null,null]", "[0,null,7,325323]"} + ) + ); + } + + @Test + public void testArrayAggAsArrayFromJoin() throws Exception + { + cannotVectorize(); + List expectedResults; + if (useDefault) { + expectedResults = ImmutableList.of( + new Object[]{"a", "[\"2\",\"10.1\"]", "2,10.1"}, + new Object[]{"a", "[\"2\",\"10.1\"]", "2,10.1"}, + new Object[]{"a", "[\"2\",\"10.1\"]", "2,10.1"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"} + ); + } else { + expectedResults = ImmutableList.of( + new Object[]{"a", "[\"\",\"2\",\"10.1\"]", ",2,10.1"}, + new Object[]{"a", "[\"\",\"2\",\"10.1\"]", ",2,10.1"}, + new Object[]{"a", "[\"\",\"2\",\"10.1\"]", ",2,10.1"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"}, + new Object[]{"b", "[\"1\",\"abc\",\"def\"]", "1,abc,def"} + ); + } + testQuery( + "SELECT numfoo.dim4, j.arr, ARRAY_TO_STRING(j.arr, ',') FROM numfoo INNER JOIN (SELECT dim4, ARRAY_AGG(DISTINCT dim1) as arr FROM numfoo WHERE dim1 is not null GROUP BY 1) as j ON numfoo.dim4 = j.dim4", + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE3), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(not(selector("dim1", null, null))) + .setDimensions(new DefaultDimensionSpec("dim4", "_d0")) + .setAggregatorSpecs( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + "j0.", + "(\"dim4\" == \"j0._d0\")", + JoinType.INNER, + null + ) + ) + .virtualColumns( + expressionVirtualColumn("v0", "array_to_string(\"j0.a0\",',')", ValueType.STRING) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("dim4", "j0.a0", "v0") + .context(QUERY_CONTEXT_DEFAULT) + .resultFormat(ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .build() + + ), + expectedResults + ); + } + + @Test + public void testArrayAggGroupByArrayAggFromSubquery() throws Exception + { + cannotVectorize(); + // yo, can't group on array types right now so expect failure + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("Cannot create query type helper from invalid type [STRING_ARRAY]"); + testQuery( + "SELECT dim2, arr, COUNT(*) FROM (SELECT dim2, ARRAY_AGG(DISTINCT dim1) as arr FROM foo WHERE dim1 is not null GROUP BY 1 LIMIT 5) GROUP BY 1,2", + ImmutableList.of(), + ImmutableList.of() + ); + } + + @Test + public void testArrayAggArrayContainsSubquery() throws Exception + { + cannotVectorize(); + List expectedResults; + if (useDefault) { + expectedResults = ImmutableList.of( + new Object[]{"10.1", ""}, + new Object[]{"2", ""}, + new Object[]{"1", "a"}, + new Object[]{"def", "abc"}, + new Object[]{"abc", ""} + ); + } else { + expectedResults = ImmutableList.of( + new Object[]{"", "a"}, + new Object[]{"10.1", null}, + new Object[]{"2", ""}, + new Object[]{"1", "a"}, + new Object[]{"def", "abc"}, + new Object[]{"abc", null} + ); + } + testQuery( + "SELECT dim1,dim2 FROM foo WHERE ARRAY_CONTAINS((SELECT ARRAY_AGG(DISTINCT dim1) FROM foo WHERE dim1 is not null), dim1)", + ImmutableList.of( + Druids.newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(not(selector("dim1", null, null))) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "j0.", + "1", + JoinType.LEFT, + null + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters( + new ExpressionDimFilter( + "array_contains(\"j0.a0\",\"dim1\")", + TestExprMacroTable.INSTANCE + ) + ) + .columns("dim1", "dim2") + .context(QUERY_CONTEXT_DEFAULT) + .resultFormat(ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .build() + + ), + expectedResults + ); + } + + @Test + public void testArrayAggGroupByArrayContainsSubquery() throws Exception + { + cannotVectorize(); + List expectedResults; + if (useDefault) { + expectedResults = ImmutableList.of( + new Object[]{"", 3L}, + new Object[]{"a", 1L}, + new Object[]{"abc", 1L} + ); + } else { + expectedResults = ImmutableList.of( + new Object[]{null, 2L}, + new Object[]{"", 1L}, + new Object[]{"a", 2L}, + new Object[]{"abc", 1L} + ); + } + testQuery( + "SELECT dim2, COUNT(*) FROM foo WHERE ARRAY_CONTAINS((SELECT ARRAY_AGG(DISTINCT dim1) FROM foo WHERE dim1 is not null), dim1) GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(not(selector("dim1", null, null))) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ) + ) + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "j0.", + "1", + JoinType.LEFT, + null + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimFilter( + new ExpressionDimFilter( + "array_contains(\"j0.a0\",\"dim1\")", + TestExprMacroTable.INSTANCE + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0"))) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) + .setGranularity(Granularities.ALL) + .setLimitSpec(NoopLimitSpec.instance()) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + + ), + expectedResults + ); + } } From 8cb659ca8162bed3de77346db7cea2c7da196711 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Fri, 23 Apr 2021 02:32:05 -0700 Subject: [PATCH 2/7] add javadoc --- .../org/apache/druid/sql/calcite/planner/Calcites.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index d4bb1d26b1a3..87b731713031 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -125,6 +125,13 @@ public static String escapeStringLiteral(final String s) } + /** + * Convert {@link RelDataType} to the most appropriate {@link ValueType}, coercing all ARRAY types to STRING (until + * the time is right and we are more comfortable handling Druid ARRAY types in all parts of the engine). + * + * Callers who are not scared of ARRAY types should isntead call {@link #getValueTypeForRelDataTypeFull(RelDataType)}, + * which returns the most accurate conversion of {@link RelDataType} to {@link ValueType}. + */ @Nullable public static ValueType getValueTypeForRelDataType(final RelDataType type) { From 5cbe533e501a4fe6867ac3cf623fabb837849749 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Fri, 23 Apr 2021 12:22:33 -0700 Subject: [PATCH 3/7] spelling --- website/.spelling | 1 + 1 file changed, 1 insertion(+) diff --git a/website/.spelling b/website/.spelling index d717fbf9ecf7..5bebca07f0d8 100644 --- a/website/.spelling +++ b/website/.spelling @@ -1497,6 +1497,7 @@ file2 - ../docs/querying/sql.md APPROX_COUNT_DISTINCT APPROX_QUANTILE +ARRAY_AGG BIGINT CATALOG_NAME CHARACTER_MAXIMUM_LENGTH From 7c9d09817c31ffe4e6f3a453a8db9ff1e709b7e5 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Fri, 30 Apr 2021 01:52:31 -0700 Subject: [PATCH 4/7] review stuff, return null instead of empty when nil input --- docs/querying/sql.md | 6 +- .../ExpressionLambdaAggregatorFactory.java | 8 ++- .../builtin/ArraySqlAggregator.java | 4 +- .../druid/sql/calcite/CalciteQueryTest.java | 56 ++++++++++++------- 4 files changed, 45 insertions(+), 29 deletions(-) diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 8da8e2d1410f..9c1ad4f31b9e 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -353,10 +353,8 @@ Only the COUNT and ARRAY_AGG aggregations can accept DISTINCT. |`ANY_VALUE(expr)`|Returns any value of `expr` including null. `expr` must be numeric. This aggregator can simplify and optimize the performance by returning the first encountered value (including null)| |`ANY_VALUE(expr, maxBytesPerString)`|Like `ANY_VALUE(expr)`, but for strings. The `maxBytesPerString` parameter determines how much aggregation space to allocate per string. Strings longer than this limit will be truncated. This parameter should be set as low as possible, since high values will lead to wasted memory.| |`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.| -|`ARRAY_AGG(expr)`|Collects all values of `expr` into an ARRAY, including null values, with the default limit on aggregation size of 1024 bytes. `ORDER BY` on the `ARRAY_AGG` expression is not currently supported.| -|`ARRAY_AGG(DISTINCT expr)`|Collects all distinct values of `expr` into an ARRAY, including null values, with the default limit on aggregation size of 1024 bytes per aggregate. `ORDER BY` on the `ARRAY_AGG` expression is not currently supported.| -|`ARRAY_AGG(expr, maxSizeBytes)`|Collects all values of `expr` into an ARRAY, including null values, with specified maximum byte size per aggregate. `ORDER BY` on the `ARRAY_AGG` expression is not currently supported.| -|`ARRAY_AGG(DISTINCT expr, maxSizeBytes)`|Collects all distinct values of `expr` into an ARRAY, including null values, with specified maximum byte size per aggregate. `ORDER BY` on the `ARRAY_AGG` expression is not currently supported.| +|`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the outuput array may vary depending on processing order.| +|`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the outuput array may vary depending on processing order.| For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.md#approx). diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java index 2da1abde8c0d..b2ff468a0103 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java @@ -363,9 +363,11 @@ public ValueType getFinalizedType() Expr finalizeExpr = finalizeExpression.get(); ExprEval initialVal = initialCombineValue.get(); if (finalizeExpr != null) { - return ExprType.toValueType( - finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, initialVal)).type() - ); + ExprEval eval = finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, initialVal)); + // this might be wrong from time to time, but if evaluating the finalizer on the initial value produces null + // then we cannot safely assume the type since non-vectorized expressions might report all null values as string + // typed, so just assume it preserves the initial value + return ExprType.toValueType(eval.value() == null ? initialValue.get().type() : eval.type()); } return ExprType.toValueType(initialVal.type()); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java index d429323a0ab6..b6ea5c9887a6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -150,7 +150,7 @@ public Aggregation toDruidAggregation( StringUtils.format("array_set_add(\"__acc\", \"%s\")", fieldName), StringUtils.format("array_set_add_all(\"__acc\", \"%s\")", name), null, - null, + "if(array_length(o) == 0, null, o)", maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, macroTable ) @@ -167,7 +167,7 @@ public Aggregation toDruidAggregation( StringUtils.format("array_append(\"__acc\", \"%s\")", fieldName), StringUtils.format("array_concat(\"__acc\", \"%s\")", name), null, - null, + "if(array_length(o) == 0, null, o)", maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, macroTable ) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 0c6875ce9065..5df9570af423 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -17915,7 +17915,7 @@ public void testArrayAgg() throws Exception { cannotVectorize(); testQuery( - "SELECT ARRAY_AGG(dim1), ARRAY_AGG(DISTINCT dim1) FROM foo WHERE dim1 is not null", + "SELECT ARRAY_AGG(dim1), ARRAY_AGG(DISTINCT dim1), ARRAY_AGG(DISTINCT dim1) FILTER(WHERE dim1 = 'shazbot') FROM foo WHERE dim1 is not null", ImmutableList.of( Druids.newTimeseriesQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) @@ -17933,7 +17933,7 @@ public void testArrayAgg() throws Exception "array_append(\"__acc\", \"dim1\")", "array_concat(\"__acc\", \"a0\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -17946,9 +17946,25 @@ public void testArrayAgg() throws Exception "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a1\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a2", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + "array_set_add(\"__acc\", \"dim1\")", + "array_set_add_all(\"__acc\", \"a2\")", + null, + "if(array_length(o) == 0, null, o)", + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + selector("dim1", "shazbot", null) ) ) ) @@ -17957,8 +17973,8 @@ public void testArrayAgg() throws Exception ), ImmutableList.of( useDefault - ? new Object[]{"[\"10.1\",\"2\",\"1\",\"def\",\"abc\"]", "[\"1\",\"2\",\"abc\",\"def\",\"10.1\"]"} - : new Object[]{"[\"\",\"10.1\",\"2\",\"1\",\"def\",\"abc\"]", "[\"\",\"1\",\"2\",\"abc\",\"def\",\"10.1\"]"} + ? new Object[]{"[\"10.1\",\"2\",\"1\",\"def\",\"abc\"]", "[\"1\",\"2\",\"abc\",\"def\",\"10.1\"]", null} + : new Object[]{"[\"\",\"10.1\",\"2\",\"1\",\"def\",\"abc\"]", "[\"\",\"1\",\"2\",\"abc\",\"def\",\"10.1\"]", null} ) ); } @@ -17985,7 +18001,7 @@ public void testArrayAggMultiValue() throws Exception "array_append(\"__acc\", \"dim3\")", "array_concat(\"__acc\", \"a0\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -17998,7 +18014,7 @@ public void testArrayAggMultiValue() throws Exception "array_set_add(\"__acc\", \"dim3\")", "array_set_add_all(\"__acc\", \"a1\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -18037,7 +18053,7 @@ public void testArrayAggNumeric() throws Exception "array_append(\"__acc\", \"l1\")", "array_concat(\"__acc\", \"a0\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -18050,7 +18066,7 @@ public void testArrayAggNumeric() throws Exception "array_set_add(\"__acc\", \"l1\")", "array_set_add_all(\"__acc\", \"a1\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -18063,7 +18079,7 @@ public void testArrayAggNumeric() throws Exception "array_append(\"__acc\", \"d1\")", "array_concat(\"__acc\", \"a2\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -18076,7 +18092,7 @@ public void testArrayAggNumeric() throws Exception "array_set_add(\"__acc\", \"d1\")", "array_set_add_all(\"__acc\", \"a3\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -18089,7 +18105,7 @@ public void testArrayAggNumeric() throws Exception "array_append(\"__acc\", \"f1\")", "array_concat(\"__acc\", \"a4\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -18102,7 +18118,7 @@ public void testArrayAggNumeric() throws Exception "array_set_add(\"__acc\", \"f1\")", "array_set_add_all(\"__acc\", \"a5\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -18156,7 +18172,7 @@ public void testArrayAggToString() throws Exception "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a0\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -18197,7 +18213,7 @@ public void testArrayAggExpression() throws Exception "array_set_add(\"__acc\", \"v0\")", "array_set_add_all(\"__acc\", \"a0\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -18235,7 +18251,7 @@ public void testArrayAggMaxBytes() throws Exception "array_append(\"__acc\", \"l1\")", "array_concat(\"__acc\", \"a0\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(128), TestExprMacroTable.INSTANCE ), @@ -18248,7 +18264,7 @@ public void testArrayAggMaxBytes() throws Exception "array_set_add(\"__acc\", \"l1\")", "array_set_add_all(\"__acc\", \"a1\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(128), TestExprMacroTable.INSTANCE ) @@ -18314,7 +18330,7 @@ public void testArrayAggAsArrayFromJoin() throws Exception "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a0\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -18405,7 +18421,7 @@ public void testArrayAggArrayContainsSubquery() throws Exception "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a0\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -18481,7 +18497,7 @@ public void testArrayAggGroupByArrayContainsSubquery() throws Exception "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a0\")", null, - null, + "if(array_length(o) == 0, null, o)", new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) From 9c1a7b9577a74b61ff76d1c89370ac98b82436c0 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Fri, 30 Apr 2021 02:00:29 -0700 Subject: [PATCH 5/7] review stuff --- .../builtin/ArraySqlAggregator.java | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java index b6ea5c9887a6..96e51bb6db45 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -61,11 +61,12 @@ public class ArraySqlAggregator implements SqlAggregator { private static final String NAME = "ARRAY_AGG"; + private static final SqlAggFunction FUNCTION = new ArrayAggFunction(); @Override public SqlAggFunction calciteFunction() { - return new ArrayAggFunction(); + return FUNCTION; } @Nullable @@ -114,19 +115,24 @@ public Aggregation toDruidAggregation( final String initialvalue; final ValueType elementType; final ValueType druidType = Calcites.getValueTypeForRelDataTypeFull(aggregateCall.getType()); - switch (druidType) { - case LONG_ARRAY: - initialvalue = "[]"; - elementType = ValueType.LONG; - break; - case DOUBLE_ARRAY: - initialvalue = "[]"; - elementType = ValueType.DOUBLE; - break; - default: - initialvalue = "[]"; - elementType = ValueType.STRING; - break; + if (druidType == null) { + initialvalue = "[]"; + elementType = ValueType.STRING; + } else { + switch (druidType) { + case LONG_ARRAY: + initialvalue = "[]"; + elementType = ValueType.LONG; + break; + case DOUBLE_ARRAY: + initialvalue = "[]"; + elementType = ValueType.DOUBLE; + break; + default: + initialvalue = "[]"; + elementType = ValueType.STRING; + break; + } } List virtualColumns = new ArrayList<>(); From b84c69c4ded6b71f9a78c01b4917aef128462ee6 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Fri, 30 Apr 2021 12:19:34 -0700 Subject: [PATCH 6/7] Update sql.md --- docs/querying/sql.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 9c1ad4f31b9e..da89ed4f8969 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -353,8 +353,8 @@ Only the COUNT and ARRAY_AGG aggregations can accept DISTINCT. |`ANY_VALUE(expr)`|Returns any value of `expr` including null. `expr` must be numeric. This aggregator can simplify and optimize the performance by returning the first encountered value (including null)| |`ANY_VALUE(expr, maxBytesPerString)`|Like `ANY_VALUE(expr)`, but for strings. The `maxBytesPerString` parameter determines how much aggregation space to allocate per string. Strings longer than this limit will be truncated. This parameter should be set as low as possible, since high values will lead to wasted memory.| |`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.| -|`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the outuput array may vary depending on processing order.| -|`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the outuput array may vary depending on processing order.| +|`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.| +|`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.| For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.md#approx). From 04dcefe02cfb12518802a269da3c150792eee7d9 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Sat, 1 May 2021 17:06:48 -0700 Subject: [PATCH 7/7] use type inference for finalize, refactor some things --- .../org/apache/druid/math/expr/ExprEval.java | 60 +++++++++++++---- .../apache/druid/math/expr/InputBindings.java | 64 +++++++++++++++++++ .../org/apache/druid/math/expr/Parser.java | 19 ------ .../druid/math/expr/UnaryOperatorExpr.java | 2 +- .../druid/math/expr/ApplyFunctionTest.java | 2 +- .../org/apache/druid/math/expr/EvalTest.java | 6 +- .../apache/druid/math/expr/FunctionTest.java | 2 +- .../apache/druid/math/expr/ParserTest.java | 8 +-- .../ExpressionLambdaAggregatorFactory.java | 19 ++++-- .../post/ExpressionPostAggregator.java | 3 +- .../QueryableIndexColumnSelectorFactory.java | 2 +- .../segment/QueryableIndexStorageAdapter.java | 10 +-- ...yableIndexVectorColumnSelectorFactory.java | 2 +- .../segment/virtual/ExpressionSelectors.java | 4 +- ...ExpressionLambdaAggregatorFactoryTest.java | 3 +- .../CaseInsensitiveExprMacroTest.java | 42 ++++++------ .../expression/ContainsExprMacroTest.java | 38 +++++------ .../druid/query/expression/ExprMacroTest.java | 3 +- .../RegexpExtractExprMacroTest.java | 30 ++++----- .../expression/RegexpLikeExprMacroTest.java | 22 +++---- .../query/expression/LookupExprMacroTest.java | 3 +- .../expression/ExpressionTestHelper.java | 3 +- 22 files changed, 216 insertions(+), 131 deletions(-) create mode 100644 core/src/main/java/org/apache/druid/math/expr/InputBindings.java diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 1a9a96287055..a2ef91f0cfa4 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -374,21 +374,11 @@ private static Class convertType(@Nullable Class existing, Class next) throw new UOE("Invalid array expression type: %s", next); } - public static ExprEval ofLong(@Nullable Number longValue) - { - return new LongExprEval(longValue); - } - public static ExprEval of(long longValue) { return new LongExprEval(longValue); } - public static ExprEval ofDouble(@Nullable Number doubleValue) - { - return new DoubleExprEval(doubleValue); - } - public static ExprEval of(double doubleValue) { return new DoubleExprEval(doubleValue); @@ -402,22 +392,50 @@ public static ExprEval of(@Nullable String stringValue) return new StringExprEval(stringValue); } + public static ExprEval ofLong(@Nullable Number longValue) + { + if (longValue == null) { + return LongExprEval.OF_NULL; + } + return new LongExprEval(longValue); + } + + public static ExprEval ofDouble(@Nullable Number doubleValue) + { + if (doubleValue == null) { + return DoubleExprEval.OF_NULL; + } + return new DoubleExprEval(doubleValue); + } + public static ExprEval ofLongArray(@Nullable Long[] longValue) { + if (longValue == null) { + return LongArrayExprEval.OF_NULL; + } return new LongArrayExprEval(longValue); } public static ExprEval ofDoubleArray(@Nullable Double[] doubleValue) { + if (doubleValue == null) { + return DoubleArrayExprEval.OF_NULL; + } return new DoubleArrayExprEval(doubleValue); } public static ExprEval ofStringArray(@Nullable String[] stringValue) { + if (stringValue == null) { + return StringArrayExprEval.OF_NULL; + } return new StringArrayExprEval(stringValue); } - public static ExprEval of(boolean value, ExprType type) + /** + * Convert a boolean back into native expression type + */ + public static ExprEval ofBoolean(boolean value, ExprType type) { switch (type) { case DOUBLE: @@ -431,11 +449,17 @@ public static ExprEval of(boolean value, ExprType type) } } + /** + * Convert a boolean into a long expression type + */ public static ExprEval ofLongBoolean(boolean value) { return ExprEval.of(Evals.asLong(value)); } + /** + * Examine java type to find most appropriate expression type + */ public static ExprEval bestEffortOf(@Nullable Object val) { if (val instanceof ExprEval) { @@ -631,6 +655,8 @@ public boolean isNumericNull() private static class DoubleExprEval extends NumericExprEval { + private static final DoubleExprEval OF_NULL = new DoubleExprEval(null); + private DoubleExprEval(@Nullable Number value) { super(value == null ? NullHandling.defaultDoubleValue() : (Double) value.doubleValue()); @@ -691,6 +717,8 @@ public Expr toExpr() private static class LongExprEval extends NumericExprEval { + private static final LongExprEval OF_NULL = new LongExprEval(null); + private LongExprEval(@Nullable Number value) { super(value == null ? NullHandling.defaultLongValue() : (Long) value.longValue()); @@ -758,6 +786,8 @@ public Expr toExpr() private static class StringExprEval extends ExprEval { + private static final StringExprEval OF_NULL = new StringExprEval(null); + // Cached primitive values. private boolean intValueValid = false; private boolean longValueValid = false; @@ -768,8 +798,6 @@ private static class StringExprEval extends ExprEval private double doubleValue; private boolean booleanValue; - private static final StringExprEval OF_NULL = new StringExprEval(null); - @Nullable private Number numericVal; @@ -1014,6 +1042,8 @@ public T getIndex(int index) private static class LongArrayExprEval extends ArrayExprEval { + private static final LongArrayExprEval OF_NULL = new LongArrayExprEval(null); + private LongArrayExprEval(@Nullable Long[] value) { super(value); @@ -1073,6 +1103,8 @@ public Expr toExpr() private static class DoubleArrayExprEval extends ArrayExprEval { + private static final DoubleArrayExprEval OF_NULL = new DoubleArrayExprEval(null); + private DoubleArrayExprEval(@Nullable Double[] value) { super(value); @@ -1132,6 +1164,8 @@ public Expr toExpr() private static class StringArrayExprEval extends ArrayExprEval { + private static final StringArrayExprEval OF_NULL = new StringArrayExprEval(null); + private boolean longValueValid = false; private boolean doubleValueValid = false; private Long[] longValues; diff --git a/core/src/main/java/org/apache/druid/math/expr/InputBindings.java b/core/src/main/java/org/apache/druid/math/expr/InputBindings.java new file mode 100644 index 000000000000..9862bcaaf872 --- /dev/null +++ b/core/src/main/java/org/apache/druid/math/expr/InputBindings.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.math.expr; + +import com.google.common.base.Supplier; + +import javax.annotation.Nullable; +import java.util.Map; + +public class InputBindings +{ + /** + * Create an {@link Expr.InputBindingInspector} backed by a map of binding identifiers to their {@link ExprType} + */ + public static Expr.InputBindingInspector inspectorFromTypeMap(final Map types) + { + return new Expr.InputBindingInspector() + { + @Nullable + @Override + public ExprType getType(String name) + { + return types.get(name); + } + }; + } + + /** + * Create {@link Expr.ObjectBinding} backed by {@link Map} to provide values for identifiers to evaluate {@link Expr} + */ + public static Expr.ObjectBinding withMap(final Map bindings) + { + return bindings::get; + } + + /** + * Create {@link Expr.ObjectBinding} backed by map of {@link Supplier} to provide values for identifiers to evaluate + * {@link Expr} + */ + public static Expr.ObjectBinding withSuppliers(final Map> bindings) + { + return (String name) -> { + Supplier supplier = bindings.get(name); + return supplier == null ? null : supplier.get(); + }; + } +} diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java index b0c923c025fe..0a42bfa5d487 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Parser.java +++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java @@ -601,23 +601,4 @@ public static void validateExpr(Expr expression, Expr.BindingAnalysis bindingAna } } - /** - * Create {@link Expr.ObjectBinding} backed by {@link Map} to provide values for identifiers to evaluate {@link Expr} - */ - public static Expr.ObjectBinding withMap(final Map bindings) - { - return bindings::get; - } - - /** - * Create {@link Expr.ObjectBinding} backed by map of {@link Supplier} to provide values for identifiers to evaluate - * {@link Expr} - */ - public static Expr.ObjectBinding withSuppliers(final Map> bindings) - { - return (String name) -> { - Supplier supplier = bindings.get(name); - return supplier == null ? null : supplier.get(); - }; - } } diff --git a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java index e05cc08f6559..6993ed75c127 100644 --- a/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java +++ b/core/src/main/java/org/apache/druid/math/expr/UnaryOperatorExpr.java @@ -166,7 +166,7 @@ public ExprEval eval(ObjectBinding bindings) } // conforming to other boolean-returning binary operators ExprType retType = ret.type() == ExprType.DOUBLE ? ExprType.DOUBLE : ExprType.LONG; - return ExprEval.of(!ret.asBoolean(), retType); + return ExprEval.ofBoolean(!ret.asBoolean(), retType); } @Nullable diff --git a/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java index 8dea860b33cb..d352bfd67725 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java @@ -48,7 +48,7 @@ public void setup() builder.put("d", new String[] {null}); builder.put("e", new String[] {null, "foo", "bar"}); builder.put("f", new String[0]); - bindings = Parser.withMap(builder.build()); + bindings = InputBindings.withMap(builder.build()); } @Test diff --git a/core/src/test/java/org/apache/druid/math/expr/EvalTest.java b/core/src/test/java/org/apache/druid/math/expr/EvalTest.java index 732e744fd8fe..a80edba8f2db 100644 --- a/core/src/test/java/org/apache/druid/math/expr/EvalTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/EvalTest.java @@ -51,7 +51,7 @@ private ExprEval eval(String x, Expr.ObjectBinding bindings) @Test public void testDoubleEval() { - Expr.ObjectBinding bindings = Parser.withMap(ImmutableMap.of("x", 2.0d)); + Expr.ObjectBinding bindings = InputBindings.withMap(ImmutableMap.of("x", 2.0d)); Assert.assertEquals(2.0, evalDouble("x", bindings), 0.0001); Assert.assertEquals(2.0, evalDouble("\"x\"", bindings), 0.0001); Assert.assertEquals(304.0, evalDouble("300 + \"x\" * 2", bindings), 0.0001); @@ -89,7 +89,7 @@ public void testDoubleEval() @Test public void testLongEval() { - Expr.ObjectBinding bindings = Parser.withMap(ImmutableMap.of("x", 9223372036854775807L)); + Expr.ObjectBinding bindings = InputBindings.withMap(ImmutableMap.of("x", 9223372036854775807L)); Assert.assertEquals(9223372036854775807L, evalLong("x", bindings)); Assert.assertEquals(9223372036854775807L, evalLong("\"x\"", bindings)); @@ -147,7 +147,7 @@ public void testLongEval() @Test public void testBooleanReturn() { - Expr.ObjectBinding bindings = Parser.withMap( + Expr.ObjectBinding bindings = InputBindings.withMap( ImmutableMap.of("x", 100L, "y", 100L, "z", 100D, "w", 100D) ); ExprEval eval = Parser.parse("x==y", ExprMacroTable.nil()).eval(bindings); diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java index 1bd423f57fb0..bc283d31844c 100644 --- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -59,7 +59,7 @@ public void setup() .put("a", new String[] {"foo", "bar", "baz", "foobar"}) .put("b", new Long[] {1L, 2L, 3L, 4L, 5L}) .put("c", new Double[] {3.1, 4.2, 5.3}); - bindings = Parser.withMap(builder.build()); + bindings = InputBindings.withMap(builder.build()); } @Test diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java index 51f991f3f8f0..12dcfee0b894 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java @@ -741,7 +741,7 @@ private void validateConstantExpression(String expression, Object expected) Assert.assertEquals( expression, expected, - parsed.eval(Parser.withMap(ImmutableMap.of())).value() + parsed.eval(InputBindings.withMap(ImmutableMap.of())).value() ); final Expr parsedNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false); @@ -749,7 +749,7 @@ private void validateConstantExpression(String expression, Object expected) Assert.assertEquals( expression, expected, - parsedRoundTrip.eval(Parser.withMap(ImmutableMap.of())).value() + parsedRoundTrip.eval(InputBindings.withMap(ImmutableMap.of())).value() ); Assert.assertEquals(parsed.stringify(), parsedRoundTrip.stringify()); } @@ -757,7 +757,7 @@ private void validateConstantExpression(String expression, Object expected) private void validateConstantExpression(String expression, Object[] expected) { Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); - Object evaluated = parsed.eval(Parser.withMap(ImmutableMap.of())).value(); + Object evaluated = parsed.eval(InputBindings.withMap(ImmutableMap.of())).value(); Assert.assertArrayEquals( expression, expected, @@ -770,7 +770,7 @@ private void validateConstantExpression(String expression, Object[] expected) Assert.assertArrayEquals( expression, expected, - (Object[]) roundTrip.eval(Parser.withMap(ImmutableMap.of())).value() + (Object[]) roundTrip.eval(InputBindings.withMap(ImmutableMap.of())).value() ); Assert.assertEquals(parsed.stringify(), roundTrip.stringify()); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java index b2ff468a0103..be8b1004e55b 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java @@ -25,6 +25,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.StringUtils; @@ -33,6 +34,7 @@ import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExprType; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.apache.druid.math.expr.SettableObjectBinding; import org.apache.druid.query.cache.CacheKeyBuilder; @@ -94,6 +96,7 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory Suppliers.memoize(() -> new SettableObjectBinding(2)); private final Supplier finalizeBindings = Suppliers.memoize(() -> new SettableObjectBinding(1)); + private final Supplier finalizeInspector; @JsonCreator public ExpressionLambdaAggregatorFactory( @@ -148,9 +151,15 @@ public ExpressionLambdaAggregatorFactory( this.foldExpression = Parser.lazyParse(foldExpressionString, macroTable); this.combineExpression = Parser.lazyParse(combineExpressionString, macroTable); this.compareExpression = Parser.lazyParse(compareExpressionString, macroTable); + this.finalizeInspector = Suppliers.memoize( + () -> InputBindings.inspectorFromTypeMap( + ImmutableMap.of(FINALIZE_IDENTIFIER, this.initialCombineValue.get().type()) + ) + ); this.finalizeExpression = Parser.lazyParse(finalizeExpressionString, macroTable); this.maxSizeBytes = maxSizeBytes != null ? maxSizeBytes : DEFAULT_MAX_SIZE_BYTES; Preconditions.checkArgument(this.maxSizeBytes.getBytesInInt() >= MIN_SIZE_BYTES); + } @JsonProperty @@ -363,11 +372,11 @@ public ValueType getFinalizedType() Expr finalizeExpr = finalizeExpression.get(); ExprEval initialVal = initialCombineValue.get(); if (finalizeExpr != null) { - ExprEval eval = finalizeExpr.eval(finalizeBindings.get().withBinding(FINALIZE_IDENTIFIER, initialVal)); - // this might be wrong from time to time, but if evaluating the finalizer on the initial value produces null - // then we cannot safely assume the type since non-vectorized expressions might report all null values as string - // typed, so just assume it preserves the initial value - return ExprType.toValueType(eval.value() == null ? initialValue.get().type() : eval.type()); + ExprType type = finalizeExpr.getOutputType(finalizeInspector.get()); + if (type == null) { + type = initialVal.type(); + } + return ExprType.toValueType(type); } return ExprType.toValueType(initialVal.type()); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java index 34cbb16d1636..b05ecbe35089 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java @@ -31,6 +31,7 @@ import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.PostAggregator; @@ -170,7 +171,7 @@ public Object compute(Map values) } ); - return parsed.get().eval(Parser.withMap(finalizedValues)).value(); + return parsed.get().eval(InputBindings.withMap(finalizedValues)).value(); } @Override diff --git a/processing/src/main/java/org/apache/druid/segment/QueryableIndexColumnSelectorFactory.java b/processing/src/main/java/org/apache/druid/segment/QueryableIndexColumnSelectorFactory.java index 1f9b1dac6b92..2f05a2fa2dc8 100644 --- a/processing/src/main/java/org/apache/druid/segment/QueryableIndexColumnSelectorFactory.java +++ b/processing/src/main/java/org/apache/druid/segment/QueryableIndexColumnSelectorFactory.java @@ -190,7 +190,7 @@ public ColumnCapabilities getColumnCapabilities(String columnName) { if (virtualColumns.exists(columnName)) { return virtualColumns.getColumnCapabilities( - baseColumnName -> QueryableIndexStorageAdapter.getColumnCapabilities(index, baseColumnName), + QueryableIndexStorageAdapter.getColumnInspectorForIndex(index), columnName ); } diff --git a/processing/src/main/java/org/apache/druid/segment/QueryableIndexStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/QueryableIndexStorageAdapter.java index cbfe79b84367..84cd638d09a6 100644 --- a/processing/src/main/java/org/apache/druid/segment/QueryableIndexStorageAdapter.java +++ b/processing/src/main/java/org/apache/druid/segment/QueryableIndexStorageAdapter.java @@ -315,15 +315,7 @@ public static ColumnCapabilities getColumnCapabilities(ColumnSelector index, Str public static ColumnInspector getColumnInspectorForIndex(ColumnSelector index) { - return new ColumnInspector() - { - @Nullable - @Override - public ColumnCapabilities getColumnCapabilities(String column) - { - return QueryableIndexStorageAdapter.getColumnCapabilities(index, column); - } - }; + return column -> getColumnCapabilities(index, column); } @Override diff --git a/processing/src/main/java/org/apache/druid/segment/vector/QueryableIndexVectorColumnSelectorFactory.java b/processing/src/main/java/org/apache/druid/segment/vector/QueryableIndexVectorColumnSelectorFactory.java index 6d1b9a89489f..40d2f950dcff 100644 --- a/processing/src/main/java/org/apache/druid/segment/vector/QueryableIndexVectorColumnSelectorFactory.java +++ b/processing/src/main/java/org/apache/druid/segment/vector/QueryableIndexVectorColumnSelectorFactory.java @@ -267,7 +267,7 @@ public ColumnCapabilities getColumnCapabilities(final String columnName) { if (virtualColumns.exists(columnName)) { return virtualColumns.getColumnCapabilities( - baseColumnName -> QueryableIndexStorageAdapter.getColumnCapabilities(index, baseColumnName), + QueryableIndexStorageAdapter.getColumnInspectorForIndex(index), columnName ); } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index d910d7b54f79..fc6ddfff6064 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -26,7 +26,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.expression.ExprUtils; import org.apache.druid.query.extraction.ExtractionFn; @@ -308,7 +308,7 @@ public static Expr.ObjectBinding createBindings( return supplier.get(); }; } else { - return Parser.withSuppliers(suppliers); + return InputBindings.withSuppliers(suppliers); } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java index 5b5c2960af1b..a4143794791a 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java @@ -84,7 +84,8 @@ public void testEqualsAndHashCode() "finalizeExpression", "compareBindings", "combineBindings", - "finalizeBindings" + "finalizeBindings", + "finalizeInspector" ) .verify(); } diff --git a/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java index 1722ad32fe08..54a15393a7c3 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/CaseInsensitiveExprMacroTest.java @@ -23,7 +23,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprType; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.junit.Assert; import org.junit.Test; @@ -38,22 +38,22 @@ public CaseInsensitiveExprMacroTest() public void testErrorZeroArguments() { expectException(IllegalArgumentException.class, "Function[icontains_string] must have 2 arguments"); - eval("icontains_string()", Parser.withMap(ImmutableMap.of())); + eval("icontains_string()", InputBindings.withMap(ImmutableMap.of())); } @Test public void testErrorThreeArguments() { expectException(IllegalArgumentException.class, "Function[icontains_string] must have 2 arguments"); - eval("icontains_string('a', 'b', 'c')", Parser.withMap(ImmutableMap.of())); + eval("icontains_string('a', 'b', 'c')", InputBindings.withMap(ImmutableMap.of())); } @Test public void testMatchSearchLowerCase() { - final ExprEval result = eval("icontains_string(a, 'OBA')", Parser.withMap(ImmutableMap.of("a", "foobar"))); + final ExprEval result = eval("icontains_string(a, 'OBA')", InputBindings.withMap(ImmutableMap.of("a", "foobar"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -61,9 +61,9 @@ public void testMatchSearchLowerCase() @Test public void testMatchSearchUpperCase() { - final ExprEval result = eval("icontains_string(a, 'oba')", Parser.withMap(ImmutableMap.of("a", "FOOBAR"))); + final ExprEval result = eval("icontains_string(a, 'oba')", InputBindings.withMap(ImmutableMap.of("a", "FOOBAR"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -71,9 +71,9 @@ public void testMatchSearchUpperCase() @Test public void testNoMatch() { - final ExprEval result = eval("icontains_string(a, 'bar')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("icontains_string(a, 'bar')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(false, ExprType.LONG).value(), + ExprEval.ofBoolean(false, ExprType.LONG).value(), result.value() ); } @@ -85,9 +85,9 @@ public void testNullSearch() expectException(IllegalArgumentException.class, "Function[icontains_string] substring must be a string literal"); } - final ExprEval result = eval("icontains_string(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("icontains_string(a, null)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -95,9 +95,9 @@ public void testNullSearch() @Test public void testEmptyStringSearch() { - final ExprEval result = eval("icontains_string(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("icontains_string(a, '')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -109,9 +109,9 @@ public void testNullSearchOnEmptyString() expectException(IllegalArgumentException.class, "Function[icontains_string] substring must be a string literal"); } - final ExprEval result = eval("icontains_string(a, null)", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("icontains_string(a, null)", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -119,9 +119,9 @@ public void testNullSearchOnEmptyString() @Test public void testEmptyStringSearchOnEmptyString() { - final ExprEval result = eval("icontains_string(a, '')", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("icontains_string(a, '')", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -135,10 +135,10 @@ public void testNullSearchOnNull() final ExprEval result = eval( "icontains_string(a, null)", - Parser.withSuppliers(ImmutableMap.of("a", () -> null)) + InputBindings.withSuppliers(ImmutableMap.of("a", () -> null)) ); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -146,9 +146,9 @@ public void testNullSearchOnNull() @Test public void testEmptyStringSearchOnNull() { - final ExprEval result = eval("icontains_string(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("icontains_string(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(!NullHandling.sqlCompatible(), ExprType.LONG).value(), + ExprEval.ofBoolean(!NullHandling.sqlCompatible(), ExprType.LONG).value(), result.value() ); } diff --git a/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java index bfbff7d0ab56..decd899a727b 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/ContainsExprMacroTest.java @@ -23,7 +23,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprType; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.junit.Assert; import org.junit.Test; @@ -38,22 +38,22 @@ public ContainsExprMacroTest() public void testErrorZeroArguments() { expectException(IllegalArgumentException.class, "Function[contains_string] must have 2 arguments"); - eval("contains_string()", Parser.withMap(ImmutableMap.of())); + eval("contains_string()", InputBindings.withMap(ImmutableMap.of())); } @Test public void testErrorThreeArguments() { expectException(IllegalArgumentException.class, "Function[contains_string] must have 2 arguments"); - eval("contains_string('a', 'b', 'c')", Parser.withMap(ImmutableMap.of())); + eval("contains_string('a', 'b', 'c')", InputBindings.withMap(ImmutableMap.of())); } @Test public void testMatch() { - final ExprEval result = eval("contains_string(a, 'oba')", Parser.withMap(ImmutableMap.of("a", "foobar"))); + final ExprEval result = eval("contains_string(a, 'oba')", InputBindings.withMap(ImmutableMap.of("a", "foobar"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -61,9 +61,9 @@ public void testMatch() @Test public void testNoMatch() { - final ExprEval result = eval("contains_string(a, 'bar')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("contains_string(a, 'bar')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(false, ExprType.LONG).value(), + ExprEval.ofBoolean(false, ExprType.LONG).value(), result.value() ); } @@ -75,9 +75,9 @@ public void testNullSearch() expectException(IllegalArgumentException.class, "Function[contains_string] substring must be a string literal"); } - final ExprEval result = eval("contains_string(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("contains_string(a, null)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -85,9 +85,9 @@ public void testNullSearch() @Test public void testEmptyStringSearch() { - final ExprEval result = eval("contains_string(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("contains_string(a, '')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -99,9 +99,9 @@ public void testNullSearchOnEmptyString() expectException(IllegalArgumentException.class, "Function[contains_string] substring must be a string literal"); } - final ExprEval result = eval("contains_string(a, null)", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("contains_string(a, null)", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -109,9 +109,9 @@ public void testNullSearchOnEmptyString() @Test public void testEmptyStringSearchOnEmptyString() { - final ExprEval result = eval("contains_string(a, '')", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("contains_string(a, '')", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -123,9 +123,9 @@ public void testNullSearchOnNull() expectException(IllegalArgumentException.class, "Function[contains_string] substring must be a string literal"); } - final ExprEval result = eval("contains_string(a, null)", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("contains_string(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(true, ExprType.LONG).value(), + ExprEval.ofBoolean(true, ExprType.LONG).value(), result.value() ); } @@ -133,9 +133,9 @@ public void testNullSearchOnNull() @Test public void testEmptyStringSearchOnNull() { - final ExprEval result = eval("contains_string(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("contains_string(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( - ExprEval.of(!NullHandling.sqlCompatible(), ExprType.LONG).value(), + ExprEval.ofBoolean(!NullHandling.sqlCompatible(), ExprType.LONG).value(), result.value() ); } diff --git a/processing/src/test/java/org/apache/druid/query/expression/ExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/ExprMacroTest.java index 5f6b59214820..77e17493388a 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/ExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/ExprMacroTest.java @@ -23,6 +23,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.junit.Assert; import org.junit.BeforeClass; @@ -34,7 +35,7 @@ public class ExprMacroTest { private static final String IPV4_STRING = "192.168.0.1"; private static final long IPV4_LONG = 3232235521L; - private static final Expr.ObjectBinding BINDINGS = Parser.withMap( + private static final Expr.ObjectBinding BINDINGS = InputBindings.withMap( ImmutableMap.builder() .put("t", DateTimes.of("2000-02-03T04:05:06").getMillis()) .put("t1", DateTimes.of("2000-02-03T00:00:00").getMillis()) diff --git a/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java index 2f811d788b4f..8d7a322efc05 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/RegexpExtractExprMacroTest.java @@ -22,7 +22,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.junit.Assert; import org.junit.Test; @@ -37,34 +37,34 @@ public RegexpExtractExprMacroTest() public void testErrorZeroArguments() { expectException(IllegalArgumentException.class, "Function[regexp_extract] must have 2 to 3 arguments"); - eval("regexp_extract()", Parser.withMap(ImmutableMap.of())); + eval("regexp_extract()", InputBindings.withMap(ImmutableMap.of())); } @Test public void testErrorFourArguments() { expectException(IllegalArgumentException.class, "Function[regexp_extract] must have 2 to 3 arguments"); - eval("regexp_extract('a', 'b', 'c', 'd')", Parser.withMap(ImmutableMap.of())); + eval("regexp_extract('a', 'b', 'c', 'd')", InputBindings.withMap(ImmutableMap.of())); } @Test public void testMatch() { - final ExprEval result = eval("regexp_extract(a, 'f(.o)')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.o)')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals("foo", result.value()); } @Test public void testMatchGroup0() { - final ExprEval result = eval("regexp_extract(a, 'f(.o)', 0)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.o)', 0)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals("foo", result.value()); } @Test public void testMatchGroup1() { - final ExprEval result = eval("regexp_extract(a, 'f(.o)', 1)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.o)', 1)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals("oo", result.value()); } @@ -72,20 +72,20 @@ public void testMatchGroup1() public void testMatchGroup2() { expectedException.expectMessage("No group 2"); - final ExprEval result = eval("regexp_extract(a, 'f(.o)', 2)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.o)', 2)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); } @Test public void testNoMatch() { - final ExprEval result = eval("regexp_extract(a, 'f(.x)')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, 'f(.x)')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertNull(result.value()); } @Test public void testMatchInMiddle() { - final ExprEval result = eval("regexp_extract(a, '.o$')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, '.o$')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals("oo", result.value()); } @@ -96,14 +96,14 @@ public void testNullPattern() expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal"); } - final ExprEval result = eval("regexp_extract(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, null)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertNull(result.value()); } @Test public void testEmptyStringPattern() { - final ExprEval result = eval("regexp_extract(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_extract(a, '')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals(NullHandling.emptyToNullIfNeeded(""), result.value()); } @@ -111,14 +111,14 @@ public void testEmptyStringPattern() public void testNumericPattern() { expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal"); - eval("regexp_extract(a, 1)", Parser.withMap(ImmutableMap.of("a", "foo"))); + eval("regexp_extract(a, 1)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); } @Test public void testNonLiteralPattern() { expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal"); - eval("regexp_extract(a, a)", Parser.withMap(ImmutableMap.of("a", "foo"))); + eval("regexp_extract(a, a)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); } @Test @@ -128,14 +128,14 @@ public void testNullPatternOnNull() expectException(IllegalArgumentException.class, "Function[regexp_extract] pattern must be a string literal"); } - final ExprEval result = eval("regexp_extract(a, null)", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("regexp_extract(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertNull(result.value()); } @Test public void testEmptyStringPatternOnNull() { - final ExprEval result = eval("regexp_extract(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("regexp_extract(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertNull(result.value()); } } diff --git a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java index b57db64b2327..fb6d99f7c72b 100644 --- a/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java +++ b/processing/src/test/java/org/apache/druid/query/expression/RegexpLikeExprMacroTest.java @@ -22,7 +22,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.Parser; +import org.apache.druid.math.expr.InputBindings; import org.junit.Assert; import org.junit.Test; @@ -37,20 +37,20 @@ public RegexpLikeExprMacroTest() public void testErrorZeroArguments() { expectException(IllegalArgumentException.class, "Function[regexp_like] must have 2 arguments"); - eval("regexp_like()", Parser.withMap(ImmutableMap.of())); + eval("regexp_like()", InputBindings.withMap(ImmutableMap.of())); } @Test public void testErrorThreeArguments() { expectException(IllegalArgumentException.class, "Function[regexp_like] must have 2 arguments"); - eval("regexp_like('a', 'b', 'c')", Parser.withMap(ImmutableMap.of())); + eval("regexp_like('a', 'b', 'c')", InputBindings.withMap(ImmutableMap.of())); } @Test public void testMatch() { - final ExprEval result = eval("regexp_like(a, 'f.o')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_like(a, 'f.o')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -60,7 +60,7 @@ public void testMatch() @Test public void testNoMatch() { - final ExprEval result = eval("regexp_like(a, 'f.x')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_like(a, 'f.x')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( ExprEval.ofLongBoolean(false).value(), result.value() @@ -74,7 +74,7 @@ public void testNullPattern() expectException(IllegalArgumentException.class, "Function[regexp_like] pattern must be a string literal"); } - final ExprEval result = eval("regexp_like(a, null)", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_like(a, null)", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -84,7 +84,7 @@ public void testNullPattern() @Test public void testEmptyStringPattern() { - final ExprEval result = eval("regexp_like(a, '')", Parser.withMap(ImmutableMap.of("a", "foo"))); + final ExprEval result = eval("regexp_like(a, '')", InputBindings.withMap(ImmutableMap.of("a", "foo"))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -98,7 +98,7 @@ public void testNullPatternOnEmptyString() expectException(IllegalArgumentException.class, "Function[regexp_like] pattern must be a string literal"); } - final ExprEval result = eval("regexp_like(a, null)", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("regexp_like(a, null)", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -108,7 +108,7 @@ public void testNullPatternOnEmptyString() @Test public void testEmptyStringPatternOnEmptyString() { - final ExprEval result = eval("regexp_like(a, '')", Parser.withMap(ImmutableMap.of("a", ""))); + final ExprEval result = eval("regexp_like(a, '')", InputBindings.withMap(ImmutableMap.of("a", ""))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -122,7 +122,7 @@ public void testNullPatternOnNull() expectException(IllegalArgumentException.class, "Function[regexp_like] pattern must be a string literal"); } - final ExprEval result = eval("regexp_like(a, null)", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("regexp_like(a, null)", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( ExprEval.ofLongBoolean(true).value(), result.value() @@ -132,7 +132,7 @@ public void testNullPatternOnNull() @Test public void testEmptyStringPatternOnNull() { - final ExprEval result = eval("regexp_like(a, '')", Parser.withSuppliers(ImmutableMap.of("a", () -> null))); + final ExprEval result = eval("regexp_like(a, '')", InputBindings.withSuppliers(ImmutableMap.of("a", () -> null))); Assert.assertEquals( ExprEval.ofLongBoolean(NullHandling.replaceWithDefault()).value(), result.value() diff --git a/server/src/test/java/org/apache/druid/query/expression/LookupExprMacroTest.java b/server/src/test/java/org/apache/druid/query/expression/LookupExprMacroTest.java index f23afaa99bd3..ce8c9e3b41a6 100644 --- a/server/src/test/java/org/apache/druid/query/expression/LookupExprMacroTest.java +++ b/server/src/test/java/org/apache/druid/query/expression/LookupExprMacroTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; @@ -30,7 +31,7 @@ public class LookupExprMacroTest extends InitializedNullHandlingTest { - private static final Expr.ObjectBinding BINDINGS = Parser.withMap( + private static final Expr.ObjectBinding BINDINGS = InputBindings.withMap( ImmutableMap.builder() .put("x", "foo") .build() diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java index ebd3ecdf2799..dc56fe7bd711 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java @@ -31,6 +31,7 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.data.input.MapBasedRow; import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.InputBindings; import org.apache.druid.math.expr.Parser; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.ValueMatcher; @@ -246,7 +247,7 @@ void testExpression( Assert.assertEquals("Expression for: " + rexNode, expectedExpression, expression); ExprEval result = Parser.parse(expression.getExpression(), PLANNER_CONTEXT.getExprMacroTable()) - .eval(Parser.withMap(bindings)); + .eval(InputBindings.withMap(bindings)); Assert.assertEquals("Result for: " + rexNode, expectedResult, result.value()); }