Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 172 additions & 94 deletions core/src/main/java/org/apache/druid/math/expr/Expr.java

Large diffs are not rendered by default.

72 changes: 61 additions & 11 deletions core/src/main/java/org/apache/druid/math/expr/ExprListenerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@
import org.apache.commons.lang.StringEscapeUtils;
import org.apache.druid.annotations.UsedInGeneratedCode;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.antlr.ExprBaseListener;
import org.apache.druid.math.expr.antlr.ExprParser;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Implementation of antlr parse tree listener, transforms {@link ParseTree} to {@link Expr}, based on the grammar
Expand All @@ -44,11 +47,17 @@ public class ExprListenerImpl extends ExprBaseListener
private final ExprMacroTable macroTable;
private final ParseTree rootNodeKey;

private final Set<String> lambdaIdentifiers;
private final Set<String> uniqueIdentifiers;
private int uniqueCounter = 0;

ExprListenerImpl(ParseTree rootNodeKey, ExprMacroTable macroTable)
{
this.rootNodeKey = rootNodeKey;
this.macroTable = macroTable;
this.nodes = new HashMap<>();
this.lambdaIdentifiers = new HashSet<>();
this.uniqueIdentifiers = new HashSet<>();
}

Expr getAST()
Expand Down Expand Up @@ -347,14 +356,19 @@ public void exitFunctionExpr(ExprParser.FunctionExprContext ctx)
@Override
public void exitIdentifierExpr(ExprParser.IdentifierExprContext ctx)
{
String text = ctx.getText();
if (text.charAt(0) == '"' && text.charAt(text.length() - 1) == '"') {
text = StringEscapeUtils.unescapeJava(text.substring(1, text.length() - 1));
final String text = sanitizeIdentifierString(ctx.getText());
nodes.put(ctx, createIdentifierExpr(text));
}

@Override
public void enterLambda(ExprParser.LambdaContext ctx)
{
// mark lambda identifiers on enter, for reference later when creating the IdentifierExpr inside of the lambdas
for (int i = 0; i < ctx.IDENTIFIER().size(); i++) {
String text = ctx.IDENTIFIER(i).getText();
text = sanitizeIdentifierString(text);
this.lambdaIdentifiers.add(text);
}
nodes.put(
ctx,
new IdentifierExpr(text)
);
}

@Override
Expand All @@ -363,10 +377,10 @@ public void exitLambda(ExprParser.LambdaContext ctx)
List<IdentifierExpr> identifiers = new ArrayList<>(ctx.IDENTIFIER().size());
for (int i = 0; i < ctx.IDENTIFIER().size(); i++) {
String text = ctx.IDENTIFIER(i).getText();
if (text.charAt(0) == '"' && text.charAt(text.length() - 1) == '"') {
text = StringEscapeUtils.unescapeJava(text.substring(1, text.length() - 1));
}
identifiers.add(i, new IdentifierExpr(text));
text = sanitizeIdentifierString(text);
identifiers.add(i, createIdentifierExpr(text));
// clean up lambda identifier references on exit
lambdaIdentifiers.remove(text);
}

nodes.put(ctx, new LambdaExpr(identifiers, (Expr) nodes.get(ctx.expr())));
Expand Down Expand Up @@ -405,6 +419,42 @@ public void exitEmptyArray(ExprParser.EmptyArrayContext ctx)
nodes.put(ctx, new StringArrayExpr(new String[0]));
}

/**
* All {@link IdentifierExpr} that are *not* bound to a {@link LambdaExpr} identifier, will recieve a unique
* {@link IdentifierExpr#identifier} value which may or may not be the same as the
* {@link IdentifierExpr#bindingIdentifier} value. {@link LambdaExpr} identifiers however, will always have
* {@link IdentifierExpr#identifier} be the same as {@link IdentifierExpr#bindingIdentifier} because they have
* synthetic bindings set at evaluation time. This is done to aid in analysis needed for the automatic expression
* translation which maps scalar expressions to multi-value inputs. See
* {@link Parser#applyUnappliedIdentifiers(Expr, Expr.BindingDetails, List)}} for additional details.
*/
private IdentifierExpr createIdentifierExpr(String binding)
{
if (!lambdaIdentifiers.contains(binding)) {
String uniqueIdentifier = binding;
while (uniqueIdentifiers.contains(uniqueIdentifier)) {
uniqueIdentifier = StringUtils.format("%s_%s", binding, uniqueCounter++);
}
uniqueIdentifiers.add(uniqueIdentifier);
return new IdentifierExpr(uniqueIdentifier, binding);
}
return new IdentifierExpr(binding);
}

