From b03a48a478f7afc93ccd5b93cd66acd8118ab78e Mon Sep 17 00:00:00 2001 From: Zoltan Haindrich Date: Fri, 1 Mar 2024 15:27:27 +0000 Subject: [PATCH 1/3] Pull up literals in InputAccessor * pull up literals in `InputAccessor` * remove the need to pass `constants` of `Window` operator Fixes #15353 --- .../hll/sql/HllSketchSqlAggregatorTest.java | 16 ++++ .../druid/sql/calcite/rel/DruidQuery.java | 5 +- .../druid/sql/calcite/rel/InputAccessor.java | 84 ++++++++++++------- .../druid/sql/calcite/rel/Windowing.java | 5 +- 4 files changed, 76 insertions(+), 34 deletions(-) diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 22204a5f9a4d..2c3356748a3d 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -32,6 +32,7 @@ import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.BaseQuery; import org.apache.druid.query.Druids; +import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -1268,6 +1269,7 @@ public void testEstimateStringAndDoubleAreDifferent() ); } + /** * This is a test in a similar vein to {@link #testEstimateStringAndDoubleAreDifferent()} except here we are * ensuring that float values and doubles values are considered equivalent. The expected initial inputs were @@ -1318,6 +1320,20 @@ public void testFloatAndDoubleAreConsideredTheSame() ); } + @Test + public void testDsHllOnTopOfNested() + { + // this query was not planable: https://github.com/apache/druid/issues/15353 + testBuilder() + .queryContext(ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true)) + .sql( + "SELECT d1,dim2,APPROX_COUNT_DISTINCT_DS_HLL(dim2, 18) as val" + + " FROM (select d1,dim1,dim2 from druid.foo group by d1,dim1,dim2 order by dim1 limit 3) t " + + " group by 1,2" + ) + .run(); + } + private ExpressionVirtualColumn makeSketchEstimateExpression(String outputName, String field) { return new ExpressionVirtualColumn( diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java index 6e0bab212771..f3ea896b842f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/DruidQuery.java @@ -591,10 +591,9 @@ private static List computeAggregations( virtualColumnRegistry, rexBuilder, InputAccessor.buildFor( - rexBuilder, - rowSignature, + aggregate, partialQuery.getSelectProject(), - null), + rowSignature), aggregations, aggName, aggCall, diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java index 12c81d887567..aeadab0fa9ab 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/InputAccessor.java @@ -20,12 +20,17 @@ package org.apache.druid.sql.calcite.rel; import com.google.common.collect.ImmutableList; +import org.apache.calcite.plan.RelOptPredicateList; +import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Window; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.druid.segment.column.RowSignature; -import org.apache.druid.sql.calcite.expression.Expressions; +import org.apache.druid.sql.calcite.table.RowSignatures; import javax.annotation.Nullable; import java.util.List; @@ -38,43 +43,67 @@ */ public class InputAccessor { - private final Project project; - private final ImmutableList constants; - private final RexBuilder rexBuilder; + private final RelNode relNode; + @Nullable + private final Project flattenedProject; private final RowSignature inputRowSignature; + @Nullable + private final ImmutableList constants; + private final RelNode inputRelNode; + private final RelDataType inputRelRowType; + private final RelOptPredicateList predicates; private final int inputFieldCount; + private final RelDataType inputDruidRowType; public static InputAccessor buildFor( - RexBuilder rexBuilder, - RowSignature inputRowSignature, - @Nullable Project project, - @Nullable ImmutableList constants) + RelNode relNode, + @Nullable Project flattenedProject, + RowSignature rowSignature) { - return new InputAccessor(rexBuilder, inputRowSignature, project, constants); + return new InputAccessor( + relNode, + flattenedProject, + rowSignature + ); } private InputAccessor( - RexBuilder rexBuilder, - RowSignature inputRowSignature, - Project project, - ImmutableList constants) + RelNode relNode, + Project flattenedProject, + RowSignature rowSignature) { - this.rexBuilder = rexBuilder; - this.inputRowSignature = inputRowSignature; - this.project = project; - this.constants = constants; - this.inputFieldCount = project != null ? project.getRowType().getFieldCount() : inputRowSignature.size(); + this.relNode = relNode; + this.constants = getConstants(relNode); + this.inputRelNode = relNode.getInput(0).stripped(); + this.flattenedProject = flattenedProject; + this.inputRowSignature = rowSignature; + this.inputRelRowType = inputRelNode.getRowType(); + this.predicates = relNode.getCluster().getMetadataQuery().getPulledUpPredicates(inputRelNode); + this.inputFieldCount = inputRelRowType.getFieldCount(); + this.inputDruidRowType = RowSignatures.toRelDataType(inputRowSignature, getRexBuilder().getTypeFactory()); } - public RexNode getField(int argIndex) + private ImmutableList getConstants(RelNode relNode) { + if (relNode instanceof Window) { + return ((Window) relNode).constants; + } + return null; + } + public RexNode getField(int argIndex) + { if (argIndex < inputFieldCount) { - return Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - inputRowSignature, - project, - argIndex); + RexInputRef inputRef = RexInputRef.of(argIndex, inputRelRowType); + RexNode constant = predicates.constantMap.get(inputRef); + if (constant != null) { + return constant; + } + if (flattenedProject != null) { + return flattenedProject.getProjects().get(argIndex); + } else { + return RexInputRef.of(argIndex, inputDruidRowType); + } } else { return constants.get(argIndex - inputFieldCount); } @@ -90,18 +119,17 @@ public List getFields(List argList) public @Nullable Project getProject() { - return project; + return flattenedProject; } - public RexBuilder getRexBuilder() { - return rexBuilder; + return relNode.getCluster().getRexBuilder(); } - public RowSignature getInputRowSignature() { return inputRowSignature; } + } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java index 60f0f1d539d8..c96b3bdd39f5 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rel/Windowing.java @@ -180,10 +180,9 @@ public static Windowing fromCalciteStuff( virtualColumnRegistry, rexBuilder, InputAccessor.buildFor( - rexBuilder, - sourceRowSignature, + window, partialQuery.getSelectProject(), - window.constants), + sourceRowSignature), Collections.emptyList(), aggName, aggregateCall, From cca7176e8ec6eee37f9740ff8603cf113f49440b Mon Sep 17 00:00:00 2001 From: Zoltan Haindrich Date: Wed, 6 Mar 2024 16:55:05 +0000 Subject: [PATCH 2/3] update test --- .../datasketches/hll/sql/HllSketchSqlAggregatorTest.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 2c3356748a3d..680e09230911 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -32,7 +32,6 @@ import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.BaseQuery; import org.apache.druid.query.Druids; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -1325,12 +1324,18 @@ public void testDsHllOnTopOfNested() { // this query was not planable: https://github.com/apache/druid/issues/15353 testBuilder() - .queryContext(ImmutableMap.of(QueryContexts.ENABLE_DEBUG, true)) .sql( "SELECT d1,dim2,APPROX_COUNT_DISTINCT_DS_HLL(dim2, 18) as val" + " FROM (select d1,dim1,dim2 from druid.foo group by d1,dim1,dim2 order by dim1 limit 3) t " + " group by 1,2" ) + .expectedResults( + ImmutableList.of( + new Object[] {null, "a", 1L}, + new Object[] {"1.0", "a", 1L}, + new Object[] {"1.7", null, 0L} + ) + ) .run(); } From ca506dc3d7f331181a296780fab18cb60b39d283 Mon Sep 17 00:00:00 2001 From: Zoltan Haindrich Date: Wed, 6 Mar 2024 17:41:23 +0000 Subject: [PATCH 3/3] enable relax_nulls --- .../datasketches/hll/sql/HllSketchSqlAggregatorTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 680e09230911..538ca8171808 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -1330,6 +1330,7 @@ public void testDsHllOnTopOfNested() + " group by 1,2" ) .expectedResults( + ResultMatchMode.RELAX_NULLS, ImmutableList.of( new Object[] {null, "a", 1L}, new Object[] {"1.0", "a", 1L},