Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import it.unimi.dsi.fastutil.objects.Object2IntArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;

Expand Down Expand Up @@ -135,10 +136,28 @@ ExprEval applyMap(@Nullable ExpressionType arrayType, LambdaExpr expr, Indexable
{
final int length = bindings.getLength();
Object[] out = new Object[length];
final boolean computeArrayType = arrayType == null;
ExpressionType arrayElementType = arrayType != null
? (ExpressionType) arrayType.getElementType()
: null;
final ExprEval<?>[] outEval = computeArrayType ? new ExprEval[length] : null;
for (int i = 0; i < length; i++) {

ExprEval evaluated = expr.eval(bindings.withIndex(i));
arrayType = Function.ArrayConstructorFunction.setArrayOutput(arrayType, out, i, evaluated);
final ExprEval<?> eval = expr.eval(bindings.withIndex(i));
if (computeArrayType && outEval[i].value() != null) {
arrayElementType = ExpressionTypeConversion.leastRestrictiveType(arrayElementType, eval.type());
outEval[i] = eval;
} else {
out[i] = eval.castTo(arrayElementType).value();
}
}
if (arrayElementType == null) {
arrayElementType = NullHandling.sqlCompatible() ? ExpressionType.LONG : ExpressionType.STRING;
}
if (computeArrayType) {
arrayType = ExpressionTypeFactory.getInstance().ofArray(arrayElementType);
for (int i = 0; i < length; i++) {
out[i] = outEval[i].castTo(arrayElementType).value();
}
}
return ExprEval.ofArray(arrayType, out);
}
Expand Down Expand Up @@ -237,7 +256,7 @@ public ExprEval apply(LambdaExpr lambdaExpr, List<Expr> argsExpr, Expr.ObjectBin
List<List<Object>> product = CartesianList.create(arrayInputs);
CartesianMapLambdaBinding lambdaBinding = new CartesianMapLambdaBinding(elementType, product, lambdaExpr, bindings);
ExpressionType lambdaType = lambdaExpr.getOutputType(lambdaBinding);
return applyMap(ExpressionType.asArrayType(lambdaType), lambdaExpr, lambdaBinding);
return applyMap(lambdaType == null ? null : ExpressionTypeFactory.getInstance().ofArray(lambdaType), lambdaExpr, lambdaBinding);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ public Expr asSingleThreaded(InputBindingInspector inspector)
return new ExprEvalBasedConstantExpr<T>(realEval());
}

@Override
public <E> ExprVectorProcessor<E> asVectorProcessor(VectorInputBindingInspector inspector)
{
return VectorProcessors.constant(value, inspector.getMaxVectorSize(), outputType);
}
/**
* Constant expression based on a concreate ExprEval.
*
Expand Down Expand Up @@ -415,7 +420,7 @@ protected ExprEval realEval()
@Override
public <T> ExprVectorProcessor<T> asVectorProcessor(VectorInputBindingInspector inspector)
{
return VectorProcessors.constant(value, inspector.getMaxVectorSize());
return VectorProcessors.constant(value, inspector.getMaxVectorSize(), ExpressionType.STRING);
}

@Override
Expand Down Expand Up @@ -459,12 +464,6 @@ protected ExprEval realEval()
return ExprEval.ofArray(outputType, value);
}

@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return false;
}

@Override
public String stringify()
{
Expand Down Expand Up @@ -547,12 +546,6 @@ protected ExprEval realEval()
return ExprEval.ofComplex(outputType, value);
}

@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return false;
}

@Override
public String stringify()
{
Expand Down
49 changes: 18 additions & 31 deletions processing/src/main/java/org/apache/druid/math/expr/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -2026,7 +2026,8 @@ public <T> ExprVectorProcessor<T> asVectorProcessor(Expr.VectorInputBindingInspe
{
return CastToTypeVectorProcessor.cast(
args.get(0).asVectorProcessor(inspector),
ExpressionType.fromString(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString()))
ExpressionType.fromString(StringUtils.toUpperCase(args.get(1).getLiteralValue().toString())),
inspector.getMaxVectorSize()
);
}
}
Expand Down Expand Up @@ -3357,19 +3358,24 @@ public String name()
@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
// this is copied from 'BaseMapFunction.applyMap', need to find a better way to consolidate, or construct arrays,
// or.. something...
final int length = args.size();
Object[] out = new Object[length];

ExpressionType arrayType = null;

ExpressionType arrayElementType = null;
final ExprEval[] outEval = new ExprEval[length];
for (int i = 0; i < length; i++) {
ExprEval<?> evaluated = args.get(i).eval(bindings);
arrayType = setArrayOutput(arrayType, out, i, evaluated);
outEval[i] = args.get(i).eval(bindings);
if (outEval[i].value() != null) {
arrayElementType = ExpressionTypeConversion.leastRestrictiveType(arrayElementType, outEval[i].type());
}
}

return ExprEval.ofArray(arrayType, out);
if (arrayElementType == null) {
arrayElementType = NullHandling.sqlCompatible() ? ExpressionType.LONG : ExpressionType.STRING;
}
for (int i = 0; i < length; i++) {
out[i] = outEval[i].castTo(arrayElementType).value();
}
return ExprEval.ofArray(ExpressionTypeFactory.getInstance().ofArray(arrayElementType), out);
}

@Override
Expand All @@ -3394,28 +3400,6 @@ public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<E
}
return type == null ? null : ExpressionTypeFactory.getInstance().ofArray(type);
}

/**
* Set an array element to the output array, checking for null if the array is numeric. If the type of the evaluated
* array element does not match the array element type, this method will attempt to call {@link ExprEval#castTo}
* to the array element type, else will set the element as is. If the type of the array is unknown, it will be
* detected and defined from the first element. Returns the type of the array, which will be identical to the input
* type, unless the input type was null.
*/
static ExpressionType setArrayOutput(@Nullable ExpressionType arrayType, Object[] out, int i, ExprEval evaluated)
{
if (arrayType == null) {
arrayType = ExpressionTypeFactory.getInstance().ofArray(evaluated.type());
}
if (arrayType.getElementType().isNumeric() && evaluated.isNumericNull()) {
out[i] = null;
} else if (!evaluated.asArrayType().equals(arrayType)) {
out[i] = evaluated.castTo((ExpressionType) arrayType.getElementType()).value();
} else {
out[i] = evaluated.value();
}
return arrayType;
}
}