/**
* Remove double quotes from an identifier variable string, returning unqouted identifier
*/
private static String sanitizeIdentifierString(String text)
{
if (text.charAt(0) == '"' && text.charAt(text.length() - 1) == '"') {
text = StringEscapeUtils.unescapeJava(text.substring(1, text.length() - 1));
}
return text;
}

/**
* Remove single quote from a string literal, returning unquoted string value
*/
private static String escapeStringLiteral(String text)
{
String unquoted = text.substring(1, text.length() - 1);
Expand Down
17 changes: 5 additions & 12 deletions core/src/main/java/org/apache/druid/math/expr/ExprMacroTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,7 @@ public void visit(final Visitor visitor)
@Override
public BindingDetails analyzeInputs()
{
final String identifier = arg.getIdentifierIfIdentifier();
if (identifier == null) {
return arg.analyzeInputs();
}
return arg.analyzeInputs().mergeWithScalars(ImmutableSet.of(identifier));
return arg.analyzeInputs().withScalarArguments(ImmutableSet.of(arg));
}
}

Expand Down Expand Up @@ -145,16 +141,13 @@ public void visit(final Visitor visitor)
@Override
public BindingDetails analyzeInputs()
{
Set<String> scalars = new HashSet<>();
final Set<Expr> argSet = new HashSet<>(args.size());
BindingDetails accumulator = new BindingDetails();
for (Expr arg : args) {
final String identifier = arg.getIdentifierIfIdentifier();
if (identifier != null) {
scalars.add(identifier);
}
accumulator = accumulator.merge(arg.analyzeInputs());
accumulator = accumulator.with(arg);
argSet.add(arg);
}
return accumulator.mergeWithScalars(scalars);
return accumulator.withScalarArguments(argSet);
}
}
}
105 changes: 69 additions & 36 deletions core/src/main/java/org/apache/druid/math/expr/Parser.java
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,9 @@ public static Expr applyUnappliedIdentifiers(Expr expr, Expr.BindingDetails bind
return expr;
}
List<String> unapplied = toApply.stream()
.filter(x -> bindingDetails.getFreeVariables().contains(x))
.filter(x -> bindingDetails.getRequiredColumns().contains(x))
.collect(Collectors.toList());

ApplyFunction fn;
final LambdaExpr lambdaExpr;
final List<Expr> args;

// any unapplied identifiers that are inside a lambda expression need that lambda expression to be rewritten
Expr newExpr = expr.visit(
childExpr -> {
Expand Down Expand Up @@ -215,26 +211,52 @@ public static Expr applyUnappliedIdentifiers(Expr expr, Expr.BindingDetails bind
return newExpr;
}

// else, it *should be safe* to wrap in either map or cartesian_map because we still have missing bindings that
// were *not* referenced in a lambda body
if (remainingUnappliedArgs.size() == 1) {
return applyUnapplied(newExpr, remainingUnappliedArgs);
}

/**
* translate an {@link Expr} into an {@link ApplyFunctionExpr} for {@link ApplyFunction.MapFunction} or
* {@link ApplyFunction.CartesianMapFunction} if there are multiple unbound arguments to be applied
*/
private static Expr applyUnapplied(Expr expr, List<String> unapplied)
{
// wrap an expression in either map or cartesian_map to apply any unapplied identifiers
final Map<IdentifierExpr, IdentifierExpr> toReplace = new HashMap<>();
final List<IdentifierExpr> args = expr.analyzeInputs()
.getFreeVariables()
.stream()
.filter(x -> unapplied.contains(x.getBindingIdentifier()))
.collect(Collectors.toList());

final List<IdentifierExpr> lambdaArgs = new ArrayList<>();

// construct lambda args from list of args to apply
for (IdentifierExpr applyFnArg : args) {
IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getIdentifier());
lambdaArgs.add(lambdaRewrite);
toReplace.put(applyFnArg, lambdaRewrite);
}

// rewrite identifiers in the expression which will become the lambda body, so they match the lambda identifiers we
// are constructing
Expr newExpr = expr.visit(childExpr -> {
if (childExpr instanceof IdentifierExpr) {
if (toReplace.containsKey(childExpr)) {
return toReplace.get(childExpr);
}
}
return childExpr;
});

final LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr);
final ApplyFunction fn;
if (args.size() == 1) {
fn = new ApplyFunction.MapFunction();
IdentifierExpr lambdaArg = new IdentifierExpr(remainingUnappliedArgs.iterator().next());
lambdaExpr = new LambdaExpr(ImmutableList.of(lambdaArg), newExpr);
args = ImmutableList.of(lambdaArg);
} else {
fn = new ApplyFunction.CartesianMapFunction();
List<IdentifierExpr> identifiers = new ArrayList<>(remainingUnappliedArgs.size());
args = new ArrayList<>(remainingUnappliedArgs.size());
for (String remainingUnappliedArg : remainingUnappliedArgs) {
IdentifierExpr arg = new IdentifierExpr(remainingUnappliedArg);
identifiers.add(arg);
args.add(arg);
}
lambdaExpr = new LambdaExpr(identifiers, newExpr);
}

Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, args);
final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(args));
return magic;
}

