Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions core/src/main/java/org/apache/druid/math/expr/Parser.java
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,6 @@ public static Expr applyUnappliedBindings(Expr expr, Expr.BindingDetails binding
*/
private static Expr applyUnapplied(Expr expr, List<String> unappliedBindings)
{
final Map<IdentifierExpr, IdentifierExpr> toReplace = new HashMap<>();

// filter to get list of IdentifierExpr that are backed by the unapplied bindings
final List<IdentifierExpr> args = expr.analyzeInputs()
.getFreeVariables()
Expand All @@ -236,18 +234,23 @@ private static Expr applyUnapplied(Expr expr, List<String> unappliedBindings)

// construct lambda args from list of args to apply. Identifiers in a lambda body have artificial 'binding' values
// that is the same as the 'identifier', because the bindings are supplied by the wrapping apply function
// replacements are done by binding rather than identifier because repeats of the same input should not result
// in a cartesian product
final Map<String, IdentifierExpr> toReplace = new HashMap<>();
for (IdentifierExpr applyFnArg : args) {
IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getIdentifier());
lambdaArgs.add(lambdaRewrite);
toReplace.put(applyFnArg, lambdaRewrite);
if (!toReplace.containsKey(applyFnArg.getBinding())) {
IdentifierExpr lambdaRewrite = new IdentifierExpr(applyFnArg.getBinding());
lambdaArgs.add(lambdaRewrite);
toReplace.put(applyFnArg.getBinding(), lambdaRewrite);
}
}

// rewrite identifiers in the expression which will become the lambda body, so they match the lambda identifiers we
// are constructing
Expr newExpr = expr.visit(childExpr -> {
if (childExpr instanceof IdentifierExpr) {
if (toReplace.containsKey(childExpr)) {
return toReplace.get(childExpr);
if (toReplace.containsKey(((IdentifierExpr) childExpr).getBinding())) {
return toReplace.get(((IdentifierExpr) childExpr).getBinding());
}
}
return childExpr;
Expand All @@ -257,13 +260,13 @@ private static Expr applyUnapplied(Expr expr, List<String> unappliedBindings)
// wrap an expression in either map or cartesian_map to apply any unapplied identifiers
final LambdaExpr lambdaExpr = new LambdaExpr(lambdaArgs, newExpr);
final ApplyFunction fn;
if (args.size() == 1) {
if (lambdaArgs.size() == 1) {
fn = new ApplyFunction.MapFunction();
} else {
fn = new ApplyFunction.CartesianMapFunction();
}

final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(args));
final Expr magic = new ApplyFunctionExpr(fn, fn.name(), lambdaExpr, ImmutableList.copyOf(lambdaArgs));
return magic;
}

Expand Down
13 changes: 10 additions & 3 deletions core/src/test/java/org/apache/druid/math/expr/ParserTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,13 @@ public void testApplyUnapplied()
"(cast [x, LONG_ARRAY])",
ImmutableList.of("x")
);

validateApplyUnapplied(
"case_searched((x == 'b'),'b',(x == 'g'),'g','Other')",
"(case_searched [(== x b), b, (== x g), g, Other])",
"(map ([x] -> (case_searched [(== x b), b, (== x g), g, Other])), [x])",
ImmutableList.of("x")
);
}

@Test
Expand All @@ -424,22 +431,22 @@ public void testUniquify()
validateApplyUnapplied(
"x + x",
"(+ x x)",
"(cartesian_map ([x, x_0] -> (+ x x_0)), [x, x])",
"(map ([x] -> (+ x x)), [x])",
ImmutableList.of("x")
);

validateApplyUnapplied(
"x + x + x",
"(+ (+ x x) x)",
"(cartesian_map ([x, x_0, x_1] -> (+ (+ x x_0) x_1)), [x, x, x])",
"(map ([x] -> (+ (+ x x) x)), [x])",
ImmutableList.of("x")
);

// heh
validateApplyUnapplied(
"x + x + x + y + y + y + y + z + z + z",
"(+ (+ (+ (+ (+ (+ (+ (+ (+ x x) x) y) y) y) y) z) z) z)",
"(cartesian_map ([x, x_0, x_1, y, y_2, y_3, y_4, z, z_5, z_6] -> (+ (+ (+ (+ (+ (+ (+ (+ (+ x x_0) x_1) y) y_2) y_3) y_4) z) z_5) z_6)), [x, x, x, y, y, y, y, z, z, z])",
"(cartesian_map ([x, y, z] -> (+ (+ (+ (+ (+ (+ (+ (+ (+ x x) x) y) y) y) y) z) z) z)), [x, y, z])",
ImmutableList.of("x", "y", "z")
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.segment.writeout.SegmentWriteOutMediumFactory;
import org.apache.druid.segment.writeout.TmpFileSegmentWriteOutMediumFactory;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.apache.druid.timeline.SegmentId;
import org.junit.After;
import org.junit.Before;
Expand All @@ -89,7 +90,7 @@
/**
*/
@RunWith(Parameterized.class)
public class MultiValuedDimensionTest
public class MultiValuedDimensionTest extends InitializedNullHandlingTest
{
@Parameterized.Parameters(name = "groupby: {0} forceHashAggregation: {2} ({1})")
public static Collection<?> constructorFeeder()
Expand Down Expand Up @@ -609,8 +610,8 @@ public void testGroupByExpressionMultiMultiAutoAutoDupeIdentifier()
List<ResultRow> expectedResults = Arrays.asList(
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t3t3", "count", 4L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t5t5", "count", 4L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2t1", "count", 2L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t1t2", "count", 2L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t4t4", "count", 2L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t2t2", "count", 2L),
GroupByQueryRunnerTestHelper.createExpectedRow(query, "1970", "texpr", "t7t7", "count", 2L)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8519,8 +8519,8 @@ public void testMultiValueStringWorksLikeStringSelfConcatScan() throws Exception
.build()
),
ImmutableList.of(
new Object[]{"[\"a-lol-a\",\"a-lol-b\",\"b-lol-a\",\"b-lol-b\"]"},
new Object[]{"[\"b-lol-b\",\"b-lol-c\",\"c-lol-b\",\"c-lol-c\"]"},
new Object[]{"[\"a-lol-a\",\"b-lol-b\"]"},
new Object[]{"[\"b-lol-b\",\"c-lol-c\"]"},
new Object[]{"[\"d-lol-d\"]"},
new Object[]{"[\"-lol-\"]"},
new Object[]{nullVal},
Expand Down