From c4cfebe8d69685cb89bb6ec965060a623d665c97 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Mon, 15 Sep 2025 20:43:32 -0700 Subject: [PATCH] allow queries with filtered aggs to match unfiltered projection aggs as long as the filter required columns are present --- .../segment/projections/Projections.java | 115 ++++++++++++++++-- .../segment/CursorFactoryProjectionTest.java | 104 ++++++++++++++++ 2 files changed, 209 insertions(+), 10 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/segment/projections/Projections.java b/processing/src/main/java/org/apache/druid/segment/projections/Projections.java index c7f62759a6c9..bc3269e5a758 100644 --- a/processing/src/main/java/org/apache/druid/segment/projections/Projections.java +++ b/processing/src/main/java/org/apache/druid/segment/projections/Projections.java @@ -19,6 +19,7 @@ package org.apache.druid.segment.projections; +import com.google.common.collect.RangeSet; import org.apache.druid.data.input.impl.AggregateProjectionSpec; import org.apache.druid.error.InvalidInput; import org.apache.druid.java.util.common.granularity.Granularities; @@ -26,7 +27,9 @@ import org.apache.druid.java.util.common.granularity.PeriodGranularity; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.aggregation.AggregatorFactory; +import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.cache.CacheKeyBuilder; +import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.Filter; import org.apache.druid.segment.AggregateProjectionMetadata; import org.apache.druid.segment.CursorBuildSpec; @@ -153,7 +156,7 @@ public static ProjectionMatch matchAggregateProjection( return null; } - matchBuilder = matchAggregators(projection, queryCursorBuildSpec, matchBuilder); + matchBuilder = matchAggregators(projection, queryCursorBuildSpec, physicalColumnChecker, matchBuilder); if (matchBuilder == null) { return null; } @@ -199,15 +202,7 @@ public static ProjectionMatchBuilder matchFilter( final Set originalRequired = queryFilter.getRequiredColumns(); // try to rewrite the query filter into a projection filter, if the rewrite is valid, we can proceed final Filter projectionFilter = projection.getFilter().toOptimizedFilter(false); - final Map filterRewrites = new HashMap<>(); - // start with identity - for (String required : queryFilter.getRequiredColumns()) { - filterRewrites.put(required, required); - } - // overlay projection rewrites - filterRewrites.putAll(matchBuilder.getRemapColumns()); - - final Filter remappedQueryFilter = queryFilter.rewriteRequiredColumns(filterRewrites); + final Filter remappedQueryFilter = remapFilterToProjection(matchBuilder, queryFilter); final Filter rewritten = ProjectionFilterMatch.rewriteFilter(projectionFilter, remappedQueryFilter); // if the filter does not contain the projection filter, we cannot match this projection @@ -288,6 +283,7 @@ public static ProjectionMatchBuilder matchGrouping( public static ProjectionMatchBuilder matchAggregators( AggregateProjectionMetadata.Schema projection, CursorBuildSpec queryCursorBuildSpec, + PhysicalColumnChecker physicalColumnChecker, ProjectionMatchBuilder matchBuilder ) { @@ -296,6 +292,10 @@ public static ProjectionMatchBuilder matchAggregators( } boolean allMatch = true; for (AggregatorFactory queryAgg : queryCursorBuildSpec.getAggregators()) { + AggregatorFactory filterAgg = null; + if (queryAgg instanceof FilteredAggregatorFactory) { + filterAgg = ((FilteredAggregatorFactory) queryAgg).getAggregator(); + } boolean foundMatch = false; for (AggregatorFactory projectionAgg : projection.getAggregators()) { final AggregatorFactory combining = queryAgg.substituteCombiningFactory(projectionAgg); @@ -306,6 +306,37 @@ public static ProjectionMatchBuilder matchAggregators( foundMatch = true; break; } + + if (filterAgg != null) { + final AggregatorFactory filteredCombining = filterAgg.substituteCombiningFactory(projectionAgg); + if (filteredCombining != null) { + FilteredAggregatorFactory filteredQueryAgg = (FilteredAggregatorFactory) queryAgg; + final Filter aggFilter = filteredQueryAgg.getFilter().toFilter(); + final Filter remappedAggFilter = remapFilterToProjection(matchBuilder, aggFilter); + for (String column : aggFilter.getRequiredColumns()) { + matchBuilder = matchRequiredColumn( + column, + projection, + queryCursorBuildSpec, + physicalColumnChecker, + matchBuilder + ); + if (matchBuilder == null) { + return null; + } + } + + final FilteredAggregatorFactory remappedFilteredAgg = new FilteredAggregatorFactory( + filteredCombining, + new RewrittenAggDimFilter(filteredQueryAgg.getFilter(), remappedAggFilter) + ); + matchBuilder.remapColumn(queryAgg.getName(), projectionAgg.getName()) + .addReferencedPhysicalColumn(projectionAgg.getName()) + .addPreAggregatedAggregator(remappedFilteredAgg); + foundMatch = true; + break; + } + } } allMatch = allMatch && foundMatch; } @@ -493,6 +524,20 @@ private static boolean isUnalignedInterval( return false; } + private static Filter remapFilterToProjection(ProjectionMatchBuilder matchBuilder, Filter aggFilter) + { + final Map filterRewrites = new HashMap<>(); + // start with identity + for (String required : aggFilter.getRequiredColumns()) { + filterRewrites.put(required, required); + } + // overlay projection rewrites + filterRewrites.putAll(matchBuilder.getRemapColumns()); + + final Filter remappedAggFilter = aggFilter.rewriteRequiredColumns(filterRewrites); + return remappedAggFilter; + } + /** * Returns true if column is defined in {@link AggregateProjectionSpec#getGroupingColumns()} OR if the column does not * exist in the base table. Part of determining if a projection can be used for a given {@link CursorBuildSpec}, @@ -505,6 +550,56 @@ public interface PhysicalColumnChecker boolean check(String projectionName, String columnName); } + private static final class RewrittenAggDimFilter implements DimFilter + { + private final DimFilter originalFilter; + private final Filter rewrittenFilter; + + private RewrittenAggDimFilter(DimFilter originalFilter, Filter rewrittenFilter) + { + this.originalFilter = originalFilter; + this.rewrittenFilter = rewrittenFilter; + } + + @Override + public DimFilter optimize(boolean mayIncludeUnknown) + { + return this; + } + + @Override + public Filter toOptimizedFilter(boolean mayIncludeUnknown) + { + return rewrittenFilter; + } + + @Override + public Filter toFilter() + { + return rewrittenFilter; + } + + @Nullable + @Override + public RangeSet getDimensionRangeSet(String dimension) + { + return null; + } + + @Override + public Set getRequiredColumns() + { + return rewrittenFilter.getRequiredColumns(); + } + + @Nullable + @Override + public byte[] getCacheKey() + { + return originalFilter.getCacheKey(); + } + } + private Projections() { // no instantiation diff --git a/processing/src/test/java/org/apache/druid/segment/CursorFactoryProjectionTest.java b/processing/src/test/java/org/apache/druid/segment/CursorFactoryProjectionTest.java index 0eb1c70a8722..1230a888a672 100644 --- a/processing/src/test/java/org/apache/druid/segment/CursorFactoryProjectionTest.java +++ b/processing/src/test/java/org/apache/druid/segment/CursorFactoryProjectionTest.java @@ -54,6 +54,7 @@ import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FloatSumAggregatorFactory; import org.apache.druid.query.aggregation.LongMaxAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; @@ -62,6 +63,8 @@ import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.expression.TimestampFloorExprMacro; import org.apache.druid.query.filter.EqualityFilter; +import org.apache.druid.query.filter.NullFilter; +import org.apache.druid.query.filter.OrDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQueryConfig; import org.apache.druid.query.groupby.GroupByQueryMetrics; @@ -826,6 +829,107 @@ public void testProjectionSingleDimFilter() ); } + @Test + public void testProjectionSingleDimFilteredAgg() + { + final GroupByQuery query = + GroupByQuery.builder() + .setDataSource("test") + .setGranularity(Granularities.ALL) + .setInterval(new Interval(UTC_MIDNIGHT, UTC_MIDNIGHT.plusDays(1))) + .addDimension("a") + .addAggregator( + new FilteredAggregatorFactory( + new LongSumAggregatorFactory("c_sum", "c"), + new EqualityFilter("a", ColumnType.STRING, "a", null) + ) + ) + .build(); + final ExpectedProjectionGroupBy queryMetrics = + new ExpectedProjectionGroupBy("a_hourly_c_sum_with_count_latest"); + final CursorBuildSpec buildSpec = GroupingEngine.makeCursorBuildSpec(query, queryMetrics); + + assertCursorProjection(buildSpec, queryMetrics, 3); + + testGroupBy( + query, + queryMetrics, + List.of( + new Object[]{"a", 7L}, + new Object[]{"b", null} + ) + ); + } + + @Test + public void testProjectionSingleDimFilteredAggLessMatchy() + { + final GroupByQuery query = + GroupByQuery.builder() + .setDataSource("test") + .setGranularity(Granularities.ALL) + .setInterval(new Interval(UTC_MIDNIGHT, UTC_MIDNIGHT.plusDays(1))) + .addDimension("a") + .addAggregator( + new FilteredAggregatorFactory( + new LongSumAggregatorFactory("c_sum", "c"), + new EqualityFilter("b", ColumnType.STRING, "bb", null) + ) + ) + .build(); + final ExpectedProjectionGroupBy queryMetrics = + new ExpectedProjectionGroupBy("ab_hourly_cd_sum"); + final CursorBuildSpec buildSpec = GroupingEngine.makeCursorBuildSpec(query, queryMetrics); + + assertCursorProjection(buildSpec, queryMetrics, 7); + + testGroupBy( + query, + queryMetrics, + List.of( + new Object[]{"a", 1L}, + new Object[]{"b", 5L} + ) + ); + } + + @Test + public void testProjectionSingleDimFilteredAggNoMatchy() + { + final GroupByQuery query = + GroupByQuery.builder() + .setDataSource("test") + .setGranularity(Granularities.ALL) + .setInterval(new Interval(UTC_MIDNIGHT, UTC_MIDNIGHT.plusDays(1))) + .addDimension("a") + .addAggregator( + new FilteredAggregatorFactory( + new LongSumAggregatorFactory("c_sum", "c"), + new OrDimFilter( + List.of( + new EqualityFilter("b", ColumnType.STRING, "bb", null), + new NullFilter("e", null) + ) + ) + ) + ) + .build(); + final ExpectedProjectionGroupBy queryMetrics = + new ExpectedProjectionGroupBy(null); + final CursorBuildSpec buildSpec = GroupingEngine.makeCursorBuildSpec(query, queryMetrics); + + assertCursorProjection(buildSpec, queryMetrics, 8); + + testGroupBy( + query, + queryMetrics, + List.of( + new Object[]{"a", 2L}, + new Object[]{"b", 5L} + ) + ); + } + @Test public void testProjectionSingleDimFilterWithPartialIntervalAligned() {