Expand All @@ -249,28 +271,38 @@ public static Expr applyUnappliedIdentifiers(Expr expr, Expr.BindingDetails bind
*/
private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List<String> unappliedArgs)
{

// recursively evaluate arguments to ensure they are properly transformed into arrays as necessary
List<String> unappliedInThisApply =
Set<String> unappliedInThisApply =
unappliedArgs.stream()
.filter(u -> !expr.bindingDetails.getArrayVariables().contains(u))
.collect(Collectors.toList());
.filter(u -> !expr.bindingDetails.getArrayColumns().contains(u))
.collect(Collectors.toSet());

List<String> unappliedIdentifiers =
expr.bindingDetails
.getFreeVariables()
.stream()
.filter(x -> unappliedInThisApply.contains(x.getIdentifierBindingIfIdentifier()))
.map(IdentifierExpr::getIdentifierIfIdentifier)
.collect(Collectors.toList());

List<Expr> newArgs = new ArrayList<>();
for (int i = 0; i < expr.argsExpr.size(); i++) {
newArgs.add(applyUnappliedIdentifiers(
expr.argsExpr.get(i),
expr.argsBindingDetails.get(i),
unappliedInThisApply)
newArgs.add(
applyUnappliedIdentifiers(
expr.argsExpr.get(i),
expr.argsBindingDetails.get(i),
unappliedIdentifiers
)
);
}

// this will _not_ include the lambda identifiers.. anything in this list needs to be applied
List<IdentifierExpr> unappliedLambdaBindings = expr.lambdaBindingDetails.getFreeVariables()
.stream()
.filter(unappliedArgs::contains)
.map(IdentifierExpr::new)
.collect(Collectors.toList());
List<IdentifierExpr> unappliedLambdaBindings =
expr.lambdaBindingDetails.getFreeVariables()
.stream()
.filter(x -> unappliedArgs.contains(x.getIdentifierBindingIfIdentifier()))
.map(x -> new IdentifierExpr(x.getIdentifier(), x.getBindingIdentifier()))
.collect(Collectors.toList());

if (unappliedLambdaBindings.size() == 0) {
return new ApplyFunctionExpr(expr.function, expr.name, expr.lambdaExpr, newArgs);
Expand Down Expand Up @@ -321,11 +353,12 @@ private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List<St
// cartesian_fold((x, y, acc) -> acc + x + y + z, x, y, acc) =>
// cartesian_fold((x, y, z, acc) -> acc + x + y + z, x, y, z, acc)

final List<Expr> newFoldArgs = new ArrayList<>(expr.argsExpr.size() + unappliedLambdaBindings.size());
final List<Expr> newFoldArgs =
new ArrayList<>(expr.argsExpr.size() + unappliedLambdaBindings.size());
final List<IdentifierExpr> newFoldLambdaIdentifiers =
new ArrayList<>(expr.lambdaExpr.getIdentifiers().size() + unappliedLambdaBindings.size());
final List<IdentifierExpr> existingFoldLambdaIdentifiers = expr.lambdaExpr.getIdentifierExprs();
// accumulator argument is last argument, slice it off when constructing new arg list and lambda args identifiers
// accumulator argument is last argument, slice it off when constructing new arg list and lambda args
for (int i = 0; i < expr.argsExpr.size() - 1; i++) {
newFoldArgs.add(expr.argsExpr.get(i));
newFoldLambdaIdentifiers.add(existingFoldLambdaIdentifiers.get(i));
Expand Down Expand Up @@ -353,7 +386,7 @@ private static ApplyFunctionExpr liftApplyLambda(ApplyFunctionExpr expr, List<St
public static void validateExpr(Expr expression, Expr.BindingDetails bindingDetails)
{
final Set<String> conflicted =
Sets.intersection(bindingDetails.getScalarVariables(), bindingDetails.getArrayVariables());
Sets.intersection(bindingDetails.getScalarColumns(), bindingDetails.getArrayColumns());
if (conflicted.size() != 0) {
throw new RE("Invalid expression: %s; %s used as both scalar and array variables", expression, conflicted);
}
Expand Down
Loading