class ArrayLengthFunction implements Function
Expand Down Expand Up @@ -3954,6 +3938,9 @@ public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
return ExprEval.ofLongBoolean(Arrays.asList(array1).containsAll(Arrays.asList(array2)));
} else {
final Object elem = rhsExpr.castTo((ExpressionType) array1Type.getElementType()).value();
if (elem == null && rhsExpr.value() != null) {
return ExprEval.ofLongBoolean(false);
}
return ExprEval.ofLongBoolean(Arrays.asList(array1).contains(elem));
}
}
Expand Down
141 changes: 13 additions & 128 deletions processing/src/main/java/org/apache/druid/math/expr/FunctionalExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.apache.druid.math.expr;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.StringUtils;
Expand All @@ -30,136 +29,13 @@
import javax.annotation.Nullable;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

@SuppressWarnings("unused")
final class FunctionalExpr
{
// phony class to enable maven to track the compilation of this class
}

@SuppressWarnings("ClassName")
class LambdaExpr implements Expr
{
private final ImmutableList<IdentifierExpr> args;
private final Expr expr;

LambdaExpr(List<IdentifierExpr> args, Expr expr)
{
this.args = ImmutableList.copyOf(args);
this.expr = expr;
}

@Override
public String toString()
{
return StringUtils.format("(%s -> %s)", args, expr);
}

int identifierCount()
{
return args.size();
}

@Nullable
public String getIdentifier()
{
Preconditions.checkState(args.size() < 2, "LambdaExpr has multiple arguments, use getIdentifiers");
if (args.size() == 1) {
return args.get(0).toString();
}
return null;
}

public List<String> getIdentifiers()
{
return args.stream().map(IdentifierExpr::toString).collect(Collectors.toList());
}

public List<String> stringifyIdentifiers()
{
return args.stream().map(IdentifierExpr::stringify).collect(Collectors.toList());
}

ImmutableList<IdentifierExpr> getIdentifierExprs()
{
return args;
}

public Expr getExpr()
{
return expr;
}

@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return expr.canVectorize(inspector);
}

@Override
public <T> ExprVectorProcessor<T> asVectorProcessor(VectorInputBindingInspector inspector)
{
return expr.asVectorProcessor(inspector);
}

@Override
public ExprEval eval(ObjectBinding bindings)
{
return expr.eval(bindings);
}

@Override
public String stringify()
{
return StringUtils.format("(%s) -> %s", ARG_JOINER.join(stringifyIdentifiers()), expr.stringify());
}

@Override
public Expr visit(Shuttle shuttle)
{
List<IdentifierExpr> newArgs =
args.stream().map(arg -> (IdentifierExpr) shuttle.visit(arg)).collect(Collectors.toList());
Expr newBody = expr.visit(shuttle);
return shuttle.visit(new LambdaExpr(newArgs, newBody));
}

@Override
public BindingAnalysis analyzeInputs()
{
final Set<String> lambdaArgs = args.stream().map(IdentifierExpr::toString).collect(Collectors.toSet());
BindingAnalysis bodyDetails = expr.analyzeInputs();
return bodyDetails.removeLambdaArguments(lambdaArgs);
}

@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return expr.getOutputType(inspector);
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LambdaExpr that = (LambdaExpr) o;
return Objects.equals(args, that.args) &&
Objects.equals(expr, that.expr);
}

@Override
public int hashCode()
{
return Objects.hash(args, expr);
}
}

/**
* {@link Expr} node for a {@link Function} call. {@link FunctionExpr} has children {@link Expr} in the form of the
* list of arguments that are passed to the {@link Function} along with the {@link Expr.ObjectBinding} when it is
Expand Down Expand Up @@ -350,15 +226,24 @@ public ExprEval eval(ObjectBinding bindings)
@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return function.canVectorize(inspector, lambdaExpr, argsExpr) &&
lambdaExpr.canVectorize(inspector) &&
argsExpr.stream().allMatch(expr -> expr.canVectorize(inspector));
return canVectorizeNative(inspector) || (getOutputType(inspector) != null && inspector.canVectorize(argsExpr));
}

@Override
public <T> ExprVectorProcessor<T> asVectorProcessor(VectorInputBindingInspector inspector)
{
return function.asVectorProcessor(inspector, lambdaExpr, argsExpr);
if (canVectorizeNative(inspector)) {
return function.asVectorProcessor(inspector, lambdaExpr, argsExpr);
} else {
return FallbackVectorProcessor.create(function, lambdaExpr, argsExpr, inspector);
}
}

private boolean canVectorizeNative(InputBindingInspector inspector)
{
return function.canVectorize(inspector, lambdaExpr, argsExpr) &&
lambdaExpr.canVectorize(inspector) &&
argsExpr.stream().allMatch(expr -> expr.canVectorize(inspector));
}

@Override
Expand Down
Loading