diff --git a/LICENSE b/LICENSE index e7f791dba913..a9850e12a120 100644 --- a/LICENSE +++ b/LICENSE @@ -257,6 +257,8 @@ SOURCE/JAVA-CORE * core/src/main/java/org/apache/druid/java/util/common/parsers/DelimitedParser.java DirectExecutorService class: * core/src/main/java/org/apache/druid/java/util/common/concurrent/DirectExecutorService.java + CartesianList class: + * core/src/main/java/org/apache/druid/math/expr/CartesianList.java This product contains modified versions of the Dockerfile, scripts, and related configuration files used for building SequenceIQ's Hadoop Docker image, copyright SequenceIQ, Inc. (https://github.com/sequenceiq/hadoop-docker/) diff --git a/core/src/main/antlr4/org/apache/druid/math/expr/antlr/Expr.g4 b/core/src/main/antlr4/org/apache/druid/math/expr/antlr/Expr.g4 index 348b5037a5d1..aacbbe9d4290 100644 --- a/core/src/main/antlr4/org/apache/druid/math/expr/antlr/Expr.g4 +++ b/core/src/main/antlr4/org/apache/druid/math/expr/antlr/Expr.g4 @@ -23,13 +23,21 @@ expr : 'null' # null | expr ('<'|'<='|'>'|'>='|'=='|'!=') expr # logicalOpExpr | expr ('&&'|'||') expr # logicalAndOrExpr | '(' expr ')' # nestedExpr + | IDENTIFIER '(' lambda ',' fnArgs ')' # applyFunctionExpr | IDENTIFIER '(' fnArgs? ')' # functionExpr | IDENTIFIER # identifierExpr | DOUBLE # doubleExpr | LONG # longExpr | STRING # string + | '[' DOUBLE (',' DOUBLE)* ']' # doubleArray + | '[' LONG (',' LONG)* ']' # longArray + | '[' STRING (',' STRING)* ']' # stringArray + | '[]' # emptyArray ; +lambda : (IDENTIFIER | '(' ')' | '(' IDENTIFIER (',' IDENTIFIER)* ')') '->' expr + ; + fnArgs : expr (',' expr)* # functionArgs ; diff --git a/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java new file mode 100644 index 000000000000..f50fe8eb4b42 --- /dev/null +++ b/core/src/main/java/org/apache/druid/math/expr/ApplyFunction.java @@ -0,0 +1,820 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.math.expr; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; +import it.unimi.dsi.fastutil.objects.Object2IntArrayMap; +import it.unimi.dsi.fastutil.objects.Object2IntMap; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.RE; +import org.apache.druid.java.util.common.StringUtils; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; + +/** + * Base interface describing the mechanism used to evaluate an {@link ApplyFunctionExpr}, which 'applies' a + * {@link LambdaExpr} to one or more array {@link Expr}. All {@link ApplyFunction} implementations are immutable. + */ +public interface ApplyFunction +{ + /** + * Name of the function + */ + String name(); + + /** + * Apply {@link LambdaExpr} to argument list of {@link Expr} given a set of outer {@link Expr.ObjectBinding}. These + * outer bindings will be used to form the scope for the bindings used to evaluate the {@link LambdaExpr}, which use + * the array inputs to supply scalar values to use as bindings for {@link IdentifierExpr} in the lambda body. + */ + ExprEval apply(LambdaExpr lambdaExpr, List argsExpr, Expr.ObjectBinding bindings); + + /** + * Get list of input arguments which must evaluate to an array {@link ExprType} + */ + Set getArrayInputs(List args); + + void validateArguments(LambdaExpr lambdaExpr, List args); + + /** + * Base class for "map" functions, which are a class of {@link ApplyFunction} which take a lambda function that is + * mapped to the values of an {@link IndexableMapLambdaObjectBinding} which is created from the outer + * {@link Expr.ObjectBinding} and the values of the array {@link Expr} argument(s) + */ + abstract class BaseMapFunction implements ApplyFunction + { + /** + * Evaluate {@link LambdaExpr} against every index position of an {@link IndexableMapLambdaObjectBinding} + */ + ExprEval applyMap(LambdaExpr expr, IndexableMapLambdaObjectBinding bindings) + { + final int length = bindings.getLength(); + String[] stringsOut = null; + Long[] longsOut = null; + Double[] doublesOut = null; + + ExprType elementType = null; + for (int i = 0; i < length; i++) { + + ExprEval evaluated = expr.eval(bindings.withIndex(i)); + if (elementType == null) { + elementType = evaluated.type(); + switch (elementType) { + case STRING: + stringsOut = new String[length]; + break; + case LONG: + longsOut = new Long[length]; + break; + case DOUBLE: + doublesOut = new Double[length]; + break; + default: + throw new RE("Unhandled map function output type [%s]", elementType); + } + } + + switch (elementType) { + case STRING: + stringsOut[i] = evaluated.asString(); + break; + case LONG: + longsOut[i] = evaluated.asLong(); + break; + case DOUBLE: + doublesOut[i] = evaluated.asDouble(); + break; + } + } + + switch (elementType) { + case STRING: + return ExprEval.ofStringArray(stringsOut); + case LONG: + return ExprEval.ofLongArray(longsOut); + case DOUBLE: + return ExprEval.ofDoubleArray(doublesOut); + default: + throw new RE("Unhandled map function output type [%s]", elementType); + } + } + } + + /** + * Map the scalar values of a single array input {@link Expr} to a single argument {@link LambdaExpr} + */ + class MapFunction extends BaseMapFunction + { + static final String NAME = "map"; + + @Override + public String name() + { + return NAME; + } + + @Override + public ExprEval apply(LambdaExpr lambdaExpr, List argsExpr, Expr.ObjectBinding bindings) + { + Expr arrayExpr = argsExpr.get(0); + ExprEval arrayEval = arrayExpr.eval(bindings); + + Object[] array = arrayEval.asArray(); + if (array == null) { + return ExprEval.of(null); + } + if (array.length == 0) { + return arrayEval; + } + + MapLambdaBinding lambdaBinding = new MapLambdaBinding(array, lambdaExpr, bindings); + return applyMap(lambdaExpr, lambdaBinding); + } + + @Override + public Set getArrayInputs(List args) + { + if (args.size() == 1) { + return ImmutableSet.of(args.get(0)); + } + return Collections.emptySet(); + } + + @Override + public void validateArguments(LambdaExpr lambdaExpr, List args) + { + Preconditions.checkArgument(args.size() == 1); + if (lambdaExpr.identifierCount() > 0) { + Preconditions.checkArgument( + args.size() == lambdaExpr.identifierCount(), + StringUtils.format("lambda expression argument count does not match %s argument count", name()) + ); + } + } + } + + /** + * Map the cartesian product of 'n' array input arguments to an 'n' argument {@link LambdaExpr} + */ + class CartesianMapFunction extends BaseMapFunction + { + static final String NAME = "cartesian_map"; + + @Override + public String name() + { + return NAME; + } + + @Override + public ExprEval apply(LambdaExpr lambdaExpr, List argsExpr, Expr.ObjectBinding bindings) + { + List> arrayInputs = new ArrayList<>(); + boolean hadNull = false; + boolean hadEmpty = false; + for (Expr expr : argsExpr) { + ExprEval arrayEval = expr.eval(bindings); + Object[] array = arrayEval.asArray(); + if (array == null) { + hadNull = true; + continue; + } + if (array.length == 0) { + hadEmpty = true; + continue; + } + arrayInputs.add(Arrays.asList(array)); + } + if (hadNull) { + return ExprEval.of(null); + } + if (hadEmpty) { + return ExprEval.ofStringArray(new String[0]); + } + + List> product = CartesianList.create(arrayInputs); + CartesianMapLambdaBinding lambdaBinding = new CartesianMapLambdaBinding(product, lambdaExpr, bindings); + return applyMap(lambdaExpr, lambdaBinding); + } + + @Override + public Set getArrayInputs(List args) + { + return ImmutableSet.copyOf(args); + } + + @Override + public void validateArguments(LambdaExpr lambdaExpr, List args) + { + Preconditions.checkArgument(args.size() > 0); + if (lambdaExpr.identifierCount() > 0) { + Preconditions.checkArgument( + args.size() == lambdaExpr.identifierCount(), + StringUtils.format("lambda expression argument count does not match %s argument count", name()) + ); + } + } + } + + /** + * Base class for family of {@link ApplyFunction} which aggregate a scalar or array value given one or more array + * input {@link Expr} arguments and an array or scalar "accumulator" argument with an initial value + */ + abstract class BaseFoldFunction implements ApplyFunction + { + /** + * Accumulate a value by evaluating a {@link LambdaExpr} for each index position of an + * {@link IndexableFoldLambdaBinding} + */ + ExprEval applyFold(LambdaExpr lambdaExpr, Object accumulator, IndexableFoldLambdaBinding bindings) + { + for (int i = 0; i < bindings.getLength(); i++) { + ExprEval evaluated = lambdaExpr.eval(bindings.accumulateWithIndex(i, accumulator)); + accumulator = evaluated.value(); + } + return ExprEval.bestEffortOf(accumulator); + } + } + + /** + * Accumulate a value for a single array input with a 2 argument {@link LambdaExpr}. The 'array' input expression is + * the first argument, the initial value for the accumlator expression is the 2nd argument. + */ + class FoldFunction extends BaseFoldFunction + { + static final String NAME = "fold"; + + @Override + public String name() + { + return NAME; + } + + @Override + public ExprEval apply(LambdaExpr lambdaExpr, List argsExpr, Expr.ObjectBinding bindings) + { + Expr arrayExpr = argsExpr.get(0); + Expr accExpr = argsExpr.get(1); + + ExprEval arrayEval = arrayExpr.eval(bindings); + ExprEval accEval = accExpr.eval(bindings); + + Object[] array = arrayEval.asArray(); + if (array == null) { + return ExprEval.of(null); + } + Object accumlator = accEval.value(); + + FoldLambdaBinding lambdaBinding = new FoldLambdaBinding(array, accumlator, lambdaExpr, bindings); + return applyFold(lambdaExpr, accumlator, lambdaBinding); + } + + @Override + public Set getArrayInputs(List args) + { + // accumulator argument cannot currently be inferred, so ignore it until we think of something better to do + return ImmutableSet.of(args.get(0)); + } + + @Override + public void validateArguments(LambdaExpr lambdaExpr, List args) + { + Preconditions.checkArgument(args.size() == 2); + Preconditions.checkArgument( + args.size() == lambdaExpr.identifierCount(), + StringUtils.format("lambda expression argument count does not match %s argument count", name()) + ); + } + } + + /** + * Accumulate a value for the cartesian product of 'n' array inputs arguments with an 'n + 1' argument + * {@link LambdaExpr}. The 'array' input expressions are the first 'n' arguments, the initial value for the accumlator + * expression is the final argument. + */ + class CartesianFoldFunction extends BaseFoldFunction + { + static final String NAME = "cartesian_fold"; + + @Override + public String name() + { + return NAME; + } + + @Override + public ExprEval apply(LambdaExpr lambdaExpr, List argsExpr, Expr.ObjectBinding bindings) + { + List> arrayInputs = new ArrayList<>(); + boolean hadNull = false; + boolean hadEmpty = false; + for (int i = 0; i < argsExpr.size() - 1; i++) { + Expr expr = argsExpr.get(i); + ExprEval arrayEval = expr.eval(bindings); + Object[] array = arrayEval.asArray(); + if (array == null) { + hadNull = true; + continue; + } + if (array.length == 0) { + hadEmpty = true; + continue; + } + arrayInputs.add(Arrays.asList(array)); + } + if (hadNull) { + return ExprEval.of(null); + } + if (hadEmpty) { + return ExprEval.ofStringArray(new String[0]); + } + Expr accExpr = argsExpr.get(argsExpr.size() - 1); + + List> product = CartesianList.create(arrayInputs); + + ExprEval accEval = accExpr.eval(bindings); + + Object accumlator = accEval.value(); + + CartesianFoldLambdaBinding lambdaBindings = + new CartesianFoldLambdaBinding(product, accumlator, lambdaExpr, bindings); + return applyFold(lambdaExpr, accumlator, lambdaBindings); + } + + @Override + public Set getArrayInputs(List args) + { + // accumulator argument cannot be inferred, so ignore it until we think of something better to do + return ImmutableSet.copyOf(args.subList(0, args.size() - 1)); + } + + @Override + public void validateArguments(LambdaExpr lambdaExpr, List args) + { + Preconditions.checkArgument( + args.size() == lambdaExpr.identifierCount(), + StringUtils.format("lambda expression argument count does not match %s argument count", name()) + ); + } + } + + /** + * Filter an array to all elements that evaluate to a 'truthy' value for a {@link LambdaExpr} + */ + class FilterFunction implements ApplyFunction + { + static final String NAME = "filter"; + + @Override + public String name() + { + return NAME; + } + + @Override + public ExprEval apply(LambdaExpr lambdaExpr, List argsExpr, Expr.ObjectBinding bindings) + { + Expr arrayExpr = argsExpr.get(0); + ExprEval arrayEval = arrayExpr.eval(bindings); + + Object[] array = arrayEval.asArray(); + if (array == null) { + return ExprEval.of(null); + } + + SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(lambdaExpr, bindings); + switch (arrayEval.type()) { + case STRING: + case STRING_ARRAY: + String[] filteredString = + this.filter(arrayEval.asStringArray(), lambdaExpr, lambdaBinding).toArray(String[]::new); + return ExprEval.ofStringArray(filteredString); + case LONG: + case LONG_ARRAY: + Long[] filteredLong = + this.filter(arrayEval.asLongArray(), lambdaExpr, lambdaBinding).toArray(Long[]::new); + return ExprEval.ofLongArray(filteredLong); + case DOUBLE: + case DOUBLE_ARRAY: + Double[] filteredDouble = + this.filter(arrayEval.asDoubleArray(), lambdaExpr, lambdaBinding).toArray(Double[]::new); + return ExprEval.ofDoubleArray(filteredDouble); + default: + throw new RE("Unhandled filter function input type [%s]", arrayEval.type()); + } + } + + @Override + public Set getArrayInputs(List args) + { + if (args.size() != 1) { + throw new IAE("ApplyFunction[%s] needs 1 argument", name()); + } + + return ImmutableSet.of(args.get(0)); + } + + @Override + public void validateArguments(LambdaExpr lambdaExpr, List args) + { + Preconditions.checkArgument(args.size() == 1); + Preconditions.checkArgument( + args.size() == lambdaExpr.identifierCount(), + StringUtils.format("lambda expression argument count does not match %s argument count", name()) + ); + } + + private Stream filter(T[] array, LambdaExpr expr, SettableLambdaBinding binding) + { + return Arrays.stream(array).filter(s -> expr.eval(binding.withBinding(expr.getIdentifier(), s)).asBoolean()); + } + } + + /** + * Base class for family of {@link ApplyFunction} which evaluate elements elements of a single array input against + * a {@link LambdaExpr} to evaluate to a final 'truthy' value + */ + abstract class MatchFunction implements ApplyFunction + { + @Override + public ExprEval apply(LambdaExpr lambdaExpr, List argsExpr, Expr.ObjectBinding bindings) + { + Expr arrayExpr = argsExpr.get(0); + ExprEval arrayEval = arrayExpr.eval(bindings); + + final Object[] array = arrayEval.asArray(); + if (array == null) { + return ExprEval.bestEffortOf(false); + } + + SettableLambdaBinding lambdaBinding = new SettableLambdaBinding(lambdaExpr, bindings); + return match(array, lambdaExpr, lambdaBinding); + } + + @Override + public Set getArrayInputs(List args) + { + if (args.size() != 1) { + throw new IAE("ApplyFunction[%s] needs 1 argument", name()); + } + + return ImmutableSet.of(args.get(0)); + } + + @Override + public void validateArguments(LambdaExpr lambdaExpr, List args) + { + Preconditions.checkArgument(args.size() == 1); + Preconditions.checkArgument( + args.size() == lambdaExpr.identifierCount(), + StringUtils.format("lambda expression argument count does not match %s argument count", name()) + ); + } + + public abstract ExprEval match(Object[] values, LambdaExpr expr, SettableLambdaBinding bindings); + } + + /** + * Evaluates to true if any element of the array input {@link Expr} causes the {@link LambdaExpr} to evaluate to a + * 'truthy' value + */ + class AnyMatchFunction extends MatchFunction + { + static final String NAME = "any"; + + @Override + public String name() + { + return NAME; + } + + @Override + public ExprEval match(Object[] values, LambdaExpr expr, SettableLambdaBinding bindings) + { + boolean anyMatch = Arrays.stream(values) + .anyMatch(o -> expr.eval(bindings.withBinding(expr.getIdentifier(), o)).asBoolean()); + return ExprEval.bestEffortOf(anyMatch); + } + } + + /** + * Evaluates to true if all element of the array input {@link Expr} causes the {@link LambdaExpr} to evaluate to a + * 'truthy' value + */ + class AllMatchFunction extends MatchFunction + { + static final String NAME = "all"; + + @Override + public String name() + { + return NAME; + } + + @Override + public ExprEval match(Object[] values, LambdaExpr expr, SettableLambdaBinding bindings) + { + boolean allMatch = Arrays.stream(values) + .allMatch(o -> expr.eval(bindings.withBinding(expr.getIdentifier(), o)).asBoolean()); + return ExprEval.bestEffortOf(allMatch); + } + } + + /** + * Simple, mutable, {@link Expr.ObjectBinding} for a {@link LambdaExpr} which provides a {@link Map} for storing + * arbitrary values to use as values for {@link IdentifierExpr} in the body of the lambda that are arguments to the + * lambda + */ + class SettableLambdaBinding implements Expr.ObjectBinding + { + private final Expr.ObjectBinding bindings; + private final Map lambdaBindings; + + SettableLambdaBinding(LambdaExpr expr, Expr.ObjectBinding bindings) + { + this.lambdaBindings = new HashMap<>(); + for (String lambdaIdentifier : expr.getIdentifiers()) { + lambdaBindings.put(lambdaIdentifier, null); + } + this.bindings = bindings != null ? bindings : Collections.emptyMap()::get; + } + + @Nullable + @Override + public Object get(String name) + { + if (lambdaBindings.containsKey(name)) { + return lambdaBindings.get(name); + } + return bindings.get(name); + } + + SettableLambdaBinding withBinding(String key, Object value) + { + this.lambdaBindings.put(key, value); + return this; + } + } + + /** + * {@link Expr.ObjectBinding} which can be iterated by an integer index position for {@link BaseMapFunction}. + * Evaluating an {@link IdentifierExpr} against these bindings will return the value(s) of the array at the current + * index for any lambda identifiers, and fall through to the base {@link Expr.ObjectBinding} for all bindings provided + * by an outer scope. + */ + interface IndexableMapLambdaObjectBinding extends Expr.ObjectBinding + { + /** + * Total number of bindings in this binding + */ + int getLength(); + + /** + * Update index position + */ + IndexableMapLambdaObjectBinding withIndex(int index); + } + + /** + * {@link IndexableMapLambdaObjectBinding} for a {@link MapFunction}. Lambda argument binding is stored in an object + * array, retrieving binding values for the lambda identifier returns the value at the current index. + */ + class MapLambdaBinding implements IndexableMapLambdaObjectBinding + { + private final Expr.ObjectBinding bindings; + @Nullable + private final String lambdaIdentifier; + private final Object[] arrayValues; + private int index = 0; + private final boolean scoped; + + MapLambdaBinding(Object[] arrayValues, LambdaExpr expr, Expr.ObjectBinding bindings) + { + this.lambdaIdentifier = expr.getIdentifier(); + this.arrayValues = arrayValues; + this.bindings = bindings != null ? bindings : Collections.emptyMap()::get; + this.scoped = lambdaIdentifier != null; + } + + @Nullable + @Override + public Object get(String name) + { + if (scoped && name.equals(lambdaIdentifier)) { + return arrayValues[index]; + } + return bindings.get(name); + } + + @Override + public int getLength() + { + return arrayValues.length; + } + + @Override + public MapLambdaBinding withIndex(int index) + { + this.index = index; + return this; + } + } + + /** + * {@link IndexableMapLambdaObjectBinding} for a {@link CartesianMapFunction}. Lambda argument bindings stored as a + * cartesian product in the form of a list of lists of objects, where the inner list is the in order list of values + * for each {@link LambdaExpr} argument + */ + class CartesianMapLambdaBinding implements IndexableMapLambdaObjectBinding + { + private final Expr.ObjectBinding bindings; + private final Object2IntMap lambdaIdentifiers; + private final List> lambdaInputs; + private final boolean scoped; + private int index = 0; + + CartesianMapLambdaBinding(List> inputs, LambdaExpr expr, Expr.ObjectBinding bindings) + { + this.lambdaInputs = inputs; + List ids = expr.getIdentifiers(); + this.scoped = ids.size() > 0; + this.lambdaIdentifiers = new Object2IntArrayMap<>(ids.size()); + for (int i = 0; i < ids.size(); i++) { + lambdaIdentifiers.put(ids.get(i), i); + } + + this.bindings = bindings != null ? bindings : Collections.emptyMap()::get; + } + + @Nullable + @Override + public Object get(String name) + { + if (scoped && lambdaIdentifiers.containsKey(name)) { + return lambdaInputs.get(index).get(lambdaIdentifiers.getInt(name)); + } + return bindings.get(name); + } + + @Override + public int getLength() + { + return lambdaInputs.size(); + } + + @Override + public CartesianMapLambdaBinding withIndex(int index) + { + this.index = index; + return this; + } + } + + /** + * {@link Expr.ObjectBinding} which can be iterated by an integer index position for {@link BaseFoldFunction}. + * Evaluating an {@link IdentifierExpr} against these bindings will return the value(s) of the array at the current + * index for any lambda array identifiers, the value of the 'accumulator' for the lambda accumulator identifier, + * and fall through to the base {@link Expr.ObjectBinding} for all bindings provided by an outer scope. + */ + interface IndexableFoldLambdaBinding extends Expr.ObjectBinding + { + /** + * Total number of bindings in this binding + */ + int getLength(); + + /** + * Update the index and accumulator value + */ + IndexableFoldLambdaBinding accumulateWithIndex(int index, Object accumulator); + } + + /** + * {@link IndexableFoldLambdaBinding} for a {@link FoldFunction}. Like {@link MapLambdaBinding} + * but with additional information to track and provide binding values for an accumulator. + */ + class FoldLambdaBinding implements IndexableFoldLambdaBinding + { + private final Expr.ObjectBinding bindings; + private final String elementIdentifier; + private final Object[] arrayValues; + private final String accumulatorIdentifier; + private Object accumulatorValue; + private int index; + + FoldLambdaBinding(Object[] arrayValues, Object initialAccumulator, LambdaExpr expr, Expr.ObjectBinding bindings) + { + List ids = expr.getIdentifiers(); + this.elementIdentifier = ids.get(0); + this.accumulatorIdentifier = ids.get(1); + this.arrayValues = arrayValues; + this.accumulatorValue = initialAccumulator; + this.bindings = bindings != null ? bindings : Collections.emptyMap()::get; + } + + @Nullable + @Override + public Object get(String name) + { + if (name.equals(elementIdentifier)) { + return arrayValues[index]; + } else if (name.equals(accumulatorIdentifier)) { + return accumulatorValue; + } + return bindings.get(name); + } + + @Override + public int getLength() + { + return arrayValues.length; + } + + @Override + public FoldLambdaBinding accumulateWithIndex(int index, Object acc) + { + this.index = index; + this.accumulatorValue = acc; + return this; + } + } + + /** + * {@link IndexableFoldLambdaBinding} for a {@link CartesianFoldFunction}. Like {@link CartesianMapLambdaBinding} + * but with additional information to track and provide binding values for an accumulator. + */ + class CartesianFoldLambdaBinding implements IndexableFoldLambdaBinding + { + private final Expr.ObjectBinding bindings; + private final Object2IntMap lambdaIdentifiers; + private final List> lambdaInputs; + private final String accumulatorIdentifier; + private Object accumulatorValue; + private int index = 0; + + CartesianFoldLambdaBinding(List> inputs, Object accumulatorValue, LambdaExpr expr, Expr.ObjectBinding bindings) + { + this.lambdaInputs = inputs; + List ids = expr.getIdentifiers(); + this.lambdaIdentifiers = new Object2IntArrayMap<>(ids.size()); + for (int i = 0; i < ids.size() - 1; i++) { + lambdaIdentifiers.put(ids.get(i), i); + } + this.accumulatorIdentifier = ids.get(ids.size() - 1); + this.bindings = bindings != null ? bindings : Collections.emptyMap()::get; + this.accumulatorValue = accumulatorValue; + } + + @Nullable + @Override + public Object get(String name) + { + if (lambdaIdentifiers.containsKey(name)) { + return lambdaInputs.get(index).get(lambdaIdentifiers.getInt(name)); + } else if (accumulatorIdentifier.equals(name)) { + return accumulatorValue; + } + return bindings.get(name); + } + + @Override + public int getLength() + { + return lambdaInputs.size(); + } + + @Override + public CartesianFoldLambdaBinding accumulateWithIndex(int index, Object acc) + { + this.index = index; + this.accumulatorValue = acc; + return this; + } + } +} diff --git a/core/src/main/java/org/apache/druid/math/expr/CartesianList.java b/core/src/main/java/org/apache/druid/math/expr/CartesianList.java new file mode 100644 index 000000000000..d1373bd62a4b --- /dev/null +++ b/core/src/main/java/org/apache/druid/math/expr/CartesianList.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.math.expr; + +import com.google.common.base.Preconditions; +import com.google.common.math.IntMath; + +import javax.annotation.Nullable; +import java.util.AbstractList; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.ListIterator; +import java.util.RandomAccess; + +/** + * {@link CartesianList} computes the cartesian product of n lists. It is adapted from and is *nearly* identical to one + * Guava CartesianList which comes from a version from "the future" that we don't yet have, with the key difference that + * it is not {@link com.google.common.collect.ImmutableList} based, so it can hold null values to be compatible with the + * evaluation and handling of cartesian products of expression arrays with null elements, e.g. ['a', 'b', null] + */ + +public final class CartesianList extends AbstractList> implements RandomAccess +{ + private final transient List> axes; + private final transient int[] axesSizeProduct; + + public static List> create(List> lists) + { + List> axesBuilder = new ArrayList<>(lists.size()); + for (List list : lists) { + if (list.isEmpty()) { + return Collections.emptyList(); + } + axesBuilder.add(new ArrayList<>(list)); + } + return new CartesianList(axesBuilder); + } + + CartesianList(List> axes) + { + this.axes = axes; + int[] axesSizeProduct = new int[axes.size() + 1]; + axesSizeProduct[axes.size()] = 1; + try { + for (int i = axes.size() - 1; i >= 0; i--) { + axesSizeProduct[i] = IntMath.checkedMultiply(axesSizeProduct[i + 1], axes.get(i).size()); + } + } + catch (ArithmeticException e) { + throw new IllegalArgumentException( + "Cartesian product too large; must have size at most Integer.MAX_VALUE"); + } + this.axesSizeProduct = axesSizeProduct; + } + + private int getAxisIndexForProductIndex(int index, int axis) + { + return (index / axesSizeProduct[axis + 1]) % axes.get(axis).size(); + } + + @Override + public int indexOf(Object o) + { + if (!(o instanceof List)) { + return -1; + } + List list = (List) o; + if (list.size() != axes.size()) { + return -1; + } + ListIterator itr = list.listIterator(); + int computedIndex = 0; + while (itr.hasNext()) { + int axisIndex = itr.nextIndex(); + int elemIndex = axes.get(axisIndex).indexOf(itr.next()); + if (elemIndex == -1) { + return -1; + } + computedIndex += elemIndex * axesSizeProduct[axisIndex + 1]; + } + return computedIndex; + } + + @Override + public List get(final int index) + { + Preconditions.checkElementIndex(index, size()); + return new AbstractList() + { + @Override + public int size() + { + return axes.size(); + } + + @Override + public E get(int axis) + { + Preconditions.checkElementIndex(axis, size()); + int axisIndex = getAxisIndexForProductIndex(index, axis); + return axes.get(axis).get(axisIndex); + } + }; + } + + @Override + public int size() + { + return axesSizeProduct[0]; + } + + @Override + public boolean contains(@Nullable Object o) + { + return indexOf(o) != -1; + } +} diff --git a/core/src/main/java/org/apache/druid/math/expr/Expr.java b/core/src/main/java/org/apache/druid/math/expr/Expr.java index cce15b662c24..ee5ff9689e77 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Expr.java +++ b/core/src/main/java/org/apache/druid/math/expr/Expr.java @@ -20,22 +20,36 @@ package org.apache.druid.math.expr; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import com.google.common.math.LongMath; import com.google.common.primitives.Ints; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.guava.Comparators; -import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; /** + * Base interface of Druid expression language abstract syntax tree nodes. All {@link Expr} implementations are + * immutable. */ public interface Expr { + /** + * Indicates expression is a constant whose literal value can be extracted by {@link Expr#getLiteralValue()}, + * making evaluating with arguments and bindings unecessary + */ default boolean isLiteral() { // Overridden by things that are literals. @@ -45,7 +59,7 @@ default boolean isLiteral() /** * Returns the value of expr if expr is a literal, or throws an exception otherwise. * - * @return expr's literal value + * @return {@link ConstantExpr}'s literal value * * @throws IllegalStateException if expr is not a literal */ @@ -56,23 +70,179 @@ default Object getLiteralValue() throw new ISE("Not a literal"); } - @Nonnull + /** + * Returns the string identifier of an {@link IdentifierExpr}, else null + */ + @Nullable + default String getIdentifierIfIdentifier() + { + // overridden by things that are identifiers + return null; + } + + /** + * Evaluate the {@link Expr} with the bindings which supply {@link IdentifierExpr} with their values, producing an + * {@link ExprEval} with the result. + */ ExprEval eval(ObjectBinding bindings); + /** + * Programmatically inspect the {@link Expr} tree with a {@link Visitor}. Each {@link Expr} is responsible for + * ensuring the {@link Visitor} can visit all of its {@link Expr} children before visiting itself. + */ + void visit(Visitor visitor); + + /** + * Programatically rewrite the {@link Expr} tree with a {@link Shuttle}.Each {@link Expr} is responsible for + * ensuring the {@link Shuttle} can visit all of its {@link Expr} children, as well as updating its children + * {@link Expr} with the results from the {@link Shuttle}, before finally visiting an updated form of itself. + */ + Expr visit(Shuttle shuttle); + + /** + * Examine the usage of {@link IdentifierExpr} children of an {@link Expr}, constructing a {@link BindingDetails} + */ + BindingDetails analyzeInputs(); + + /** + * Mechanism to supply values to back {@link IdentifierExpr} during expression evaluation + */ interface ObjectBinding { + /** + * Get value binding for string identifier of {@link IdentifierExpr} + */ @Nullable Object get(String name); } - void visit(Visitor visitor); - + /** + * Mechanism to inspect an {@link Expr}, implementing a {@link Visitor} allows visiting all children of an + * {@link Expr} + */ interface Visitor { + /** + * Provide the {@link Visitor} with an {@link Expr} to inspect + */ void visit(Expr expr); } + + /** + * Mechanism to rewrite an {@link Expr}, implementing a {@link Shuttle} allows visiting all children of an + * {@link Expr}, and replacing them as desired. + */ + interface Shuttle + { + /** + * Provide the {@link Shuttle} with an {@link Expr} to inspect and potentially rewrite. + */ + Expr visit(Expr expr); + } + + /** + * Information about the context in which {@link IdentifierExpr} are used in a greater {@link Expr}, listing + * the 'free variables' (total set of required input columns or values) and distinguishing between which identifiers + * are used as scalar values and which are used as array values. + */ + class BindingDetails + { + private final ImmutableSet freeVariables; + private final ImmutableSet scalarVariables; + private final ImmutableSet arrayVariables; + + public BindingDetails() + { + this(Collections.emptySet(), Collections.emptySet(), Collections.emptySet()); + } + + public BindingDetails(String identifier) + { + this(ImmutableSet.of(identifier), Collections.emptySet(), Collections.emptySet()); + } + + public BindingDetails(Set freeVariables, Set scalarVariables, Set arrayVariables) + { + this.freeVariables = ImmutableSet.copyOf(freeVariables); + this.scalarVariables = ImmutableSet.copyOf(scalarVariables); + this.arrayVariables = ImmutableSet.copyOf(arrayVariables); + } + + /** + * Get the list of required column inputs to evaluate an expression + */ + public ImmutableList getRequiredColumns() + { + return ImmutableList.copyOf(freeVariables); + } + + /** + * Total set of 'free' identifiers of an {@link Expr}, that are not supplied by a {@link LambdaExpr} binding + */ + public ImmutableSet getFreeVariables() + { + return freeVariables; + } + + /** + * Set of identifiers which are used with scalar operators and functions + */ + public ImmutableSet getScalarVariables() + { + return scalarVariables; + } + + /** + * Set of identifiers which are used with array typed functions and apply functions. + */ + public ImmutableSet getArrayVariables() + { + return arrayVariables; + } + + public BindingDetails merge(BindingDetails other) + { + return new BindingDetails( + Sets.union(freeVariables, other.freeVariables), + Sets.union(scalarVariables, other.scalarVariables), + Sets.union(arrayVariables, other.arrayVariables) + ); + } + + public BindingDetails mergeWith(Set moreScalars, Set moreArrays) + { + return new BindingDetails( + Sets.union(freeVariables, Sets.union(moreScalars, moreArrays)), + Sets.union(scalarVariables, moreScalars), + Sets.union(arrayVariables, moreArrays) + ); + } + + public BindingDetails mergeWithScalars(Set moreScalars) + { + return new BindingDetails( + Sets.union(freeVariables, moreScalars), + Sets.union(scalarVariables, moreScalars), + arrayVariables + ); + } + + public BindingDetails mergeWithArrays(Set moreArrays) + { + return new BindingDetails( + Sets.union(freeVariables, moreArrays), + scalarVariables, + Sets.union(arrayVariables, moreArrays) + ); + } + } } +/** + * Base type for all constant expressions. {@link ConstantExpr} allow for direct value extraction without evaluating + * {@link Expr.ObjectBinding}. {@link ConstantExpr} are terminal nodes of an expression tree, and have no children + * {@link Expr}. + */ abstract class ConstantExpr implements Expr { @Override @@ -86,18 +256,29 @@ public void visit(Visitor visitor) { visitor.visit(this); } + + @Override + public Expr visit(Shuttle shuttle) + { + return shuttle.visit(this); + } + + @Override + public BindingDetails analyzeInputs() + { + return new BindingDetails(); + } } class LongExpr extends ConstantExpr { private final Long value; - public LongExpr(Long value) + LongExpr(Long value) { this.value = Preconditions.checkNotNull(value, "value"); } - @Nonnull @Override public Object getLiteralValue() { @@ -110,7 +291,6 @@ public String toString() return String.valueOf(value); } - @Nonnull @Override public ExprEval eval(ObjectBinding bindings) { @@ -118,11 +298,40 @@ public ExprEval eval(ObjectBinding bindings) } } +class LongArrayExpr extends ConstantExpr +{ + private final Long[] value; + + LongArrayExpr(Long[] value) + { + this.value = Preconditions.checkNotNull(value, "value"); + } + + @Override + public Object getLiteralValue() + { + return value; + } + + @Override + public String toString() + { + return Arrays.toString(value); + } + + @Override + public ExprEval eval(ObjectBinding bindings) + { + return ExprEval.ofLongArray(value); + } +} + class StringExpr extends ConstantExpr { + @Nullable private final String value; - public StringExpr(String value) + StringExpr(@Nullable String value) { this.value = NullHandling.emptyToNullIfNeeded(value); } @@ -134,13 +343,13 @@ public Object getLiteralValue() return value; } + @Nullable @Override public String toString() { return value; } - @Nonnull @Override public ExprEval eval(ObjectBinding bindings) { @@ -148,16 +357,43 @@ public ExprEval eval(ObjectBinding bindings) } } +class StringArrayExpr extends ConstantExpr +{ + private final String[] value; + + StringArrayExpr(String[] value) + { + this.value = Preconditions.checkNotNull(value, "value"); + } + + @Override + public Object getLiteralValue() + { + return value; + } + + @Override + public String toString() + { + return Arrays.toString(value); + } + + @Override + public ExprEval eval(ObjectBinding bindings) + { + return ExprEval.ofStringArray(value); + } +} + class DoubleExpr extends ConstantExpr { private final Double value; - public DoubleExpr(Double value) + DoubleExpr(Double value) { this.value = Preconditions.checkNotNull(value, "value"); } - @Nonnull @Override public Object getLiteralValue() { @@ -170,7 +406,6 @@ public String toString() return String.valueOf(value); } - @Nonnull @Override public ExprEval eval(ObjectBinding bindings) { @@ -178,26 +413,71 @@ public ExprEval eval(ObjectBinding bindings) } } +class DoubleArrayExpr extends ConstantExpr +{ + private final Double[] value; + + DoubleArrayExpr(Double[] value) + { + this.value = Preconditions.checkNotNull(value, "value"); + } + + @Override + public Object getLiteralValue() + { + return value; + } + + @Override + public String toString() + { + return Arrays.toString(value); + } + + @Override + public ExprEval eval(ObjectBinding bindings) + { + return ExprEval.ofDoubleArray(value); + } +} + +/** + * This {@link Expr} node is used to represent a variable in the expression language. At evaluation time, the string + * identifier will be used to retrieve the runtime value for the variable from {@link Expr.ObjectBinding}. + * {@link IdentifierExpr} are terminal nodes of an expression tree, and have no children {@link Expr}. + */ class IdentifierExpr implements Expr { - private final String value; + private final String identifier; - public IdentifierExpr(String value) + IdentifierExpr(String value) { - this.value = value; + this.identifier = value; } @Override public String toString() { - return value; + return identifier; + } + + @Nullable + @Override + public String getIdentifierIfIdentifier() + { + return identifier; + } + + @Override + public BindingDetails analyzeInputs() + { + return new BindingDetails(identifier); } - @Nonnull @Override public ExprEval eval(ObjectBinding bindings) { - return ExprEval.bestEffortOf(bindings.get(value)); + return ExprEval.bestEffortOf(bindings.get(identifier)); } @Override @@ -205,28 +485,121 @@ public void visit(Visitor visitor) { visitor.visit(this); } + + @Override + public Expr visit(Shuttle shuttle) + { + return shuttle.visit(this); + } } +class LambdaExpr implements Expr +{ + private final ImmutableList args; + private final Expr expr; + + LambdaExpr(List args, Expr expr) + { + this.args = ImmutableList.copyOf(args); + this.expr = expr; + } + + @Override + public String toString() + { + return StringUtils.format("(%s -> %s)", args, expr); + } + + public int identifierCount() + { + return args.size(); + } + + @Nullable + public String getIdentifier() + { + Preconditions.checkState(args.size() < 2, "LambdaExpr has multiple arguments"); + if (args.size() == 1) { + return args.get(0).toString(); + } + return null; + } + + public List getIdentifiers() + { + return args.stream().map(IdentifierExpr::toString).collect(Collectors.toList()); + } + + public ImmutableList getIdentifierExprs() + { + return args; + } + + public Expr getExpr() + { + return expr; + } + + @Override + public ExprEval eval(ObjectBinding bindings) + { + return expr.eval(bindings); + } + + @Override + public void visit(Visitor visitor) + { + expr.visit(visitor); + visitor.visit(this); + } + + @Override + public Expr visit(Shuttle shuttle) + { + List newArgs = + args.stream().map(arg -> (IdentifierExpr) shuttle.visit(arg)).collect(Collectors.toList()); + Expr newBody = expr.visit(shuttle); + return shuttle.visit(new LambdaExpr(newArgs, newBody)); + } + + @Override + public BindingDetails analyzeInputs() + { + final Set lambdaArgs = args.stream().map(IdentifierExpr::toString).collect(Collectors.toSet()); + BindingDetails bodyDetails = expr.analyzeInputs(); + return new BindingDetails( + Sets.difference(bodyDetails.getFreeVariables(), lambdaArgs), + Sets.difference(bodyDetails.getScalarVariables(), lambdaArgs), + Sets.difference(bodyDetails.getArrayVariables(), lambdaArgs) + ); + } +} + +/** + * {@link Expr} node for a {@link Function} call. {@link FunctionExpr} has children {@link Expr} in the form of the + * list of arguments that are passed to the {@link Function} along with the {@link Expr.ObjectBinding} when it is + * evaluated. + */ class FunctionExpr implements Expr { final Function function; final String name; - final List args; + final ImmutableList args; - public FunctionExpr(Function function, String name, List args) + FunctionExpr(Function function, String name, List args) { this.function = function; this.name = name; - this.args = args; + this.args = ImmutableList.copyOf(args); + function.validateArguments(args); } @Override public String toString() { - return "(" + name + " " + args + ")"; + return StringUtils.format("(%s %s)", name, args); } - @Nonnull @Override public ExprEval eval(ObjectBinding bindings) { @@ -241,8 +614,131 @@ public void visit(Visitor visitor) } visitor.visit(this); } + + @Override + public Expr visit(Shuttle shuttle) + { + List newArgs = args.stream().map(shuttle::visit).collect(Collectors.toList()); + return shuttle.visit(new FunctionExpr(function, name, newArgs)); + } + + @Override + public BindingDetails analyzeInputs() + { + final Set scalarVariables = new HashSet<>(); + final Set arrayVariables = new HashSet<>(); + final Set scalarArgs = function.getScalarInputs(args); + final Set arrayArgs = function.getArrayInputs(args); + BindingDetails accumulator = new BindingDetails(); + + for (Expr arg : args) { + accumulator = accumulator.merge(arg.analyzeInputs()); + } + for (Expr arg : scalarArgs) { + String s = arg.getIdentifierIfIdentifier(); + if (s != null) { + scalarVariables.add(s); + } + } + for (Expr arg : arrayArgs) { + String s = arg.getIdentifierIfIdentifier(); + if (s != null) { + arrayVariables.add(s); + } + } + return accumulator.mergeWith(scalarVariables, arrayVariables); + } } +/** + * This {@link Expr} node is representative of an {@link ApplyFunction}, and has children in the form of a + * {@link LambdaExpr} and the list of {@link Expr} arguments that are combined with {@link Expr.ObjectBinding} to + * evaluate the {@link LambdaExpr}. + */ +class ApplyFunctionExpr implements Expr +{ + final ApplyFunction function; + final String name; + final LambdaExpr lambdaExpr; + final ImmutableList argsExpr; + final BindingDetails bindingDetails; + final BindingDetails lambdaBindingDetails; + final ImmutableList argsBindingDetails; + + ApplyFunctionExpr(ApplyFunction function, String name, LambdaExpr expr, List args) + { + this.function = function; + this.name = name; + this.argsExpr = ImmutableList.copyOf(args); + this.lambdaExpr = expr; + + function.validateArguments(expr, args); + + // apply function expressions are examined during expression selector creation, so precompute and cache the + // binding details of children + ImmutableList.Builder argBindingDetailsBuilder = ImmutableList.builder(); + BindingDetails accumulator = new BindingDetails(); + for (Expr arg : argsExpr) { + BindingDetails argDetails = arg.analyzeInputs(); + argBindingDetailsBuilder.add(argDetails); + accumulator = accumulator.merge(argDetails); + } + + final Set arrayVariables = new HashSet<>(); + Set arrayArgs = function.getArrayInputs(argsExpr); + + for (Expr arg : arrayArgs) { + String s = arg.getIdentifierIfIdentifier(); + if (s != null) { + arrayVariables.add(s); + } + } + + lambdaBindingDetails = lambdaExpr.analyzeInputs(); + bindingDetails = accumulator.merge(lambdaBindingDetails).mergeWithArrays(arrayVariables); + argsBindingDetails = argBindingDetailsBuilder.build(); + } + + @Override + public String toString() + { + return StringUtils.format("(%s %s, %s)", name, lambdaExpr, argsExpr); + } + + @Override + public ExprEval eval(ObjectBinding bindings) + { + return function.apply(lambdaExpr, argsExpr, bindings); + } + + @Override + public void visit(Visitor visitor) + { + lambdaExpr.visit(visitor); + for (Expr arg : argsExpr) { + arg.visit(visitor); + } + visitor.visit(this); + } + + @Override + public Expr visit(Shuttle shuttle) + { + LambdaExpr newLambda = (LambdaExpr) lambdaExpr.visit(shuttle); + List newArgs = argsExpr.stream().map(shuttle::visit).collect(Collectors.toList()); + return shuttle.visit(new ApplyFunctionExpr(function, name, newLambda, newArgs)); + } + + @Override + public BindingDetails analyzeInputs() + { + return bindingDetails; + } +} + +/** + * Base type for all single argument operators, with a single {@link Expr} child for the operand. + */ abstract class UnaryExpr implements Expr { final Expr expr; @@ -252,12 +748,38 @@ abstract class UnaryExpr implements Expr this.expr = expr; } + abstract UnaryExpr copy(Expr expr); + @Override public void visit(Visitor visitor) { expr.visit(visitor); visitor.visit(this); } + + @Override + public Expr visit(Shuttle shuttle) + { + Expr newExpr = expr.visit(shuttle); + if (newExpr != expr) { + return shuttle.visit(copy(newExpr)); + } + return shuttle.visit(this); + } + + @Override + public BindingDetails analyzeInputs() + { + // currently all unary operators only operate on scalar inputs + final Set scalars; + final String identifierMaybe = expr.getIdentifierIfIdentifier(); + if (identifierMaybe != null) { + scalars = ImmutableSet.of(identifierMaybe); + } else { + scalars = Collections.emptySet(); + } + return expr.analyzeInputs().mergeWithScalars(scalars); + } } class UnaryMinusExpr extends UnaryExpr @@ -267,7 +789,12 @@ class UnaryMinusExpr extends UnaryExpr super(expr); } - @Nonnull + @Override + UnaryExpr copy(Expr expr) + { + return new UnaryMinusExpr(expr); + } + @Override public ExprEval eval(ObjectBinding bindings) { @@ -284,17 +811,10 @@ public ExprEval eval(ObjectBinding bindings) throw new IAE("unsupported type " + ret.type()); } - @Override - public void visit(Visitor visitor) - { - expr.visit(visitor); - visitor.visit(this); - } - @Override public String toString() { - return "-" + expr; + return StringUtils.format("-%s", expr); } } @@ -305,7 +825,12 @@ class UnaryNotExpr extends UnaryExpr super(expr); } - @Nonnull + @Override + UnaryExpr copy(Expr expr) + { + return new UnaryNotExpr(expr); + } + @Override public ExprEval eval(ObjectBinding bindings) { @@ -321,19 +846,24 @@ public ExprEval eval(ObjectBinding bindings) @Override public String toString() { - return "!" + expr; + return StringUtils.format("!%s", expr); } } -// all concrete subclass of this should have constructor with the form of (String, Expr, Expr) -// if it's not possible, just be sure Evals.binaryOp() can handle that +/** + * Base type for all binary operators, this {@link Expr} has two children {@link Expr} for the left and right side + * operands. + * + * Note: all concrete subclass of this should have constructor with the form of (String, Expr, Expr) + * if it's not possible, just be sure Evals.binaryOp() can handle that + */ abstract class BinaryOpExprBase implements Expr { protected final String op; protected final Expr left; protected final Expr right; - public BinaryOpExprBase(String op, Expr left, Expr right) + BinaryOpExprBase(String op, Expr left, Expr right) { this.op = op; this.left = left; @@ -348,21 +878,55 @@ public void visit(Visitor visitor) visitor.visit(this); } + @Override + public Expr visit(Shuttle shuttle) + { + Expr newLeft = left.visit(shuttle); + Expr newRight = right.visit(shuttle); + if (left != newLeft || right != newRight) { + return shuttle.visit(copy(newLeft, newRight)); + } + return shuttle.visit(this); + } + @Override public String toString() { - return "(" + op + " " + left + " " + right + ")"; + return StringUtils.format("(%s %s %s)", op, left, right); + } + + protected abstract BinaryOpExprBase copy(Expr left, Expr right); + + @Override + public BindingDetails analyzeInputs() + { + // currently all binary operators operate on scalar inputs + final Set scalars = new HashSet<>(); + final String leftIdentifer = left.getIdentifierIfIdentifier(); + final String rightIdentifier = right.getIdentifierIfIdentifier(); + if (leftIdentifer != null) { + scalars.add(leftIdentifer); + } + if (rightIdentifier != null) { + scalars.add(rightIdentifier); + } + return left.analyzeInputs() + .merge(right.analyzeInputs()) + .mergeWithScalars(scalars); } } +/** + * Base class for numerical binary operators, with additional methods defined to evaluate primitive values directly + * instead of wrapped with {@link ExprEval} + */ abstract class BinaryEvalOpExprBase extends BinaryOpExprBase { - public BinaryEvalOpExprBase(String op, Expr left, Expr right) + BinaryEvalOpExprBase(String op, Expr left, Expr right) { super(op, left, right); } - @Nonnull @Override public ExprEval eval(ObjectBinding bindings) { @@ -375,6 +939,7 @@ public ExprEval eval(ObjectBinding bindings) return ExprEval.of(null); } + if (leftVal.type() == ExprType.STRING && rightVal.type() == ExprType.STRING) { return evalString(leftVal.asString(), rightVal.asString()); } else if (leftVal.type() == ExprType.LONG && rightVal.type() == ExprType.LONG) { @@ -407,6 +972,12 @@ class BinMinusExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinMinusExpr(op, left, right); + } + @Override protected final long evalLong(long left, long right) { @@ -427,6 +998,12 @@ class BinPowExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinPowExpr(op, left, right); + } + @Override protected final long evalLong(long left, long right) { @@ -447,6 +1024,12 @@ class BinMulExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinMulExpr(op, left, right); + } + @Override protected final long evalLong(long left, long right) { @@ -467,6 +1050,12 @@ class BinDivExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinDivExpr(op, left, right); + } + @Override protected final long evalLong(long left, long right) { @@ -487,6 +1076,12 @@ class BinModuloExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinModuloExpr(op, left, right); + } + @Override protected final long evalLong(long left, long right) { @@ -507,6 +1102,12 @@ class BinPlusExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinPlusExpr(op, left, right); + } + @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { @@ -534,6 +1135,12 @@ class BinLtExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinLtExpr(op, left, right); + } + @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { @@ -561,6 +1168,12 @@ class BinLeqExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinLeqExpr(op, left, right); + } + @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { @@ -588,6 +1201,12 @@ class BinGtExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinGtExpr(op, left, right); + } + @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { @@ -615,6 +1234,12 @@ class BinGeqExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinGeqExpr(op, left, right); + } + @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { @@ -642,6 +1267,12 @@ class BinEqExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinEqExpr(op, left, right); + } + @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { @@ -668,6 +1299,12 @@ class BinNeqExpr extends BinaryEvalOpExprBase super(op, left, right); } + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinNeqExpr(op, left, right); + } + @Override protected ExprEval evalString(@Nullable String left, @Nullable String right) { @@ -694,7 +1331,12 @@ class BinAndExpr extends BinaryOpExprBase super(op, left, right); } - @Nonnull + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinAndExpr(op, left, right); + } + @Override public ExprEval eval(ObjectBinding bindings) { @@ -710,7 +1352,12 @@ class BinOrExpr extends BinaryOpExprBase super(op, left, right); } - @Nonnull + @Override + protected BinaryOpExprBase copy(Expr left, Expr right) + { + return new BinOrExpr(op, left, right); + } + @Override public ExprEval eval(ObjectBinding bindings) { diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 4dad8100c952..1cafb17a750c 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -25,15 +25,14 @@ import org.apache.druid.java.util.common.IAE; import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.stream.Collectors; /** + * Generic result holder for evaluated {@link Expr} containing the value and {@link ExprType} of the value to allow */ public abstract class ExprEval { - // Cached String values. Protected so they can be used by subclasses. - private boolean stringValueValid = false; - private String stringValue; - public static ExprEval ofLong(@Nullable Number longValue) { return new LongExprEval(longValue); @@ -62,6 +61,21 @@ public static ExprEval of(@Nullable String stringValue) return new StringExprEval(stringValue); } + public static ExprEval ofLongArray(@Nullable Long[] longValue) + { + return new LongArrayExprEval(longValue); + } + + public static ExprEval ofDoubleArray(@Nullable Double[] doubleValue) + { + return new DoubleArrayExprEval(doubleValue); + } + + public static ExprEval ofStringArray(@Nullable String[] stringValue) + { + return new StringArrayExprEval(stringValue); + } + public static ExprEval of(boolean value, ExprType type) { switch (type) { @@ -87,9 +101,27 @@ public static ExprEval bestEffortOf(@Nullable Object val) } return new LongExprEval((Number) val); } + if (val instanceof Long[]) { + return new LongArrayExprEval((Long[]) val); + } + if (val instanceof Double[]) { + return new DoubleArrayExprEval((Double[]) val); + } + if (val instanceof Float[]) { + return new DoubleArrayExprEval(Arrays.stream((Float[]) val).map(Float::doubleValue).toArray(Double[]::new)); + } + if (val instanceof String[]) { + return new StringArrayExprEval((String[]) val); + } + return new StringExprEval(val == null ? null : String.valueOf(val)); } + // Cached String values + private boolean stringValueValid = false; + @Nullable + private String stringValue; + @Nullable final T value; @@ -100,22 +132,12 @@ private ExprEval(@Nullable T value) public abstract ExprType type(); - public Object value() + @Nullable + public T value() { return value; } - /** - * returns true if numeric primitive value for this ExprEval is null, otherwise false. - */ - public abstract boolean isNumericNull(); - - public abstract int asInt(); - - public abstract long asLong(); - - public abstract double asDouble(); - @Nullable public String asString() { @@ -132,8 +154,36 @@ public String asString() return stringValue; } + /** + * returns true if numeric primitive value for this ExprEval is null, otherwise false. + */ + public abstract boolean isNumericNull(); + + public boolean isArray() + { + return false; + } + + public abstract int asInt(); + + public abstract long asLong(); + + public abstract double asDouble(); + public abstract boolean asBoolean(); + @Nullable + public abstract Object[] asArray(); + + @Nullable + public abstract String[] asStringArray(); + + @Nullable + public abstract Long[] asLongArray(); + + @Nullable + public abstract Double[] asDoubleArray(); + public abstract ExprEval castTo(ExprType castTo); public abstract Expr toExpr(); @@ -163,6 +213,27 @@ public final double asDouble() return value.doubleValue(); } + @Nullable + @Override + public String[] asStringArray() + { + return isNumericNull() ? null : new String[] {value.toString()}; + } + + @Nullable + @Override + public Long[] asLongArray() + { + return isNumericNull() ? null : new Long[] {value.longValue()}; + } + + @Nullable + @Override + public Double[] asDoubleArray() + { + return isNumericNull() ? null : new Double[] {value.doubleValue()}; + } + @Override public boolean isNumericNull() { @@ -189,6 +260,13 @@ public final boolean asBoolean() return Evals.asBoolean(asDouble()); } + @Nullable + @Override + public Object[] asArray() + { + return asDoubleArray(); + } + @Override public final ExprEval castTo(ExprType castTo) { @@ -203,6 +281,12 @@ public final ExprEval castTo(ExprType castTo) } case STRING: return ExprEval.of(asString()); + case DOUBLE_ARRAY: + return ExprEval.ofDoubleArray(asDoubleArray()); + case LONG_ARRAY: + return ExprEval.ofLongArray(asLongArray()); + case STRING_ARRAY: + return ExprEval.ofStringArray(asStringArray()); } throw new IAE("invalid type " + castTo); } @@ -233,6 +317,20 @@ public final boolean asBoolean() return Evals.asBoolean(asLong()); } + @Nullable + @Override + public Object[] asArray() + { + return asLongArray(); + } + + @Nullable + @Override + public Long[] asLongArray() + { + return isNumericNull() ? null : new Long[]{value.longValue()}; + } + @Override public final ExprEval castTo(ExprType castTo) { @@ -247,6 +345,12 @@ public final ExprEval castTo(ExprType castTo) return this; case STRING: return ExprEval.of(asString()); + case DOUBLE_ARRAY: + return ExprEval.ofDoubleArray(asDoubleArray()); + case LONG_ARRAY: + return ExprEval.ofLongArray(asLongArray()); + case STRING_ARRAY: + return ExprEval.ofStringArray(asStringArray()); } throw new IAE("invalid type " + castTo); } @@ -256,6 +360,7 @@ public Expr toExpr() { return new LongExpr(value.longValue()); } + } private static class StringExprEval extends ExprEval @@ -325,6 +430,13 @@ public String asString() return value; } + @Nullable + @Override + public Object[] asArray() + { + return asStringArray(); + } + private int computeInt() { Number number = computeNumber(); @@ -395,6 +507,27 @@ public final boolean asBoolean() return booleanValue; } + @Nullable + @Override + public String[] asStringArray() + { + return value == null ? null : new String[] {value}; + } + + @Nullable + @Override + public Long[] asLongArray() + { + return value == null ? null : new Long[] {computeLong()}; + } + + @Nullable + @Override + public Double[] asDoubleArray() + { + return value == null ? null : new Double[] {computeDouble()}; + } + @Override public final ExprEval castTo(ExprType castTo) { @@ -405,6 +538,12 @@ public final ExprEval castTo(ExprType castTo) return ExprEval.ofLong(computeNumber()); case STRING: return this; + case DOUBLE_ARRAY: + return ExprEval.ofDoubleArray(asDoubleArray()); + case LONG_ARRAY: + return ExprEval.ofLongArray(asLongArray()); + case STRING_ARRAY: + return ExprEval.ofStringArray(asStringArray()); } throw new IAE("invalid type " + castTo); } @@ -415,4 +554,283 @@ public Expr toExpr() return new StringExpr(value); } } + + abstract static class ArrayExprEval extends ExprEval + { + private ArrayExprEval(@Nullable T[] value) + { + super(value); + } + + @Override + public boolean isNumericNull() + { + return false; + } + + @Override + public boolean isArray() + { + return true; + } + + @Override + public int asInt() + { + return 0; + } + + @Override + public long asLong() + { + return 0; + } + + @Override + public double asDouble() + { + return 0; + } + + @Override + public boolean asBoolean() + { + return false; + } + + @Nullable + @Override + public T[] asArray() + { + return value; + } + + @Nullable + public T getIndex(int index) + { + return value == null ? null : value[index]; + } + } + + private static class LongArrayExprEval extends ArrayExprEval + { + private LongArrayExprEval(@Nullable Long[] value) + { + super(value); + } + + @Override + public ExprType type() + { + return ExprType.LONG_ARRAY; + } + + @Nullable + @Override + public String[] asStringArray() + { + return value == null ? null : Arrays.stream(value).map(String::valueOf).toArray(String[]::new); + } + + @Nullable + @Override + public Long[] asLongArray() + { + return value; + } + + @Nullable + @Override + public Double[] asDoubleArray() + { + return value == null ? null : Arrays.stream(value).map(Long::doubleValue).toArray(Double[]::new); + } + + @Override + public ExprEval castTo(ExprType castTo) + { + if (value == null) { + return StringExprEval.OF_NULL; + } + switch (castTo) { + case STRING: + return ExprEval.of(Arrays.stream(value).map(String::valueOf).collect(Collectors.joining(", "))); + case LONG_ARRAY: + return this; + case DOUBLE_ARRAY: + return ExprEval.ofDoubleArray(asDoubleArray()); + case STRING_ARRAY: + return ExprEval.ofStringArray(asStringArray()); + } + + throw new IAE("invalid type " + castTo); + } + + @Override + public Expr toExpr() + { + return new LongArrayExpr(value); + } + } + + private static class DoubleArrayExprEval extends ArrayExprEval + { + private DoubleArrayExprEval(@Nullable Double[] value) + { + super(value); + } + + @Override + public ExprType type() + { + return ExprType.DOUBLE_ARRAY; + } + + @Nullable + @Override + public String[] asStringArray() + { + return value == null ? null : Arrays.stream(value).map(String::valueOf).toArray(String[]::new); + } + + @Nullable + @Override + public Long[] asLongArray() + { + return value == null ? null : Arrays.stream(value).map(Double::longValue).toArray(Long[]::new); + } + + @Nullable + @Override + public Double[] asDoubleArray() + { + return value; + } + + @Override + public ExprEval castTo(ExprType castTo) + { + if (value == null) { + return StringExprEval.OF_NULL; + } + switch (castTo) { + case STRING: + return ExprEval.of(Arrays.stream(value).map(String::valueOf).collect(Collectors.joining(", "))); + case LONG_ARRAY: + return ExprEval.ofLongArray(asLongArray()); + case DOUBLE_ARRAY: + return this; + case STRING_ARRAY: + return ExprEval.ofStringArray(asStringArray()); + } + + throw new IAE("invalid type " + castTo); + } + + @Override + public Expr toExpr() + { + return new DoubleArrayExpr(value); + } + } + + private static class StringArrayExprEval extends ArrayExprEval + { + private boolean longValueValid = false; + private boolean doubleValueValid = false; + private Long[] longValues; + private Double[] doubleValues; + + private StringArrayExprEval(@Nullable String[] value) + { + super(value); + } + + @Override + public ExprType type() + { + return ExprType.STRING_ARRAY; + } + + @Nullable + @Override + public String[] asStringArray() + { + return value; + } + + @Nullable + @Override + public Long[] asLongArray() + { + if (!longValueValid) { + longValues = computeLongs(); + longValueValid = true; + } + return longValues; + } + + @Nullable + @Override + public Double[] asDoubleArray() + { + if (!doubleValueValid) { + doubleValues = computeDoubles(); + doubleValueValid = true; + } + return doubleValues; + } + + @Override + public ExprEval castTo(ExprType castTo) + { + if (value == null) { + return StringExprEval.OF_NULL; + } + switch (castTo) { + case STRING: + return ExprEval.of(Arrays.stream(value).map(String::valueOf).collect(Collectors.joining(", "))); + case STRING_ARRAY: + return this; + case LONG_ARRAY: + return ExprEval.ofLongArray(asLongArray()); + case DOUBLE_ARRAY: + return ExprEval.ofDoubleArray(asDoubleArray()); + } + throw new IAE("invalid type " + castTo); + } + + @Override + public Expr toExpr() + { + return new StringArrayExpr(value); + } + + @Nullable + private Long[] computeLongs() + { + if (value == null) { + return null; + } + return Arrays.stream(value).map(value -> { + Long lv = GuavaUtils.tryParseLong(value); + if (lv == null) { + Double d = Doubles.tryParse(value); + if (d != null) { + lv = d.longValue(); + } + } + return lv; + }).toArray(Long[]::new); + } + + @Nullable + private Double[] computeDoubles() + { + if (value == null) { + return null; + } + return Arrays.stream(value).map(Doubles::tryParse).toArray(Double[]::new); + } + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java index 8b33224867d3..b4dc961e61f9 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java @@ -34,6 +34,9 @@ import java.util.Map; /** + * Implementation of antlr parse tree listener, transforms {@link ParseTree} to {@link Expr}, based on the grammar + * defined in Expr.g4. All + * {@link Expr} are created on 'exit' so that children {@link Expr} are already constructed. */ public class ExprListenerImpl extends ExprBaseListener { @@ -69,6 +72,22 @@ public void exitUnaryOpExpr(ExprParser.UnaryOpExprContext ctx) } } + @Override + public void exitApplyFunctionExpr(ExprParser.ApplyFunctionExprContext ctx) + { + String fnName = ctx.getChild(0).getText(); + // Built-in functions. + final ApplyFunction function = Parser.getApplyFunction(fnName); + if (function == null) { + throw new RE("function '%s' is not defined.", fnName); + } + + nodes.put( + ctx, + new ApplyFunctionExpr(function, fnName, (LambdaExpr) nodes.get(ctx.lambda()), (List) nodes.get(ctx.fnArgs())) + ); + } + @Override public void exitDoubleExpr(ExprParser.DoubleExprContext ctx) { @@ -78,6 +97,16 @@ public void exitDoubleExpr(ExprParser.DoubleExprContext ctx) ); } + @Override + public void exitDoubleArray(ExprParser.DoubleArrayContext ctx) + { + Double[] values = new Double[ctx.DOUBLE().size()]; + for (int i = 0; i < values.length; i++) { + values[i] = Double.parseDouble(ctx.DOUBLE(i).getText()); + } + nodes.put(ctx, new DoubleArrayExpr(values)); + } + @Override public void exitAddSubExpr(ExprParser.AddSubExprContext ctx) { @@ -147,6 +176,16 @@ public void exitLogicalAndOrExpr(ExprParser.LogicalAndOrExprContext ctx) } } + @Override + public void exitLongArray(ExprParser.LongArrayContext ctx) + { + Long[] values = new Long[ctx.LONG().size()]; + for (int i = 0; i < values.length; i++) { + values[i] = Long.parseLong(ctx.LONG(i).getText()); + } + nodes.put(ctx, new LongArrayExpr(values)); + } + @Override public void exitNestedExpr(ExprParser.NestedExprContext ctx) { @@ -156,10 +195,7 @@ public void exitNestedExpr(ExprParser.NestedExprContext ctx) @Override public void exitString(ExprParser.StringContext ctx) { - String text = ctx.getText(); - String unquoted = text.substring(1, text.length() - 1); - String unescaped = unquoted.indexOf('\\') >= 0 ? StringEscapeUtils.unescapeJava(unquoted) : unquoted; - nodes.put(ctx, new StringExpr(unescaped)); + nodes.put(ctx, new StringExpr(escapeStringLiteral(ctx.getText()))); } @Override @@ -321,16 +357,27 @@ public void exitIdentifierExpr(ExprParser.IdentifierExprContext ctx) ); } + @Override + public void exitLambda(ExprParser.LambdaContext ctx) + { + List identifiers = new ArrayList<>(ctx.IDENTIFIER().size()); + for (int i = 0; i < ctx.IDENTIFIER().size(); i++) { + String text = ctx.IDENTIFIER(i).getText(); + if (text.charAt(0) == '"' && text.charAt(text.length() - 1) == '"') { + text = StringEscapeUtils.unescapeJava(text.substring(1, text.length() - 1)); + } + identifiers.add(i, new IdentifierExpr(text)); + } + + nodes.put(ctx, new LambdaExpr(identifiers, (Expr) nodes.get(ctx.expr()))); + } + @Override public void exitFunctionArgs(ExprParser.FunctionArgsContext ctx) { List args = new ArrayList<>(); - args.add((Expr) nodes.get(ctx.getChild(0))); - - if (ctx.getChildCount() > 1) { - for (int i = 1; i <= ctx.getChildCount() / 2; i++) { - args.add((Expr) nodes.get(ctx.getChild(2 * i))); - } + for (ParseTree exprCtx : ctx.expr()) { + args.add((Expr) nodes.get(exprCtx)); } nodes.put(ctx, args); @@ -341,4 +388,26 @@ public void exitNull(ExprParser.NullContext ctx) { nodes.put(ctx, new StringExpr(null)); } + + @Override + public void exitStringArray(ExprParser.StringArrayContext ctx) + { + String[] values = new String[ctx.STRING().size()]; + for (int i = 0; i < values.length; i++) { + values[i] = escapeStringLiteral(ctx.STRING(i).getText()); + } + nodes.put(ctx, new StringArrayExpr(values)); + } + + @Override + public void exitEmptyArray(ExprParser.EmptyArrayContext ctx) + { + nodes.put(ctx, new StringArrayExpr(new String[0])); + } + + private static String escapeStringLiteral(String text) + { + String unquoted = text.substring(1, text.length() - 1); + return unquoted.indexOf('\\') >= 0 ? StringEscapeUtils.unescapeJava(unquoted) : unquoted; + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java index c4b475832aaf..370c5a1633ce 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java @@ -20,14 +20,23 @@ package org.apache.druid.math.expr; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import org.apache.druid.java.util.common.StringUtils; import javax.annotation.Nullable; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; +/** + * Mechanism by which Druid expressions can define new functions for the Druid expression language. When + * {@link ExprListenerImpl} is creating a {@link FunctionExpr}, {@link ExprMacroTable} will first be checked to find + * the function by name, falling back to {@link Parser#getFunction(String)} to map to a built-in {@link Function} if + * none is defined in the macro table. + */ public class ExprMacroTable { private static final ExprMacroTable NIL = new ExprMacroTable(Collections.emptyList()); @@ -80,4 +89,72 @@ public interface ExprMacro Expr apply(List args); } + + /** + * Base class for single argument {@link ExprMacro} function {@link Expr} + */ + public abstract static class BaseScalarUnivariateMacroFunctionExpr implements Expr + { + protected final Expr arg; + + public BaseScalarUnivariateMacroFunctionExpr(Expr arg) + { + this.arg = arg; + } + + @Override + public void visit(final Visitor visitor) + { + arg.visit(visitor); + visitor.visit(this); + } + + @Override + public BindingDetails analyzeInputs() + { + final String identifier = arg.getIdentifierIfIdentifier(); + if (identifier == null) { + return arg.analyzeInputs(); + } + return arg.analyzeInputs().mergeWithScalars(ImmutableSet.of(identifier)); + } + } + + /** + * Base class for multi-argument {@link ExprMacro} function {@link Expr} + */ + public abstract static class BaseScalarMacroFunctionExpr implements Expr + { + protected final List args; + + public BaseScalarMacroFunctionExpr(final List args) + { + this.args = args; + } + + + @Override + public void visit(final Visitor visitor) + { + for (Expr arg : args) { + arg.visit(visitor); + } + visitor.visit(this); + } + + @Override + public BindingDetails analyzeInputs() + { + Set scalars = new HashSet<>(); + BindingDetails accumulator = new BindingDetails(); + for (Expr arg : args) { + final String identifier = arg.getIdentifierIfIdentifier(); + if (identifier != null) { + scalars.add(identifier); + } + accumulator = accumulator.merge(arg.analyzeInputs()); + } + return accumulator.mergeWithScalars(scalars); + } + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprType.java b/core/src/main/java/org/apache/druid/math/expr/ExprType.java index 050cc6100008..0bc1573bef56 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprType.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprType.java @@ -20,8 +20,14 @@ package org.apache.druid.math.expr; /** + * Base 'value' types of Druid expression language, all {@link Expr} must evaluate to one of these types. */ public enum ExprType { - DOUBLE, LONG, STRING + DOUBLE, + LONG, + STRING, + DOUBLE_ARRAY, + LONG_ARRAY, + STRING_ARRAY } diff --git a/core/src/main/java/org/apache/druid/math/expr/Function.java b/core/src/main/java/org/apache/druid/math/expr/Function.java index 14aa44b3cfaf..65643e226d77 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Function.java +++ b/core/src/main/java/org/apache/druid/math/expr/Function.java @@ -19,9 +19,11 @@ package org.apache.druid.math.expr; +import com.google.common.collect.ImmutableSet; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.StringUtils; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; @@ -29,26 +31,72 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** + * Base interface describing the mechanism used to evaluate a {@link FunctionExpr}. All {@link Function} implementations + * are immutable. + * * Do NOT remove "unused" members in this class. They are used by generated Antlr */ @SuppressWarnings("unused") interface Function { + /** + * Name of the function. + */ String name(); + /** + * Evaluate the function, given a list of arguments and a set of bindings to provide values for {@link IdentifierExpr}. + */ ExprEval apply(List args, Expr.ObjectBinding bindings); - abstract class SingleParam implements Function + /** + * Given a list of arguments to this {@link Function}, get the set of arguments that must evaluate to a scalar value + */ + default Set getScalarInputs(List args) + { + return ImmutableSet.copyOf(args); + } + + /** + * Given a list of arguments to this {@link Function}, get the set of arguments that must evaluate to an array + * value + */ + default Set getArrayInputs(List args) + { + return Collections.emptySet(); + } + + /** + * Validate function arguments + */ + void validateArguments(List args); + + /** + * Base class for a single variable input {@link Function} implementation + */ + abstract class UnivariateFunction implements Function { @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) + public void validateArguments(List args) { if (args.size() != 1) { throw new IAE("Function[%s] needs 1 argument", name()); } + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { Expr expr = args.get(0); return eval(expr.eval(bindings)); } @@ -56,14 +104,22 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) protected abstract ExprEval eval(ExprEval param); } - abstract class DoubleParam implements Function + /** + * Base class for a 2 variable input {@link Function} implementation + */ + abstract class BivariateFunction implements Function { @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) + public void validateArguments(List args) { if (args.size() != 2) { throw new IAE("Function[%s] needs 2 arguments", name()); } + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { Expr expr1 = args.get(0); Expr expr2 = args.get(1); return eval(expr1.eval(bindings), expr2.eval(bindings)); @@ -72,7 +128,11 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) protected abstract ExprEval eval(ExprEval x, ExprEval y); } - abstract class SingleParamMath extends SingleParam + /** + * Base class for a single variable input mathematical {@link Function}, with specialized 'eval' implementations that + * that operate on primitive number types + */ + abstract class UnivariateMathFunction extends UnivariateFunction { @Override protected final ExprEval eval(ExprEval param) @@ -99,7 +159,11 @@ protected ExprEval eval(double param) } } - abstract class DoubleParamMath extends DoubleParam + /** + * Base class for a 2 variable input mathematical {@link Function}, with specialized 'eval' implementations that + * operate on primitive number types + */ + abstract class BivariateMathFunction extends BivariateFunction { @Override protected final ExprEval eval(ExprEval x, ExprEval y) @@ -125,7 +189,11 @@ protected ExprEval eval(double x, double y) } } - abstract class DoubleParamString extends DoubleParam + /** + * Base class for a 2 variable input {@link Function} whose first argument is a {@link ExprType#STRING} and second + * argument is {@link ExprType#LONG} + */ + abstract class StringLongFunction extends BivariateFunction { @Override protected final ExprEval eval(ExprEval x, ExprEval y) @@ -142,6 +210,88 @@ protected final ExprEval eval(ExprEval x, ExprEval y) protected abstract ExprEval eval(String x, int y); } + /** + * {@link Function} that takes 1 array operand and 1 scalar operand + */ + abstract class ArrayScalarFunction implements Function + { + @Override + public void validateArguments(List args) + { + if (args.size() != 2) { + throw new IAE("Function[%s] needs 2 argument", name()); + } + } + + @Override + public Set getScalarInputs(List args) + { + return ImmutableSet.of(args.get(1)); + } + + @Override + public Set getArrayInputs(List args) + { + return ImmutableSet.of(args.get(0)); + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval arrayExpr = args.get(0).eval(bindings); + final ExprEval scalarExpr = args.get(1).eval(bindings); + if (arrayExpr.asArray() == null) { + return ExprEval.of(null); + } + return doApply(arrayExpr, scalarExpr); + } + + abstract ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr); + } + + /** + * {@link Function} that takes 2 array operands + */ + abstract class ArraysFunction implements Function + { + @Override + public void validateArguments(List args) + { + if (args.size() != 2) { + throw new IAE("Function[%s] needs 2 argument", name()); + } + } + + @Override + public Set getScalarInputs(List args) + { + return Collections.emptySet(); + } + + @Override + public Set getArrayInputs(List args) + { + return ImmutableSet.copyOf(args); + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval arrayExpr1 = args.get(0).eval(bindings); + final ExprEval arrayExpr2 = args.get(1).eval(bindings); + + if (arrayExpr1.asArray() == null || arrayExpr2.asArray() == null) { + return ExprEval.of(null); + } + + return doApply(arrayExpr1, arrayExpr2); + } + + abstract ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr); + } + + // ------------------------------ implementations ------------------------------ + class ParseLong implements Function { @Override @@ -151,16 +301,17 @@ public String name() } @Override - public ExprEval apply(List args, Expr.ObjectBinding bindings) + public void validateArguments(List args) { - final int radix; - if (args.size() == 1) { - radix = 10; - } else if (args.size() == 2) { - radix = args.get(1).eval(bindings).asInt(); - } else { + if (args.size() != 1 && args.size() != 2) { throw new IAE("Function[%s] needs 1 or 2 arguments", name()); } + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final int radix = args.size() == 1 ? 10 : args.get(1).eval(bindings).asInt(); final String input = NullHandling.nullToEmptyIfNeeded(args.get(0).eval(bindings).asString()); if (input == null) { @@ -197,15 +348,19 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() >= 1) { + return ExprEval.of(PI); + } + + @Override + public void validateArguments(List args) + { + if (args.size() > 0) { throw new IAE("Function[%s] needs 0 argument", name()); } - - return ExprEval.of(PI); } } - class Abs extends SingleParamMath + class Abs extends UnivariateMathFunction { @Override public String name() @@ -226,7 +381,7 @@ protected ExprEval eval(double param) } } - class Acos extends SingleParamMath + class Acos extends UnivariateMathFunction { @Override public String name() @@ -241,7 +396,7 @@ protected ExprEval eval(double param) } } - class Asin extends SingleParamMath + class Asin extends UnivariateMathFunction { @Override public String name() @@ -256,7 +411,7 @@ protected ExprEval eval(double param) } } - class Atan extends SingleParamMath + class Atan extends UnivariateMathFunction { @Override public String name() @@ -271,7 +426,7 @@ protected ExprEval eval(double param) } } - class Cbrt extends SingleParamMath + class Cbrt extends UnivariateMathFunction { @Override public String name() @@ -286,7 +441,7 @@ protected ExprEval eval(double param) } } - class Ceil extends SingleParamMath + class Ceil extends UnivariateMathFunction { @Override public String name() @@ -301,7 +456,7 @@ protected ExprEval eval(double param) } } - class Cos extends SingleParamMath + class Cos extends UnivariateMathFunction { @Override public String name() @@ -316,7 +471,7 @@ protected ExprEval eval(double param) } } - class Cosh extends SingleParamMath + class Cosh extends UnivariateMathFunction { @Override public String name() @@ -331,7 +486,7 @@ protected ExprEval eval(double param) } } - class Cot extends SingleParamMath + class Cot extends UnivariateMathFunction { @Override public String name() @@ -346,7 +501,7 @@ protected ExprEval eval(double param) } } - class Div extends DoubleParamMath + class Div extends BivariateMathFunction { @Override public String name() @@ -367,7 +522,7 @@ protected ExprEval eval(final double x, final double y) } } - class Exp extends SingleParamMath + class Exp extends UnivariateMathFunction { @Override public String name() @@ -382,7 +537,7 @@ protected ExprEval eval(double param) } } - class Expm1 extends SingleParamMath + class Expm1 extends UnivariateMathFunction { @Override public String name() @@ -397,7 +552,7 @@ protected ExprEval eval(double param) } } - class Floor extends SingleParamMath + class Floor extends UnivariateMathFunction { @Override public String name() @@ -412,7 +567,7 @@ protected ExprEval eval(double param) } } - class GetExponent extends SingleParamMath + class GetExponent extends UnivariateMathFunction { @Override public String name() @@ -427,7 +582,7 @@ protected ExprEval eval(double param) } } - class Log extends SingleParamMath + class Log extends UnivariateMathFunction { @Override public String name() @@ -442,7 +597,7 @@ protected ExprEval eval(double param) } } - class Log10 extends SingleParamMath + class Log10 extends UnivariateMathFunction { @Override public String name() @@ -457,7 +612,7 @@ protected ExprEval eval(double param) } } - class Log1p extends SingleParamMath + class Log1p extends UnivariateMathFunction { @Override public String name() @@ -472,7 +627,7 @@ protected ExprEval eval(double param) } } - class NextUp extends SingleParamMath + class NextUp extends UnivariateMathFunction { @Override public String name() @@ -487,7 +642,7 @@ protected ExprEval eval(double param) } } - class Rint extends SingleParamMath + class Rint extends UnivariateMathFunction { @Override public String name() @@ -513,10 +668,6 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() != 1 && args.size() != 2) { - throw new IAE("Function[%s] needs 1 or 2 arguments", name()); - } - ExprEval value1 = args.get(0).eval(bindings); if (value1.type() != ExprType.LONG && value1.type() != ExprType.DOUBLE) { throw new IAE("The first argument to the function[%s] should be integer or double type but get the %s type", name(), value1.type()); @@ -533,6 +684,14 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) } } + @Override + public void validateArguments(List args) + { + if (args.size() != 1 && args.size() != 2) { + throw new IAE("Function[%s] needs 1 or 2 arguments", name()); + } + } + private ExprEval eval(ExprEval param) { return eval(param, 0); @@ -550,7 +709,7 @@ private ExprEval eval(ExprEval param, int scale) } } - class Signum extends SingleParamMath + class Signum extends UnivariateMathFunction { @Override public String name() @@ -565,7 +724,7 @@ protected ExprEval eval(double param) } } - class Sin extends SingleParamMath + class Sin extends UnivariateMathFunction { @Override public String name() @@ -580,7 +739,7 @@ protected ExprEval eval(double param) } } - class Sinh extends SingleParamMath + class Sinh extends UnivariateMathFunction { @Override public String name() @@ -595,7 +754,7 @@ protected ExprEval eval(double param) } } - class Sqrt extends SingleParamMath + class Sqrt extends UnivariateMathFunction { @Override public String name() @@ -610,7 +769,7 @@ protected ExprEval eval(double param) } } - class Tan extends SingleParamMath + class Tan extends UnivariateMathFunction { @Override public String name() @@ -625,7 +784,7 @@ protected ExprEval eval(double param) } } - class Tanh extends SingleParamMath + class Tanh extends UnivariateMathFunction { @Override public String name() @@ -640,7 +799,7 @@ protected ExprEval eval(double param) } } - class ToDegrees extends SingleParamMath + class ToDegrees extends UnivariateMathFunction { @Override public String name() @@ -655,7 +814,7 @@ protected ExprEval eval(double param) } } - class ToRadians extends SingleParamMath + class ToRadians extends UnivariateMathFunction { @Override public String name() @@ -670,7 +829,7 @@ protected ExprEval eval(double param) } } - class Ulp extends SingleParamMath + class Ulp extends UnivariateMathFunction { @Override public String name() @@ -685,7 +844,7 @@ protected ExprEval eval(double param) } } - class Atan2 extends DoubleParamMath + class Atan2 extends BivariateMathFunction { @Override public String name() @@ -700,7 +859,7 @@ protected ExprEval eval(double y, double x) } } - class CopySign extends DoubleParamMath + class CopySign extends BivariateMathFunction { @Override public String name() @@ -715,7 +874,7 @@ protected ExprEval eval(double x, double y) } } - class Hypot extends DoubleParamMath + class Hypot extends BivariateMathFunction { @Override public String name() @@ -730,7 +889,7 @@ protected ExprEval eval(double x, double y) } } - class Remainder extends DoubleParamMath + class Remainder extends BivariateMathFunction { @Override public String name() @@ -745,7 +904,7 @@ protected ExprEval eval(double x, double y) } } - class Max extends DoubleParamMath + class Max extends BivariateMathFunction { @Override public String name() @@ -766,7 +925,7 @@ protected ExprEval eval(double x, double y) } } - class Min extends DoubleParamMath + class Min extends BivariateMathFunction { @Override public String name() @@ -787,7 +946,7 @@ protected ExprEval eval(double x, double y) } } - class NextAfter extends DoubleParamMath + class NextAfter extends BivariateMathFunction { @Override public String name() @@ -802,7 +961,7 @@ protected ExprEval eval(double x, double y) } } - class Pow extends DoubleParamMath + class Pow extends BivariateMathFunction { @Override public String name() @@ -817,7 +976,7 @@ protected ExprEval eval(double x, double y) } } - class Scalb extends DoubleParam + class Scalb extends BivariateFunction { @Override public String name() @@ -842,13 +1001,17 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + ExprEval x = args.get(0).eval(bindings); + return x.asBoolean() ? args.get(1).eval(bindings) : args.get(2).eval(bindings); + } + + @Override + public void validateArguments(List args) { if (args.size() != 3) { throw new IAE("Function[%s] needs 3 arguments", name()); } - - ExprEval x = args.get(0).eval(bindings); - return x.asBoolean() ? args.get(1).eval(bindings) : args.get(2).eval(bindings); } } @@ -866,10 +1029,6 @@ public String name() @Override public ExprEval apply(final List args, final Expr.ObjectBinding bindings) { - if (args.size() < 2) { - throw new IAE("Function[%s] must have at least 2 arguments", name()); - } - for (int i = 0; i < args.size(); i += 2) { if (i == args.size() - 1) { // ELSE else_result. @@ -882,6 +1041,14 @@ public ExprEval apply(final List args, final Expr.ObjectBinding bindings) return ExprEval.of(null); } + + @Override + public void validateArguments(List args) + { + if (args.size() < 2) { + throw new IAE("Function[%s] must have at least 2 arguments", name()); + } + } } /** @@ -898,10 +1065,6 @@ public String name() @Override public ExprEval apply(final List args, final Expr.ObjectBinding bindings) { - if (args.size() < 3) { - throw new IAE("Function[%s] must have at least 3 arguments", name()); - } - for (int i = 1; i < args.size(); i += 2) { if (i == args.size() - 1) { // ELSE else_result. @@ -914,9 +1077,17 @@ public ExprEval apply(final List args, final Expr.ObjectBinding bindings) return ExprEval.of(null); } + + @Override + public void validateArguments(List args) + { + if (args.size() < 3) { + throw new IAE("Function[%s] must have at least 3 arguments", name()); + } + } } - class CastFunc extends DoubleParam + class CastFunc extends BivariateFunction { @Override public String name() @@ -939,6 +1110,42 @@ protected ExprEval eval(ExprEval x, ExprEval y) } return x.castTo(castTo); } + + @Override + public Set getScalarInputs(List args) + { + if (args.get(1).isLiteral()) { + ExprType castTo = ExprType.valueOf(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())); + switch (castTo) { + case LONG_ARRAY: + case DOUBLE_ARRAY: + case STRING_ARRAY: + return Collections.emptySet(); + default: + return ImmutableSet.of(args.get(0)); + } + } + // unknown cast, can't safely assume either way + return Collections.emptySet(); + } + + @Override + public Set getArrayInputs(List args) + { + if (args.get(1).isLiteral()) { + ExprType castTo = ExprType.valueOf(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())); + switch (castTo) { + case LONG: + case DOUBLE: + case STRING: + return Collections.emptySet(); + default: + return ImmutableSet.of(args.get(0)); + } + } + // unknown cast, can't safely assume either way + return Collections.emptySet(); + } } class TimestampFromEpochFunc implements Function @@ -952,9 +1159,6 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() != 1 && args.size() != 2) { - throw new IAE("Function[%s] needs 1 or 2 arguments", name()); - } ExprEval value = args.get(0).eval(bindings); if (value.type() != ExprType.STRING) { throw new IAE("first argument should be string type but got %s type", value.type()); @@ -978,6 +1182,14 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) return toValue(date); } + @Override + public void validateArguments(List args) + { + if (args.size() != 1 && args.size() != 2) { + throw new IAE("Function[%s] needs 1 or 2 arguments", name()); + } + } + protected ExprEval toValue(DateTime date) { return ExprEval.of(date.getMillis()); @@ -1009,12 +1221,17 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval eval = args.get(0).eval(bindings); + return eval.value() == null ? args.get(1).eval(bindings) : eval; + } + + @Override + public void validateArguments(List args) { if (args.size() != 2) { throw new IAE("Function[%s] needs 2 arguments", name()); } - final ExprEval eval = args.get(0).eval(bindings); - return eval.value() == null ? args.get(1).eval(bindings) : eval; } } @@ -1053,6 +1270,12 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) return ExprEval.of(builder.toString()); } } + + @Override + public void validateArguments(List args) + { + // anything goes + } } class StrlenFunc implements Function @@ -1065,13 +1288,17 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final String arg = args.get(0).eval(bindings).asString(); + return arg == null ? ExprEval.ofLong(NullHandling.defaultLongValue()) : ExprEval.of(arg.length()); + } + + @Override + public void validateArguments(List args) { if (args.size() != 1) { throw new IAE("Function[%s] needs 1 argument", name()); } - - final String arg = args.get(0).eval(bindings).asString(); - return arg == null ? ExprEval.ofLong(NullHandling.defaultLongValue()) : ExprEval.of(arg.length()); } } @@ -1086,10 +1313,6 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() < 1) { - throw new IAE("Function[%s] needs 1 or more arguments", name()); - } - final String formatString = NullHandling.nullToEmptyIfNeeded(args.get(0).eval(bindings).asString()); if (formatString == null) { @@ -1103,6 +1326,14 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) return ExprEval.of(StringUtils.nonStrictFormat(formatString, formatArgs)); } + + @Override + public void validateArguments(List args) + { + if (args.size() < 1) { + throw new IAE("Function[%s] needs 1 or more arguments", name()); + } + } } class StrposFunc implements Function @@ -1116,10 +1347,6 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() < 2 || args.size() > 3) { - throw new IAE("Function[%s] needs 2 or 3 arguments", name()); - } - final String haystack = NullHandling.nullToEmptyIfNeeded(args.get(0).eval(bindings).asString()); final String needle = NullHandling.nullToEmptyIfNeeded(args.get(1).eval(bindings).asString()); @@ -1137,6 +1364,14 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) return ExprEval.of(haystack.indexOf(needle, fromIndex)); } + + @Override + public void validateArguments(List args) + { + if (args.size() < 2 || args.size() > 3) { + throw new IAE("Function[%s] needs 2 or 3 arguments", name()); + } + } } class SubstringFunc implements Function @@ -1150,10 +1385,6 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() != 3) { - throw new IAE("Function[%s] needs 3 arguments", name()); - } - final String arg = args.get(0).eval(bindings).asString(); if (arg == null) { @@ -1176,9 +1407,17 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) return ExprEval.of(NullHandling.defaultStringValue()); } } + + @Override + public void validateArguments(List args) + { + if (args.size() != 3) { + throw new IAE("Function[%s] needs 3 arguments", name()); + } + } } - class RightFunc extends DoubleParamString + class RightFunc extends StringLongFunction { @Override public String name() @@ -1200,7 +1439,7 @@ protected ExprEval eval(String x, int y) } } - class LeftFunc extends DoubleParamString + class LeftFunc extends StringLongFunction { @Override public String name() @@ -1232,10 +1471,6 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() != 3) { - throw new IAE("Function[%s] needs 3 arguments", name()); - } - final String arg = args.get(0).eval(bindings).asString(); final String pattern = NullHandling.nullToEmptyIfNeeded(args.get(1).eval(bindings).asString()); final String replacement = NullHandling.nullToEmptyIfNeeded(args.get(2).eval(bindings).asString()); @@ -1244,6 +1479,14 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) } return ExprEval.of(StringUtils.replace(arg, pattern, replacement)); } + + @Override + public void validateArguments(List args) + { + if (args.size() != 3) { + throw new IAE("Function[%s] needs 3 arguments", name()); + } + } } class LowerFunc implements Function @@ -1257,16 +1500,20 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() != 1) { - throw new IAE("Function[%s] needs 1 argument", name()); - } - final String arg = args.get(0).eval(bindings).asString(); if (arg == null) { return ExprEval.of(NullHandling.defaultStringValue()); } return ExprEval.of(StringUtils.toLowerCase(arg)); } + + @Override + public void validateArguments(List args) + { + if (args.size() != 1) { + throw new IAE("Function[%s] needs 1 argument", name()); + } + } } class UpperFunc implements Function @@ -1280,19 +1527,23 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() != 1) { - throw new IAE("Function[%s] needs 1 argument", name()); - } - final String arg = args.get(0).eval(bindings).asString(); if (arg == null) { return ExprEval.of(NullHandling.defaultStringValue()); } return ExprEval.of(StringUtils.toUpperCase(arg)); } + + @Override + public void validateArguments(List args) + { + if (args.size() != 1) { + throw new IAE("Function[%s] needs 1 argument", name()); + } + } } - class ReverseFunc extends SingleParam + class ReverseFunc extends UnivariateFunction { @Override public String name() @@ -1314,7 +1565,7 @@ protected ExprEval eval(ExprEval param) } } - class RepeatFunc extends DoubleParamString + class RepeatFunc extends StringLongFunction { @Override public String name() @@ -1339,13 +1590,17 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval expr = args.get(0).eval(bindings); + return ExprEval.of(expr.value() == null, ExprType.LONG); + } + + @Override + public void validateArguments(List args) { if (args.size() != 1) { throw new IAE("Function[%s] needs 1 argument", name()); } - - final ExprEval expr = args.get(0).eval(bindings); - return ExprEval.of(expr.value() == null, ExprType.LONG); } } @@ -1359,13 +1614,17 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval expr = args.get(0).eval(bindings); + return ExprEval.of(expr.value() != null, ExprType.LONG); + } + + @Override + public void validateArguments(List args) { if (args.size() != 1) { throw new IAE("Function[%s] needs 1 argument", name()); } - - final ExprEval expr = args.get(0).eval(bindings); - return ExprEval.of(expr.value() != null, ExprType.LONG); } } @@ -1380,10 +1639,6 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() != 3) { - throw new IAE("Function[%s] needs 3 arguments", name()); - } - String base = args.get(0).eval(bindings).asString(); int len = args.get(1).eval(bindings).asInt(); String pad = args.get(2).eval(bindings).asString(); @@ -1395,6 +1650,14 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) } } + + @Override + public void validateArguments(List args) + { + if (args.size() != 3) { + throw new IAE("Function[%s] needs 3 arguments", name()); + } + } } class RpadFunc implements Function @@ -1408,10 +1671,6 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() != 3) { - throw new IAE("Function[%s] needs 3 arguments", name()); - } - String base = args.get(0).eval(bindings).asString(); int len = args.get(1).eval(bindings).asInt(); String pad = args.get(2).eval(bindings).asString(); @@ -1423,6 +1682,14 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) } } + + @Override + public void validateArguments(List args) + { + if (args.size() != 3) { + throw new IAE("Function[%s] needs 3 arguments", name()); + } + } } class SubMonthFunc implements Function @@ -1436,10 +1703,6 @@ public String name() @Override public ExprEval apply(List args, Expr.ObjectBinding bindings) { - if (args.size() != 3) { - throw new IAE("Function[%s] needs 3 arguments", name()); - } - Long left = args.get(0).eval(bindings).asLong(); Long right = args.get(1).eval(bindings).asLong(); DateTimeZone timeZone = DateTimes.inferTzFromString(args.get(2).eval(bindings).asString()); @@ -1451,6 +1714,352 @@ public ExprEval apply(List args, Expr.ObjectBinding bindings) } } + + @Override + public void validateArguments(List args) + { + if (args.size() != 3) { + throw new IAE("Function[%s] needs 3 arguments", name()); + } + } + } + + class ArrayLengthFunction implements Function + { + @Override + public String name() + { + return "array_length"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval expr = args.get(0).eval(bindings); + final Object[] array = expr.asArray(); + if (array == null) { + return ExprEval.of(null); + } + + return ExprEval.ofLong(array.length); + } + + + @Override + public Set getArrayInputs(List args) + { + if (args.size() != 1) { + throw new IAE("Function[%s] needs 1 argument", name()); + } + return ImmutableSet.of(args.get(0)); + } + + @Override + public void validateArguments(List args) + { + if (args.size() != 1) { + throw new IAE("Function[%s] needs 1 argument", name()); + } + } + + @Override + public Set getScalarInputs(List args) + { + return Collections.emptySet(); + } + } + + class StringToArrayFunction implements Function + { + @Override + public String name() + { + return "string_to_array"; + } + + @Override + public void validateArguments(List args) + { + if (args.size() != 2) { + throw new IAE("Function[%s] needs 2 argument", name()); + } + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval expr = args.get(0).eval(bindings); + final String arrayString = expr.asString(); + if (arrayString == null) { + return ExprEval.of(null); + } + + final String split = args.get(1).eval(bindings).asString(); + return ExprEval.ofStringArray(arrayString.split(split != null ? split : "")); + } + + @Override + public Set getScalarInputs(List args) + { + return ImmutableSet.copyOf(args); + } + } + + class ArrayToStringFunction extends ArrayScalarFunction + { + @Override + public String name() + { + return "array_to_string"; + } + + @Override + ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) + { + final String join = scalarExpr.asString(); + return ExprEval.of( + Arrays.stream(arrayExpr.asArray()).map(String::valueOf).collect(Collectors.joining(join != null ? join : "")) + ); + } + } + + class ArrayOffsetFunction extends ArrayScalarFunction + { + @Override + public String name() + { + return "array_offset"; + } + + @Override + ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) + { + final Object[] array = arrayExpr.asArray(); + final int position = scalarExpr.asInt(); + + if (array.length > position) { + return ExprEval.bestEffortOf(array[position]); + } + return ExprEval.of(null); + } + } + + class ArrayOrdinalFunction extends ArrayScalarFunction + { + @Override + public String name() + { + return "array_ordinal"; + } + + @Override + ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) + { + final Object[] array = arrayExpr.asArray(); + final int position = scalarExpr.asInt() - 1; + + if (array.length > position) { + return ExprEval.bestEffortOf(array[position]); + } + return ExprEval.of(null); + } + } + + class ArrayOffsetOfFunction extends ArrayScalarFunction + { + @Override + public String name() + { + return "array_offset_of"; + } + + @Override + ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) + { + final Object[] array = arrayExpr.asArray(); + + switch (scalarExpr.type()) { + case STRING: + case LONG: + case DOUBLE: + int index = -1; + for (int i = 0; i < array.length; i++) { + if (Objects.equals(array[i], scalarExpr.value())) { + index = i; + break; + } + } + return index < 0 ? ExprEval.of(null) : ExprEval.ofLong(index); + default: + throw new IAE("Function[%s] 2nd argument must be a a scalar type", name()); + } + } + } + + class ArrayOrdinalOfFunction extends ArrayScalarFunction + { + @Override + public String name() + { + return "array_ordinal_of"; + } + + @Override + ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) + { + final Object[] array = arrayExpr.asArray(); + switch (scalarExpr.type()) { + case STRING: + case LONG: + case DOUBLE: + int index = -1; + for (int i = 0; i < array.length; i++) { + if (Objects.equals(array[i], scalarExpr.value())) { + index = i; + break; + } + } + return index < 0 ? ExprEval.of(null) : ExprEval.ofLong(index + 1); + default: + throw new IAE("Function[%s] 2nd argument must be a a scalar type", name()); + } + } } + class ArrayAppendFunction extends ArrayScalarFunction + { + @Override + public String name() + { + return "array_append"; + } + + @Override + ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr) + { + switch (arrayExpr.type()) { + case STRING: + case STRING_ARRAY: + return ExprEval.ofStringArray(this.append(arrayExpr.asStringArray(), scalarExpr.asString()).toArray(String[]::new)); + case LONG: + case LONG_ARRAY: + return ExprEval.ofLongArray( + this.append( + arrayExpr.asLongArray(), + scalarExpr.isNumericNull() ? null : scalarExpr.asLong()).toArray(Long[]::new + ) + ); + case DOUBLE: + case DOUBLE_ARRAY: + return ExprEval.ofDoubleArray( + this.append( + arrayExpr.asDoubleArray(), + scalarExpr.isNumericNull() ? null : scalarExpr.asDouble()).toArray(Double[]::new + ) + ); + } + + throw new RE("Unable to append to unknown type %s", arrayExpr.type()); + } + + private Stream append(T[] array, T val) + { + List l = new ArrayList<>(Arrays.asList(array)); + l.add(val); + return l.stream(); + } + } + + class ArrayConcatFunction extends ArraysFunction + { + @Override + public String name() + { + return "array_concat"; + } + + @Override + ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) + { + final Object[] array1 = lhsExpr.asArray(); + final Object[] array2 = rhsExpr.asArray(); + + if (array1 == null) { + return ExprEval.of(null); + } + if (array2 == null) { + return lhsExpr; + } + + switch (lhsExpr.type()) { + case STRING: + case STRING_ARRAY: + return ExprEval.ofStringArray( + cat(lhsExpr.asStringArray(), rhsExpr.asStringArray()).toArray(String[]::new) + ); + case LONG: + case LONG_ARRAY: + return ExprEval.ofLongArray( + cat(lhsExpr.asLongArray(), rhsExpr.asLongArray()).toArray(Long[]::new) + ); + case DOUBLE: + case DOUBLE_ARRAY: + return ExprEval.ofDoubleArray( + cat(lhsExpr.asDoubleArray(), rhsExpr.asDoubleArray()).toArray(Double[]::new) + ); + } + throw new RE("Unable to concatenate to unknown type %s", lhsExpr.type()); + } + + private Stream cat(T[] array1, T[] array2) + { + List l = new ArrayList<>(Arrays.asList(array1)); + l.addAll(Arrays.asList(array2)); + return l.stream(); + } + + @Override + public Set getArrayInputs(List args) + { + return ImmutableSet.copyOf(args); + } + } + + class ArrayContainsFunction extends ArraysFunction + { + @Override + public String name() + { + return "array_contains"; + } + + @Override + ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) + { + final Object[] array1 = lhsExpr.asArray(); + final Object[] array2 = rhsExpr.asArray(); + return ExprEval.bestEffortOf(Arrays.asList(array1).containsAll(Arrays.asList(array2))); + } + } + + class ArrayOverlapFunction extends ArraysFunction + { + @Override + public String name() + { + return "array_overlap"; + } + + @Override + ExprEval doApply(ExprEval lhsExpr, ExprEval rhsExpr) + { + final Object[] array1 = lhsExpr.asArray(); + final List array2 = Arrays.asList(rhsExpr.asArray()); + boolean any = false; + for (Object check : array1) { + any |= array2.contains(check); + } + return ExprEval.bestEffortOf(any); + } + } } diff --git a/core/src/main/java/org/apache/druid/math/expr/Parser.java b/core/src/main/java/org/apache/druid/math/expr/Parser.java index 861fc3c2d92b..59ed0dcbe8ab 100644 --- a/core/src/main/java/org/apache/druid/math/expr/Parser.java +++ b/core/src/main/java/org/apache/druid/math/expr/Parser.java @@ -22,29 +22,32 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import org.antlr.v4.runtime.ANTLRInputStream; import org.antlr.v4.runtime.CommonTokenStream; import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.ParseTreeWalker; +import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.math.expr.antlr.ExprLexer; import org.apache.druid.math.expr.antlr.ExprParser; -import javax.annotation.Nullable; import java.lang.reflect.Modifier; +import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; public class Parser { private static final Logger log = new Logger(Parser.class); private static final Map FUNCTIONS; + private static final Map APPLY_FUNCTIONS; static { Map functionMap = new HashMap<>(); @@ -55,18 +58,50 @@ public class Parser functionMap.put(StringUtils.toLowerCase(function.name()), function); } catch (Exception e) { - log.info("failed to instantiate " + clazz.getName() + ".. ignoring", e); + log.error(e, "failed to instantiate %s.. ignoring", clazz.getName()); } } } FUNCTIONS = ImmutableMap.copyOf(functionMap); + + Map applyFunctionMap = new HashMap<>(); + for (Class clazz : ApplyFunction.class.getClasses()) { + if (!Modifier.isAbstract(clazz.getModifiers()) && ApplyFunction.class.isAssignableFrom(clazz)) { + try { + ApplyFunction function = (ApplyFunction) clazz.newInstance(); + applyFunctionMap.put(StringUtils.toLowerCase(function.name()), function); + } + catch (Exception e) { + log.error(e, "failed to instantiate %s.. ignoring", clazz.getName()); + } + } + } + APPLY_FUNCTIONS = ImmutableMap.copyOf(applyFunctionMap); } + /** + * Get {@link Function} by {@link Function#name()} + */ public static Function getFunction(String name) { return FUNCTIONS.get(StringUtils.toLowerCase(name)); } + /** + * Get {@link ApplyFunction} by {@link ApplyFunction#name()} + */ + public static ApplyFunction getApplyFunction(String name) + { + return APPLY_FUNCTIONS.get(StringUtils.toLowerCase(name)); + } + + /** + * Parse a string into a flattened {@link Expr}. There is some overhead to this, and these objects are all immutable, + * so re-use instead of re-creating whenever possible. + * @param in expression to parse + * @param macroTable additional extensions to expression language + * @return + */ public static Expr parse(String in, ExprMacroTable macroTable) { return parse(in, macroTable, true); @@ -86,83 +121,256 @@ static Expr parse(String in, ExprMacroTable macroTable, boolean withFlatten) return withFlatten ? flatten(listener.getAST()) : listener.getAST(); } + /** + * Flatten an {@link Expr}, evaluating expressions on constants where possible to simplify the {@link Expr}. + */ public static Expr flatten(Expr expr) { - if (expr instanceof BinaryOpExprBase) { - BinaryOpExprBase binary = (BinaryOpExprBase) expr; - Expr left = flatten(binary.left); - Expr right = flatten(binary.right); - if (Evals.isAllConstants(left, right)) { - expr = expr.eval(null).toExpr(); - } else if (left != binary.left || right != binary.right) { - return Evals.binaryOp(binary, left, right); - } - } else if (expr instanceof UnaryExpr) { - UnaryExpr unary = (UnaryExpr) expr; - Expr eval = flatten(unary.expr); - if (eval instanceof ConstantExpr) { - expr = expr.eval(null).toExpr(); - } else if (eval != unary.expr) { - if (expr instanceof UnaryMinusExpr) { - expr = new UnaryMinusExpr(eval); - } else if (expr instanceof UnaryNotExpr) { - expr = new UnaryNotExpr(eval); - } else { - expr = unary; // unknown type.. + return expr.visit(childExpr -> { + if (childExpr instanceof BinaryOpExprBase) { + BinaryOpExprBase binary = (BinaryOpExprBase) childExpr; + if (Evals.isAllConstants(binary.left, binary.right)) { + return childExpr.eval(null).toExpr(); + } + } else if (childExpr instanceof UnaryExpr) { + UnaryExpr unary = (UnaryExpr) childExpr; + + if (unary.expr instanceof ConstantExpr) { + return childExpr.eval(null).toExpr(); + } + } else if (childExpr instanceof FunctionExpr) { + FunctionExpr functionExpr = (FunctionExpr) childExpr; + List args = functionExpr.args; + if (Evals.isAllConstants(args)) { + return childExpr.eval(null).toExpr(); + } + } else if (childExpr instanceof ApplyFunctionExpr) { + ApplyFunctionExpr applyFunctionExpr = (ApplyFunctionExpr) childExpr; + List args = applyFunctionExpr.argsExpr; + if (Evals.isAllConstants(args)) { + if (applyFunctionExpr.analyzeInputs().getFreeVariables().size() == 0) { + return childExpr.eval(null).toExpr(); + } } } - } else if (expr instanceof FunctionExpr) { - FunctionExpr functionExpr = (FunctionExpr) expr; - List args = functionExpr.args; - boolean flattened = false; - List flattening = Lists.newArrayListWithCapacity(args.size()); - for (Expr arg : args) { - Expr flatten = flatten(arg); - flattened |= flatten != arg; - flattening.add(flatten); - } - if (Evals.isAllConstants(flattening)) { - expr = expr.eval(null).toExpr(); - } else if (flattened) { - expr = new FunctionExpr(functionExpr.function, functionExpr.name, flattening); - } - } - return expr; + return childExpr; + }); } - public static List findRequiredBindings(Expr expr) + /** + * Applies a transformation to an {@link Expr} given a list of known (or uknown) multi-value input columns that are + * used in a scalar manner, walking the {@link Expr} tree and lifting array variables into the {@link LambdaExpr} of + * {@link ApplyFunctionExpr} and transforming the arguments of {@link FunctionExpr} + * @param expr expression to visit and rewrite + * @param toApply + * @return + */ + public static Expr applyUnappliedIdentifiers(Expr expr, Expr.BindingDetails bindingDetails, List toApply) { - final Set found = new LinkedHashSet<>(); - expr.visit( - new Expr.Visitor() - { - @Override - public void visit(Expr expr) - { - if (expr instanceof IdentifierExpr) { - found.add(expr.toString()); + if (toApply.size() == 0) { + return expr; + } + List unapplied = toApply.stream() + .filter(x -> bindingDetails.getFreeVariables().contains(x)) + .collect(Collectors.toList()); + + ApplyFunction fn; + final LambdaExpr lambdaExpr; + final List args; + + // any unapplied identifiers that are inside a lambda expression need that lambda expression to be rewritten + Expr newExpr = expr.visit( + childExpr -> { + if (childExpr instanceof ApplyFunctionExpr) { + // try to lift unapplied arguments into the apply function lambda + return liftApplyLambda((ApplyFunctionExpr) childExpr, unapplied); + } else if (childExpr instanceof FunctionExpr) { + // check array function arguments for unapplied identifiers to transform if necessary + FunctionExpr fnExpr = (FunctionExpr) childExpr; + Set arrayInputs = fnExpr.function.getArrayInputs(fnExpr.args); + List newArgs = new ArrayList<>(); + for (Expr arg : fnExpr.args) { + if (arg.getIdentifierIfIdentifier() == null && arrayInputs.contains(arg)) { + Expr newArg = applyUnappliedIdentifiers(arg, bindingDetails, unapplied); + newArgs.add(newArg); + } else { + newArgs.add(arg); + } } + + FunctionExpr newFnExpr = new FunctionExpr(fnExpr.function, fnExpr.function.name(), newArgs); + return newFnExpr; } + return childExpr; } ); - return Lists.newArrayList(found); + + Expr.BindingDetails newExprBindings = newExpr.analyzeInputs(); + final Set expectedArrays = newExprBindings.getArrayVariables(); + List remainingUnappliedArgs = + unapplied.stream().filter(x -> !expectedArrays.contains(x)).collect(Collectors.toList()); + + // if lifting the lambdas got rid of all missing bindings, return the transformed expression + if (remainingUnappliedArgs.size() == 0) { + return newExpr; + } + + // else, it *should be safe* to wrap in either map or cartesian_map because we still have missing bindings that + // were *not* referenced in a lambda body + if (remainingUnappliedArgs.size() == 1) { + fn = new ApplyFunction.MapFunction(); + IdentifierExpr lambdaArg = new IdentifierExpr(remainingUnappliedArgs.iterator().next()); + lambdaExpr = new LambdaExpr(ImmutableList.of(lambdaArg), newExpr); + args = ImmutableList.of(lambdaArg); + } else { + fn = new ApplyFunction.CartesianMapFunction(); + List identifiers = new ArrayList<>(remainingUnappliedArgs.size()); + args = new ArrayList<>(remainingUnappliedArgs.size()); + for (String remainingUnappliedArg : remainingUnappliedArgs) { + IdentifierExpr arg = new IdentifierExpr(remainingUnappliedArg); + identifiers.add(arg); + args.add(arg); + } + lambdaExpr = new LambdaExpr(identifiers, newExpr); + } + + Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, args); + return magic; } - @Nullable - public static String getIdentifierIfIdentifier(Expr expr) + /** + * Performs partial lifting of free identifiers of the lambda expression of an {@link ApplyFunctionExpr}, constrained + * by a list of "unapplied" identifiers, and translating them into arguments of a new {@link LambdaExpr} and + * {@link ApplyFunctionExpr} as appropriate. + * + * The "unapplied" identifiers list is used to allow say only lifting array identifiers and adding it to the cartesian + * product to allow "magical" translation of multi-value string dimensions which are expressed as single value + * dimensions to function correctly and as expected. + */ + private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List unappliedArgs) { - if (expr instanceof IdentifierExpr) { - return expr.toString(); - } else { - return null; + + // recursively evaluate arguments to ensure they are properly transformed into arrays as necessary + List unappliedInThisApply = + unappliedArgs.stream() + .filter(u -> !expr.bindingDetails.getArrayVariables().contains(u)) + .collect(Collectors.toList()); + + List newArgs = new ArrayList<>(); + for (int i = 0; i < expr.argsExpr.size(); i++) { + newArgs.add(applyUnappliedIdentifiers( + expr.argsExpr.get(i), + expr.argsBindingDetails.get(i), + unappliedInThisApply) + ); + } + + // this will _not_ include the lambda identifiers.. anything in this list needs to be applied + List unappliedLambdaBindings = expr.lambdaBindingDetails.getFreeVariables() + .stream() + .filter(unappliedArgs::contains) + .map(IdentifierExpr::new) + .collect(Collectors.toList()); + + if (unappliedLambdaBindings.size() == 0) { + return new ApplyFunctionExpr(expr.function, expr.name, expr.lambdaExpr, newArgs); + } + + final ApplyFunction newFn; + final ApplyFunctionExpr newExpr; + + newArgs.addAll(unappliedLambdaBindings); + + switch (expr.function.name()) { + case ApplyFunction.MapFunction.NAME: + case ApplyFunction.CartesianMapFunction.NAME: + // map(x -> x + y, x) => + // cartesian_map((x, y) -> x + y, x, y) + // cartesian_map((x, y) -> x + y + z, x, y) => + // cartesian_map((x, y, z) -> x + y + z, x, y, z) + final List lambdaIds = + new ArrayList<>(expr.lambdaExpr.getIdentifiers().size() + unappliedArgs.size()); + lambdaIds.addAll(expr.lambdaExpr.getIdentifierExprs()); + lambdaIds.addAll(unappliedLambdaBindings); + final LambdaExpr newLambda = new LambdaExpr(lambdaIds, expr.lambdaExpr.getExpr()); + newFn = new ApplyFunction.CartesianMapFunction(); + newExpr = new ApplyFunctionExpr(newFn, newFn.name(), newLambda, newArgs); + break; + case ApplyFunction.AllMatchFunction.NAME: + case ApplyFunction.AnyMatchFunction.NAME: + case ApplyFunction.FilterFunction.NAME: + // i'm lazy and didn't add 'cartesian_filter', 'cartesian_any', and 'cartesian_and', so instead steal the match + // expressions lambda and translate it into a 'cartesian_map', and apply that to the match function with a new + // identity expression lambda since the input is an array of boolean expression results (or should be..) + // filter(x -> x > y, x) => + // filter(x -> x, cartesian_map((x,y) -> x > y, x, y)) + // any(x -> x > y, x) => + // any(x -> x, cartesian_map((x, y) -> x > y, x, y)) + // all(x -> x > y, x) => + // all(x -> x, cartesian_map((x, y) -> x > y, x, y)) + ApplyFunction newArrayFn = new ApplyFunction.CartesianMapFunction(); + IdentifierExpr identityExprIdentifier = new IdentifierExpr("_"); + LambdaExpr identityExpr = new LambdaExpr(ImmutableList.of(identityExprIdentifier), identityExprIdentifier); + ApplyFunctionExpr arrayExpr = new ApplyFunctionExpr(newArrayFn, newArrayFn.name(), identityExpr, newArgs); + newExpr = new ApplyFunctionExpr(expr.function, expr.function.name(), identityExpr, ImmutableList.of(arrayExpr)); + break; + case ApplyFunction.FoldFunction.NAME: + case ApplyFunction.CartesianFoldFunction.NAME: + // fold((x, acc) -> acc + x + y, x, acc) => + // cartesian_fold((x, y, acc) -> acc + x + y, x, y, acc) + // cartesian_fold((x, y, acc) -> acc + x + y + z, x, y, acc) => + // cartesian_fold((x, y, z, acc) -> acc + x + y + z, x, y, z, acc) + + final List newFoldArgs = new ArrayList<>(expr.argsExpr.size() + unappliedLambdaBindings.size()); + final List newFoldLambdaIdentifiers = + new ArrayList<>(expr.lambdaExpr.getIdentifiers().size() + unappliedLambdaBindings.size()); + final List existingFoldLambdaIdentifiers = expr.lambdaExpr.getIdentifierExprs(); + // accumulator argument is last argument, slice it off when constructing new arg list and lambda args identifiers + for (int i = 0; i < expr.argsExpr.size() - 1; i++) { + newFoldArgs.add(expr.argsExpr.get(i)); + newFoldLambdaIdentifiers.add(existingFoldLambdaIdentifiers.get(i)); + } + newFoldArgs.addAll(unappliedLambdaBindings); + newFoldLambdaIdentifiers.addAll(unappliedLambdaBindings); + // add accumulator last + newFoldLambdaIdentifiers.add(existingFoldLambdaIdentifiers.get(existingFoldLambdaIdentifiers.size() - 1)); + newFoldArgs.add(expr.argsExpr.get(expr.argsExpr.size() - 1)); + final LambdaExpr newFoldLambda = new LambdaExpr(newFoldLambdaIdentifiers, expr.lambdaExpr.getExpr()); + + newFn = new ApplyFunction.CartesianFoldFunction(); + newExpr = new ApplyFunctionExpr(newFn, newFn.name(), newFoldLambda, newFoldArgs); + break; + default: + throw new RE("Unable to transform apply function:[%s]", expr.function.name()); + } + + return newExpr; + } + + /** + * Validate that an expression uses input bindings in a type consistent manner. + */ + public static void validateExpr(Expr expression, Expr.BindingDetails bindingDetails) + { + final Set conflicted = + Sets.intersection(bindingDetails.getScalarVariables(), bindingDetails.getArrayVariables()); + if (conflicted.size() != 0) { + throw new RE("Invalid expression: %s; %s used as both scalar and array variables", expression, conflicted); } } + /** + * Create {@link Expr.ObjectBinding} backed by {@link Map} to provide values for identifiers to evaluate {@link Expr} + */ public static Expr.ObjectBinding withMap(final Map bindings) { return bindings::get; } + /** + * Create {@link Expr.ObjectBinding} backed by map of {@link Supplier} to provide values for identifiers to evaluate + * {@link Expr} + */ public static Expr.ObjectBinding withSuppliers(final Map> bindings) { return (String name) -> { diff --git a/core/src/main/java/org/apache/druid/math/expr/package-info.java b/core/src/main/java/org/apache/druid/math/expr/package-info.java new file mode 100644 index 000000000000..d7c92f963400 --- /dev/null +++ b/core/src/main/java/org/apache/druid/math/expr/package-info.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +@EverythingIsNonnullByDefault +package org.apache.druid.math.expr; + +import org.apache.druid.annotations.EverythingIsNonnullByDefault; diff --git a/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java new file mode 100644 index 000000000000..57c937df6ca3 --- /dev/null +++ b/core/src/test/java/org/apache/druid/math/expr/ApplyFunctionTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.math.expr; + +import com.google.common.collect.ImmutableMap; +import org.apache.druid.common.config.NullHandling; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +public class ApplyFunctionTest +{ + private Expr.ObjectBinding bindings; + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Before + public void setup() + { + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.put("x", "foo"); + builder.put("y", 2); + builder.put("z", 3.1); + builder.put("a", new String[] {"foo", "bar", "baz", "foobar"}); + builder.put("b", new Long[] {1L, 2L, 3L, 4L, 5L}); + builder.put("c", new Double[] {3.1, 4.2, 5.3}); + builder.put("d", new String[] {null}); + builder.put("e", new String[] {null, "foo", "bar"}); + builder.put("f", new String[0]); + bindings = Parser.withMap(builder.build()); + } + + @Test + public void testMap() + { + assertExpr("map((x) -> concat(x, 'foo'), ['foo', 'bar', 'baz', 'foobar'])", new String[] {"foofoo", "barfoo", "bazfoo", "foobarfoo"}); + assertExpr("map((x) -> concat(x, 'foo'), a)", new String[] {"foofoo", "barfoo", "bazfoo", "foobarfoo"}); + + assertExpr("map((x) -> x + 1, [1, 2, 3, 4, 5])", new Long[] {2L, 3L, 4L, 5L, 6L}); + assertExpr("map((x) -> x + 1, b)", new Long[] {2L, 3L, 4L, 5L, 6L}); + + assertExpr("map((c) -> c + z, [3.1, 4.2, 5.3])", new Double[]{6.2, 7.3, 8.4}); + assertExpr("map((c) -> c + z, c)", new Double[]{6.2, 7.3, 8.4}); + + assertExpr("map((x) -> x + 1, map((x) -> x + 1, [1, 2, 3, 4, 5]))", new Long[] {3L, 4L, 5L, 6L, 7L}); + assertExpr("map((x) -> x + 1, map((x) -> x + 1, b))", new Long[] {3L, 4L, 5L, 6L, 7L}); + assertExpr("map(() -> 1, [1, 2, 3, 4, 5])", new Long[] {1L, 1L, 1L, 1L, 1L}); + } + + @Test + public void testCartesianMap() + { + assertExpr("cartesian_map((x, y) -> concat(x, y), ['foo', 'bar', 'baz', 'foobar'], ['bar', 'baz'])", new String[] {"foobar", "foobaz", "barbar", "barbaz", "bazbar", "bazbaz", "foobarbar", "foobarbaz"}); + assertExpr("cartesian_map((x, y, z) -> concat(concat(x, y), z), ['foo', 'bar', 'baz', 'foobar'], ['bar', 'baz'], ['omg'])", new String[] {"foobaromg", "foobazomg", "barbaromg", "barbazomg", "bazbaromg", "bazbazomg", "foobarbaromg", "foobarbazomg"}); + assertExpr("cartesian_map(() -> 1, [1, 2], [1, 2, 3])", new Long[] {1L, 1L, 1L, 1L, 1L, 1L}); + assertExpr("cartesian_map((x, y) -> concat(x, y), d, d)", new String[] {null}); + assertExpr("cartesian_map((x, y) -> concat(x, y), d, f)", new String[0]); + if (NullHandling.replaceWithDefault()) { + assertExpr("cartesian_map((x, y) -> concat(x, y), d, e)", new String[]{null, "foo", "bar"}); + assertExpr("cartesian_map((x, y) -> concat(x, y), e, e)", new String[] {null, "foo", "bar", "foo", "foofoo", "foobar", "bar", "barfoo", "barbar"}); + } else { + assertExpr("cartesian_map((x, y) -> concat(x, y), d, e)", new String[]{null, null, null}); + assertExpr("cartesian_map((x, y) -> concat(x, y), e, e)", new String[] {null, null, null, null, "foofoo", "foobar", null, "barfoo", "barbar"}); + } + } + + @Test + public void testFilter() + { + assertExpr("filter((x) -> strlen(x) > 3, ['foo', 'bar', 'baz', 'foobar'])", new String[] {"foobar"}); + assertExpr("filter((x) -> strlen(x) > 3, a)", new String[] {"foobar"}); + + assertExpr("filter((x) -> x > 2, [1, 2, 3, 4, 5])", new Long[] {3L, 4L, 5L}); + assertExpr("filter((x) -> x > 2, b)", new Long[] {3L, 4L, 5L}); + } + + @Test + public void testFold() + { + assertExpr("fold((x, y) -> x + y, [1, 1, 1, 1, 1], 0)", 5L); + assertExpr("fold((b, acc) -> b * acc, map((b) -> b * 2, filter(b -> b > 3, b)), 1)", 80L); + assertExpr("fold((a, acc) -> concat(a, acc), a, '')", "foobarbazbarfoo"); + assertExpr("fold((a, acc) -> array_append(acc, a), a, [])", new String[]{"foo", "bar", "baz", "foobar"}); + assertExpr("fold((a, acc) -> array_append(acc, a), b, cast([], 'LONG_ARRAY'))", new Long[]{1L, 2L, 3L, 4L, 5L}); + } + + @Test + public void testCartesianFold() + { + assertExpr("cartesian_fold((x, y, acc) -> x + y + acc, [1, 1, 1, 1, 1], [1, 1], 0)", 20L); + } + + @Test + public void testAnyMatch() + { + assertExpr("any(x -> x > 3, [1, 2, 3, 4])", "true"); + assertExpr("any(x -> x > 3, [1, 2, 3])", "false"); + assertExpr("any(x -> x, map(x -> x > 3, [1, 2, 3, 4]))", "true"); + assertExpr("any(x -> x, map(x -> x > 3, [1, 2, 3]))", "false"); + } + + @Test + public void testAllMatch() + { + assertExpr("all(x -> x > 0, [1, 2, 3, 4])", "true"); + assertExpr("all(x -> x > 1, [1, 2, 3, 4])", "false"); + assertExpr("all(x -> x, map(x -> x > 0, [1, 2, 3, 4]))", "true"); + assertExpr("all(x -> x, map(x -> x > 1, [1, 2, 3, 4]))", "false"); + } + + @Test + public void testScoping() + { + assertExpr("map(b -> b + 1, b)", new Long[]{2L, 3L, 4L, 5L, 6L}); + assertExpr("fold((b, acc) -> acc + b, map(b -> b + 1, b), 0)", 20L); + assertExpr("fold((b, acc) -> acc + b, map(b -> b + 1, b), fold((b, acc) -> acc + b, map(b -> b + 1, b), 0))", 40L); + assertExpr("fold((b, acc) -> acc + b, map(b -> b + 1, b), 0) + fold((b, acc) -> acc + b, map(b -> b + 1, b), 0)", 40L); + assertExpr("fold((b, acc) -> acc + b, map(b -> b + 1, b), fold((b, acc) -> acc + b, map(b -> b + 1, b), 0) + fold((b, acc) -> acc + b, map(b -> b + 1, b), 0))", 60L); + } + + @Test + public void testInvalidArgCount() + { + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("lambda expression argument count does not match fold argument count"); + assertExpr("fold(() -> 1, [1, 1, 1, 1, 1], 0)", null); + + expectedException.expectMessage("lambda expression argument count does not match cartesian_fold argument count"); + assertExpr("cartesian_fold(() -> 1, [1, 1, 1, 1, 1], [1, 1], 0)", null); + + expectedException.expectMessage("lambda expression argument count does not match any argument count"); + assertExpr("any(() -> 1, [1, 2, 3, 4])", null); + + expectedException.expectMessage("lambda expression argument count does not match all argument count"); + assertExpr("all(() -> 0, [1, 2, 3, 4])", null); + + } + + private void assertExpr(final String expression, final Object expectedResult) + { + final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); + Assert.assertEquals(expression, expectedResult, expr.eval(bindings).value()); + } + + private void assertExpr(final String expression, final Object[] expectedResult) + { + final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); + final Object[] result = expr.eval(bindings).asArray(); + if (expectedResult.length != 0 || result == null || result.length != 0) { + Assert.assertArrayEquals(expression, expectedResult, result); + } + } + + private void assertExpr(final String expression, final Double[] expectedResult) + { + final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); + Double[] result = (Double[]) expr.eval(bindings).value(); + Assert.assertEquals(expectedResult.length, result.length); + for (int i = 0; i < result.length; i++) { + Assert.assertEquals(expression, expectedResult[i], result[i], 0.00001); // something is lame somewhere.. + } + } +} diff --git a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java index 2a5bfc95cc07..ec0884cefcf4 100644 --- a/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/FunctionTest.java @@ -22,17 +22,25 @@ import com.google.common.collect.ImmutableMap; import org.apache.druid.common.config.NullHandling; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; public class FunctionTest { - private final Expr.ObjectBinding bindings = Parser.withMap( - ImmutableMap.of( - "x", "foo", - "y", 2, - "z", 3.1 - ) - ); + private Expr.ObjectBinding bindings; + + @Before + public void setup() + { + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.put("x", "foo"); + builder.put("y", 2); + builder.put("z", 3.1); + builder.put("a", new String[] {"foo", "bar", "baz", "foobar"}); + builder.put("b", new Long[] {1L, 2L, 3L, 4L, 5L}); + builder.put("c", new Double[] {3.1, 4.2, 5.3}); + bindings = Parser.withMap(builder.build()); + } @Test public void testCaseSimple() @@ -115,12 +123,6 @@ public void testUpper() assertExpr("upper(x)", "FOO"); } - private void assertExpr(final String expression, final Object expectedResult) - { - final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); - Assert.assertEquals(expression, expectedResult, expr.eval(bindings).value()); - } - @Test public void testIsNull() { @@ -156,4 +158,117 @@ public void testRpad() assertExpr("rpad(x, 5, null)", null); assertExpr("rpad(null, 5, x)", null); } + + @Test + public void testArrayLength() + { + assertExpr("array_length([1,2,3])", 3L); + assertExpr("array_length(a)", 4); + } + + @Test + public void testArrayOffset() + { + assertExpr("array_offset([1, 2, 3], 2)", 3L); + assertExpr("array_offset([1, 2, 3], 3)", null); + assertExpr("array_offset(a, 2)", "baz"); + } + + @Test + public void testArrayOrdinal() + { + assertExpr("array_ordinal([1, 2, 3], 3)", 3L); + assertExpr("array_ordinal([1, 2, 3], 4)", null); + assertExpr("array_ordinal(a, 3)", "baz"); + } + + @Test + public void testArrayOffsetOf() + { + assertExpr("array_offset_of([1, 2, 3], 3)", 2L); + assertExpr("array_offset_of([1, 2, 3], 4)", null); + assertExpr("array_offset_of(a, 'baz')", 2); + } + + @Test + public void testArrayOrdinalOf() + { + assertExpr("array_ordinal_of([1, 2, 3], 3)", 3L); + assertExpr("array_ordinal_of([1, 2, 3], 4)", null); + assertExpr("array_ordinal_of(a, 'baz')", 3); + } + + @Test + public void testArrayContains() + { + assertExpr("array_contains([1, 2, 3], 2)", "true"); + assertExpr("array_contains([1, 2, 3], 4)", "false"); + assertExpr("array_contains([1, 2, 3], [2, 3])", "true"); + assertExpr("array_contains([1, 2, 3], [3, 4])", "false"); + assertExpr("array_contains(b, [3, 4])", "true"); + } + + @Test + public void testArrayOverlap() + { + assertExpr("array_overlap([1, 2, 3], [2, 4, 6])", "true"); + assertExpr("array_overlap([1, 2, 3], [4, 5, 6])", "false"); + } + + @Test + public void testArrayAppend() + { + assertExpr("array_append([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L}); + assertExpr("array_append([1, 2, 3], 'bar')", new Long[]{1L, 2L, 3L, null}); + assertExpr("array_append([], 1)", new String[]{"1"}); + assertExpr("array_append(cast([], 'LONG_ARRAY'), 1)", new Long[]{1L}); + } + + @Test + public void testArrayConcat() + { + assertExpr("array_concat([1, 2, 3], [2, 4, 6])", new Long[]{1L, 2L, 3L, 2L, 4L, 6L}); + assertExpr("array_concat([1, 2, 3], 4)", new Long[]{1L, 2L, 3L, 4L}); + assertExpr("array_concat(0, [1, 2, 3])", new Long[]{0L, 1L, 2L, 3L}); + assertExpr("array_concat(map(y -> y * 3, b), [1, 2, 3])", new Long[]{3L, 6L, 9L, 12L, 15L, 1L, 2L, 3L}); + assertExpr("array_concat(0, 1)", new Long[]{0L, 1L}); + } + + @Test + public void testArrayToString() + { + assertExpr("array_to_string([1, 2, 3], ',')", "1,2,3"); + assertExpr("array_to_string([1], '|')", "1"); + assertExpr("array_to_string(a, '|')", "foo|bar|baz|foobar"); + } + + @Test + public void testStringToArray() + { + assertExpr("string_to_array('1,2,3', ',')", new String[]{"1", "2", "3"}); + assertExpr("string_to_array('1', ',')", new String[]{"1"}); + assertExpr("string_to_array(array_to_string(a, ','), ',')", new String[]{"foo", "bar", "baz", "foobar"}); + } + + @Test + public void testArrayCast() + { + assertExpr("cast([1, 2, 3], 'STRING_ARRAY')", new String[]{"1", "2", "3"}); + assertExpr("cast([1, 2, 3], 'DOUBLE_ARRAY')", new Double[]{1.0, 2.0, 3.0}); + assertExpr("cast(c, 'LONG_ARRAY')", new Long[]{3L, 4L, 5L}); + assertExpr("cast(string_to_array(array_to_string(b, ','), ','), 'LONG_ARRAY')", new Long[]{1L, 2L, 3L, 4L, 5L}); + assertExpr("cast(['1.0', '2.0', '3.0'], 'LONG_ARRAY')", new Long[]{1L, 2L, 3L}); + } + + private void assertExpr(final String expression, final Object expectedResult) + { + final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); + Assert.assertEquals(expression, expectedResult, expr.eval(bindings).value()); + } + + private void assertExpr(final String expression, final Object[] expectedResult) + { + final Expr expr = Parser.parse(expression, ExprMacroTable.nil()); + Assert.assertArrayEquals(expression, expectedResult, expr.eval(bindings).asArray()); + } } diff --git a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java index c108d8b1053d..aa40eb51ccfa 100644 --- a/core/src/test/java/org/apache/druid/math/expr/ParserTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/ParserTest.java @@ -21,12 +21,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.junit.Assert; import org.junit.Test; +import java.util.Collections; import java.util.List; +import java.util.Set; /** + * */ public class ParserTest { @@ -163,10 +167,10 @@ public void testMixed() @Test public void testIdentifiers() { - validateParser("foo", "foo", ImmutableList.of("foo")); - validateParser("\"foo\"", "foo", ImmutableList.of("foo")); - validateParser("\"foo bar\"", "foo bar", ImmutableList.of("foo bar")); - validateParser("\"foo\\\"bar\"", "foo\"bar", ImmutableList.of("foo\"bar")); + validateParser("foo", "foo", ImmutableList.of("foo"), ImmutableSet.of()); + validateParser("\"foo\"", "foo", ImmutableList.of("foo"), ImmutableSet.of()); + validateParser("\"foo bar\"", "foo bar", ImmutableList.of("foo bar"), ImmutableSet.of()); + validateParser("\"foo\\\"bar\"", "foo\"bar", ImmutableList.of("foo\"bar"), ImmutableSet.of()); } @Test @@ -179,13 +183,224 @@ public void testLiterals() validateConstantExpression("\'f\\u000Ao \\'b\\\\\\\"ar\'", "f\no 'b\\\"ar"); } + @Test + public void testLiteralArrays() + { + validateConstantExpression("[1.0, 2.345]", new Double[]{1.0, 2.345}); + validateConstantExpression("[1, 3]", new Long[]{1L, 3L}); + validateConstantExpression("[\'hello\', \'world\']", new String[]{"hello", "world"}); + } + @Test public void testFunctions() { validateParser("sqrt(x)", "(sqrt [x])", ImmutableList.of("x")); validateParser("if(cond,then,else)", "(if [cond, then, else])", ImmutableList.of("cond", "then", "else")); + validateParser("cast(x, 'STRING')", "(cast [x, STRING])", ImmutableList.of("x")); + validateParser("cast(x, 'LONG')", "(cast [x, LONG])", ImmutableList.of("x")); + validateParser("cast(x, 'DOUBLE')", "(cast [x, DOUBLE])", ImmutableList.of("x")); + validateParser( + "cast(x, 'STRING_ARRAY')", + "(cast [x, STRING_ARRAY])", + ImmutableList.of("x"), + ImmutableSet.of(), + ImmutableSet.of("x") + ); + validateParser( + "cast(x, 'LONG_ARRAY')", + "(cast [x, LONG_ARRAY])", + ImmutableList.of("x"), + ImmutableSet.of(), + ImmutableSet.of("x") + ); + validateParser( + "cast(x, 'DOUBLE_ARRAY')", + "(cast [x, DOUBLE_ARRAY])", + ImmutableList.of("x"), + ImmutableSet.of(), + ImmutableSet.of("x") + ); + validateParser( + "array_length(x)", + "(array_length [x])", + ImmutableList.of("x"), + ImmutableSet.of(), + ImmutableSet.of("x") + ); + validateParser( + "array_concat(x, y)", + "(array_concat [x, y])", + ImmutableList.of("x", "y"), + ImmutableSet.of(), + ImmutableSet.of("x", "y") + ); + validateParser( + "array_append(x, y)", + "(array_append [x, y])", + ImmutableList.of("x", "y"), + ImmutableSet.of("y"), + ImmutableSet.of("x") + ); + + validateFlatten("sqrt(4)", "(sqrt [4])", "2.0"); + validateFlatten("array_concat([1, 2], [3, 4])", "(array_concat [[1, 2], [3, 4]])", "[1, 2, 3, 4]"); + } + + @Test + public void testApplyFunctions() + { + validateParser( + "map(() -> 1, x)", + "(map ([] -> 1), [x])", + ImmutableList.of("x"), + ImmutableSet.of(), + ImmutableSet.of("x") + ); + validateParser( + "map((x) -> x + 1, x)", + "(map ([x] -> (+ x 1)), [x])", + ImmutableList.of("x"), + ImmutableSet.of(), + ImmutableSet.of("x") + ); + validateParser( + "x + map((x) -> x + 1, y)", + "(+ x (map ([x] -> (+ x 1)), [y]))", + ImmutableList.of("x", "y"), + ImmutableSet.of("x"), + ImmutableSet.of("y") + ); + validateParser( + "x + map((x) -> x + 1, x)", + "(+ x (map ([x] -> (+ x 1)), [x]))", + ImmutableList.of("x"), + ImmutableSet.of("x"), + ImmutableSet.of("x") + ); + validateParser( + "map((x) -> concat(x, y), z)", + "(map ([x] -> (concat [x, y])), [z])", + ImmutableList.of("z", "y"), + ImmutableSet.of("y"), + ImmutableSet.of("z") + ); + // 'y' is accumulator, and currently unknown + validateParser( + "fold((x, acc) -> acc + x, x, y)", + "(fold ([x, acc] -> (+ acc x)), [x, y])", + ImmutableList.of("x", "y"), + ImmutableSet.of(), + ImmutableSet.of("x") + ); + + validateParser( + "fold((x, acc) -> acc + x, map((x) -> x + 1, x), y)", + "(fold ([x, acc] -> (+ acc x)), [(map ([x] -> (+ x 1)), [x]), y])", + ImmutableList.of("x", "y"), + ImmutableSet.of(), + ImmutableSet.of("x") + ); + validateParser( + "array_append(z, fold((x, acc) -> acc + x, map((x) -> x + 1, x), y))", + "(array_append [z, (fold ([x, acc] -> (+ acc x)), [(map ([x] -> (+ x 1)), [x]), y])])", + ImmutableList.of("z", "x", "y"), + ImmutableSet.of(), + ImmutableSet.of("x", "z") + ); + validateParser( + "map(z -> z + 1, array_append(z, fold((x, acc) -> acc + x, map((x) -> x + 1, x), y)))", + "(map ([z] -> (+ z 1)), [(array_append [z, (fold ([x, acc] -> (+ acc x)), [(map ([x] -> (+ x 1)), [x]), y])])])", + ImmutableList.of("z", "x", "y"), + ImmutableSet.of(), + ImmutableSet.of("x", "z") + ); + + validateParser( + "array_append(map(z -> z + 1, array_append(z, fold((x, acc) -> acc + x, map((x) -> x + 1, x), y))), a)", + "(array_append [(map ([z] -> (+ z 1)), [(array_append [z, (fold ([x, acc] -> (+ acc x)), [(map ([x] -> (+ x 1)), [x]), y])])]), a])", + ImmutableList.of("z", "x", "y", "a"), + ImmutableSet.of("a"), + ImmutableSet.of("x", "z") + ); + + validateFlatten("map((x) -> x + 1, [1, 2, 3, 4])", "(map ([x] -> (+ x 1)), [[1, 2, 3, 4]])", "[2, 3, 4, 5]"); + validateFlatten( + "map((x) -> x + z, [1, 2, 3, 4])", + "(map ([x] -> (+ x z)), [[1, 2, 3, 4]])", + "(map ([x] -> (+ x z)), [[1, 2, 3, 4]])" + ); + } + + @Test + public void testApplyUnapplied() + { + validateApplyUnapplied("x + 1", "(+ x 1)", "(+ x 1)", ImmutableList.of()); + validateApplyUnapplied("x + 1", "(+ x 1)", "(+ x 1)", ImmutableList.of("z")); + validateApplyUnapplied("x + y", "(+ x y)", "(map ([x] -> (+ x y)), [x])", ImmutableList.of("x")); + validateApplyUnapplied( + "x + y", + "(+ x y)", + "(cartesian_map ([x, y] -> (+ x y)), [x, y])", + ImmutableList.of("x", "y") + ); + + validateApplyUnapplied( + "map(x -> x + y, x)", + "(map ([x] -> (+ x y)), [x])", + "(cartesian_map ([x, y] -> (+ x y)), [x, y])", + ImmutableList.of("y") + ); + validateApplyUnapplied( + "map(x -> x + 1, x + 1)", + "(map ([x] -> (+ x 1)), [(+ x 1)])", + "(map ([x] -> (+ x 1)), [(map ([x] -> (+ x 1)), [x])])", + ImmutableList.of("x") + ); + validateApplyUnapplied( + "fold((x, acc) -> acc + x + y, x, 0)", + "(fold ([x, acc] -> (+ (+ acc x) y)), [x, 0])", + "(cartesian_fold ([x, y, acc] -> (+ (+ acc x) y)), [x, y, 0])", + ImmutableList.of("y") + ); + validateApplyUnapplied( + "z + fold((x, acc) -> acc + x + y, x, 0)", + "(+ z (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))", + "(+ z (cartesian_fold ([x, y, acc] -> (+ (+ acc x) y)), [x, y, 0]))", + ImmutableList.of("y") + ); + validateApplyUnapplied( + "z + fold((x, acc) -> acc + x + y, x, 0)", + "(+ z (fold ([x, acc] -> (+ (+ acc x) y)), [x, 0]))", + "(map ([z] -> (+ z (cartesian_fold ([x, y, acc] -> (+ (+ acc x) y)), [x, y, 0]))), [z])", + ImmutableList.of("y", "z") + ); + validateApplyUnapplied( + "array_to_string(concat(x, 'hello'), ',')", + "(array_to_string [(concat [x, hello]), ,])", + "(array_to_string [(map ([x] -> (concat [x, hello])), [x]), ,])", + ImmutableList.of("x", "y") + ); + validateApplyUnapplied( + "cast(x, 'LONG')", + "(cast [x, LONG])", + "(map ([x] -> (cast [x, LONG])), [x])", + ImmutableList.of("x") + ); + validateApplyUnapplied( + "cartesian_map((x,y) -> x + y, x, y)", + "(cartesian_map ([x, y] -> (+ x y)), [x, y])", + "(cartesian_map ([x, y] -> (+ x y)), [x, y])", + ImmutableList.of("y") + ); + validateApplyUnapplied( + "cast(x, 'LONG_ARRAY')", + "(cast [x, LONG_ARRAY])", + "(cast [x, LONG_ARRAY])", + ImmutableList.of("x") + ); } + private void validateFlatten(String expression, String withoutFlatten, String withFlatten) { Assert.assertEquals(expression, withoutFlatten, Parser.parse(expression, ExprMacroTable.nil(), false).toString()); @@ -193,10 +408,44 @@ private void validateFlatten(String expression, String withoutFlatten, String wi } private void validateParser(String expression, String expected, List identifiers) + { + validateParser(expression, expected, identifiers, ImmutableSet.copyOf(identifiers), Collections.emptySet()); + } + + private void validateParser(String expression, String expected, List identifiers, Set scalars) + { + validateParser(expression, expected, identifiers, scalars, Collections.emptySet()); + } + + private void validateParser( + String expression, + String expected, + List identifiers, + Set scalars, + Set arrays + ) { final Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); + final Expr.BindingDetails deets = parsed.analyzeInputs(); Assert.assertEquals(expression, expected, parsed.toString()); - Assert.assertEquals(expression, identifiers, Parser.findRequiredBindings(parsed)); + Assert.assertEquals(expression, identifiers, deets.getRequiredColumns()); + Assert.assertEquals(expression, scalars, deets.getScalarVariables()); + Assert.assertEquals(expression, arrays, deets.getArrayVariables()); + } + + private void validateApplyUnapplied( + String expression, + String unapplied, + String applied, + List identifiers + ) + { + final Expr parsed = Parser.parse(expression, ExprMacroTable.nil()); + Expr.BindingDetails deets = parsed.analyzeInputs(); + Parser.validateExpr(parsed, deets); + final Expr transformed = Parser.applyUnappliedIdentifiers(parsed, deets, identifiers); + Assert.assertEquals(expression, unapplied, parsed.toString()); + Assert.assertEquals(applied, applied, transformed.toString()); } private void validateConstantExpression(String expression, Object expected) @@ -207,4 +456,13 @@ private void validateConstantExpression(String expression, Object expected) Parser.parse(expression, ExprMacroTable.nil()).eval(Parser.withMap(ImmutableMap.of())).value() ); } + + private void validateConstantExpression(String expression, Object[] expected) + { + Assert.assertArrayEquals( + expression, + expected, + (Object[]) Parser.parse(expression, ExprMacroTable.nil()).eval(Parser.withMap(ImmutableMap.of())).value() + ); + } } diff --git a/docs/content/misc/math-expr.md b/docs/content/misc/math-expr.md index c207f01ed231..57427a988c35 100644 --- a/docs/content/misc/math-expr.md +++ b/docs/content/misc/math-expr.md @@ -25,7 +25,8 @@ title: "Apache Druid (incubating) Expressions" # Apache Druid (incubating) Expressions
-This feature is still experimental. It has not been optimized for performance yet, and its implementation is known to have significant inefficiencies. +This feature is still experimental. It has not been optimized for performance yet, and its implementation is known to + have significant inefficiencies.
This expression language supports the following operators (listed in decreasing order of precedence). @@ -39,14 +40,29 @@ This expression language supports the following operators (listed in decreasing |<, <=, >, >=, ==, !=|Binary Comparison| |&&, ||Binary Logical AND, OR| -Long, double, and string data types are supported. If a number contains a dot, it is interpreted as a double, otherwise it is interpreted as a long. That means, always add a '.' to your number if you want it interpreted as a double value. String literals should be quoted by single quotation marks. +Long, double, and string data types are supported. If a number contains a dot, it is interpreted as a double, otherwise +it is interpreted as a long. That means, always add a '.' to your number if you want it interpreted as a double value. +String literals should be quoted by single quotation marks. -Multi-value types are not fully supported yet. Expressions may behave inconsistently on multi-value types, and you -should not rely on the behavior in this case to stay the same in future releases. +Additionally, the expression language supports long, double, and string arrays. Array literals are created by wrapping +square brackets around a list of scalar literals values delimited by a comma or space character. All values in an array +literal must be the same type. -Expressions can contain variables. Variable names may contain letters, digits, '\_' and '$'. Variable names must not begin with a digit. To escape other special characters, you can quote it with double quotation marks. +Expressions can contain variables. Variable names may contain letters, digits, '\_' and '$'. Variable names must not +begin with a digit. To escape other special characters, you can quote it with double quotation marks. + +For logical operators, a number is true if and only if it is positive (0 or negative value means false). For string +type, it's the evaluation result of 'Boolean.valueOf(string)'. + +Multi-value string dimensions are supported and may be treated as either scalar or array typed values. When treated as +a scalar type, an expression will automatically be transformed to apply the scalar operation across all values of the +multi-valued type, to mimic Druid's native behavior. Values that result in arrays will be coerced back into the native +Druid string type for aggregation. Druid aggregations on multi-value string dimensions on the individual values, _not_ +the 'array', behaving similar to the `unnest` operator available in many SQL dialects. However, by using the +`array_to_string` function, aggregations may be done on a stringified version of the complete array, allowing the +complete row to be preserved. Using `string_to_array` in an expression post-aggregator, allows transforming the +stringified dimension back into the true native array type. -For logical operators, a number is true if and only if it is positive (0 or negative value means false). For string type, it's the evaluation result of 'Boolean.valueOf(string)'. The following built-in functions are available. @@ -54,7 +70,7 @@ The following built-in functions are available. |name|description| |----|-----------| -|cast|cast(expr,'LONG' or 'DOUBLE' or 'STRING') returns expr with specified type. exception can be thrown | +|cast|cast(expr,'LONG' or 'DOUBLE' or 'STRING' or 'LONG_ARRAY', or 'DOUBLE_ARRAY' or 'STRING_ARRAY') returns expr with specified type. exception can be thrown. Scalar types may be cast to array types and will take the form of a single element list (null will still be null). | |if|if(predicate,then,else) returns 'then' if 'predicate' evaluates to a positive number, otherwise it returns 'else' | |nvl|nvl(expr,expr-for-null) returns 'expr-for-null' if 'expr' is null (or empty string for string type) | |like|like(expr, pattern[, escape]) is equivalent to SQL `expr LIKE pattern`| @@ -146,3 +162,33 @@ See javadoc of java.lang.Math for detailed explanation for each function. |todegrees|todegrees(x) converts an angle measured in radians to an approximately equivalent angle measured in degrees| |toradians|toradians(x) converts an angle measured in degrees to an approximately equivalent angle measured in radians| |ulp|ulp(x) would return the size of an ulp of the argument x| + + +## Array Functions + +| function | description | +| --- | --- | +| `array_length(arr)` | returns length of array expression | +| `array_offset(arr,long)` | returns the array element at the 0 based index supplied, or null for an out of range index| +| `array_ordinal(arr,long)` | returns the array element at the 1 based index supplied, or null for an out of range index | +| `array_contains(arr,expr)` | returns true if the array contains the element specified by expr, or contains all elements specified by expr if expr is an array | +| `array_overlap(arr1,arr2)` | returns true if arr1 and arr2 have any elements in common | +| `array_offset_of(arr,expr)` | returns the 0 based index of the first occurrence of expr in the array, or `null` if no matching elements exist in the array. | +| `array_ordinal_of(arr,expr)` | returns the 1 based index of the first occurrence of expr in the array, or `null` if no matching elements exist in the array. | +| `array_append(arr1,expr)` | appends expr to arr, the resulting array type determined by the type of the first array | +| `array_concat(arr1,arr2)` | concatenates 2 arrays, the resulting array type determined by the type of the first array | +| `array_to_string(arr,str)` | joins all elements of arr by the delimiter specified by str | +| `string_to_array(str1,str2)` | splits str1 into an array on the delimiter specified by str2 | + + +## Apply Functions + +| function | description | +| --- | --- | +| `map(lambda,arr)` | applies a transform specified by a single argument lambda expression to all elements of arr, returning a new array | +| `cartesian_map(lambda,arr1,arr2,...)` | applies a transform specified by a multi argument lambda expression to all elements of the cartesian product of all input arrays, returning a new array; the number of lambda arguments and array inputs must be the same | +| `filter(lambda,arr)` | filters arr by a single argument lambda, returning a new array with all matching elements, or null if no elements match | +| `fold(lambda,arr)` | folds a 2 argument lambda across arr. The first argument of the lambda is the array element and the second the accumulator, returning a single accumulated value. | +| `cartesian_fold(lambda,arr1,arr2,...)` | folds a multi argument lambda across the cartesian product of all input arrays. The first arguments of the lambda is the array element and the last is the accumulator, returning a single accumulated value. | +| `any(lambda,arr)` | returns true if any element in the array matches the lambda expression | +| `all(lambda,arr)` | returns true if all elements in the array matches the lambda expression | diff --git a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java index ec6007a37e5c..4328a5835e92 100644 --- a/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java +++ b/extensions-core/druid-bloom-filter/src/main/java/org/apache/druid/query/expressions/BloomFilterExprMacro.java @@ -67,8 +67,13 @@ public Expr apply(List args) throw new RuntimeException("Failed to deserialize bloom filter", ioe); } - class BloomExpr implements Expr + class BloomExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { + private BloomExpr(Expr arg) + { + super(arg); + } + @Nonnull @Override public ExprEval eval(final ObjectBinding bindings) @@ -111,14 +116,15 @@ private boolean nullMatch() return filter.testBytes(null, 0, 0); } + @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + Expr newArg = arg.visit(shuttle); + return shuttle.visit(new BloomExpr(newArg)); } } - return new BloomExpr(); + return new BloomExpr(arg); } } diff --git a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/sampler/FirehoseSamplerTest.java b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/sampler/FirehoseSamplerTest.java index 9c3da5b2c46a..782247d9cd2f 100644 --- a/indexing-service/src/test/java/org/apache/druid/indexing/overlord/sampler/FirehoseSamplerTest.java +++ b/indexing-service/src/test/java/org/apache/druid/indexing/overlord/sampler/FirehoseSamplerTest.java @@ -606,7 +606,7 @@ public void testWithTransformsAutoDimensions() GranularitySpec granularitySpec = new UniformGranularitySpec(Granularities.DAY, Granularities.HOUR, true, null); TransformSpec transformSpec = new TransformSpec( null, - ImmutableList.of(new ExpressionTransform("dim1PlusBar", "concat(dim1 + 'bar')", TestExprMacroTable.INSTANCE)) + ImmutableList.of(new ExpressionTransform("dim1PlusBar", "concat(dim1, 'bar')", TestExprMacroTable.INSTANCE)) ); DataSchema dataSchema = new DataSchema( diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java index bb3cfa7ac386..d20b07ac139b 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleDoubleAggregatorFactory.java @@ -121,7 +121,7 @@ public List requiredFields() { return fieldName != null ? Collections.singletonList(fieldName) - : Parser.findRequiredBindings(fieldExpression.get()); + : fieldExpression.get().analyzeInputs().getRequiredColumns(); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java index 6b43113c5eba..92dbc972f77c 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleFloatAggregatorFactory.java @@ -115,7 +115,7 @@ public List requiredFields() { return fieldName != null ? Collections.singletonList(fieldName) - : Parser.findRequiredBindings(fieldExpression.get()); + : fieldExpression.get().analyzeInputs().getRequiredColumns(); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java index f53a57df0f09..3a77e3ce7e29 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/SimpleLongAggregatorFactory.java @@ -111,7 +111,7 @@ public List requiredFields() { return fieldName != null ? Collections.singletonList(fieldName) - : Parser.findRequiredBindings(fieldExpression.get()); + : fieldExpression.get().analyzeInputs().getRequiredColumns(); } @Override diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java index 8b35a0344f1f..84421953b867 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/post/ExpressionPostAggregator.java @@ -27,7 +27,6 @@ import com.google.common.base.Supplier; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.math.expr.Expr; @@ -119,7 +118,7 @@ private ExpressionPostAggregator( macroTable, finalizers, parsed, - Suppliers.memoize(() -> ImmutableSet.copyOf(Parser.findRequiredBindings(parsed.get())))); + Suppliers.memoize(() -> parsed.get().analyzeInputs().getFreeVariables())); } private ExpressionPostAggregator( diff --git a/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java index f6329100a1f2..a1c980465eff 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/LikeExprMacro.java @@ -67,8 +67,13 @@ public Expr apply(final List args) escapeChar ); - class LikeExtractExpr implements Expr + class LikeExtractExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { + private LikeExtractExpr(Expr arg) + { + super(arg); + } + @Nonnull @Override public ExprEval eval(final ObjectBinding bindings) @@ -77,13 +82,13 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + Expr newArg = arg.visit(shuttle); + return shuttle.visit(new LikeExtractExpr(newArg)); } } - return new LikeExtractExpr(); + return new LikeExtractExpr(arg); } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java index 3a5e40ee53a6..88e3ce44e578 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/LookupExprMacro.java @@ -71,8 +71,13 @@ public Expr apply(final List args) null ); - class LookupExpr implements Expr + class LookupExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { + private LookupExpr(Expr arg) + { + super(arg); + } + @Nonnull @Override public ExprEval eval(final ObjectBinding bindings) @@ -81,13 +86,13 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + Expr newArg = arg.visit(shuttle); + return shuttle.visit(new LookupExpr(newArg)); } } - return new LookupExpr(); + return new LookupExpr(arg); } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java index 305c03e4ed51..df8f1f955a4a 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/RegexpExtractExprMacro.java @@ -57,8 +57,14 @@ public Expr apply(final List args) final Pattern pattern = Pattern.compile(String.valueOf(patternExpr.getLiteralValue())); final int index = indexExpr == null ? 0 : ((Number) indexExpr.getLiteralValue()).intValue(); - class RegexpExtractExpr implements Expr + + class RegexpExtractExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { + private RegexpExtractExpr(Expr arg) + { + super(arg); + } + @Nonnull @Override public ExprEval eval(final ObjectBinding bindings) @@ -70,12 +76,12 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + Expr newArg = arg.visit(shuttle); + return shuttle.visit(new RegexpExtractExpr(newArg)); } } - return new RegexpExtractExpr(); + return new RegexpExtractExpr(arg); } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java index 37b35d809dea..bb2f5af5e1e9 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampCeilExprMacro.java @@ -30,6 +30,7 @@ import javax.annotation.Nonnull; import java.util.List; +import java.util.stream.Collectors; public class TimestampCeilExprMacro implements ExprMacroTable.ExprMacro { @@ -53,14 +54,13 @@ public Expr apply(final List args) } } - private static class TimestampCeilExpr implements Expr + private static class TimestampCeilExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr { - private final Expr arg; private final Granularity granularity; - public TimestampCeilExpr(final List args) + TimestampCeilExpr(final List args) { - this.arg = args.get(0); + super(args); this.granularity = getGranularity(args, ExprUtils.nilBindings()); } @@ -68,12 +68,12 @@ public TimestampCeilExpr(final List args) @Override public ExprEval eval(final ObjectBinding bindings) { - ExprEval eval = arg.eval(bindings); + ExprEval eval = args.get(0).eval(bindings); if (eval.isNumericNull()) { // Return null if the argument if null. return ExprEval.of(null); } - DateTime argTime = DateTimes.utc(arg.eval(bindings).asLong()); + DateTime argTime = DateTimes.utc(eval.asLong()); DateTime bucketStartTime = granularity.bucketStart(argTime); if (argTime.equals(bucketStartTime)) { return ExprEval.of(bucketStartTime.getMillis()); @@ -82,10 +82,10 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); + return shuttle.visit(new TimestampCeilExpr(newArgs)); } } @@ -99,13 +99,11 @@ private static PeriodGranularity getGranularity(final List args, final Exp ); } - private static class TimestampCeilDynamicExpr implements Expr + private static class TimestampCeilDynamicExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr { - private final List args; - - public TimestampCeilDynamicExpr(final List args) + TimestampCeilDynamicExpr(final List args) { - this.args = args; + super(args); } @Nonnull @@ -122,12 +120,10 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - for (Expr arg : args) { - arg.visit(visitor); - } - visitor.visit(this); + List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); + return shuttle.visit(new TimestampCeilDynamicExpr(newArgs)); } } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java index 3f1f6836bdc9..48ae86caadf8 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampExtractExprMacro.java @@ -82,8 +82,13 @@ public Expr apply(final List args) final ISOChronology chronology = ISOChronology.getInstance(timeZone); - class TimestampExtractExpr implements Expr + class TimestampExtractExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { + private TimestampExtractExpr(Expr arg) + { + super(arg); + } + @Nonnull @Override public ExprEval eval(final ObjectBinding bindings) @@ -123,13 +128,13 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + Expr newArg = arg.visit(shuttle); + return shuttle.visit(new TimestampExtractExpr(newArg)); } } - return new TimestampExtractExpr(); + return new TimestampExtractExpr(arg); } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java index ae0f16b2151a..00cbb547ac64 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampFloorExprMacro.java @@ -28,6 +28,7 @@ import javax.annotation.Nonnull; import java.util.List; +import java.util.stream.Collectors; public class TimestampFloorExprMacro implements ExprMacroTable.ExprMacro { @@ -61,14 +62,13 @@ private static PeriodGranularity computeGranularity(final List args, final ); } - public static class TimestampFloorExpr implements Expr + public static class TimestampFloorExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr { - private final Expr arg; private final PeriodGranularity granularity; - public TimestampFloorExpr(final List args) + TimestampFloorExpr(final List args) { - this.arg = args.get(0); + super(args); this.granularity = computeGranularity(args, ExprUtils.nilBindings()); } @@ -77,7 +77,7 @@ public TimestampFloorExpr(final List args) */ public Expr getArg() { - return arg; + return args.get(0); } /** @@ -92,7 +92,7 @@ public PeriodGranularity getGranularity() @Override public ExprEval eval(final ObjectBinding bindings) { - ExprEval eval = arg.eval(bindings); + ExprEval eval = args.get(0).eval(bindings); if (eval.isNumericNull()) { // Return null if the argument if null. return ExprEval.of(null); @@ -101,20 +101,19 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); + + return shuttle.visit(new TimestampFloorExpr(newArgs)); } } - public static class TimestampFloorDynamicExpr implements Expr + public static class TimestampFloorDynamicExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr { - private final List args; - - public TimestampFloorDynamicExpr(final List args) + TimestampFloorDynamicExpr(final List args) { - this.args = args; + super(args); } @Nonnull @@ -126,12 +125,10 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - for (Expr arg : args) { - arg.visit(visitor); - } - visitor.visit(this); + List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); + return shuttle.visit(new TimestampFloorDynamicExpr(newArgs)); } } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java index 196b61ea42c3..f82b8b83f71f 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampFormatExprMacro.java @@ -68,8 +68,13 @@ public Expr apply(final List args) ? ISODateTimeFormat.dateTime() : DateTimeFormat.forPattern(formatString).withZone(timeZone); - class TimestampFormatExpr implements Expr + class TimestampFormatExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { + private TimestampFormatExpr(Expr arg) + { + super(arg); + } + @Nonnull @Override public ExprEval eval(final ObjectBinding bindings) @@ -83,13 +88,13 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + Expr newArg = arg.visit(shuttle); + return shuttle.visit(new TimestampFormatExpr(newArg)); } } - return new TimestampFormatExpr(); + return new TimestampFormatExpr(arg); } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java index 2b65fb54653e..b2079f4ac194 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampParseExprMacro.java @@ -64,8 +64,13 @@ public Expr apply(final List args) ? createDefaultParser(timeZone) : DateTimes.wrapFormatter(DateTimeFormat.forPattern(formatString).withZone(timeZone)); - class TimestampParseExpr implements Expr + class TimestampParseExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { + private TimestampParseExpr(Expr arg) + { + super(arg); + } + @Nonnull @Override public ExprEval eval(final ObjectBinding bindings) @@ -86,14 +91,14 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + Expr newArg = arg.visit(shuttle); + return shuttle.visit(new TimestampParseExpr(newArg)); } } - return new TimestampParseExpr(); + return new TimestampParseExpr(arg); } /** diff --git a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java index c76803809d7b..872c89fcc00b 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TimestampShiftExprMacro.java @@ -30,6 +30,7 @@ import javax.annotation.Nonnull; import java.util.List; +import java.util.stream.Collectors; public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro { @@ -70,17 +71,16 @@ private static int getStep(final List args, final Expr.ObjectBinding bindi return args.get(2).eval(bindings).asInt(); } - private static class TimestampShiftExpr implements Expr + private static class TimestampShiftExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr { - private final Expr arg; private final Chronology chronology; private final Period period; private final int step; - public TimestampShiftExpr(final List args) + TimestampShiftExpr(final List args) { + super(args); final PeriodGranularity granularity = getGranularity(args, ExprUtils.nilBindings()); - arg = args.get(0); period = granularity.getPeriod(); chronology = ISOChronology.getInstance(granularity.getTimeZone()); step = getStep(args, ExprUtils.nilBindings()); @@ -90,24 +90,22 @@ public TimestampShiftExpr(final List args) @Override public ExprEval eval(final ObjectBinding bindings) { - return ExprEval.of(chronology.add(period, arg.eval(bindings).asLong(), step)); + return ExprEval.of(chronology.add(period, args.get(0).eval(bindings).asLong(), step)); } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - arg.visit(visitor); - visitor.visit(this); + List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); + return shuttle.visit(new TimestampShiftExpr(newArgs)); } } - private static class TimestampShiftDynamicExpr implements Expr + private static class TimestampShiftDynamicExpr extends ExprMacroTable.BaseScalarMacroFunctionExpr { - private final List args; - - public TimestampShiftDynamicExpr(final List args) + TimestampShiftDynamicExpr(final List args) { - this.args = args; + super(args); } @Nonnull @@ -122,12 +120,10 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - for (Expr arg : args) { - arg.visit(visitor); - } - visitor.visit(this); + List newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); + return shuttle.visit(new TimestampShiftDynamicExpr(newArgs)); } } } diff --git a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java index 777cb90b07a1..e3b49db8ca1a 100644 --- a/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java +++ b/processing/src/main/java/org/apache/druid/query/expression/TrimExprMacro.java @@ -25,7 +25,9 @@ import org.apache.druid.math.expr.ExprMacroTable; import javax.annotation.Nonnull; +import java.util.HashSet; import java.util.List; +import java.util.Set; public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro { @@ -94,16 +96,15 @@ public Expr apply(final List args) } } - private static class TrimStaticCharsExpr implements Expr + private static class TrimStaticCharsExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr { private final TrimMode mode; - private final Expr stringExpr; private final char[] chars; public TrimStaticCharsExpr(final TrimMode mode, final Expr stringExpr, final char[] chars) { + super(stringExpr); this.mode = mode; - this.stringExpr = stringExpr; this.chars = chars; } @@ -111,7 +112,7 @@ public TrimStaticCharsExpr(final TrimMode mode, final Expr stringExpr, final cha @Override public ExprEval eval(final ObjectBinding bindings) { - final ExprEval stringEval = stringExpr.eval(bindings); + final ExprEval stringEval = arg.eval(bindings); if (chars.length == 0 || stringEval.value() == null) { return stringEval; @@ -150,10 +151,10 @@ public ExprEval eval(final ObjectBinding bindings) } @Override - public void visit(final Visitor visitor) + public Expr visit(Shuttle shuttle) { - stringExpr.visit(visitor); - visitor.visit(this); + Expr newStringExpr = arg.visit(shuttle); + return shuttle.visit(new TrimStaticCharsExpr(mode, newStringExpr, chars)); } } @@ -226,6 +227,29 @@ public void visit(final Visitor visitor) charsExpr.visit(visitor); visitor.visit(this); } + + @Override + public Expr visit(Shuttle shuttle) + { + Expr newStringExpr = stringExpr.visit(shuttle); + Expr newCharsExpr = charsExpr.visit(shuttle); + return shuttle.visit(new TrimDynamicCharsExpr(mode, newStringExpr, newCharsExpr)); + } + + @Override + public BindingDetails analyzeInputs() + { + final String stringIdentifier = stringExpr.getIdentifierIfIdentifier(); + final Set scalars = new HashSet<>(); + if (stringIdentifier != null) { + scalars.add(stringIdentifier); + } + final String charsIdentifier = charsExpr.getIdentifierIfIdentifier(); + if (charsIdentifier != null) { + scalars.add(charsIdentifier); + } + return stringExpr.analyzeInputs().merge(charsExpr.analyzeInputs()).mergeWithScalars(scalars); + } } private static boolean arrayContains(char[] array, char c) diff --git a/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java b/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java index 9a348006a9e9..4e731e0c4b5e 100644 --- a/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java +++ b/processing/src/main/java/org/apache/druid/query/filter/ExpressionDimFilter.java @@ -77,7 +77,7 @@ public RangeSet getDimensionRangeSet(final String dimension) @Override public HashSet getRequiredColumns() { - return Sets.newHashSet(Parser.findRequiredBindings(parsed.get())); + return Sets.newHashSet(parsed.get().analyzeInputs().getFreeVariables()); } @Override diff --git a/processing/src/main/java/org/apache/druid/segment/IndexMergerV9.java b/processing/src/main/java/org/apache/druid/segment/IndexMergerV9.java index 950eb5fcd0a2..de6284926fa5 100644 --- a/processing/src/main/java/org/apache/druid/segment/IndexMergerV9.java +++ b/processing/src/main/java/org/apache/druid/segment/IndexMergerV9.java @@ -712,11 +712,11 @@ private void mergeCapabilities( for (IndexableAdapter adapter : adapters) { for (String dimension : adapter.getDimensionNames()) { ColumnCapabilities capabilities = adapter.getCapabilities(dimension); - capabilitiesMap.computeIfAbsent(dimension, d -> new ColumnCapabilitiesImpl()).merge(capabilities); + capabilitiesMap.computeIfAbsent(dimension, d -> new ColumnCapabilitiesImpl().setIsComplete(true)).merge(capabilities); } for (String metric : adapter.getMetricNames()) { ColumnCapabilities capabilities = adapter.getCapabilities(metric); - capabilitiesMap.computeIfAbsent(metric, m -> new ColumnCapabilitiesImpl()).merge(capabilities); + capabilitiesMap.computeIfAbsent(metric, m -> new ColumnCapabilitiesImpl().setIsComplete(true)).merge(capabilities); metricsValueTypes.put(metric, capabilities.getType()); metricTypeNames.put(metric, adapter.getMetricType(metric)); } diff --git a/processing/src/main/java/org/apache/druid/segment/column/ColumnBuilder.java b/processing/src/main/java/org/apache/druid/segment/column/ColumnBuilder.java index 7ef78c674e55..ce081dffe9a6 100644 --- a/processing/src/main/java/org/apache/druid/segment/column/ColumnBuilder.java +++ b/processing/src/main/java/org/apache/druid/segment/column/ColumnBuilder.java @@ -109,6 +109,7 @@ public ColumnHolder build() .setHasBitmapIndexes(bitmapIndex != null) .setHasSpatialIndexes(spatialIndex != null) .setHasMultipleValues(hasMultipleValues) + .setIsComplete(true) .setFilterable(filterable), columnSupplier, bitmapIndex, diff --git a/processing/src/main/java/org/apache/druid/segment/column/ColumnCapabilities.java b/processing/src/main/java/org/apache/druid/segment/column/ColumnCapabilities.java index f3bf54efac6a..4e1902d87f8d 100644 --- a/processing/src/main/java/org/apache/druid/segment/column/ColumnCapabilities.java +++ b/processing/src/main/java/org/apache/druid/segment/column/ColumnCapabilities.java @@ -31,4 +31,12 @@ public interface ColumnCapabilities boolean hasSpatialIndexes(); boolean hasMultipleValues(); boolean isFilterable(); + + /** + * This property indicates that this {@link ColumnCapabilities} is "complete" in that all properties can be expected + * to supply valid responses. Not all {@link ColumnCapabilities} are created equal. Some, such as those provided by + * {@link org.apache.druid.query.groupby.RowBasedColumnSelectorFactory} only have type information, if even that, and + * cannot supply information like {@link ColumnCapabilities#hasMultipleValues}, and will report as false. + */ + boolean isComplete(); } diff --git a/processing/src/main/java/org/apache/druid/segment/column/ColumnCapabilitiesImpl.java b/processing/src/main/java/org/apache/druid/segment/column/ColumnCapabilitiesImpl.java index 65c94ae091ab..5141dba7308a 100644 --- a/processing/src/main/java/org/apache/druid/segment/column/ColumnCapabilitiesImpl.java +++ b/processing/src/main/java/org/apache/druid/segment/column/ColumnCapabilitiesImpl.java @@ -38,6 +38,9 @@ public class ColumnCapabilitiesImpl implements ColumnCapabilities @JsonIgnore private boolean filterable; + @JsonIgnore + private boolean complete = false; + @Override @JsonProperty public ValueType getType() @@ -114,6 +117,12 @@ public boolean isFilterable() filterable; } + @Override + public boolean isComplete() + { + return complete; + } + public ColumnCapabilitiesImpl setFilterable(boolean filterable) { this.filterable = filterable; @@ -126,6 +135,12 @@ public ColumnCapabilitiesImpl setHasMultipleValues(boolean hasMultipleValues) return this; } + public ColumnCapabilitiesImpl setIsComplete(boolean complete) + { + this.complete = complete; + return this; + } + public void merge(ColumnCapabilities other) { if (other == null) { @@ -145,6 +160,7 @@ public void merge(ColumnCapabilities other) this.hasInvertedIndexes |= other.hasBitmapIndexes(); this.hasSpatialIndexes |= other.hasSpatialIndexes(); this.hasMultipleValues |= other.hasMultipleValues(); + this.complete &= other.isComplete(); // these should always be the same? this.filterable &= other.isFilterable(); } } diff --git a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java index f1ccf82a2718..5816c0ce0d61 100644 --- a/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java +++ b/processing/src/main/java/org/apache/druid/segment/filter/ExpressionFilter.java @@ -21,13 +21,11 @@ import com.google.common.base.Supplier; import com.google.common.base.Suppliers; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import org.apache.druid.common.config.NullHandling; import org.apache.druid.math.expr.Evals; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.Parser; import org.apache.druid.query.BitmapResultFactory; import org.apache.druid.query.expression.ExprUtils; import org.apache.druid.query.filter.BitmapIndexSelector; @@ -39,6 +37,7 @@ import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.virtual.ExpressionSelectors; +import java.util.Arrays; import java.util.Set; public class ExpressionFilter implements Filter @@ -49,7 +48,7 @@ public class ExpressionFilter implements Filter public ExpressionFilter(final Supplier expr) { this.expr = expr; - this.requiredBindings = Suppliers.memoize(() -> ImmutableSet.copyOf(Parser.findRequiredBindings(expr.get()))); + this.requiredBindings = Suppliers.memoize(() -> expr.get().analyzeInputs().getFreeVariables()); } @Override @@ -64,7 +63,23 @@ public boolean matches() if (NullHandling.sqlCompatible() && selector.isNull()) { return false; } - return Evals.asBoolean(selector.getLong()); + ExprEval eval = selector.getObject(); + if (eval == null) { + return false; + } + switch (eval.type()) { + case LONG_ARRAY: + Long[] lResult = eval.asLongArray(); + return Arrays.stream(lResult).anyMatch(Evals::asBoolean); + case STRING_ARRAY: + String[] sResult = eval.asStringArray(); + return Arrays.stream(sResult).anyMatch(Evals::asBoolean); + case DOUBLE_ARRAY: + Double[] dResult = eval.asDoubleArray(); + return Arrays.stream(dResult).anyMatch(Evals::asBoolean); + default: + return Evals.asBoolean(selector.getLong()); + } } @Override diff --git a/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java b/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java index 3ed3da048c73..515c47571f42 100644 --- a/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java +++ b/processing/src/main/java/org/apache/druid/segment/incremental/IncrementalIndex.java @@ -317,7 +317,7 @@ protected IncrementalIndex( } //__time capabilities - ColumnCapabilitiesImpl timeCapabilities = new ColumnCapabilitiesImpl(); + ColumnCapabilitiesImpl timeCapabilities = new ColumnCapabilitiesImpl().setIsComplete(true); timeCapabilities.setType(ValueType.LONG); columnCapabilities.put(ColumnHolder.TIME_COLUMN_NAME, timeCapabilities); @@ -654,6 +654,7 @@ IncrementalIndexRowResult toIncrementalIndexRow(InputRow row) capabilities.setType(ValueType.STRING); capabilities.setDictionaryEncoded(true); capabilities.setHasBitmapIndexes(true); + capabilities.setIsComplete(true); columnCapabilities.put(dimension, capabilities); } DimensionHandler handler = DimensionHandlerUtils.getHandlerFromCapabilities(dimension, capabilities, null); @@ -912,6 +913,7 @@ private ColumnCapabilitiesImpl makeCapabilitiesFromValueType(ValueType type) capabilities.setDictionaryEncoded(type == ValueType.STRING); capabilities.setHasBitmapIndexes(type == ValueType.STRING); capabilities.setType(type); + capabilities.setIsComplete(true); return capabilities; } @@ -1106,7 +1108,7 @@ public MetricDesc(int index, AggregatorFactory factory) this.name = factory.getName(); String typeInfo = factory.getTypeName(); - this.capabilities = new ColumnCapabilitiesImpl(); + this.capabilities = new ColumnCapabilitiesImpl().setIsComplete(true); if ("float".equalsIgnoreCase(typeInfo)) { capabilities.setType(ValueType.FLOAT); this.type = typeInfo; diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionColumnValueSelector.java index caffc48b9bd1..c0c3eafa1cac 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionColumnValueSelector.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionColumnValueSelector.java @@ -27,10 +27,14 @@ import javax.annotation.Nonnull; +/** + * Basic expression {@link ColumnValueSelector}. Evaluates {@link Expr} into {@link ExprEval} against + * {@link Expr.ObjectBinding} which are backed by the underlying expression input {@link ColumnValueSelector}s + */ public class ExpressionColumnValueSelector implements ColumnValueSelector { - private final Expr.ObjectBinding bindings; - private final Expr expression; + final Expr.ObjectBinding bindings; + final Expr expression; public ExpressionColumnValueSelector(Expr expression, Expr.ObjectBinding bindings) { diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java index 3e452406b060..c948bd0d183b 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionSelectors.java @@ -24,6 +24,7 @@ import com.google.common.base.Supplier; import com.google.common.collect.Iterables; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.Parser; @@ -43,11 +44,14 @@ import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.data.IndexedInts; -import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; public class ExpressionSelectors { @@ -131,7 +135,9 @@ public static ColumnValueSelector makeExprEvalSelector( Expr expression ) { - final List columns = Parser.findRequiredBindings(expression); + final Expr.BindingDetails exprDetails = expression.analyzeInputs(); + Parser.validateExpr(expression, exprDetails); + final List columns = exprDetails.getRequiredColumns(); if (columns.size() == 1) { final String column = Iterables.getOnlyElement(columns); @@ -146,7 +152,10 @@ public static ColumnValueSelector makeExprEvalSelector( ); } else if (capabilities != null && capabilities.getType() == ValueType.STRING - && capabilities.isDictionaryEncoded()) { + && capabilities.isDictionaryEncoded() + && capabilities.isComplete() + && !capabilities.hasMultipleValues() + && !exprDetails.getArrayVariables().contains(column)) { // Optimization for expressions that hit one string column and nothing else. return new SingleStringInputCachingExpressionColumnValueSelector( columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(column, column, ValueType.STRING)), @@ -155,24 +164,58 @@ public static ColumnValueSelector makeExprEvalSelector( } } - final Expr.ObjectBinding bindings = createBindings(expression, columnSelectorFactory); + final Pair, Set> arrayUsage = + examineColumnSelectorFactoryArrays(columnSelectorFactory, exprDetails, columns); + final Set actualArrays = arrayUsage.lhs; + final Set unknownIfArrays = arrayUsage.rhs; + + final List needsApplied = + columns.stream() + .filter(c -> actualArrays.contains(c) && !exprDetails.getArrayVariables().contains(c)) + .collect(Collectors.toList()); + final Expr finalExpr; + if (needsApplied.size() > 0) { + finalExpr = Parser.applyUnappliedIdentifiers(expression, exprDetails, needsApplied); + } else { + finalExpr = expression; + } + + final Expr.ObjectBinding bindings = createBindings(exprDetails, columnSelectorFactory); if (bindings.equals(ExprUtils.nilBindings())) { // Optimization for constant expressions. return new ConstantExprEvalSelector(expression.eval(bindings)); } - // No special optimization. - return new ExpressionColumnValueSelector(expression, bindings); + // if any unknown column input types, fall back to an expression selector that examines input bindings on a + // per row basis + if (unknownIfArrays.size() > 0) { + return new RowBasedExpressionColumnValueSelector( + finalExpr, + exprDetails, + bindings, + unknownIfArrays + ); + } + + // generic expression value selector for fully known input types + return new ExpressionColumnValueSelector(finalExpr, bindings); } + /** + * Makes a single or multi-value {@link DimensionSelector} wrapper around a {@link ColumnValueSelector} created by + * {@link ExpressionSelectors#makeExprEvalSelector(ColumnSelectorFactory, Expr)} as appropriate + */ public static DimensionSelector makeDimensionSelector( final ColumnSelectorFactory columnSelectorFactory, final Expr expression, final ExtractionFn extractionFn ) { - final List columns = Parser.findRequiredBindings(expression); + final Expr.BindingDetails exprDetails = expression.analyzeInputs(); + Parser.validateExpr(expression, exprDetails); + final List columns = exprDetails.getRequiredColumns(); + if (columns.size() == 1) { final String column = Iterables.getOnlyElement(columns); @@ -180,7 +223,11 @@ public static DimensionSelector makeDimensionSelector( if (capabilities != null && capabilities.getType() == ValueType.STRING - && capabilities.isDictionaryEncoded()) { + && capabilities.isDictionaryEncoded() + && capabilities.isComplete() + && !capabilities.hasMultipleValues() + && !exprDetails.getArrayVariables().contains(column) + ) { // Optimization for dimension selectors that wrap a single underlying string column. return new SingleStringInputDimensionSelector( columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(column, column, ValueType.STRING)), @@ -189,7 +236,16 @@ public static DimensionSelector makeDimensionSelector( } } + final Pair, Set> arrayUsage = + examineColumnSelectorFactoryArrays(columnSelectorFactory, exprDetails, columns); + final Set actualArrays = arrayUsage.lhs; + final Set unknownIfArrays = arrayUsage.rhs; + + final ColumnValueSelector baseSelector = makeExprEvalSelector(columnSelectorFactory, expression); + final boolean multiVal = actualArrays.size() > 0 || + exprDetails.getArrayVariables().size() > 0 || + unknownIfArrays.size() > 0; if (baseSelector instanceof ConstantExprEvalSelector) { // Optimization for dimension selectors on constants. @@ -198,49 +254,108 @@ public static DimensionSelector makeDimensionSelector( // Optimization for null dimension selector. return DimensionSelector.constant(null); } else if (extractionFn == null) { - class DefaultExpressionDimensionSelector extends BaseSingleValueDimensionSelector - { - @Override - protected String getValue() - { - return NullHandling.emptyToNullIfNeeded(baseSelector.getObject().asString()); - } - @Override - public void inspectRuntimeShape(RuntimeShapeInspector inspector) + if (multiVal) { + return new MultiValueExpressionDimensionSelector(baseSelector); + } else { + class DefaultExpressionDimensionSelector extends BaseSingleValueDimensionSelector { - inspector.visit("baseSelector", baseSelector); + @Override + protected String getValue() + { + + return NullHandling.emptyToNullIfNeeded(baseSelector.getObject().asString()); + } + + @Override + public void inspectRuntimeShape(RuntimeShapeInspector inspector) + { + inspector.visit("baseSelector", baseSelector); + } } + return new DefaultExpressionDimensionSelector(); } - return new DefaultExpressionDimensionSelector(); } else { - class ExtractionExpressionDimensionSelector extends BaseSingleValueDimensionSelector - { - @Override - protected String getValue() + if (multiVal) { + class ExtractionMultiValueDimensionSelector extends MultiValueExpressionDimensionSelector { - return extractionFn.apply(NullHandling.emptyToNullIfNeeded(baseSelector.getObject().asString())); + private ExtractionMultiValueDimensionSelector() + { + super(baseSelector); + } + + @Override + String getValue(ExprEval evaluated) + { + assert !evaluated.isArray(); + return extractionFn.apply(NullHandling.emptyToNullIfNeeded(evaluated.asString())); + } + + @Override + List getArray(ExprEval evaluated) + { + assert evaluated.isArray(); + return Arrays.stream(evaluated.asStringArray()) + .map(x -> extractionFn.apply(NullHandling.emptyToNullIfNeeded(x))) + .collect(Collectors.toList()); + } + + @Override + String getArrayValue(ExprEval evaluated, int i) + { + assert evaluated.isArray(); + String[] stringArray = evaluated.asStringArray(); + assert i < stringArray.length; + return extractionFn.apply(NullHandling.emptyToNullIfNeeded(stringArray[i])); + } + + @Override + public void inspectRuntimeShape(RuntimeShapeInspector inspector) + { + inspector.visit("baseSelector", baseSelector); + inspector.visit("extractionFn", extractionFn); + } } + return new ExtractionMultiValueDimensionSelector(); - @Override - public void inspectRuntimeShape(RuntimeShapeInspector inspector) + } else { + class ExtractionExpressionDimensionSelector extends BaseSingleValueDimensionSelector { - inspector.visit("baseSelector", baseSelector); - inspector.visit("extractionFn", extractionFn); + @Override + protected String getValue() + { + return extractionFn.apply(NullHandling.emptyToNullIfNeeded(baseSelector.getObject().asString())); + } + + @Override + public void inspectRuntimeShape(RuntimeShapeInspector inspector) + { + inspector.visit("baseSelector", baseSelector); + inspector.visit("extractionFn", extractionFn); + } } + return new ExtractionExpressionDimensionSelector(); } - return new ExtractionExpressionDimensionSelector(); } } - private static Expr.ObjectBinding createBindings(Expr expression, ColumnSelectorFactory columnSelectorFactory) + /** + * Create {@link Expr.ObjectBinding} given a {@link ColumnSelectorFactory} and {@link Expr.BindingDetails} which + * provides the set of identifiers which need a binding (list of required columns), and context of whether or not they + * are used as array or scalar inputs + */ + private static Expr.ObjectBinding createBindings( + Expr.BindingDetails bindingDetails, + ColumnSelectorFactory columnSelectorFactory + ) { final Map> suppliers = new HashMap<>(); - final List columns = Parser.findRequiredBindings(expression); + final List columns = bindingDetails.getRequiredColumns(); for (String columnName : columns) { final ColumnCapabilities columnCapabilities = columnSelectorFactory .getColumnCapabilities(columnName); final ValueType nativeType = columnCapabilities != null ? columnCapabilities.getType() : null; + final boolean multiVal = columnCapabilities != null && columnCapabilities.hasMultipleValues(); final Supplier supplier; if (nativeType == ValueType.FLOAT) { @@ -257,8 +372,8 @@ private static Expr.ObjectBinding createBindings(Expr expression, ColumnSelector supplier = makeNullableSupplier(selector, selector::getDouble); } else if (nativeType == ValueType.STRING) { supplier = supplierFromDimensionSelector( - columnSelectorFactory - .makeDimensionSelector(new DefaultDimensionSpec(columnName, columnName)) + columnSelectorFactory.makeDimensionSelector(new DefaultDimensionSpec(columnName, columnName)), + multiVal ); } else if (nativeType == null) { // Unknown ValueType. Try making an Object selector and see if that gives us anything useful. @@ -308,24 +423,39 @@ private static Supplier makeNullableSupplier( } } + /** + * Create a supplier to feed {@link Expr.ObjectBinding} for a dimension selector, coercing values to always appear as + * arrays if specified. + */ @VisibleForTesting - @Nonnull - static Supplier supplierFromDimensionSelector(final DimensionSelector selector) + static Supplier supplierFromDimensionSelector(final DimensionSelector selector, boolean coerceArray) { Preconditions.checkNotNull(selector, "selector"); return () -> { final IndexedInts row = selector.getRow(); - if (row.size() == 1) { + if (row.size() == 1 && !coerceArray) { return selector.lookupName(row.get(0)); } else { - // Can't handle non-singly-valued rows in expressions. - // Treat them as nulls until we think of something better to do. - return null; + // column selector factories hate you and use [] and [null] interchangeably for nullish data + if (row.size() == 0) { + return new String[]{null}; + } + final String[] strings = new String[row.size()]; + // noinspection SSBasedInspection + for (int i = 0; i < row.size(); i++) { + strings[i] = selector.lookupName(row.get(i)); + } + return strings; } }; } + + /** + * Create a fallback supplier to feed {@link Expr.ObjectBinding} for a selector, used if column cannot be reliably + * detected as a primitive type + */ @Nullable static Supplier supplierFromObjectSelector(final BaseObjectColumnValueSelector selector) { @@ -343,13 +473,67 @@ static Supplier supplierFromObjectSelector(final BaseObjectColumnValueSe final Object val = selector.getObject(); if (val instanceof Number || val instanceof String) { return val; + } else if (val instanceof List) { + return coerceListDimToStringArray((List) val); } else { return null; } }; + } else if (clazz.isAssignableFrom(List.class)) { + return () -> { + final Object val = selector.getObject(); + if (val != null) { + return coerceListDimToStringArray((List) val); + } + return null; + }; } else { // No numbers or strings. return null; } } + + /** + * Selectors are not consistent in treatment of null, [], and [null], so coerce [] to [null] + */ + private static Object coerceListDimToStringArray(List val) + { + Object[] arrayVal = val.stream().map(Object::toString).toArray(String[]::new); + if (arrayVal.length > 0) { + return arrayVal; + } + return new String[]{null}; + } + + /** + * Returns pair of columns which are definitely multi-valued, or 'actual' arrays, and those which we are unable to + * discern from the {@link ColumnSelectorFactory#getColumnCapabilities(String)}, or 'unknown' arrays. + */ + private static Pair, Set> examineColumnSelectorFactoryArrays( + ColumnSelectorFactory columnSelectorFactory, + Expr.BindingDetails exprDetails, + List columns + ) + { + final Set actualArrays = new HashSet<>(); + final Set unknownIfArrays = new HashSet<>(); + for (String column : columns) { + final ColumnCapabilities capabilities = columnSelectorFactory.getColumnCapabilities(column); + if (capabilities != null) { + if (capabilities.hasMultipleValues()) { + actualArrays.add(column); + } else if ( + !capabilities.isComplete() && + capabilities.getType().equals(ValueType.STRING) && + !exprDetails.getArrayVariables().contains(column) + ) { + unknownIfArrays.add(column); + } + } else { + unknownIfArrays.add(column); + } + } + + return new Pair<>(actualArrays, unknownIfArrays); + } } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java index d029f7f8311c..d71ac28f7db2 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/ExpressionVirtualColumn.java @@ -111,7 +111,7 @@ public ColumnCapabilities capabilities(String columnName) @Override public List requiredColumns() { - return Parser.findRequiredBindings(parsedExpression.get()); + return parsedExpression.get().analyzeInputs().getRequiredColumns(); } @Override diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/MultiValueExpressionDimensionSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/MultiValueExpressionDimensionSelector.java new file mode 100644 index 000000000000..e3b5734ee62f --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/virtual/MultiValueExpressionDimensionSelector.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.virtual; + +import com.google.common.base.Predicate; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.RangeIndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; + +import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * Basic multi-value dimension selector for an {@link org.apache.druid.math.expr.Expr} evaluating + * {@link ColumnValueSelector}. + */ +public class MultiValueExpressionDimensionSelector implements DimensionSelector +{ + private final ColumnValueSelector baseSelector; + + public MultiValueExpressionDimensionSelector(ColumnValueSelector baseSelector) + { + this.baseSelector = baseSelector; + } + + ExprEval getEvaluated() + { + return baseSelector.getObject(); + } + + String getValue(ExprEval evaluated) + { + assert !evaluated.isArray(); + return NullHandling.emptyToNullIfNeeded(evaluated.asString()); + } + + List getArray(ExprEval evaluated) + { + assert evaluated.isArray(); + return Arrays.stream(evaluated.asStringArray()) + .map(NullHandling::emptyToNullIfNeeded) + .collect(Collectors.toList()); + } + + String getArrayValue(ExprEval evaluated, int i) + { + assert evaluated.isArray(); + String[] stringArray = evaluated.asStringArray(); + assert i < stringArray.length; + return NullHandling.emptyToNullIfNeeded(stringArray[i]); + } + + @Override + public IndexedInts getRow() + { + ExprEval evaluated = getEvaluated(); + if (evaluated.isArray()) { + RangeIndexedInts ints = new RangeIndexedInts(); + ints.setSize(evaluated.asArray() != null ? evaluated.asArray().length : 0); + return ints; + } + return ZeroIndexedInts.instance(); + } + + @Override + public int getValueCardinality() + { + return CARDINALITY_UNKNOWN; + } + + @Nullable + @Override + public String lookupName(int id) + { + ExprEval evaluated = getEvaluated(); + if (evaluated.isArray()) { + return getArrayValue(evaluated, id); + } + assert id == 0; + return NullHandling.emptyToNullIfNeeded(evaluated.asString()); + } + + @Override + public ValueMatcher makeValueMatcher(@Nullable String value) + { + return new ValueMatcher() + { + @Override + public boolean matches() + { + ExprEval evaluated = getEvaluated(); + if (evaluated.isArray()) { + List array = getArray(evaluated); + return array.stream().anyMatch(x -> Objects.equals(x, value)); + } + return Objects.equals(getValue(evaluated), value); + } + + @Override + public void inspectRuntimeShape(RuntimeShapeInspector inspector) + { + inspector.visit("selector", baseSelector); + } + }; + } + + @Override + public ValueMatcher makeValueMatcher(Predicate predicate) + { + return new ValueMatcher() + { + @Override + public boolean matches() + { + ExprEval evaluated = getEvaluated(); + if (evaluated.isArray()) { + List array = getArray(evaluated); + return array.stream().anyMatch(x -> predicate.apply(x)); + } + return predicate.apply(getValue(evaluated)); + } + + @Override + public void inspectRuntimeShape(RuntimeShapeInspector inspector) + { + inspector.visit("selector", baseSelector); + inspector.visit("predicate", predicate); + } + }; + } + + @Override + public void inspectRuntimeShape(RuntimeShapeInspector inspector) + { + inspector.visit("baseSelector", baseSelector); + } + + @Override + public boolean nameLookupPossibleInAdvance() + { + return false; + } + + @Nullable + @Override + public IdLookup idLookup() + { + return null; + } + + @Nullable + @Override + public Object getObject() + { + ExprEval evaluated = getEvaluated(); + if (evaluated.isArray()) { + return getArray(evaluated); + } + return getValue(evaluated); + } + + @Override + public Class classOfObject() + { + return Object.class; + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java new file mode 100644 index 000000000000..e34f26a606e7 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/segment/virtual/RowBasedExpressionColumnValueSelector.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.virtual; + +import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import org.apache.druid.math.expr.Expr; +import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.Parser; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Expression column value selector that examines a set of 'unknown' type input bindings on a row by row basis, + * transforming the expression to handle multi-value list typed inputs as they are encountered. + * + * Currently, string dimensions are the only bindings which might appear as a {@link String} or a {@link String[]}, so + * numbers are eliminated from the set of 'unknown' bindings to check as they are encountered. + */ +public class RowBasedExpressionColumnValueSelector extends ExpressionColumnValueSelector +{ + private final List unknownColumns; + private final Expr.BindingDetails baseExprBindingDetails; + private final Set ignoredColumns; + private final Int2ObjectMap transformedCache; + + public RowBasedExpressionColumnValueSelector( + Expr expression, + Expr.BindingDetails baseExprBindingDetails, + Expr.ObjectBinding bindings, + Set unknownColumnsSet + ) + { + super(expression, bindings); + this.unknownColumns = unknownColumnsSet.stream() + .filter(x -> !baseExprBindingDetails.getArrayVariables().contains(x)) + .collect(Collectors.toList()); + this.baseExprBindingDetails = baseExprBindingDetails; + this.ignoredColumns = new HashSet<>(); + this.transformedCache = new Int2ObjectArrayMap<>(unknownColumns.size()); + } + + @Override + public ExprEval getObject() + { + // check to find any arrays for this row + List arrayBindings = unknownColumns.stream().filter(this::isBindingArray).collect(Collectors.toList()); + + // eliminate anything that will never be an array + if (ignoredColumns.size() > 0) { + unknownColumns.removeAll(ignoredColumns); + ignoredColumns.clear(); + } + + // if there are arrays, we need to transform the expression to one that applies each value of the array to the + // base expression, we keep a cache of transformed expressions to minimize extra work + if (arrayBindings.size() > 0) { + final int key = arrayBindings.hashCode(); + if (transformedCache.containsKey(key)) { + return transformedCache.get(key).eval(bindings); + } + Expr transformed = Parser.applyUnappliedIdentifiers(expression, baseExprBindingDetails, arrayBindings); + transformedCache.put(key, transformed); + return transformed.eval(bindings); + } + // no arrays for this row, evaluate base expression + return expression.eval(bindings); + } + + /** + * Check if row value binding for identifier is an array, adding identifiers that retrieve {@link Number} to a set + * of 'unknowns' to eliminate by side effect + */ + private boolean isBindingArray(String x) + { + Object binding = bindings.get(x); + if (binding != null) { + if (binding instanceof String[] && ((String[]) binding).length > 1) { + return true; + } else if (binding instanceof Number) { + ignoredColumns.add(x); + } + } + return false; + } +} diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/SingleLongInputCachingExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/SingleLongInputCachingExpressionColumnValueSelector.java index f05329bd367b..af71a9979b1e 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/SingleLongInputCachingExpressionColumnValueSelector.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/SingleLongInputCachingExpressionColumnValueSelector.java @@ -24,7 +24,6 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.Parser; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.ColumnValueSelector; @@ -60,7 +59,7 @@ public SingleLongInputCachingExpressionColumnValueSelector( ) { // Verify expression has just one binding. - if (Parser.findRequiredBindings(expression).size() != 1) { + if (expression.analyzeInputs().getFreeVariables().size() != 1) { throw new ISE("WTF?! Expected expression with just one binding"); } diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputCachingExpressionColumnValueSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputCachingExpressionColumnValueSelector.java index 87c5df19d1f6..4d358e08a90b 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputCachingExpressionColumnValueSelector.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputCachingExpressionColumnValueSelector.java @@ -25,7 +25,6 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprEval; -import org.apache.druid.math.expr.Parser; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.DimensionSelector; @@ -55,14 +54,14 @@ public SingleStringInputCachingExpressionColumnValueSelector( ) { // Verify expression has just one binding. - if (Parser.findRequiredBindings(expression).size() != 1) { + if (expression.analyzeInputs().getFreeVariables().size() != 1) { throw new ISE("WTF?! Expected expression with just one binding"); } this.selector = Preconditions.checkNotNull(selector, "selector"); this.expression = Preconditions.checkNotNull(expression, "expression"); - final Supplier inputSupplier = ExpressionSelectors.supplierFromDimensionSelector(selector); + final Supplier inputSupplier = ExpressionSelectors.supplierFromDimensionSelector(selector, false); this.bindings = name -> inputSupplier.get(); if (selector.getValueCardinality() == DimensionSelector.CARDINALITY_UNKNOWN) { diff --git a/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputDimensionSelector.java b/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputDimensionSelector.java index ce49901553b3..275869a7b636 100644 --- a/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputDimensionSelector.java +++ b/processing/src/main/java/org/apache/druid/segment/virtual/SingleStringInputDimensionSelector.java @@ -23,7 +23,6 @@ import com.google.common.base.Predicate; import org.apache.druid.java.util.common.ISE; import org.apache.druid.math.expr.Expr; -import org.apache.druid.math.expr.Parser; import org.apache.druid.query.filter.ValueMatcher; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.DimensionSelector; @@ -56,7 +55,7 @@ public SingleStringInputDimensionSelector( ) { // Verify expression has just one binding. - if (Parser.findRequiredBindings(expression).size() != 1) { + if (expression.analyzeInputs().getFreeVariables().size() != 1) { throw new ISE("WTF?! Expected expression with just one binding"); } diff --git a/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java b/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java index 1ee0dafec31a..f0282c6ed7b0 100644 --- a/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java +++ b/processing/src/test/java/org/apache/druid/query/MultiValuedDimensionTest.java @@ -40,6 +40,7 @@ import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.ListFilteredDimensionSpec; import org.apache.druid.query.dimension.RegexFilteredDimensionSpec; +import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.SelectorDimFilter; import org.apache.druid.query.groupby.GroupByQuery; @@ -59,14 +60,18 @@ import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.QueryableIndexSegment; import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.incremental.IncrementalIndex; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.segment.writeout.SegmentWriteOutMediumFactory; import org.apache.druid.segment.writeout.TmpFileSegmentWriteOutMediumFactory; import org.apache.druid.timeline.SegmentId; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -104,15 +109,18 @@ public static Collection constructorFeeder() private IncrementalIndex incrementalIndex; private QueryableIndex queryableIndex; - private File persistedSegmentDir; private IncrementalIndex incrementalIndexNullSampler; private QueryableIndex queryableIndexNullSampler; private File persistedSegmentDirNullSampler; + private final GroupByQueryConfig config; private final ImmutableMap context; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + public MultiValuedDimensionTest(final GroupByQueryConfig config, SegmentWriteOutMediumFactory segmentWriteOutMediumFactory, boolean forceHashAggregation) { helper = AggregationTestHelper.createGroupByQueryAggregationTestHelper( @@ -120,6 +128,7 @@ public MultiValuedDimensionTest(final GroupByQueryConfig config, SegmentWriteOut config, null ); + this.config = config; this.segmentWriteOutMediumFactory = segmentWriteOutMediumFactory; this.context = config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1) @@ -138,9 +147,9 @@ public void setup() throws Exception StringInputRowParser parser = new StringInputRowParser( new CSVParseSpec( new TimestampSpec("timestamp", "iso", null), - new DimensionsSpec(DimensionsSpec.getDefaultSchemas(ImmutableList.of("product", "tags")), null, null), + new DimensionsSpec(DimensionsSpec.getDefaultSchemas(ImmutableList.of("product", "tags", "othertags")), null, null), "\t", - ImmutableList.of("timestamp", "product", "tags"), + ImmutableList.of("timestamp", "product", "tags", "othertags"), false, 0 ), @@ -148,24 +157,23 @@ public void setup() throws Exception ); String[] rows = new String[]{ - "2011-01-12T00:00:00.000Z,product_1,t1\tt2\tt3", - "2011-01-13T00:00:00.000Z,product_2,t3\tt4\tt5", - "2011-01-14T00:00:00.000Z,product_3,t5\tt6\tt7", - "2011-01-14T00:00:00.000Z,product_4" + "2011-01-12T00:00:00.000Z,product_1,t1\tt2\tt3,u1\tu2", + "2011-01-13T00:00:00.000Z,product_2,t3\tt4\tt5,u3\tu4", + "2011-01-14T00:00:00.000Z,product_3,t5\tt6\tt7,u1\tu5", + "2011-01-14T00:00:00.000Z,product_4,,u2" }; for (String row : rows) { incrementalIndex.add(parser.parse(row)); } + persistedSegmentDir = Files.createTempDir(); TestHelper.getTestIndexMergerV9(segmentWriteOutMediumFactory) .persist(incrementalIndex, persistedSegmentDir, new IndexSpec(), null); - queryableIndex = TestHelper.getTestIndexIO().loadIndex(persistedSegmentDir); - StringInputRowParser parserNullSampler = new StringInputRowParser( new JSONParseSpec( new TimestampSpec("time", "iso", null), @@ -216,7 +224,6 @@ public void testGroupByNoFilter() .setGranularity(Granularities.ALL) .setDimensions(new DefaultDimensionSpec("tags", "tags")) .setAggregatorSpecs(new CountAggregatorFactory("count")) - .setContext(context) .build(); Sequence result = helper.runQueryOnSegmentsObjs( @@ -228,7 +235,13 @@ public void testGroupByNoFilter() ); List expectedResults = Arrays.asList( - GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tags", null, "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow( + "1970-01-01T00:00:00.000Z", + "tags", + NullHandling.replaceWithDefault() ? null : "", + "count", + 2L + ), GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tags", "t1", "count", 2L), GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tags", "t2", "count", 2L), GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tags", "t3", "count", 4L), @@ -376,6 +389,485 @@ public void testGroupByWithDimFilterAndWithFilteredDimSpec() TestHelper.assertExpectedObjects(expectedResults, result.toList(), "filteredDim"); } + @Test + public void testGroupByExpression() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality."); + } + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("texpr", "texpr")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "texpr", + "map(x -> concat(x, 'foo'), tags)", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + Sequence result = helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t1foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t2foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t3foo", "count", 4L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t4foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t5foo", "count", 4L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t6foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t7foo", "count", 2L) + ); + + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr"); + } + + @Test + public void testGroupByExpressionMultiMulti() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality."); + } + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("texpr", "texpr")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "texpr", + "cartesian_map((x,y) -> concat(x, y), tags, othertags)", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setLimit(5) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + Sequence result = helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t1u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t1u2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t2u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t2u2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t3u1", "count", 2L) + ); + + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr-multi-multi"); + } + + @Test + public void testGroupByExpressionMultiMultiAuto() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality."); + } + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("texpr", "texpr")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "texpr", + "map((x) -> concat(x, othertags), tags)", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setLimit(5) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + Sequence result = helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t1u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t1u2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t2u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t2u2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t3u1", "count", 2L) + ); + + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr-multi-multi-auto"); + } + + @Test + public void testGroupByExpressionMultiMultiAutoAuto() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality."); + } + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("texpr", "texpr")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "texpr", + "concat(tags, othertags)", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setLimit(5) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + Sequence result = helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t1u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t1u2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t2u1", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t2u2", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "texpr", "t3u1", "count", 2L) + ); + + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr-multi-multi-auto-auto"); + } + + @Test + public void testGroupByExpressionAuto() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality."); + } + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("tt", "tt")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "tt", + "concat(tags, 'foo')", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + Sequence result = helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t1foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t2foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t3foo", "count", 4L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t4foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t5foo", "count", 4L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t6foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t7foo", "count", 2L) + ); + + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr-auto"); + } + + @Test + public void testGroupByExpressionArrayFnArg() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality."); + } + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("tt", "tt")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "tt", + "array_to_string(map(tags -> concat('foo', tags), tags), ', ')", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + Sequence result = helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foot1, foot2, foot3", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foot3, foot4, foot5", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foot5, foot6, foot7", "count", 2L) + ); + + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr-array-fn"); + } + + @Test + public void testGroupByExpressionAutoArrayFnArg() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality."); + } + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("tt", "tt")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "tt", + "array_to_string(concat('foo', tags), ', ')", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + Sequence result = helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foot1, foot2, foot3", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foot3, foot4, foot5", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foot5, foot6, foot7", "count", 2L) + ); + + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr-arrayfn-auto"); + } + + @Test + public void testGroupByExpressionFoldArrayToString() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality."); + } + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("tt", "tt")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "tt", + "fold((tag, acc) -> concat(acc, tag), tags, '')", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + Sequence result = helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ); + + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow( + "1970-01-01T00:00:00.000Z", + "tt", + NullHandling.replaceWithDefault() ? null : "", + "count", + 2L + ), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t1t2t3", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t3t4t5", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "t5t6t7", "count", 2L) + ); + + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr-arrayfn-auto"); + } + + @Test + public void testGroupByExpressionFoldArrayToStringWithConcats() + { + if (config.getDefaultStrategy().equals(GroupByStrategySelector.STRATEGY_V1)) { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage("GroupBy v1 does not support dimension selectors with unknown cardinality."); + } + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("tt", "tt")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "tt", + "fold((tag, acc) -> concat(concat(acc, case_searched(acc == '', '', ', '), concat('foo', tag)))), tags, '')", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + Sequence result = helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ); + + List expectedResults = Arrays.asList( + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foo", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foot1, foot2, foot3", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foot3, foot4, foot5", "count", 2L), + GroupByQueryRunnerTestHelper.createExpectedRow("1970-01-01T00:00:00.000Z", "tt", "foot5, foot6, foot7", "count", 2L) + ); + + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "expr-arrayfn-auto"); + } + + + @Test + public void testGroupByExpressionMultiConflicting() + { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage( + "Invalid expression: (concat [(map ([x] -> (concat [x, othertags])), [tags]), tags]); [tags] used as both scalar and array variables" + ); + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("texpr", "texpr")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "texpr", + "concat(map((x) -> concat(x, othertags), tags), tags)", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setLimit(5) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ).toList(); + } + + @Test + public void testGroupByExpressionMultiConflictingAlso() + { + expectedException.expect(RuntimeException.class); + expectedException.expectMessage( + "Invalid expression: (array_concat [tags, (array_append [othertags, tags])]); [tags] used as both scalar and array variables" + ); + GroupByQuery query = GroupByQuery + .builder() + .setDataSource("xx") + .setQuerySegmentSpec(new LegacySegmentSpec("1970/3000")) + .setGranularity(Granularities.ALL) + .setDimensions(new DefaultDimensionSpec("texpr", "texpr")) + .setVirtualColumns( + new ExpressionVirtualColumn( + "texpr", + "array_concat(tags, (array_append(othertags, tags)))", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .setLimit(5) + .setAggregatorSpecs(new CountAggregatorFactory("count")) + .setContext(context) + .build(); + + helper.runQueryOnSegmentsObjs( + ImmutableList.of( + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + new IncrementalIndexSegment(incrementalIndex, SegmentId.dummy("sid2")) + ), + query + ).toList(); + } + @Test public void testTopNWithDimFilterAndWithFilteredDimSpec() { @@ -427,6 +919,129 @@ public void testTopNWithDimFilterAndWithFilteredDimSpec() } } + @Test + public void testTopNExpression() + { + TopNQuery query = new TopNQueryBuilder() + .dataSource("xx") + .granularity(Granularities.ALL) + .dimension(new DefaultDimensionSpec("texpr", "texpr")) + .virtualColumns( + new ExpressionVirtualColumn( + "texpr", + "map(x -> concat(x, 'foo'), tags)", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .metric("count") + .intervals(QueryRunnerTestHelper.fullOnIntervalSpec) + .aggregators(Collections.singletonList(new CountAggregatorFactory("count"))) + .threshold(15) + .build(); + + try (CloseableStupidPool pool = TestQueryRunners.createDefaultNonBlockingPool()) { + QueryRunnerFactory factory = new TopNQueryRunnerFactory( + pool, + new TopNQueryQueryToolChest( + new TopNQueryConfig(), + QueryRunnerTestHelper.noopIntervalChunkingQueryRunnerDecorator() + ), + QueryRunnerTestHelper.NOOP_QUERYWATCHER + ); + QueryRunner> runner = QueryRunnerTestHelper.makeQueryRunner( + factory, + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + null + ); + Map context = new HashMap<>(); + Sequence> result = runner.run(QueryPlus.wrap(query), context); + List> expected = + ImmutableList.>builder() + .add(ImmutableMap.of("texpr", "t3foo", "count", 2L)) + .add(ImmutableMap.of("texpr", "t5foo", "count", 2L)) + .add(ImmutableMap.of("texpr", "foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t1foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t2foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t4foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t6foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t7foo", "count", 1L)) + .build(); + + List> expectedResults = Collections.singletonList( + new Result( + DateTimes.of("2011-01-12T00:00:00.000Z"), + new TopNResultValue( + expected + ) + ) + ); + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "filteredDim"); + } + } + + @Test + public void testTopNExpressionAutoTransform() + { + TopNQuery query = new TopNQueryBuilder() + .dataSource("xx") + .granularity(Granularities.ALL) + .dimension(new DefaultDimensionSpec("texpr", "texpr")) + .virtualColumns( + new ExpressionVirtualColumn( + "texpr", + "concat(tags, 'foo')", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ) + ) + .metric("count") + .intervals(QueryRunnerTestHelper.fullOnIntervalSpec) + .aggregators(Collections.singletonList(new CountAggregatorFactory("count"))) + .threshold(15) + .build(); + + try (CloseableStupidPool pool = TestQueryRunners.createDefaultNonBlockingPool()) { + QueryRunnerFactory factory = new TopNQueryRunnerFactory( + pool, + new TopNQueryQueryToolChest( + new TopNQueryConfig(), + QueryRunnerTestHelper.noopIntervalChunkingQueryRunnerDecorator() + ), + QueryRunnerTestHelper.NOOP_QUERYWATCHER + ); + QueryRunner> runner = QueryRunnerTestHelper.makeQueryRunner( + factory, + new QueryableIndexSegment(queryableIndex, SegmentId.dummy("sid1")), + null + ); + Map context = new HashMap<>(); + Sequence> result = runner.run(QueryPlus.wrap(query), context); + + List> expected = + ImmutableList.>builder() + .add(ImmutableMap.of("texpr", "t3foo", "count", 2L)) + .add(ImmutableMap.of("texpr", "t5foo", "count", 2L)) + .add(ImmutableMap.of("texpr", "foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t1foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t2foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t4foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t6foo", "count", 1L)) + .add(ImmutableMap.of("texpr", "t7foo", "count", 1L)) + .build(); + + List> expectedResults = Collections.singletonList( + new Result( + DateTimes.of("2011-01-12T00:00:00.000Z"), + new TopNResultValue( + expected + ) + ) + ); + TestHelper.assertExpectedObjects(expectedResults, result.toList(), "filteredDim"); + } + } + @After public void cleanup() throws Exception { diff --git a/processing/src/test/java/org/apache/druid/segment/filter/ExpressionFilterTest.java b/processing/src/test/java/org/apache/druid/segment/filter/ExpressionFilterTest.java index f9d4a91d19ef..91b89652392e 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/ExpressionFilterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/ExpressionFilterTest.java @@ -138,16 +138,14 @@ public void testOneSingleValuedStringColumn() @Test public void testOneMultiValuedStringColumn() { - // Expressions currently treat multi-valued arrays as nulls. - // This test is just documenting the current behavior, not necessarily saying it makes sense. if (NullHandling.replaceWithDefault()) { - assertFilterMatches(edf("dim4 == ''"), ImmutableList.of("0", "1", "2", "4", "5", "6", "7", "8")); + assertFilterMatches(edf("dim4 == ''"), ImmutableList.of("1", "2", "6", "7", "8")); } else { assertFilterMatches(edf("dim4 == ''"), ImmutableList.of("2")); // AS per SQL standard null == null returns false. assertFilterMatches(edf("dim4 == null"), ImmutableList.of()); } - assertFilterMatches(edf("dim4 == '1'"), ImmutableList.of()); + assertFilterMatches(edf("dim4 == '1'"), ImmutableList.of("0")); assertFilterMatches(edf("dim4 == '3'"), ImmutableList.of("3")); } @@ -212,10 +210,7 @@ public void testCompareColumns() assertFilterMatches(edf("dim2 == dim3"), ImmutableList.of("2", "5", "8")); } - // String vs. multi-value string - // Expressions currently treat multi-valued arrays as nulls. - // This test is just documenting the current behavior, not necessarily saying it makes sense. - assertFilterMatches(edf("dim0 == dim4"), ImmutableList.of("3")); + assertFilterMatches(edf("dim0 == dim4"), ImmutableList.of("3", "4", "5")); } @Test diff --git a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionColumnValueSelectorTest.java b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionColumnValueSelectorTest.java index ae95dbe3b0e4..079f692fc653 100644 --- a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionColumnValueSelectorTest.java +++ b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionColumnValueSelectorTest.java @@ -20,6 +20,7 @@ package org.apache.druid.segment.virtual; import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableList; import org.apache.druid.common.guava.SettableSupplier; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.segment.BaseSingleValueDimensionSelector; @@ -38,7 +39,8 @@ public void testSupplierFromDimensionSelector() { final SettableSupplier settableSupplier = new SettableSupplier<>(); final Supplier supplier = ExpressionSelectors.supplierFromDimensionSelector( - dimensionSelectorFromSupplier(settableSupplier) + dimensionSelectorFromSupplier(settableSupplier), + false ); Assert.assertNotNull(supplier); @@ -120,8 +122,12 @@ public void testSupplierFromObjectSelectorList() objectSelectorFromSupplier(settableSupplier, List.class) ); - // List can't be a number, so supplierFromObjectSelector should return null. - Assert.assertNull(supplier); + Assert.assertNotNull(supplier); + Assert.assertEquals(null, supplier.get()); + + settableSupplier.set(ImmutableList.of("1", "2", "3")); + Assert.assertArrayEquals(new String[]{"1", "2", "3"}, (Object[]) supplier.get()); + } private static DimensionSelector dimensionSelectorFromSupplier( diff --git a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionVirtualColumnTest.java b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionVirtualColumnTest.java index b28fd5b8af9b..66223bed66e8 100644 --- a/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionVirtualColumnTest.java +++ b/processing/src/test/java/org/apache/druid/segment/virtual/ExpressionVirtualColumnTest.java @@ -30,6 +30,7 @@ import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.Parser; import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.extraction.BucketExtractionFn; @@ -70,6 +71,29 @@ public class ExpressionVirtualColumnTest ImmutableMap.of("x", 2L, "y", 3L, "z", "foobar") ); + private static final InputRow ROWMULTI = new MapBasedInputRow( + DateTimes.of("2000-01-02T01:00:00").getMillis(), + ImmutableList.of(), + ImmutableMap.of( + "x", 2L, + "y", 3L, + "a", ImmutableList.of("a", "b", "c"), + "b", ImmutableList.of("1", "2", "3"), + "c", ImmutableList.of("4", "5", "6") + ) + ); + private static final InputRow ROWMULTI2 = new MapBasedInputRow( + DateTimes.of("2000-01-02T01:00:00").getMillis(), + ImmutableList.of(), + ImmutableMap.of( + "x", 3L, + "y", 4L, + "a", ImmutableList.of("d", "e", "f"), + "b", ImmutableList.of("3", "4", "5"), + "c", ImmutableList.of("7", "8", "9") + ) + ); + private static final ExpressionVirtualColumn X_PLUS_Y = new ExpressionVirtualColumn( "expr", "x + y", @@ -125,6 +149,20 @@ public class ExpressionVirtualColumnTest TestExprMacroTable.INSTANCE ); + private static final ExpressionVirtualColumn SCALE_LIST_IMPLICIT = new ExpressionVirtualColumn( + "expr", + "b * 2", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ); + + private static final ExpressionVirtualColumn SCALE_LIST_EXPLICIT = new ExpressionVirtualColumn( + "expr", + "map(b -> b * 2, b)", + ValueType.STRING, + TestExprMacroTable.INSTANCE + ); + private static final ThreadLocal CURRENT_ROW = new ThreadLocal<>(); private static final ColumnSelectorFactory COLUMN_SELECTOR_FACTORY = RowBasedColumnSelectorFactory.create( CURRENT_ROW, @@ -154,6 +192,24 @@ public void testObjectSelector() Assert.assertEquals(5L, selector.getObject()); } + @Test + public void testMultiObjectSelector() + { + DimensionSpec spec = new DefaultDimensionSpec("expr", "expr"); + + final BaseObjectColumnValueSelector selectorImplicit = SCALE_LIST_IMPLICIT.makeDimensionSelector(spec, COLUMN_SELECTOR_FACTORY); + CURRENT_ROW.set(ROWMULTI); + Assert.assertEquals(ImmutableList.of("2.0", "4.0", "6.0"), selectorImplicit.getObject()); + CURRENT_ROW.set(ROWMULTI2); + Assert.assertEquals(ImmutableList.of("6.0", "8.0", "10.0"), selectorImplicit.getObject()); + + final BaseObjectColumnValueSelector selectorExplicit = SCALE_LIST_EXPLICIT.makeDimensionSelector(spec, COLUMN_SELECTOR_FACTORY); + CURRENT_ROW.set(ROWMULTI); + Assert.assertEquals(ImmutableList.of("2.0", "4.0", "6.0"), selectorExplicit.getObject()); + CURRENT_ROW.set(ROWMULTI2); + Assert.assertEquals(ImmutableList.of("6.0", "8.0", "10.0"), selectorExplicit.getObject()); + } + @Test public void testLongSelector() { @@ -288,6 +344,22 @@ public void testDimensionSelector() Assert.assertEquals("5", selector.lookupName(selector.getRow().get(0))); } + @Test + public void testNullDimensionSelector() + { + final DimensionSelector selector = X_PLUS_Y.makeDimensionSelector( + new DefaultDimensionSpec("expr", "expr"), + COLUMN_SELECTOR_FACTORY + ); + + final ValueMatcher nonNullMatcher = selector.makeValueMatcher(Predicates.notNull()); + + CURRENT_ROW.set(ROW0); + Assert.assertEquals(false, nonNullMatcher.matches()); + + + } + @Test public void testDimensionSelectorUsingStringFunction() { @@ -374,7 +446,7 @@ public void testDimensionSelectorWithExtraction() Assert.assertEquals(false, nullMatcher.matches()); Assert.assertEquals(false, fiveMatcher.matches()); Assert.assertEquals(true, nonNullMatcher.matches()); - Assert.assertEquals("4", selector.lookupName(selector.getRow().get(0))); + Assert.assertEquals("4.0", selector.lookupName(selector.getRow().get(0))); } else { // y is null in row1 Assert.assertEquals(true, nullMatcher.matches()); @@ -387,7 +459,7 @@ public void testDimensionSelectorWithExtraction() Assert.assertEquals(false, nullMatcher.matches()); Assert.assertEquals(true, fiveMatcher.matches()); Assert.assertEquals(true, nonNullMatcher.matches()); - Assert.assertEquals("5", selector.lookupName(selector.getRow().get(0))); + Assert.assertEquals("5.1", selector.lookupName(selector.getRow().get(0))); CURRENT_ROW.set(ROW3); Assert.assertEquals(false, nullMatcher.matches()); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java index 2413cdb9ec5e..ae31f86d4d39 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java @@ -595,7 +595,7 @@ public static Granularity toQueryGranularity(final DruidExpression expression, f final Expr arg = expr.getArg(); final Granularity granularity = expr.getGranularity(); - if (ColumnHolder.TIME_COLUMN_NAME.equals(Parser.getIdentifierIfIdentifier(arg))) { + if (ColumnHolder.TIME_COLUMN_NAME.equals(arg.getIdentifierIfIdentifier())) { return granularity; } else { return null;