From c82fcfea67364b54c260458e6492005a1f33738a Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Thu, 20 May 2021 21:21:59 -0700 Subject: [PATCH 1/4] bitwise aggregators, better nulls for expression agg --- .../org/apache/druid/math/expr/ExprEval.java | 3 +- docs/querying/sql.md | 3 + .../ExpressionLambdaAggregator.java | 12 +- .../ExpressionLambdaAggregatorFactory.java | 17 + .../ExpressionLambdaBufferAggregator.java | 15 + ...ExpressionLambdaAggregatorFactoryTest.java | 24 ++ .../query/groupby/GroupByQueryRunnerTest.java | 6 + .../timeseries/TimeseriesQueryRunnerTest.java | 5 + .../druid/query/topn/TopNQueryRunnerTest.java | 4 + .../builtin/ArraySqlAggregator.java | 6 +- .../builtin/BitwiseSqlAggregator.java | 189 +++++++++ .../calcite/planner/DruidOperatorTable.java | 4 + .../sql/calcite/CalciteArraysQueryTest.java | 54 ++- .../druid/sql/calcite/CalciteQueryTest.java | 390 +++++++++++++++++- 14 files changed, 701 insertions(+), 31 deletions(-) create mode 100644 sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 38bb045f5ef9..3df4bde42b32 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -38,6 +38,7 @@ */ public abstract class ExprEval { + private static final byte TYPE_MASK = 0x0F; private static final int NULL_LENGTH = -1; /** @@ -48,7 +49,7 @@ public abstract class ExprEval public static ExprEval deserialize(ByteBuffer buffer, int position) { // | expression type (byte) | expression bytes | - ExprType type = ExprType.fromByte(buffer.get(position)); + ExprType type = ExprType.fromByte((byte) (buffer.get(position) & TYPE_MASK)); int offset = position + 1; switch (type) { case LONG: diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 6f837de8ec12..b88099c297e7 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -360,6 +360,9 @@ Only the COUNT and ARRAY_AGG aggregations can accept the DISTINCT keyword. |`GROUPING(expr, expr...)`|Returns a number to indicate which groupBy dimension is included in a row, when using `GROUPING SETS`. Refer to [additional documentation](aggregations.md#grouping-aggregator) on how to infer this number.|N/A| |`ARRAY_AGG(expr, [size])`|Collects all values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes). Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.|`null`| |`ARRAY_AGG(DISTINCT expr, [size])`|Collects all distinct values of `expr` into an ARRAY, including null values, with `size` in bytes limit on aggregation size (default of 1024 bytes) per aggregate. Use of `ORDER BY` within the `ARRAY_AGG` expression is not currently supported, and the ordering of results within the output array may vary depending on processing order.|`null`| +|`BIT_AND(expr)`|Performs a bitwise AND operation on all input values.|`null` if `druid.generic.useDefaultValueForNull=false`, otherwise `0`| +|`BIT_OR(expr)`|Performs a bitwise OR operation on all input values.|`null` if `druid.generic.useDefaultValueForNull=false`, otherwise `0`| +|`BIT_XOR(expr)`|Performs a bitwise XOR operation on all input values.|`null` if `druid.generic.useDefaultValueForNull=false`, otherwise `0`| For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.md#approx). diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java index 59bd6f21507a..a5591b3a2a29 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java @@ -29,12 +29,19 @@ public class ExpressionLambdaAggregator implements Aggregator private final Expr lambda; private final ExpressionLambdaAggregatorInputBindings bindings; private final int maxSizeBytes; + private boolean uninitializedNullValue; - public ExpressionLambdaAggregator(Expr lambda, ExpressionLambdaAggregatorInputBindings bindings, int maxSizeBytes) + public ExpressionLambdaAggregator( + final Expr lambda, + final ExpressionLambdaAggregatorInputBindings bindings, + final boolean initiallyNull, + final int maxSizeBytes + ) { this.lambda = lambda; this.bindings = bindings; this.maxSizeBytes = maxSizeBytes; + this.uninitializedNullValue = initiallyNull; } @Override @@ -43,13 +50,14 @@ public void aggregate() final ExprEval eval = lambda.eval(bindings); ExprEval.estimateAndCheckMaxBytes(eval, maxSizeBytes); bindings.accumulate(eval); + uninitializedNullValue = false; } @Nullable @Override public Object get() { - return bindings.getAccumulator().value(); + return uninitializedNullValue ? null : bindings.getAccumulator().value(); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java index e40000df3a83..5f2dc2946227 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java @@ -27,6 +27,7 @@ import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.guava.Comparators; @@ -74,6 +75,7 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory private final String foldExpressionString; private final String initialValueExpressionString; private final String initialCombineValueExpressionString; + private final boolean initiallyNull; private final String combineExpressionString; @Nullable @@ -105,6 +107,7 @@ public ExpressionLambdaAggregatorFactory( @JsonProperty("accumulatorIdentifier") @Nullable final String accumulatorIdentifier, @JsonProperty("initialValue") final String initialValue, @JsonProperty("initialCombineValue") @Nullable final String initialCombineValue, + @JsonProperty("initiallyNull") @Nullable final Boolean initiallyNull, @JsonProperty("fold") final String foldExpression, @JsonProperty("combine") @Nullable final String combineExpression, @JsonProperty("compare") @Nullable final String compareExpression, @@ -121,6 +124,7 @@ public ExpressionLambdaAggregatorFactory( this.initialValueExpressionString = initialValue; this.initialCombineValueExpressionString = initialCombineValue == null ? initialValue : initialCombineValue; + this.initiallyNull = initiallyNull == null ? NullHandling.sqlCompatible() : initiallyNull; this.foldExpressionString = foldExpression; if (combineExpression != null) { this.combineExpressionString = combineExpression; @@ -195,6 +199,12 @@ public String getInitialCombineValueExpressionString() return initialCombineValueExpressionString; } + @JsonProperty("initiallyNull") + public boolean getInitiallyNull() + { + return initiallyNull; + } + @JsonProperty("fold") public String getFoldExpressionString() { @@ -249,6 +259,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) return new ExpressionLambdaAggregator( thePlan.getExpression(), thePlan.getBindings(), + initiallyNull, maxSizeBytes.getBytesInInt() ); } @@ -261,6 +272,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) thePlan.getExpression(), thePlan.getInitialValue(), thePlan.getBindings(), + initiallyNull, maxSizeBytes.getBytesInInt() ); } @@ -329,6 +341,7 @@ public AggregatorFactory getCombiningFactory() accumulatorId, initialValueExpressionString, initialCombineValueExpressionString, + initiallyNull, foldExpressionString, combineExpressionString, compareExpressionString, @@ -348,6 +361,7 @@ public List getRequiredColumns() accumulatorId, initialValueExpressionString, initialCombineValueExpressionString, + initiallyNull, foldExpressionString, combineExpressionString, compareExpressionString, @@ -407,6 +421,7 @@ public boolean equals(Object o) && foldExpressionString.equals(that.foldExpressionString) && initialValueExpressionString.equals(that.initialValueExpressionString) && initialCombineValueExpressionString.equals(that.initialCombineValueExpressionString) + && initiallyNull == that.initiallyNull && combineExpressionString.equals(that.combineExpressionString) && Objects.equals(compareExpressionString, that.compareExpressionString) && Objects.equals(finalizeExpressionString, that.finalizeExpressionString); @@ -422,6 +437,7 @@ public int hashCode() foldExpressionString, initialValueExpressionString, initialCombineValueExpressionString, + initiallyNull, combineExpressionString, compareExpressionString, finalizeExpressionString, @@ -439,6 +455,7 @@ public String toString() ", foldExpressionString='" + foldExpressionString + '\'' + ", initialValueExpressionString='" + initialValueExpressionString + '\'' + ", initialCombineValueExpressionString='" + initialCombineValueExpressionString + '\'' + + ", nullUnlessAggregated='" + initiallyNull + '\'' + ", combineExpressionString='" + combineExpressionString + '\'' + ", compareExpressionString='" + compareExpressionString + '\'' + ", finalizeExpressionString='" + finalizeExpressionString + '\'' + diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java index 357dd4b7d6bc..c0a2b7629988 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java @@ -27,21 +27,27 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator { + private static final short NOT_AGGREGATED_BIT = 1 << 7; + private static final short IS_AGGREGATED_MASK = 0x3F; private final Expr lambda; private final ExprEval initialValue; private final ExpressionLambdaAggregatorInputBindings bindings; private final int maxSizeBytes; + private final boolean initiallyNull; + public ExpressionLambdaBufferAggregator( Expr lambda, ExprEval initialValue, ExpressionLambdaAggregatorInputBindings bindings, + boolean initiallyNull, int maxSizeBytes ) { this.lambda = lambda; this.initialValue = initialValue; this.bindings = bindings; + this.initiallyNull = initiallyNull; this.maxSizeBytes = maxSizeBytes; } @@ -49,6 +55,10 @@ public ExpressionLambdaBufferAggregator( public void init(ByteBuffer buf, int position) { ExprEval.serialize(buf, position, initialValue, maxSizeBytes); + // set a bit to indicate we haven't aggregated on top of expression type (not going to lie this could be nicer) + if (initiallyNull) { + buf.put(position, (byte) (buf.get(position) | NOT_AGGREGATED_BIT)); + } } @Override @@ -58,12 +68,17 @@ public void aggregate(ByteBuffer buf, int position) bindings.setAccumulator(acc); ExprEval newAcc = lambda.eval(bindings); ExprEval.serialize(buf, position, newAcc, maxSizeBytes); + // scrub not aggregated bit + buf.put(position, (byte) (buf.get(position) & IS_AGGREGATED_MASK)); } @Nullable @Override public Object get(ByteBuffer buf, int position) { + if (initiallyNull && (buf.get(position) & NOT_AGGREGATED_BIT) != 0) { + return null; + } return ExprEval.deserialize(buf, position).value(); } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java index a4143794791a..f4e8a1485970 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java @@ -58,6 +58,7 @@ public void testSerde() throws IOException "customAccumulator", "0.0", "10.0", + true, "customAccumulator + some_column + some_other_column", "customAccumulator + expr_agg_name", "if (o1 > o2, if (o1 == o2, 0, 1), -1)", @@ -102,6 +103,7 @@ public void testInitialValueMustBeConstant() null, "x + y", null, + true, "__acc + some_column + some_other_column", "__acc + expr_agg_name", null, @@ -125,6 +127,7 @@ public void testInitialCombineValueMustBeConstant() null, "0.0", "x + y", + true, "__acc + some_column + some_other_column", "__acc + expr_agg_name", null, @@ -145,6 +148,7 @@ public void testSingleInputCombineExpressionIsOptional() null, "0", null, + true, "__acc + x", null, null, @@ -165,6 +169,7 @@ public void testFinalizeCanDo() null, "0", null, + true, "__acc + x", null, null, @@ -185,6 +190,7 @@ public void testFinalizeCanDoArrays() null, "0", null, + true, "array_set_add(__acc, x)", "array_set_add_all(__acc, expr_agg_name)", null, @@ -206,6 +212,7 @@ public void testStringType() null, "''", "''", + true, "concat(__acc, some_column, some_other_column)", "concat(__acc, expr_agg_name)", null, @@ -228,6 +235,7 @@ public void testLongType() null, "0", null, + null, "__acc + some_column + some_other_column", "__acc + expr_agg_name", null, @@ -250,6 +258,7 @@ public void testDoubleType() null, "0.0", null, + null, "__acc + some_column + some_other_column", "__acc + expr_agg_name", null, @@ -272,6 +281,7 @@ public void testStringArrayType() null, "''", "[]", + null, "concat(__acc, some_column, some_other_column)", "array_set_add(__acc, expr_agg_name)", null, @@ -294,6 +304,7 @@ public void testStringArrayTypeFinalized() null, "''", "[]", + null, "concat(__acc, some_column, some_other_column)", "array_set_add(__acc, expr_agg_name)", null, @@ -316,6 +327,7 @@ public void testLongArrayType() null, "0", "[]", + null, "__acc + some_column + some_other_column", "array_set_add(__acc, expr_agg_name)", null, @@ -338,6 +350,7 @@ public void testLongArrayTypeFinalized() null, "0", "[]", + null, "__acc + some_column + some_other_column", "array_set_add(__acc, expr_agg_name)", null, @@ -360,6 +373,7 @@ public void testDoubleArrayType() null, "0.0", "[]", + null, "__acc + some_column + some_other_column", "array_set_add(__acc, expr_agg_name)", null, @@ -382,6 +396,7 @@ public void testDoubleArrayTypeFinalized() null, "0.0", "[]", + null, "__acc + some_column + some_other_column", "array_set_add(__acc, expr_agg_name)", null, @@ -410,6 +425,7 @@ public void testResultArraySignature() null, "''", "''", + null, "concat(__acc, some_column, some_other_column)", "concat(__acc, string_expr)", null, @@ -423,6 +439,7 @@ public void testResultArraySignature() null, "0.0", null, + null, "__acc + some_column + some_other_column", "__acc + double_expr", null, @@ -436,6 +453,7 @@ public void testResultArraySignature() null, "0", null, + null, "__acc + some_column + some_other_column", "__acc + long_expr", null, @@ -449,6 +467,7 @@ public void testResultArraySignature() null, "[]", "[]", + null, "array_set_add(__acc, concat(some_column, some_other_column))", "array_set_add_all(__acc, string_array_expr)", null, @@ -462,6 +481,7 @@ public void testResultArraySignature() null, "0.0", "[]", + null, "__acc + some_column + some_other_column", "array_set_add(__acc, double_array)", null, @@ -475,6 +495,7 @@ public void testResultArraySignature() null, "0", "[]", + null, "__acc + some_column + some_other_column", "array_set_add(__acc, long_array_expr)", null, @@ -488,6 +509,7 @@ public void testResultArraySignature() null, "''", "[]", + null, "concat(__acc, some_column, some_other_column)", "array_set_add(__acc, string_array_expr)", null, @@ -501,6 +523,7 @@ public void testResultArraySignature() null, "0.0", "[]", + null, "__acc + some_column + some_other_column", "array_set_add(__acc, double_array)", null, @@ -514,6 +537,7 @@ public void testResultArraySignature() null, "0", "[]", + null, "__acc + some_column + some_other_column", "array_set_add(__acc, long_array_expr)", null, diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java index b20fe0059e58..99853a189f08 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerTest.java @@ -11228,6 +11228,7 @@ public void testGroupByWithExpressionAggregator() null, "0", null, + false, "__acc + 1", "__acc + rows", null, @@ -11241,6 +11242,7 @@ public void testGroupByWithExpressionAggregator() null, "0.0", null, + null, "__acc + index", null, null, @@ -11463,6 +11465,7 @@ public void testGroupByWithExpressionAggregatorWithArrays() null, "0", null, + false, "__acc + 1", "__acc + rows", null, @@ -11476,6 +11479,7 @@ public void testGroupByWithExpressionAggregatorWithArrays() null, "0.0", null, + true, "__acc + index", null, null, @@ -11489,6 +11493,7 @@ public void testGroupByWithExpressionAggregatorWithArrays() "acc", "[]", null, + null, "array_set_add(acc, market)", "array_set_add_all(acc, array_agg_distinct)", null, @@ -11747,6 +11752,7 @@ public void testGroupByExpressionAggregatorArrayMultiValue() "acc", "[]", null, + null, "array_set_add(acc, placementish)", "array_set_add_all(acc, array_agg_distinct)", null, diff --git a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java index c1de43c812b7..63e5aa7ca27d 100644 --- a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java @@ -3030,6 +3030,7 @@ public void testTimeseriesWithExpressionAggregator() null, "0", null, + false, "__acc + 1", "__acc + diy_count", null, @@ -3043,6 +3044,7 @@ public void testTimeseriesWithExpressionAggregator() null, "0.0", null, + null, "__acc + index", null, null, @@ -3056,6 +3058,7 @@ public void testTimeseriesWithExpressionAggregator() null, "0.0", "[]", + null, "__acc + index", "array_concat(__acc, diy_decomposed_sum)", null, @@ -3069,6 +3072,7 @@ public void testTimeseriesWithExpressionAggregator() "acc", "[]", null, + null, "array_set_add(acc, market)", "array_set_add_all(acc, array_agg_distinct)", null, @@ -3132,6 +3136,7 @@ public void testTimeseriesWithExpressionAggregatorTooBig() "acc", "[]", null, + null, "array_set_add(acc, market)", "array_set_add_all(acc, array_agg_distinct)", null, diff --git a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java index eb8709d4822b..15f23e4fb416 100644 --- a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java +++ b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryRunnerTest.java @@ -5986,6 +5986,7 @@ public void testExpressionAggregator() null, "0", null, + false, "__acc + 1", "__acc + diy_count", null, @@ -5999,6 +6000,7 @@ public void testExpressionAggregator() null, "0.0", null, + null, "__acc + index", null, null, @@ -6012,6 +6014,7 @@ public void testExpressionAggregator() null, "0.0", "[]", + null, "__acc + index", "array_concat(__acc, diy_decomposed_sum)", null, @@ -6025,6 +6028,7 @@ public void testExpressionAggregator() "acc", "[]", null, + null, "array_set_add(acc, quality)", "array_set_add_all(acc, array_agg_distinct)", "if(array_length(o1) > array_length(o2), 1, if (array_length(o1) == array_length(o2), 0, -1))", diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java index 0f80daa91fa0..6c0b4d036efd 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -148,10 +148,11 @@ public Aggregation toDruidAggregation( null, initialvalue, null, + true, StringUtils.format("array_set_add(\"__acc\", \"%s\")", fieldName), StringUtils.format("array_set_add_all(\"__acc\", \"%s\")", name), null, - "if(array_length(o) == 0, null, o)", + null, maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, macroTable ) @@ -164,10 +165,11 @@ public Aggregation toDruidAggregation( null, initialvalue, null, + true, StringUtils.format("array_append(\"__acc\", \"%s\")", fieldName), StringUtils.format("array_concat(\"__acc\", \"%s\")", name), null, - "if(array_length(o) == 0, null, o)", + null, maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, macroTable ) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java new file mode 100644 index 000000000000..40a4a61ee239 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.aggregation.builtin; + +import com.google.common.collect.ImmutableSet; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.Optionality; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.math.expr.ExprMacroTable; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; +import org.apache.druid.segment.VirtualColumn; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.sql.calcite.aggregation.Aggregation; +import org.apache.druid.sql.calcite.aggregation.SqlAggregator; +import org.apache.druid.sql.calcite.expression.DruidExpression; +import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; + +import javax.annotation.Nullable; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class BitwiseSqlAggregator implements SqlAggregator +{ + private static final SqlAggFunction XOR_FUNCTION = new BitwiseXorSqlAggFunction(); + + public enum Op + { + AND { + @Override + SqlAggFunction getCalciteFunction() + { + return SqlStdOperatorTable.BIT_AND; + } + + @Override + String getDruidFunction() + { + return "bitwiseAnd"; + } + }, + OR { + @Override + SqlAggFunction getCalciteFunction() + { + return SqlStdOperatorTable.BIT_OR; + } + + @Override + String getDruidFunction() + { + return "bitwiseOr"; + } + }, + XOR { + @Override + SqlAggFunction getCalciteFunction() + { + // newer versions of calcite have this built-in so someday we can drop this... + return XOR_FUNCTION; + } + + @Override + String getDruidFunction() + { + return "bitwiseXor"; + } + }; + + abstract SqlAggFunction getCalciteFunction(); + abstract String getDruidFunction(); + }; + + private final Op op; + + public BitwiseSqlAggregator(Op op) + { + this.op = op; + } + + @Override + public SqlAggFunction calciteFunction() + { + return op.getCalciteFunction(); + } + + @Nullable + @Override + public Aggregation toDruidAggregation( + PlannerContext plannerContext, + RowSignature rowSignature, + VirtualColumnRegistry virtualColumnRegistry, + RexBuilder rexBuilder, + String name, + AggregateCall aggregateCall, + Project project, + List existingAggregations, + boolean finalizeAggregations + ) + { + final List arguments = aggregateCall + .getArgList() + .stream() + .map(i -> Expressions.fromFieldAccess(rowSignature, project, i)) + .map(rexNode -> Expressions.toDruidExpression(plannerContext, rowSignature, rexNode)) + .collect(Collectors.toList()); + + if (arguments.stream().anyMatch(Objects::isNull)) { + return null; + } + + final DruidExpression arg = arguments.get(0); + final ExprMacroTable macroTable = plannerContext.getExprMacroTable(); + + final String fieldName; + if (arg.isDirectColumnAccess()) { + fieldName = arg.getDirectColumn(); + } else { + VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, ValueType.LONG); + fieldName = vc.getOutputName(); + } + + return Aggregation.create( + new ExpressionLambdaAggregatorFactory( + name, + ImmutableSet.of(fieldName), + null, + "0", + null, + null, + StringUtils.format("%s(\"__acc\", \"%s\")", op.getDruidFunction(), fieldName), + null, + null, + null, + null, + macroTable + ) + ); + } + + private static class BitwiseXorSqlAggFunction extends SqlAggFunction + { + BitwiseXorSqlAggFunction() + { + super( + "BIT_XOR", + null, + SqlKind.OTHER_FUNCTION, + ReturnTypes.explicit(SqlTypeName.BIGINT), + InferTypes.ANY_NULLABLE, + OperandTypes.EXACT_NUMERIC, + SqlFunctionCategory.NUMERIC, + false, + false, + Optionality.IGNORED + ); + } + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index 8b99dda7642e..3e9288a775fb 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -36,6 +36,7 @@ import org.apache.druid.sql.calcite.aggregation.builtin.ApproxCountDistinctSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.ArraySqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.AvgSqlAggregator; +import org.apache.druid.sql.calcite.aggregation.builtin.BitwiseSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.CountSqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.EarliestLatestAnySqlAggregator; import org.apache.druid.sql.calcite.aggregation.builtin.GroupingSqlAggregator; @@ -136,6 +137,9 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new SumZeroSqlAggregator()) .add(new GroupingSqlAggregator()) .add(new ArraySqlAggregator()) + .add(new BitwiseSqlAggregator(BitwiseSqlAggregator.Op.AND)) + .add(new BitwiseSqlAggregator(BitwiseSqlAggregator.Op.OR)) + .add(new BitwiseSqlAggregator(BitwiseSqlAggregator.Op.XOR)) .build(); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index aded4f8c2836..00aebaf2b1ca 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -1156,10 +1156,11 @@ public void testArrayAgg() throws Exception "__acc", "[]", "[]", + true, "array_append(\"__acc\", \"dim1\")", "array_concat(\"__acc\", \"a0\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -1169,10 +1170,11 @@ public void testArrayAgg() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a1\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -1183,10 +1185,11 @@ public void testArrayAgg() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a2\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -1228,10 +1231,11 @@ public void testArrayAggMultiValue() throws Exception "__acc", "[]", "[]", + true, "array_append(\"__acc\", \"dim3\")", "array_concat(\"__acc\", \"a0\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -1241,10 +1245,11 @@ public void testArrayAggMultiValue() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"dim3\")", "array_set_add_all(\"__acc\", \"a1\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -1280,10 +1285,11 @@ public void testArrayAggNumeric() throws Exception "__acc", "[]", "[]", + true, "array_append(\"__acc\", \"l1\")", "array_concat(\"__acc\", \"a0\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -1293,10 +1299,11 @@ public void testArrayAggNumeric() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"l1\")", "array_set_add_all(\"__acc\", \"a1\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -1306,10 +1313,11 @@ public void testArrayAggNumeric() throws Exception "__acc", "[]", "[]", + true, "array_append(\"__acc\", \"d1\")", "array_concat(\"__acc\", \"a2\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -1319,10 +1327,11 @@ public void testArrayAggNumeric() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"d1\")", "array_set_add_all(\"__acc\", \"a3\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -1332,10 +1341,11 @@ public void testArrayAggNumeric() throws Exception "__acc", "[]", "[]", + true, "array_append(\"__acc\", \"f1\")", "array_concat(\"__acc\", \"a4\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -1345,10 +1355,11 @@ public void testArrayAggNumeric() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"f1\")", "array_set_add_all(\"__acc\", \"a5\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -1399,10 +1410,11 @@ public void testArrayAggToString() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a0\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -1440,10 +1452,11 @@ public void testArrayAggExpression() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"v0\")", "array_set_add_all(\"__acc\", \"a0\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -1478,10 +1491,11 @@ public void testArrayAggMaxBytes() throws Exception "__acc", "[]", "[]", + true, "array_append(\"__acc\", \"l1\")", "array_concat(\"__acc\", \"a0\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(128), TestExprMacroTable.INSTANCE ), @@ -1491,10 +1505,11 @@ public void testArrayAggMaxBytes() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"l1\")", "array_set_add_all(\"__acc\", \"a1\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(128), TestExprMacroTable.INSTANCE ) @@ -1557,10 +1572,11 @@ public void testArrayAggAsArrayFromJoin() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a0\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -1648,10 +1664,11 @@ public void testArrayAggArrayContainsSubquery() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a0\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -1724,10 +1741,11 @@ public void testArrayAggGroupByArrayContainsSubquery() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"dim1\")", "array_set_add_all(\"__acc\", \"a0\")", null, - "if(array_length(o) == 0, null, o)", + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 48eb30318fa1..41357694efdd 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -94,6 +94,7 @@ import org.apache.druid.query.filter.NotDimFilter; import org.apache.druid.query.filter.RegexDimFilter; import org.apache.druid.query.filter.SelectorDimFilter; +import org.apache.druid.query.filter.TrueDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQuery.Builder; import org.apache.druid.query.groupby.GroupByQueryConfig; @@ -12650,7 +12651,10 @@ public void testTimeseriesEmptyResultsAggregatorDefaultValuesNonVectorized() thr + " EARLIEST(l1),\n" + " LATEST(dim1, 1024),\n" + " LATEST(l1),\n" - + " ARRAY_AGG(DISTINCT dim3)\n" + + " ARRAY_AGG(DISTINCT dim3),\n" + + " BIT_AND(l1),\n" + + " BIT_OR(l1),\n" + + " BIT_XOR(l1)\n" + "FROM druid.numfoo WHERE dim2 = 0", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -12672,10 +12676,53 @@ public void testTimeseriesEmptyResultsAggregatorDefaultValuesNonVectorized() thr "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"dim3\")", "array_set_add_all(\"__acc\", \"a6\")", null, - "if(array_length(o) == 0, null, o)", + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a7", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseAnd(\"__acc\", \"l1\")", + "bitwiseAnd(\"__acc\", \"a7\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a8", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseOr(\"__acc\", \"l1\")", + "bitwiseOr(\"__acc\", \"a8\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a9", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseXor(\"__acc\", \"l1\")", + "bitwiseXor(\"__acc\", \"a9\")", + null, + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ) @@ -12686,8 +12733,8 @@ public void testTimeseriesEmptyResultsAggregatorDefaultValuesNonVectorized() thr ), ImmutableList.of( useDefault - ? new Object[]{"", 0L, "", 0L, "", 0L, null} - : new Object[]{null, null, null, null, null, null, null} + ? new Object[]{"", 0L, "", 0L, "", 0L, null, 0L, 0L, 0L} + : new Object[]{null, null, null, null, null, null, null, null, null, null} ) ); } @@ -12828,7 +12875,10 @@ public void testGroupByAggregatorDefaultValuesNonVectorized() throws Exception + " EARLIEST(l1) FILTER(WHERE dim1 = 'nonexistent'),\n" + " LATEST(dim1, 1024) FILTER(WHERE dim1 = 'nonexistent'),\n" + " LATEST(l1) FILTER(WHERE dim1 = 'nonexistent'),\n" - + " ARRAY_AGG(DISTINCT dim3) FILTER(WHERE dim1 = 'nonexistent')" + + " ARRAY_AGG(DISTINCT dim3) FILTER(WHERE dim1 = 'nonexistent'),\n" + + " BIT_AND(l1) FILTER(WHERE dim1 = 'nonexistent'),\n" + + " BIT_OR(l1) FILTER(WHERE dim1 = 'nonexistent'),\n" + + " BIT_XOR(l1) FILTER(WHERE dim1 = 'nonexistent')\n" + "FROM druid.numfoo WHERE dim2 = 'a' GROUP BY dim2", ImmutableList.of( GroupByQuery.builder() @@ -12871,10 +12921,62 @@ public void testGroupByAggregatorDefaultValuesNonVectorized() throws Exception "__acc", "[]", "[]", + true, "array_set_add(\"__acc\", \"dim3\")", "array_set_add_all(\"__acc\", \"a6\")", null, - "if(array_length(o) == 0, null, o)", + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + selector("dim1", "nonexistent", null) + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a7", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseAnd(\"__acc\", \"l1\")", + "bitwiseAnd(\"__acc\", \"a7\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + selector("dim1", "nonexistent", null) + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a8", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseOr(\"__acc\", \"l1\")", + "bitwiseOr(\"__acc\", \"a8\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + selector("dim1", "nonexistent", null) + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a9", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseXor(\"__acc\", \"l1\")", + "bitwiseXor(\"__acc\", \"a9\")", + null, + null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), @@ -12887,8 +12989,8 @@ public void testGroupByAggregatorDefaultValuesNonVectorized() throws Exception ), ImmutableList.of( useDefault - ? new Object[]{"a", "", 0L, "", 0L, "", 0L, null} - : new Object[]{"a", null, null, null, null, null, null, null} + ? new Object[]{"a", "", 0L, "", 0L, "", 0L, null, 0L, 0L, 0L} + : new Object[]{"a", null, null, null, null, null, null, null, null, null, null} ) ); } @@ -17491,4 +17593,276 @@ public void testEmptyGroupWithOffsetDoesntInfiniteLoop() throws Exception ImmutableList.of() ); } + + @Test + public void testBitwiseAggregatorsTimeseries() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT\n" + + " BIT_AND(l1),\n" + + " BIT_OR(l1),\n" + + " BIT_XOR(l1),\n" + + " BIT_AND(l1) FILTER(WHERE l1 IS NOT NULL),\n" + + " BIT_OR(l1) FILTER(WHERE l1 IS NOT NULL),\n" + + " BIT_XOR(l1) FILTER(WHERE l1 IS NOT NULL)\n" + + "FROM druid.numfoo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseAnd(\"__acc\", \"l1\")", + "bitwiseAnd(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseOr(\"__acc\", \"l1\")", + "bitwiseOr(\"__acc\", \"a1\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a2", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseXor(\"__acc\", \"l1\")", + "bitwiseXor(\"__acc\", \"a2\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a3", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseAnd(\"__acc\", \"l1\")", + "bitwiseAnd(\"__acc\", \"a3\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a4", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseOr(\"__acc\", \"l1\")", + "bitwiseOr(\"__acc\", \"a4\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a5", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseXor(\"__acc\", \"l1\")", + "bitwiseXor(\"__acc\", \"a5\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + ) + ) + ) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + useDefault + ? new Object[]{0L, 325327L, 325324L, 0L, 325327L, 325324L} + : new Object[]{null, null, null, 0L, 325327L, 325324L} + ) + ); + } + + @Test + public void testBitwiseAggregatorsGroupBy() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT\n" + + " dim2,\n" + + " BIT_AND(l1),\n" + + " BIT_OR(l1),\n" + + " BIT_XOR(l1),\n" + + " BIT_AND(l1) FILTER(WHERE l1 IS NOT NULL),\n" + + " BIT_OR(l1) FILTER(WHERE l1 IS NOT NULL),\n" + + " BIT_XOR(l1) FILTER(WHERE l1 IS NOT NULL)\n" + + "FROM druid.numfoo GROUP BY 1 ORDER BY 6", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("dim2", "_d0", ValueType.STRING)) + .setAggregatorSpecs( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseAnd(\"__acc\", \"l1\")", + "bitwiseAnd(\"__acc\", \"a0\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a1", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseOr(\"__acc\", \"l1\")", + "bitwiseOr(\"__acc\", \"a1\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new ExpressionLambdaAggregatorFactory( + "a2", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseXor(\"__acc\", \"l1\")", + "bitwiseXor(\"__acc\", \"a2\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a3", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseAnd(\"__acc\", \"l1\")", + "bitwiseAnd(\"__acc\", \"a3\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a4", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseOr(\"__acc\", \"l1\")", + "bitwiseOr(\"__acc\", \"a4\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + ), + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a5", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseXor(\"__acc\", \"l1\")", + "bitwiseXor(\"__acc\", \"a5\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + ) + ) + ) + .setLimitSpec( + DefaultLimitSpec.builder() + .orderBy( + new OrderByColumnSpec( + "a4", + Direction.ASCENDING, + StringComparators.NUMERIC + ) + ) + .build() + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + useDefault + ? ImmutableList.of( + new Object[]{"abc", 0L, 0L, 0L, 0L, 0L, 0L}, + new Object[]{"a", 0L, 7L, 7L, 0L, 7L, 7L}, + new Object[]{"", 0L, 325323L, 325323L, 0L, 325323L, 325323L} + ) + : ImmutableList.of( + new Object[]{"abc", null, null, null, null, null, null}, + new Object[]{"", 0L, 0L, 0L, 0L, 0L, 0L}, + new Object[]{"a", null, null, null, 0L, 7L, 7L}, + new Object[]{null, null, null, null, 0L, 325323L, 325323L} + ) + ); + } } From e1544c1b405882bb66ecba227ae45d2c6a617657 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Thu, 20 May 2021 23:40:14 -0700 Subject: [PATCH 2/4] correct behavior --- .../builtin/BitwiseSqlAggregator.java | 32 ++- .../druid/sql/calcite/CalciteQueryTest.java | 246 ++++++------------ 2 files changed, 101 insertions(+), 177 deletions(-) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java index 40a4a61ee239..a6fe93cb2def 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/BitwiseSqlAggregator.java @@ -35,6 +35,9 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; +import org.apache.druid.query.aggregation.FilteredAggregatorFactory; +import org.apache.druid.query.filter.NotDimFilter; +import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ValueType; @@ -151,19 +154,22 @@ public Aggregation toDruidAggregation( } return Aggregation.create( - new ExpressionLambdaAggregatorFactory( - name, - ImmutableSet.of(fieldName), - null, - "0", - null, - null, - StringUtils.format("%s(\"__acc\", \"%s\")", op.getDruidFunction(), fieldName), - null, - null, - null, - null, - macroTable + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + name, + ImmutableSet.of(fieldName), + null, + "0", + null, + null, + StringUtils.format("%s(\"__acc\", \"%s\")", op.getDruidFunction(), fieldName), + null, + null, + null, + null, + macroTable + ), + new NotDimFilter(new SelectorDimFilter(fieldName, null, null)) ) ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 41357694efdd..f2e43311ec9d 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -94,7 +94,6 @@ import org.apache.druid.query.filter.NotDimFilter; import org.apache.druid.query.filter.RegexDimFilter; import org.apache.druid.query.filter.SelectorDimFilter; -import org.apache.druid.query.filter.TrueDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQuery.Builder; import org.apache.druid.query.groupby.GroupByQueryConfig; @@ -12684,47 +12683,56 @@ public void testTimeseriesEmptyResultsAggregatorDefaultValuesNonVectorized() thr new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - new ExpressionLambdaAggregatorFactory( - "a7", - ImmutableSet.of("l1"), - "__acc", - "0", - "0", - NullHandling.sqlCompatible(), - "bitwiseAnd(\"__acc\", \"l1\")", - "bitwiseAnd(\"__acc\", \"a7\")", - null, - null, - new HumanReadableBytes(1024), - TestExprMacroTable.INSTANCE + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a7", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseAnd(\"__acc\", \"l1\")", + "bitwiseAnd(\"__acc\", \"a7\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + not(selector("l1", null, null)) ), - new ExpressionLambdaAggregatorFactory( - "a8", - ImmutableSet.of("l1"), - "__acc", - "0", - "0", - NullHandling.sqlCompatible(), - "bitwiseOr(\"__acc\", \"l1\")", - "bitwiseOr(\"__acc\", \"a8\")", - null, - null, - new HumanReadableBytes(1024), - TestExprMacroTable.INSTANCE + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a8", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseOr(\"__acc\", \"l1\")", + "bitwiseOr(\"__acc\", \"a8\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + not(selector("l1", null, null)) ), - new ExpressionLambdaAggregatorFactory( - "a9", - ImmutableSet.of("l1"), - "__acc", - "0", - "0", - NullHandling.sqlCompatible(), - "bitwiseXor(\"__acc\", \"l1\")", - "bitwiseXor(\"__acc\", \"a9\")", - null, - null, - new HumanReadableBytes(1024), - TestExprMacroTable.INSTANCE + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a9", + ImmutableSet.of("l1"), + "__acc", + "0", + "0", + NullHandling.sqlCompatible(), + "bitwiseXor(\"__acc\", \"l1\")", + "bitwiseXor(\"__acc\", \"a9\")", + null, + null, + new HumanReadableBytes(1024), + TestExprMacroTable.INSTANCE + ), + not(selector("l1", null, null)) ) ) ) @@ -12946,7 +12954,7 @@ public void testGroupByAggregatorDefaultValuesNonVectorized() throws Exception new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - selector("dim1", "nonexistent", null) + and(not(selector("l1", null, null)), selector("dim1", "nonexistent", null)) ), new FilteredAggregatorFactory( new ExpressionLambdaAggregatorFactory( @@ -12963,7 +12971,7 @@ public void testGroupByAggregatorDefaultValuesNonVectorized() throws Exception new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - selector("dim1", "nonexistent", null) + and(not(selector("l1", null, null)), selector("dim1", "nonexistent", null)) ), new FilteredAggregatorFactory( new ExpressionLambdaAggregatorFactory( @@ -12980,7 +12988,7 @@ public void testGroupByAggregatorDefaultValuesNonVectorized() throws Exception new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - selector("dim1", "nonexistent", null) + and(not(selector("l1", null, null)), selector("dim1", "nonexistent", null)) ) ) ) @@ -17602,10 +17610,7 @@ public void testBitwiseAggregatorsTimeseries() throws Exception "SELECT\n" + " BIT_AND(l1),\n" + " BIT_OR(l1),\n" - + " BIT_XOR(l1),\n" - + " BIT_AND(l1) FILTER(WHERE l1 IS NOT NULL),\n" - + " BIT_OR(l1) FILTER(WHERE l1 IS NOT NULL),\n" - + " BIT_XOR(l1) FILTER(WHERE l1 IS NOT NULL)\n" + + " BIT_XOR(l1)\n" + "FROM druid.numfoo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() @@ -17614,98 +17619,56 @@ public void testBitwiseAggregatorsTimeseries() throws Exception .granularity(Granularities.ALL) .aggregators( aggregators( - new ExpressionLambdaAggregatorFactory( - "a0", - ImmutableSet.of("l1"), - "__acc", - "0", - "0", - NullHandling.sqlCompatible(), - "bitwiseAnd(\"__acc\", \"l1\")", - "bitwiseAnd(\"__acc\", \"a0\")", - null, - null, - new HumanReadableBytes(1024), - TestExprMacroTable.INSTANCE - ), - new ExpressionLambdaAggregatorFactory( - "a1", - ImmutableSet.of("l1"), - "__acc", - "0", - "0", - NullHandling.sqlCompatible(), - "bitwiseOr(\"__acc\", \"l1\")", - "bitwiseOr(\"__acc\", \"a1\")", - null, - null, - new HumanReadableBytes(1024), - TestExprMacroTable.INSTANCE - ), - new ExpressionLambdaAggregatorFactory( - "a2", - ImmutableSet.of("l1"), - "__acc", - "0", - "0", - NullHandling.sqlCompatible(), - "bitwiseXor(\"__acc\", \"l1\")", - "bitwiseXor(\"__acc\", \"a2\")", - null, - null, - new HumanReadableBytes(1024), - TestExprMacroTable.INSTANCE - ), new FilteredAggregatorFactory( new ExpressionLambdaAggregatorFactory( - "a3", + "a0", ImmutableSet.of("l1"), "__acc", "0", "0", NullHandling.sqlCompatible(), "bitwiseAnd(\"__acc\", \"l1\")", - "bitwiseAnd(\"__acc\", \"a3\")", + "bitwiseAnd(\"__acc\", \"a0\")", null, null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + not(selector("l1", null, null)) ), new FilteredAggregatorFactory( new ExpressionLambdaAggregatorFactory( - "a4", + "a1", ImmutableSet.of("l1"), "__acc", "0", "0", NullHandling.sqlCompatible(), "bitwiseOr(\"__acc\", \"l1\")", - "bitwiseOr(\"__acc\", \"a4\")", + "bitwiseOr(\"__acc\", \"a1\")", null, null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + not(selector("l1", null, null)) ), new FilteredAggregatorFactory( new ExpressionLambdaAggregatorFactory( - "a5", + "a2", ImmutableSet.of("l1"), "__acc", "0", "0", NullHandling.sqlCompatible(), "bitwiseXor(\"__acc\", \"l1\")", - "bitwiseXor(\"__acc\", \"a5\")", + "bitwiseXor(\"__acc\", \"a2\")", null, null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + not(selector("l1", null, null)) ) ) ) @@ -17714,8 +17677,8 @@ public void testBitwiseAggregatorsTimeseries() throws Exception ), ImmutableList.of( useDefault - ? new Object[]{0L, 325327L, 325324L, 0L, 325327L, 325324L} - : new Object[]{null, null, null, 0L, 325327L, 325324L} + ? new Object[]{0L, 325327L, 325324L} + : new Object[]{0L, 325327L, 325324L} ) ); } @@ -17729,11 +17692,8 @@ public void testBitwiseAggregatorsGroupBy() throws Exception + " dim2,\n" + " BIT_AND(l1),\n" + " BIT_OR(l1),\n" - + " BIT_XOR(l1),\n" - + " BIT_AND(l1) FILTER(WHERE l1 IS NOT NULL),\n" - + " BIT_OR(l1) FILTER(WHERE l1 IS NOT NULL),\n" - + " BIT_XOR(l1) FILTER(WHERE l1 IS NOT NULL)\n" - + "FROM druid.numfoo GROUP BY 1 ORDER BY 6", + + " BIT_XOR(l1)\n" + + "FROM druid.numfoo GROUP BY 1 ORDER BY 4", ImmutableList.of( GroupByQuery.builder() .setDataSource(CalciteTests.DATASOURCE3) @@ -17742,98 +17702,56 @@ public void testBitwiseAggregatorsGroupBy() throws Exception .setDimensions(new DefaultDimensionSpec("dim2", "_d0", ValueType.STRING)) .setAggregatorSpecs( aggregators( - new ExpressionLambdaAggregatorFactory( - "a0", - ImmutableSet.of("l1"), - "__acc", - "0", - "0", - NullHandling.sqlCompatible(), - "bitwiseAnd(\"__acc\", \"l1\")", - "bitwiseAnd(\"__acc\", \"a0\")", - null, - null, - new HumanReadableBytes(1024), - TestExprMacroTable.INSTANCE - ), - new ExpressionLambdaAggregatorFactory( - "a1", - ImmutableSet.of("l1"), - "__acc", - "0", - "0", - NullHandling.sqlCompatible(), - "bitwiseOr(\"__acc\", \"l1\")", - "bitwiseOr(\"__acc\", \"a1\")", - null, - null, - new HumanReadableBytes(1024), - TestExprMacroTable.INSTANCE - ), - new ExpressionLambdaAggregatorFactory( - "a2", - ImmutableSet.of("l1"), - "__acc", - "0", - "0", - NullHandling.sqlCompatible(), - "bitwiseXor(\"__acc\", \"l1\")", - "bitwiseXor(\"__acc\", \"a2\")", - null, - null, - new HumanReadableBytes(1024), - TestExprMacroTable.INSTANCE - ), new FilteredAggregatorFactory( new ExpressionLambdaAggregatorFactory( - "a3", + "a0", ImmutableSet.of("l1"), "__acc", "0", "0", NullHandling.sqlCompatible(), "bitwiseAnd(\"__acc\", \"l1\")", - "bitwiseAnd(\"__acc\", \"a3\")", + "bitwiseAnd(\"__acc\", \"a0\")", null, null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + not(selector("l1", null, null)) ), new FilteredAggregatorFactory( new ExpressionLambdaAggregatorFactory( - "a4", + "a1", ImmutableSet.of("l1"), "__acc", "0", "0", NullHandling.sqlCompatible(), "bitwiseOr(\"__acc\", \"l1\")", - "bitwiseOr(\"__acc\", \"a4\")", + "bitwiseOr(\"__acc\", \"a1\")", null, null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + not(selector("l1", null, null)) ), new FilteredAggregatorFactory( new ExpressionLambdaAggregatorFactory( - "a5", + "a2", ImmutableSet.of("l1"), "__acc", "0", "0", NullHandling.sqlCompatible(), "bitwiseXor(\"__acc\", \"l1\")", - "bitwiseXor(\"__acc\", \"a5\")", + "bitwiseXor(\"__acc\", \"a2\")", null, null, new HumanReadableBytes(1024), TestExprMacroTable.INSTANCE ), - NullHandling.sqlCompatible() ? not(selector("l1", null, null)) : TrueDimFilter.instance() + not(selector("l1", null, null)) ) ) ) @@ -17841,7 +17759,7 @@ public void testBitwiseAggregatorsGroupBy() throws Exception DefaultLimitSpec.builder() .orderBy( new OrderByColumnSpec( - "a4", + "a2", Direction.ASCENDING, StringComparators.NUMERIC ) @@ -17853,15 +17771,15 @@ public void testBitwiseAggregatorsGroupBy() throws Exception ), useDefault ? ImmutableList.of( - new Object[]{"abc", 0L, 0L, 0L, 0L, 0L, 0L}, - new Object[]{"a", 0L, 7L, 7L, 0L, 7L, 7L}, - new Object[]{"", 0L, 325323L, 325323L, 0L, 325323L, 325323L} + new Object[]{"abc", 0L, 0L, 0L}, + new Object[]{"a", 0L, 7L, 7L}, + new Object[]{"", 0L, 325323L, 325323L} ) : ImmutableList.of( - new Object[]{"abc", null, null, null, null, null, null}, - new Object[]{"", 0L, 0L, 0L, 0L, 0L, 0L}, - new Object[]{"a", null, null, null, 0L, 7L, 7L}, - new Object[]{null, null, null, null, 0L, 325323L, 325323L} + new Object[]{"abc", null, null, null}, + new Object[]{"", 0L, 0L, 0L}, + new Object[]{"a", 0L, 7L, 7L}, + new Object[]{null, 0L, 325323L, 325323L} ) ); } From 1cea00ea89fb2368127ec1f61a81a4021dcee7c7 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Wed, 23 Jun 2021 04:04:25 -0700 Subject: [PATCH 3/4] rework deserialize, better names --- .../org/apache/druid/math/expr/ExprEval.java | 16 ++++++++--- .../apache/druid/math/expr/ExprEvalTest.java | 18 +++++++++++-- .../ExpressionLambdaAggregator.java | 10 +++---- .../ExpressionLambdaAggregatorFactory.java | 24 ++++++++--------- .../ExpressionLambdaBufferAggregator.java | 27 ++++++++++++------- 5 files changed, 62 insertions(+), 33 deletions(-) diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 591f70b6d43a..8803583c4c8c 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -38,7 +38,6 @@ */ public abstract class ExprEval { - private static final byte TYPE_MASK = 0x0F; private static final int NULL_LENGTH = -1; /** @@ -48,9 +47,18 @@ public abstract class ExprEval */ public static ExprEval deserialize(ByteBuffer buffer, int position) { - // | expression type (byte) | expression bytes | - ExprType type = ExprType.fromByte((byte) (buffer.get(position) & TYPE_MASK)); - int offset = position + 1; + final ExprType type = ExprType.fromByte(buffer.get(position)); + return deserialize(buffer, position + 1, type); + } + + /** + * Deserialize an expression stored in a bytebuffer, e.g. for an agg. + * + * This should be refactored to be consolidated with some of the standard type handling of aggregators probably + */ + public static ExprEval deserialize(ByteBuffer buffer, int offset, ExprType type) + { + // | expression bytes | switch (type) { case LONG: // | expression type (byte) | is null (byte) | long bytes | diff --git a/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java index 0f1289baf61c..fa3699d22afc 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java @@ -297,9 +297,23 @@ private void assertExpr(int position, ExprEval expected, int maxSizeBytes) { ExprEval.serialize(buffer, position, expected, maxSizeBytes); if (ExprType.isArray(expected.type())) { - Assert.assertArrayEquals(expected.asArray(), ExprEval.deserialize(buffer, position).asArray()); + Assert.assertArrayEquals( + expected.asArray(), + ExprEval.deserialize(buffer, position + 1, ExprType.fromByte(buffer.get(position))).asArray() + ); + Assert.assertArrayEquals( + expected.asArray(), + ExprEval.deserialize(buffer, position).asArray() + ); } else { - Assert.assertEquals(expected.value(), ExprEval.deserialize(buffer, position).value()); + Assert.assertEquals( + expected.value(), + ExprEval.deserialize(buffer, position + 1, ExprType.fromByte(buffer.get(position))).value() + ); + Assert.assertEquals( + expected.value(), + ExprEval.deserialize(buffer, position).value() + ); } assertEstimatedBytes(expected, maxSizeBytes); } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java index a5591b3a2a29..3e56b5d4b243 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java @@ -29,19 +29,19 @@ public class ExpressionLambdaAggregator implements Aggregator private final Expr lambda; private final ExpressionLambdaAggregatorInputBindings bindings; private final int maxSizeBytes; - private boolean uninitializedNullValue; + private boolean hasValue; public ExpressionLambdaAggregator( final Expr lambda, final ExpressionLambdaAggregatorInputBindings bindings, - final boolean initiallyNull, + final boolean isNullUnlessAggregated, final int maxSizeBytes ) { this.lambda = lambda; this.bindings = bindings; this.maxSizeBytes = maxSizeBytes; - this.uninitializedNullValue = initiallyNull; + this.hasValue = !isNullUnlessAggregated; } @Override @@ -50,14 +50,14 @@ public void aggregate() final ExprEval eval = lambda.eval(bindings); ExprEval.estimateAndCheckMaxBytes(eval, maxSizeBytes); bindings.accumulate(eval); - uninitializedNullValue = false; + hasValue = true; } @Nullable @Override public Object get() { - return uninitializedNullValue ? null : bindings.getAccumulator().value(); + return hasValue ? bindings.getAccumulator().value() : null; } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java index 0bf54bf463ea..c7e974890b2c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java @@ -75,7 +75,7 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory private final String foldExpressionString; private final String initialValueExpressionString; private final String initialCombineValueExpressionString; - private final boolean initiallyNull; + private final boolean isNullUnlessAggregated; private final String combineExpressionString; @Nullable @@ -107,7 +107,7 @@ public ExpressionLambdaAggregatorFactory( @JsonProperty("accumulatorIdentifier") @Nullable final String accumulatorIdentifier, @JsonProperty("initialValue") final String initialValue, @JsonProperty("initialCombineValue") @Nullable final String initialCombineValue, - @JsonProperty("initiallyNull") @Nullable final Boolean initiallyNull, + @JsonProperty("isNullUnlessAggregated") @Nullable final Boolean isNullUnlessAggregated, @JsonProperty("fold") final String foldExpression, @JsonProperty("combine") @Nullable final String combineExpression, @JsonProperty("compare") @Nullable final String compareExpression, @@ -124,7 +124,7 @@ public ExpressionLambdaAggregatorFactory( this.initialValueExpressionString = initialValue; this.initialCombineValueExpressionString = initialCombineValue == null ? initialValue : initialCombineValue; - this.initiallyNull = initiallyNull == null ? NullHandling.sqlCompatible() : initiallyNull; + this.isNullUnlessAggregated = isNullUnlessAggregated == null ? NullHandling.sqlCompatible() : isNullUnlessAggregated; this.foldExpressionString = foldExpression; if (combineExpression != null) { this.combineExpressionString = combineExpression; @@ -200,9 +200,9 @@ public String getInitialCombineValueExpressionString() } @JsonProperty("initiallyNull") - public boolean getInitiallyNull() + public boolean getNullUnlessAggregated() { - return initiallyNull; + return isNullUnlessAggregated; } @JsonProperty("fold") @@ -259,7 +259,7 @@ public Aggregator factorize(ColumnSelectorFactory metricFactory) return new ExpressionLambdaAggregator( thePlan.getExpression(), thePlan.getBindings(), - initiallyNull, + isNullUnlessAggregated, maxSizeBytes.getBytesInInt() ); } @@ -272,7 +272,7 @@ public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) thePlan.getExpression(), thePlan.getInitialValue(), thePlan.getBindings(), - initiallyNull, + isNullUnlessAggregated, maxSizeBytes.getBytesInInt() ); } @@ -341,7 +341,7 @@ public AggregatorFactory getCombiningFactory() accumulatorId, initialValueExpressionString, initialCombineValueExpressionString, - initiallyNull, + isNullUnlessAggregated, foldExpressionString, combineExpressionString, compareExpressionString, @@ -361,7 +361,7 @@ public List getRequiredColumns() accumulatorId, initialValueExpressionString, initialCombineValueExpressionString, - initiallyNull, + isNullUnlessAggregated, foldExpressionString, combineExpressionString, compareExpressionString, @@ -421,7 +421,7 @@ public boolean equals(Object o) && foldExpressionString.equals(that.foldExpressionString) && initialValueExpressionString.equals(that.initialValueExpressionString) && initialCombineValueExpressionString.equals(that.initialCombineValueExpressionString) - && initiallyNull == that.initiallyNull + && isNullUnlessAggregated == that.isNullUnlessAggregated && combineExpressionString.equals(that.combineExpressionString) && Objects.equals(compareExpressionString, that.compareExpressionString) && Objects.equals(finalizeExpressionString, that.finalizeExpressionString); @@ -437,7 +437,7 @@ public int hashCode() foldExpressionString, initialValueExpressionString, initialCombineValueExpressionString, - initiallyNull, + isNullUnlessAggregated, combineExpressionString, compareExpressionString, finalizeExpressionString, @@ -455,7 +455,7 @@ public String toString() ", foldExpressionString='" + foldExpressionString + '\'' + ", initialValueExpressionString='" + initialValueExpressionString + '\'' + ", initialCombineValueExpressionString='" + initialCombineValueExpressionString + '\'' + - ", nullUnlessAggregated='" + initiallyNull + '\'' + + ", nullUnlessAggregated='" + isNullUnlessAggregated + '\'' + ", combineExpressionString='" + combineExpressionString + '\'' + ", compareExpressionString='" + compareExpressionString + '\'' + ", finalizeExpressionString='" + finalizeExpressionString + '\'' + diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java index c0a2b7629988..bb9948f1d2e5 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java @@ -21,6 +21,7 @@ import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.ExprType; import javax.annotation.Nullable; import java.nio.ByteBuffer; @@ -29,25 +30,26 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator { private static final short NOT_AGGREGATED_BIT = 1 << 7; private static final short IS_AGGREGATED_MASK = 0x3F; + private static final byte TYPE_MASK = 0x0F; private final Expr lambda; private final ExprEval initialValue; private final ExpressionLambdaAggregatorInputBindings bindings; private final int maxSizeBytes; - private final boolean initiallyNull; + private final boolean isNullUnlessAggregated; public ExpressionLambdaBufferAggregator( Expr lambda, ExprEval initialValue, ExpressionLambdaAggregatorInputBindings bindings, - boolean initiallyNull, + boolean isNullUnlessAggregated, int maxSizeBytes ) { this.lambda = lambda; this.initialValue = initialValue; this.bindings = bindings; - this.initiallyNull = initiallyNull; + this.isNullUnlessAggregated = isNullUnlessAggregated; this.maxSizeBytes = maxSizeBytes; } @@ -56,7 +58,7 @@ public void init(ByteBuffer buf, int position) { ExprEval.serialize(buf, position, initialValue, maxSizeBytes); // set a bit to indicate we haven't aggregated on top of expression type (not going to lie this could be nicer) - if (initiallyNull) { + if (isNullUnlessAggregated) { buf.put(position, (byte) (buf.get(position) | NOT_AGGREGATED_BIT)); } } @@ -64,7 +66,7 @@ public void init(ByteBuffer buf, int position) @Override public void aggregate(ByteBuffer buf, int position) { - ExprEval acc = ExprEval.deserialize(buf, position); + ExprEval acc = ExprEval.deserialize(buf, position + 1, getType(buf, position)); bindings.setAccumulator(acc); ExprEval newAcc = lambda.eval(bindings); ExprEval.serialize(buf, position, newAcc, maxSizeBytes); @@ -76,28 +78,28 @@ public void aggregate(ByteBuffer buf, int position) @Override public Object get(ByteBuffer buf, int position) { - if (initiallyNull && (buf.get(position) & NOT_AGGREGATED_BIT) != 0) { + if (isNullUnlessAggregated && (buf.get(position) & NOT_AGGREGATED_BIT) != 0) { return null; } - return ExprEval.deserialize(buf, position).value(); + return ExprEval.deserialize(buf, position + 1, getType(buf, position)).value(); } @Override public float getFloat(ByteBuffer buf, int position) { - return (float) ExprEval.deserialize(buf, position).asDouble(); + return (float) ExprEval.deserialize(buf, position + 1, getType(buf, position)).asDouble(); } @Override public double getDouble(ByteBuffer buf, int position) { - return ExprEval.deserialize(buf, position).asDouble(); + return ExprEval.deserialize(buf, position + 1, getType(buf, position)).asDouble(); } @Override public long getLong(ByteBuffer buf, int position) { - return ExprEval.deserialize(buf, position).asLong(); + return ExprEval.deserialize(buf, position + 1, getType(buf, position)).asLong(); } @Override @@ -105,4 +107,9 @@ public void close() { // nothing to close } + + private static ExprType getType(ByteBuffer buf, int position) + { + return ExprType.fromByte((byte) (buf.get(position) & TYPE_MASK)); + } } From 7318133036fa4aca4e67eff7f3d9c74fcf4bfaa4 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Wed, 23 Jun 2021 10:40:31 -0700 Subject: [PATCH 4/4] fix json, share mask --- .../query/aggregation/ExpressionLambdaAggregatorFactory.java | 4 ++-- .../query/aggregation/ExpressionLambdaBufferAggregator.java | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java index c7e974890b2c..5d3c79080a13 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java @@ -199,8 +199,8 @@ public String getInitialCombineValueExpressionString() return initialCombineValueExpressionString; } - @JsonProperty("initiallyNull") - public boolean getNullUnlessAggregated() + @JsonProperty("isNullUnlessAggregated") + public boolean getIsNullUnlessAggregated() { return isNullUnlessAggregated; } diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java index bb9948f1d2e5..82b954e10917 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaBufferAggregator.java @@ -30,14 +30,12 @@ public class ExpressionLambdaBufferAggregator implements BufferAggregator { private static final short NOT_AGGREGATED_BIT = 1 << 7; private static final short IS_AGGREGATED_MASK = 0x3F; - private static final byte TYPE_MASK = 0x0F; private final Expr lambda; private final ExprEval initialValue; private final ExpressionLambdaAggregatorInputBindings bindings; private final int maxSizeBytes; private final boolean isNullUnlessAggregated; - public ExpressionLambdaBufferAggregator( Expr lambda, ExprEval initialValue, @@ -110,6 +108,6 @@ public void close() private static ExprType getType(ByteBuffer buf, int position) { - return ExprType.fromByte((byte) (buf.get(position) & TYPE_MASK)); + return ExprType.fromByte((byte) (buf.get(position) & IS_AGGREGATED_MASK)); } }