diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java index 1ed1c37bc3a1..c2dd8ae9996a 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/IndexedTableJoinCursorBenchmark.java @@ -186,6 +186,7 @@ public void setup() enableFilterPushdown, enableFilterRewrite, enableFilterRewriteValueFilters, + QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER, QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE ), clauses, diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java index 8d68e6b79c5f..8c69cdeae9b9 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/JoinAndLookupBenchmark.java @@ -150,6 +150,7 @@ public void setup() throws IOException false, false, false, + false, 0 ), joinableClausesLookupStringKey, @@ -185,6 +186,7 @@ public void setup() throws IOException false, false, false, + false, 0 ), joinableClausesLookupLongKey, @@ -220,6 +222,7 @@ public void setup() throws IOException false, false, false, + false, 0 ), joinableClausesLookupLongKey, @@ -255,6 +258,7 @@ public void setup() throws IOException false, false, false, + false, 0 ), joinableClausesIndexedTableLongKey, diff --git a/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/LoadingLookup.java b/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/LoadingLookup.java index af346e2b45aa..2bffc36ee808 100644 --- a/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/LoadingLookup.java +++ b/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/LoadingLookup.java @@ -30,6 +30,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; @@ -111,12 +112,24 @@ public boolean canIterate() return false; } + @Override + public boolean canGetKeySet() + { + return false; + } + @Override public Iterable> iterable() { throw new UnsupportedOperationException("Cannot iterate"); } + @Override + public Set keySet() + { + throw new UnsupportedOperationException("Cannot get key set"); + } + @Override public byte[] getCacheKey() { diff --git a/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/PollingLookup.java b/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/PollingLookup.java index 375f3d06d4a2..84c20d5d64b6 100644 --- a/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/PollingLookup.java +++ b/extensions-core/lookups-cached-single/src/main/java/org/apache/druid/server/lookup/PollingLookup.java @@ -37,6 +37,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -173,12 +174,24 @@ public boolean canIterate() return false; } + @Override + public boolean canGetKeySet() + { + return false; + } + @Override public Iterable> iterable() { throw new UnsupportedOperationException("Cannot iterate"); } + @Override + public Set keySet() + { + throw new UnsupportedOperationException("Cannot get key set"); + } + @Override public byte[] getCacheKey() { diff --git a/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/LoadingLookupTest.java b/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/LoadingLookupTest.java index 93e147dee460..0a28454aa58c 100644 --- a/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/LoadingLookupTest.java +++ b/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/LoadingLookupTest.java @@ -26,7 +26,9 @@ import org.apache.druid.testing.InitializedNullHandlingTest; import org.easymock.EasyMock; import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import java.util.Arrays; import java.util.Collections; @@ -40,6 +42,9 @@ public class LoadingLookupTest extends InitializedNullHandlingTest LoadingCache reverseLookupCache = EasyMock.createStrictMock(LoadingCache.class); LoadingLookup loadingLookup = new LoadingLookup(dataFetcher, lookupCache, reverseLookupCache); + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @Test public void testApplyEmptyOrNull() throws ExecutionException { @@ -123,4 +128,17 @@ public void testGetCacheKey() { Assert.assertFalse(Arrays.equals(loadingLookup.getCacheKey(), loadingLookup.getCacheKey())); } + + @Test + public void testCanGetKeySet() + { + Assert.assertFalse(loadingLookup.canGetKeySet()); + } + + @Test + public void testKeySet() + { + expectedException.expect(UnsupportedOperationException.class); + loadingLookup.keySet(); + } } diff --git a/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/PollingLookupTest.java b/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/PollingLookupTest.java index c276b742b101..715100d359da 100644 --- a/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/PollingLookupTest.java +++ b/extensions-core/lookups-cached-single/src/test/java/org/apache/druid/server/lookup/PollingLookupTest.java @@ -34,7 +34,9 @@ import org.junit.After; import org.junit.Assert; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -62,6 +64,9 @@ public class PollingLookupTest extends InitializedNullHandlingTest private static final long POLL_PERIOD = 1000L; + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @JsonTypeName("mock") private static class MockDataFetcher implements DataFetcher { @@ -204,6 +209,19 @@ public void testGetCacheKey() Assert.assertFalse(Arrays.equals(pollingLookup2.getCacheKey(), pollingLookup.getCacheKey())); } + @Test + public void testCanGetKeySet() + { + Assert.assertFalse(pollingLookup.canGetKeySet()); + } + + @Test + public void testKeySet() + { + expectedException.expect(UnsupportedOperationException.class); + pollingLookup.keySet(); + } + private void assertMapLookup(Map map, LookupExtractor lookup) { for (Map.Entry entry : map.entrySet()) { diff --git a/processing/src/main/java/org/apache/druid/query/Queries.java b/processing/src/main/java/org/apache/druid/query/Queries.java index e25a88ea38b8..58de4695faf4 100644 --- a/processing/src/main/java/org/apache/druid/query/Queries.java +++ b/processing/src/main/java/org/apache/druid/query/Queries.java @@ -26,11 +26,16 @@ import org.apache.druid.java.util.common.ISE; import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.PostAggregator; +import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.query.planning.PreJoinableClause; import org.apache.druid.query.spec.MultipleSpecificSegmentSpec; +import org.apache.druid.segment.VirtualColumn; +import org.apache.druid.segment.VirtualColumns; +import org.apache.druid.segment.column.ColumnHolder; +import javax.annotation.Nullable; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -219,4 +224,73 @@ public static Query withBaseDataSource(final Query query, final DataSo return retVal; } + + /** + * Helper for implementations of {@link Query#getRequiredColumns()}. Returns the list of columns that will be read + * out of a datasource by a query that uses the provided objects in the usual way. + * + * The returned set always contains {@code __time}, no matter what. + * + * If the virtual columns, filter, dimensions, aggregators, or additional columns refer to a virtual column, then the + * inputs of the virtual column will be returned instead of the name of the virtual column itself. Therefore, the + * returned list will never contain the names of any virtual columns. + * + * @param virtualColumns virtual columns whose inputs should be included. + * @param filter optional filter whose inputs should be included. + * @param dimensions dimension specs whose inputs should be included. + * @param aggregators aggregators whose inputs should be included. + * @param additionalColumns additional columns to include. Each of these will be added to the returned set, unless it + * refers to a virtual column, in which case the virtual column inputs will be added instead. + */ + public static Set computeRequiredColumns( + final VirtualColumns virtualColumns, + @Nullable final DimFilter filter, + final List dimensions, + final List aggregators, + final List additionalColumns + ) + { + final Set requiredColumns = new HashSet<>(); + + // Everyone needs __time (it's used by intervals filters). + requiredColumns.add(ColumnHolder.TIME_COLUMN_NAME); + + for (VirtualColumn virtualColumn : virtualColumns.getVirtualColumns()) { + for (String column : virtualColumn.requiredColumns()) { + if (!virtualColumns.exists(column)) { + requiredColumns.addAll(virtualColumn.requiredColumns()); + } + } + } + + if (filter != null) { + for (String column : filter.getRequiredColumns()) { + if (!virtualColumns.exists(column)) { + requiredColumns.add(column); + } + } + } + + for (DimensionSpec dimensionSpec : dimensions) { + if (!virtualColumns.exists(dimensionSpec.getDimension())) { + requiredColumns.add(dimensionSpec.getDimension()); + } + } + + for (AggregatorFactory aggregator : aggregators) { + for (String column : aggregator.requiredFields()) { + if (!virtualColumns.exists(column)) { + requiredColumns.add(column); + } + } + } + + for (String column : additionalColumns) { + if (!virtualColumns.exists(column)) { + requiredColumns.add(column); + } + } + + return requiredColumns; + } } diff --git a/processing/src/main/java/org/apache/druid/query/Query.java b/processing/src/main/java/org/apache/druid/query/Query.java index 93b24ce45ce8..fc12d5e41722 100644 --- a/processing/src/main/java/org/apache/druid/query/Query.java +++ b/processing/src/main/java/org/apache/druid/query/Query.java @@ -46,6 +46,7 @@ import javax.annotation.Nullable; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutorService; @@ -193,4 +194,20 @@ default VirtualColumns getVirtualColumns() { return VirtualColumns.EMPTY; } + + /** + * Returns the set of columns that this query will need to access out of its datasource. + * + * This method does not "look into" what the datasource itself is doing. For example, if a query is built on a + * {@link QueryDataSource}, this method will not return the columns used by that subquery. As another example, if a + * query is built on a {@link JoinDataSource}, this method will not return the columns from the underlying datasources + * that are used by the join condition, unless those columns are also used by this query in other ways. + * + * Returns null if the set of required columns cannot be known ahead of time. + */ + @Nullable + default Set getRequiredColumns() + { + return null; + } } diff --git a/processing/src/main/java/org/apache/druid/query/QueryContexts.java b/processing/src/main/java/org/apache/druid/query/QueryContexts.java index 6edba6847771..f6528a6d0c65 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryContexts.java +++ b/processing/src/main/java/org/apache/druid/query/QueryContexts.java @@ -54,6 +54,7 @@ public class QueryContexts public static final String JOIN_FILTER_PUSH_DOWN_KEY = "enableJoinFilterPushDown"; public static final String JOIN_FILTER_REWRITE_ENABLE_KEY = "enableJoinFilterRewrite"; public static final String JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY = "enableJoinFilterRewriteValueColumnFilters"; + public static final String REWRITE_JOIN_TO_FILTER_ENABLE_KEY = "enableRewriteJoinToFilter"; public static final String JOIN_FILTER_REWRITE_MAX_SIZE_KEY = "joinFilterRewriteMaxSize"; // This flag control whether a sql join query with left scan should be attempted to be run as direct table access // instead of being wrapped inside a query. With direct table access enabled, druid can push down the join operation to @@ -80,6 +81,7 @@ public class QueryContexts public static final boolean DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN = true; public static final boolean DEFAULT_ENABLE_JOIN_FILTER_REWRITE = true; public static final boolean DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS = false; + public static final boolean DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER = false; public static final long DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE = 10000; public static final boolean DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT = false; public static final boolean DEFAULT_USE_FILTER_CNF = false; @@ -274,6 +276,7 @@ public static int getParallelMergeParallelism(Query query, int defaultVal { return parseInt(query, BROKER_PARALLELISM, defaultValue); } + public static boolean getEnableJoinFilterRewriteValueColumnFilters(Query query) { return parseBoolean( @@ -283,6 +286,15 @@ public static boolean getEnableJoinFilterRewriteValueColumnFilters(Query ); } + public static boolean getEnableRewriteJoinToFilter(Query query) + { + return parseBoolean( + query, + REWRITE_JOIN_TO_FILTER_ENABLE_KEY, + DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER + ); + } + public static long getJoinFilterRewriteMaxSize(Query query) { return parseLong(query, JOIN_FILTER_REWRITE_MAX_SIZE_KEY, DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE); diff --git a/processing/src/main/java/org/apache/druid/query/extraction/MapLookupExtractor.java b/processing/src/main/java/org/apache/druid/query/extraction/MapLookupExtractor.java index 23096020b60f..b00161566c80 100644 --- a/processing/src/main/java/org/apache/druid/query/extraction/MapLookupExtractor.java +++ b/processing/src/main/java/org/apache/druid/query/extraction/MapLookupExtractor.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; @JsonTypeName("map") @@ -128,12 +129,24 @@ public boolean canIterate() return true; } + @Override + public boolean canGetKeySet() + { + return true; + } + @Override public Iterable> iterable() { return map.entrySet(); } + @Override + public Set keySet() + { + return Collections.unmodifiableSet(map.keySet()); + } + @Override public boolean equals(Object o) { diff --git a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java index b4fb0756cded..7923e8e9c9e9 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/GroupByQuery.java @@ -73,6 +73,7 @@ import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.List; @@ -778,6 +779,19 @@ public Sequence postProcess(Sequence results) return postProcessingFn.apply(results); } + @Nullable + @Override + public Set getRequiredColumns() + { + return Queries.computeRequiredColumns( + virtualColumns, + dimFilter, + dimensions, + aggregatorSpecs, + Collections.emptyList() + ); + } + @Override public GroupByQuery withOverriddenContext(Map contextOverride) { diff --git a/processing/src/main/java/org/apache/druid/query/lookup/LookupExtractor.java b/processing/src/main/java/org/apache/druid/query/lookup/LookupExtractor.java index f806a5555ab3..f24a965b0408 100644 --- a/processing/src/main/java/org/apache/druid/query/lookup/LookupExtractor.java +++ b/processing/src/main/java/org/apache/druid/query/lookup/LookupExtractor.java @@ -29,6 +29,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") @JsonSubTypes(value = { @@ -105,6 +106,11 @@ public Map> unapplyAll(Iterable values) */ public abstract boolean canIterate(); + /** + * Returns true if this lookup extractor's {@link #keySet()} method will return a valid set. + */ + public abstract boolean canGetKeySet(); + /** * Returns an Iterable that iterates over the keys and values in this lookup extractor. * @@ -112,6 +118,13 @@ public Map> unapplyAll(Iterable values) */ public abstract Iterable> iterable(); + /** + * Returns a Set of all keys in this lookup extractor. The returned Set will not change. + * + * @throws UnsupportedOperationException if {@link #canGetKeySet()} returns false. + */ + public abstract Set keySet(); + /** * Create a cache key for use in results caching * diff --git a/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java b/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java index 347e675cea3e..067bdfff9132 100644 --- a/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java +++ b/processing/src/main/java/org/apache/druid/query/scan/ScanQuery.java @@ -31,6 +31,7 @@ import org.apache.druid.query.BaseQuery; import org.apache.druid.query.DataSource; import org.apache.druid.query.Druids; +import org.apache.druid.query.Queries; import org.apache.druid.query.Query; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.spec.QuerySegmentSpec; @@ -38,10 +39,12 @@ import org.apache.druid.segment.column.ColumnHolder; import javax.annotation.Nullable; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; public class ScanQuery extends BaseQuery { @@ -311,6 +314,24 @@ public Ordering getResultOrdering() ); } + @Nullable + @Override + public Set getRequiredColumns() + { + if (columns == null || columns.isEmpty()) { + // We don't know what columns we require. We'll find out when the segment shows up. + return null; + } else { + return Queries.computeRequiredColumns( + virtualColumns, + dimFilter, + Collections.emptyList(), + Collections.emptyList(), + columns + ); + } + } + public ScanQuery withOffset(final long newOffset) { return Druids.ScanQueryBuilder.copy(this).offset(newOffset).build(); diff --git a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java index 47567071ecbf..63c12de3670f 100644 --- a/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java +++ b/processing/src/main/java/org/apache/druid/query/timeseries/TimeseriesQuery.java @@ -38,10 +38,13 @@ import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.segment.VirtualColumns; +import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; /** */ @@ -157,6 +160,19 @@ public boolean isSkipEmptyBuckets() return getContextBoolean(SKIP_EMPTY_BUCKETS, false); } + @Nullable + @Override + public Set getRequiredColumns() + { + return Queries.computeRequiredColumns( + virtualColumns, + dimFilter, + Collections.emptyList(), + aggregatorSpecs, + Collections.emptyList() + ); + } + @Override public TimeseriesQuery withQuerySegmentSpec(QuerySegmentSpec querySegmentSpec) { diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQuery.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQuery.java index 3218139e8e32..7724a6d60fab 100644 --- a/processing/src/main/java/org/apache/druid/query/topn/TopNQuery.java +++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQuery.java @@ -37,10 +37,13 @@ import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.segment.VirtualColumns; +import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; /** */ @@ -156,6 +159,19 @@ public List getPostAggregatorSpecs() return postAggregatorSpecs; } + @Nullable + @Override + public Set getRequiredColumns() + { + return Queries.computeRequiredColumns( + virtualColumns, + dimFilter, + Collections.singletonList(dimensionSpec), + aggregatorSpecs, + Collections.emptyList() + ); + } + public void initTopNAlgorithmSelector(TopNAlgorithmSelector selector) { if (dimensionSpec.getExtractionFn() != null) { diff --git a/processing/src/main/java/org/apache/druid/segment/filter/Filters.java b/processing/src/main/java/org/apache/druid/segment/filter/Filters.java index 03209d057bfc..abecb0d48a5b 100644 --- a/processing/src/main/java/org/apache/druid/segment/filter/Filters.java +++ b/processing/src/main/java/org/apache/druid/segment/filter/Filters.java @@ -59,6 +59,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.NoSuchElementException; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; @@ -486,14 +487,14 @@ public static boolean shouldUseBitmapIndex( /** * Create a filter representing an AND relationship across a list of filters. Deduplicates filters, flattens stacks, - * and removes literal "false" filters. + * and removes null filters and literal "false" filters. * * @param filters List of filters * * @return If "filters" has more than one filter remaining after processing, returns {@link AndFilter}. * If "filters" has a single element remaining after processing, return that filter alone. * - * @throws IllegalArgumentException if "filters" is empty + * @throws IllegalArgumentException if "filters" is empty or only contains nulls */ public static Filter and(final List filters) { @@ -501,15 +502,18 @@ public static Filter and(final List filters) } /** - * Like {@link #and}, but returns an empty Optional instead of throwing an exception if "filters" is empty. + * Like {@link #and}, but returns an empty Optional instead of throwing an exception if "filters" is empty + * or only contains nulls. */ public static Optional maybeAnd(List filters) { - if (filters.isEmpty()) { + final List nonNullFilters = nonNull(filters); + + if (nonNullFilters.isEmpty()) { return Optional.empty(); } - final LinkedHashSet filtersToUse = flattenAndChildren(filters); + final LinkedHashSet filtersToUse = flattenAndChildren(nonNullFilters); if (filtersToUse.isEmpty()) { assert !filters.isEmpty(); @@ -527,7 +531,7 @@ public static Optional maybeAnd(List filters) /** * Create a filter representing an OR relationship across a list of filters. Deduplicates filters, flattens stacks, - * and removes literal "false" filters. + * and removes null filters and literal "false" filters. * * @param filters List of filters * @@ -542,18 +546,21 @@ public static Filter or(final List filters) } /** - * Like {@link #or}, but returns an empty Optional instead of throwing an exception if "filters" is empty. + * Like {@link #or}, but returns an empty Optional instead of throwing an exception if "filters" is empty + * or only contains nulls. */ public static Optional maybeOr(final List filters) { - if (filters.isEmpty()) { + final List nonNullFilters = nonNull(filters); + + if (nonNullFilters.isEmpty()) { return Optional.empty(); } - final LinkedHashSet filtersToUse = flattenOrChildren(filters); + final LinkedHashSet filtersToUse = flattenOrChildren(nonNullFilters); if (filtersToUse.isEmpty()) { - assert !filters.isEmpty(); + assert !nonNullFilters.isEmpty(); // Original "filters" list must have been 100% literally-false filters. return Optional.of(FalseFilter.instance()); } else if (filtersToUse.stream().anyMatch(filter -> filter instanceof TrueFilter)) { @@ -595,6 +602,20 @@ public static boolean filterMatchesNull(Filter filter) return valueMatcher.matches(); } + + /** + * Returns a list equivalent to the input list, but with nulls removed. If the original list has no nulls, + * it is returned directly. + */ + private static List nonNull(final List filters) + { + if (filters.stream().anyMatch(Objects::isNull)) { + return filters.stream().filter(Objects::nonNull).collect(Collectors.toList()); + } else { + return filters; + } + } + /** * Flattens children of an AND, removes duplicates, and removes literally-true filters. */ diff --git a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java index 2002ee100997..34ac51c2b549 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java +++ b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegment.java @@ -66,9 +66,9 @@ public HashJoinSegment( this.clauses = clauses; this.joinFilterPreAnalysis = joinFilterPreAnalysis; - // Verify 'clauses' is nonempty (otherwise it's a waste to create this object, and the caller should know) - if (clauses.isEmpty()) { - throw new IAE("'clauses' is empty, no need to create HashJoinSegment"); + // Verify this virtual segment is doing something useful (otherwise it's a waste to create this object) + if (clauses.isEmpty() && baseFilter == null) { + throw new IAE("'clauses' and 'baseFilter' are both empty, no need to create HashJoinSegment"); } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java index 4df490f6511e..86b7ef4aa7dc 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java +++ b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java @@ -37,16 +37,19 @@ import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.data.Indexed; import org.apache.druid.segment.data.ListIndexed; +import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.join.filter.JoinFilterAnalyzer; import org.apache.druid.segment.join.filter.JoinFilterPreAnalysis; import org.apache.druid.segment.join.filter.JoinFilterPreAnalysisKey; import org.apache.druid.segment.join.filter.JoinFilterSplit; +import org.apache.druid.segment.vector.VectorCursor; import org.joda.time.DateTime; import org.joda.time.Interval; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; @@ -56,6 +59,8 @@ public class HashJoinSegmentStorageAdapter implements StorageAdapter { private final StorageAdapter baseAdapter; + + @Nullable private final Filter baseFilter; private final List clauses; private final JoinFilterPreAnalysis joinFilterPreAnalysis; @@ -84,7 +89,7 @@ public class HashJoinSegmentStorageAdapter implements StorageAdapter */ HashJoinSegmentStorageAdapter( final StorageAdapter baseAdapter, - final Filter baseFilter, + @Nullable final Filter baseFilter, final List clauses, final JoinFilterPreAnalysis joinFilterPreAnalysis ) @@ -221,6 +226,43 @@ public Metadata getMetadata() throw new UnsupportedOperationException("Cannot retrieve metadata from join segment"); } + @Override + public boolean canVectorize(@Nullable Filter filter, VirtualColumns virtualColumns, boolean descending) + { + // HashJoinEngine isn't vectorized yet. + // However, we can still vectorize if there are no clauses, since that means all we need to do is apply + // a base filter. That's easy enough! + return clauses.isEmpty() && baseAdapter.canVectorize(baseFilterAnd(filter), virtualColumns, descending); + } + + @Nullable + @Override + public VectorCursor makeVectorCursor( + @Nullable Filter filter, + Interval interval, + VirtualColumns virtualColumns, + boolean descending, + int vectorSize, + @Nullable QueryMetrics queryMetrics + ) + { + if (!canVectorize(filter, virtualColumns, descending)) { + throw new ISE("Cannot vectorize. Check 'canVectorize' before calling 'makeVectorCursor'."); + } + + // Should have been checked by canVectorize. + assert clauses.isEmpty(); + + return baseAdapter.makeVectorCursor( + baseFilterAnd(filter), + interval, + virtualColumns, + descending, + vectorSize, + queryMetrics + ); + } + @Override public Sequence makeCursors( @Nullable final Filter filter, @@ -231,6 +273,19 @@ public Sequence makeCursors( @Nullable final QueryMetrics queryMetrics ) { + final Filter combinedFilter = baseFilterAnd(filter); + + if (clauses.isEmpty()) { + return baseAdapter.makeCursors( + combinedFilter, + interval, + virtualColumns, + gran, + descending, + queryMetrics + ); + } + // Filter pre-analysis key implied by the call to "makeCursors". We need to sanity-check that it matches // the actual pre-analysis that was done. Note: we can't infer a rewrite config from the "makeCursors" call (it // requires access to the query context) so we'll need to skip sanity-checking it, by re-using the one present @@ -240,7 +295,7 @@ public Sequence makeCursors( joinFilterPreAnalysis.getKey().getRewriteConfig(), clauses, virtualColumns, - filter + combinedFilter ); final JoinFilterPreAnalysisKey keyCached = joinFilterPreAnalysis.getKey(); @@ -363,4 +418,10 @@ private Optional getClauseForColumn(final String column) .filter(clause -> clause.includesColumn(column)) .findFirst(); } + + @Nullable + private Filter baseFilterAnd(@Nullable final Filter other) + { + return Filters.maybeAnd(Arrays.asList(baseFilter, other)).orElse(null); + } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java index 23875a0aceb0..53460b7f13da 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java +++ b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -58,6 +59,7 @@ public class JoinConditionAnalysis private final boolean isAlwaysTrue; private final boolean canHashJoin; private final Set rightKeyColumns; + private final Set requiredColumns; private JoinConditionAnalysis( final String originalExpression, @@ -80,6 +82,7 @@ private JoinConditionAnalysis( ExprUtils.nilBindings()).asBoolean()); canHashJoin = nonEquiConditions.stream().allMatch(Expr::isLiteral); rightKeyColumns = getEquiConditions().stream().map(Equality::getRightColumn).collect(Collectors.toSet()); + requiredColumns = computeRequiredColumns(rightPrefix, equiConditions, nonEquiConditions); } /** @@ -192,6 +195,15 @@ public Set getRightEquiConditionKeys() return rightKeyColumns; } + /** + * Returns the set of column names required by this join condition. Columns from the right-hand side are returned + * with their prefixes included. + */ + public Set getRequiredColumns() + { + return requiredColumns; + } + @Override public boolean equals(Object o) { @@ -217,4 +229,24 @@ public String toString() { return originalExpression; } + + private static Set computeRequiredColumns( + final String rightPrefix, + final List equiConditions, + final List nonEquiConditions + ) + { + final Set requiredColumns = new HashSet<>(); + + for (Equality equality : equiConditions) { + requiredColumns.add(rightPrefix + equality.getRightColumn()); + requiredColumns.addAll(equality.getLeftExpr().analyzeInputs().getRequiredBindings()); + } + + for (Expr expr : nonEquiConditions) { + requiredColumns.addAll(expr.analyzeInputs().getRequiredBindings()); + } + + return requiredColumns; + } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/Joinable.java b/processing/src/main/java/org/apache/druid/segment/join/Joinable.java index f22134bc0c28..25957f7e9a8f 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/Joinable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/Joinable.java @@ -85,6 +85,15 @@ JoinMatcher makeJoinMatcher( Closer closer ); + /** + * Returns all nonnull values from a particular column if they are all unique, if there are "maxNumValues" or fewer, + * and if the column exists and supports this operation. Otherwise, returns an empty Optional. + * + * @param columnName name of the column + * @param maxNumValues maximum number of values to return + */ + Optional> getNonNullColumnValuesIfAllUnique(String columnName, int maxNumValues); + /** * Searches a column from this Joinable for a particular value, finds rows that match, * and returns values of a second column for those rows. @@ -93,9 +102,9 @@ JoinMatcher makeJoinMatcher( * @param searchColumnValue Target value of the search column. This is the value that is being filtered on. * @param retrievalColumnName The column to retrieve values from. This is the column that is being joined against. * @param maxCorrelationSetSize Maximum number of values to retrieve. If we detect that more values would be - * returned than this limit, return an empty set. + * returned than this limit, return absent. * @param allowNonKeyColumnSearch If true, allow searchs on non-key columns. If this is false, - * a search on a non-key column should return an empty set. + * a search on a non-key column returns absent. * @return The set of correlated column values. If we cannot determine correlated values, return absent. * * In case either the search or retrieval column names are not found, this will return absent. diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java index b076b1ad825c..f462b93cab81 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java +++ b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java @@ -19,12 +19,22 @@ package org.apache.druid.segment.join; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Multiset; +import com.google.common.collect.Sets; +import com.google.common.primitives.Ints; import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.Query; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.filter.Filter; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.query.planning.PreJoinableClause; import org.apache.druid.segment.SegmentReference; @@ -36,8 +46,13 @@ import org.apache.druid.segment.join.filter.rewrite.JoinFilterRewriteConfig; import org.apache.druid.utils.JvmUtils; +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; @@ -61,6 +76,7 @@ public JoinableFactoryWrapper(final JoinableFactory joinableFactory) * Creates a Function that maps base segments to {@link HashJoinSegment} if needed (i.e. if the number of join * clauses is > 0). If mapping is not needed, this method will return {@link Function#identity()}. * + * @param baseFilter Filter to apply before the join takes place * @param clauses Pre-joinable clauses * @param cpuTimeAccumulator An accumulator that we will add CPU nanos to; this is part of the function to encourage * callers to remember to track metrics on CPU time required for creation of Joinables @@ -70,7 +86,7 @@ public JoinableFactoryWrapper(final JoinableFactory joinableFactory) * query from the end user. */ public Function createSegmentMapFn( - final Filter baseFilter, + @Nullable final Filter baseFilter, final List clauses, final AtomicLong cpuTimeAccumulator, final Query query @@ -84,22 +100,48 @@ public Function createSegmentMapFn( return Function.identity(); } else { final JoinableClauses joinableClauses = JoinableClauses.createClauses(clauses, joinableFactory); + final JoinFilterRewriteConfig filterRewriteConfig = JoinFilterRewriteConfig.forQuery(query); + + // Pick off any join clauses that can be converted into filters. + final Set requiredColumns = query.getRequiredColumns(); + final Filter baseFilterToUse; + final List clausesToUse; + + if (requiredColumns != null && filterRewriteConfig.isEnableRewriteJoinToFilter()) { + final Pair, List> conversionResult = convertJoinsToFilters( + joinableClauses.getJoinableClauses(), + requiredColumns, + Ints.checkedCast(Math.min(filterRewriteConfig.getFilterRewriteMaxSize(), Integer.MAX_VALUE)) + ); + + baseFilterToUse = + Filters.maybeAnd( + Lists.newArrayList( + Iterables.concat( + Collections.singleton(baseFilter), + conversionResult.lhs + ) + ) + ).orElse(null); + clausesToUse = conversionResult.rhs; + } else { + baseFilterToUse = baseFilter; + clausesToUse = joinableClauses.getJoinableClauses(); + } + + // Analyze remaining join clauses to see if filters on them can be pushed down. final JoinFilterPreAnalysis joinFilterPreAnalysis = JoinFilterAnalyzer.computeJoinFilterPreAnalysis( new JoinFilterPreAnalysisKey( - JoinFilterRewriteConfig.forQuery(query), - joinableClauses.getJoinableClauses(), + filterRewriteConfig, + clausesToUse, query.getVirtualColumns(), - Filters.toFilter(query.getFilter()) + Filters.maybeAnd(Arrays.asList(baseFilterToUse, Filters.toFilter(query.getFilter()))) + .orElse(null) ) ); return baseSegment -> - new HashJoinSegment( - baseSegment, - baseFilter, - joinableClauses.getJoinableClauses(), - joinFilterPreAnalysis - ); + new HashJoinSegment(baseSegment, baseFilterToUse, clausesToUse, joinFilterPreAnalysis); } } ); @@ -116,7 +158,9 @@ public Function createSegmentMapFn( * in the JOIN is not cacheable. * * @param dataSourceAnalysis for the join datasource + * * @return the optional cache key to be used as part of query cache key + * * @throws {@link IAE} if this operation is called on a non-join data source */ public Optional computeJoinDataSourceCacheKey( @@ -148,4 +192,112 @@ public Optional computeJoinDataSourceCacheKey( return Optional.of(keyBuilder.build()); } + + /** + * Converts any join clauses to filters that can be converted, and returns the rest as-is. + * + * See {@link #convertJoinToFilter} for details on the logic. + */ + @VisibleForTesting + static Pair, List> convertJoinsToFilters( + final List clauses, + final Set requiredColumns, + final int maxNumFilterValues + ) + { + final List filterList = new ArrayList<>(); + final List clausesToUse = new ArrayList<>(); + + // Join clauses may depend on other, earlier join clauses. + // We track that using a Multiset, because we'll need to remove required columns one by one as we convert clauses, + // and multiple clauses may refer to the same column. + final Multiset columnsRequiredByJoinClauses = HashMultiset.create(); + + for (JoinableClause clause : clauses) { + for (String column : clause.getCondition().getRequiredColumns()) { + columnsRequiredByJoinClauses.add(column, 1); + } + } + + // Walk through the list of clauses, picking off any from the start of the list that can be converted to filters. + boolean atStart = true; + for (JoinableClause clause : clauses) { + if (atStart) { + // Remove this clause from columnsRequiredByJoinClauses. It's ok if it relies on itself. + for (String column : clause.getCondition().getRequiredColumns()) { + columnsRequiredByJoinClauses.remove(column, 1); + } + + final Optional filter = + convertJoinToFilter( + clause, + Sets.union(requiredColumns, columnsRequiredByJoinClauses.elementSet()), + maxNumFilterValues + ); + + if (filter.isPresent()) { + filterList.add(filter.get()); + } else { + clausesToUse.add(clause); + atStart = false; + } + } else { + clausesToUse.add(clause); + } + } + + // Sanity check. If this exception is ever thrown, it's a bug. + if (filterList.size() + clausesToUse.size() != clauses.size()) { + throw new ISE("Lost a join clause during planning"); + } + + return Pair.of(filterList, clausesToUse); + } + + /** + * Converts a join clause into an "in" filter if possible. + * + * The requirements are: + * + * - it must be an INNER equi-join + * - the right-hand columns referenced by the condition must not have any duplicate values + * - no columns from the right-hand side can appear in "requiredColumns" + */ + @VisibleForTesting + static Optional convertJoinToFilter( + final JoinableClause clause, + final Set requiredColumns, + final int maxNumFilterValues + ) + { + if (clause.getJoinType() == JoinType.INNER + && requiredColumns.stream().noneMatch(clause::includesColumn) + && clause.getCondition().getNonEquiConditions().isEmpty() + && clause.getCondition().getEquiConditions().size() > 0) { + final List filters = new ArrayList<>(); + int numValues = maxNumFilterValues; + + for (final Equality condition : clause.getCondition().getEquiConditions()) { + final String leftColumn = condition.getLeftExpr().getBindingIfIdentifier(); + + if (leftColumn == null) { + return Optional.empty(); + } + + final Optional> columnValuesForFilter = + clause.getJoinable().getNonNullColumnValuesIfAllUnique(condition.getRightColumn(), numValues); + + if (columnValuesForFilter.isPresent()) { + numValues -= columnValuesForFilter.get().size(); + filters.add(Filters.toFilter(new InDimFilter(leftColumn, columnValuesForFilter.get()))); + } else { + return Optional.empty(); + } + } + + return Optional.of(Filters.and(filters)); + } + + return Optional.empty(); + } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java b/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java index ec18f03c0a68..88bf00bf4e4f 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfig.java @@ -47,6 +47,12 @@ public class JoinFilterRewriteConfig */ private final boolean enableRewriteValueColumnFilters; + /** + * Whether to enable eliminating entire inner join clauses by rewriting them into filters on the base segment. + * In production this should generally be {@code QueryContexts.getEnableRewriteJoinToFilter(query)}. + */ + private final boolean enableRewriteJoinToFilter; + /** * The max allowed size of correlated value sets for RHS rewrites. In production * This should generally be {@code QueryContexts.getJoinFilterRewriteMaxSize(query)}. @@ -57,12 +63,14 @@ public JoinFilterRewriteConfig( boolean enableFilterPushDown, boolean enableFilterRewrite, boolean enableRewriteValueColumnFilters, + boolean enableRewriteJoinToFilter, long filterRewriteMaxSize ) { this.enableFilterPushDown = enableFilterPushDown; this.enableFilterRewrite = enableFilterRewrite; this.enableRewriteValueColumnFilters = enableRewriteValueColumnFilters; + this.enableRewriteJoinToFilter = enableRewriteJoinToFilter; this.filterRewriteMaxSize = filterRewriteMaxSize; } @@ -72,6 +80,7 @@ public static JoinFilterRewriteConfig forQuery(final Query query) QueryContexts.getEnableJoinFilterPushDown(query), QueryContexts.getEnableJoinFilterRewrite(query), QueryContexts.getEnableJoinFilterRewriteValueColumnFilters(query), + QueryContexts.getEnableRewriteJoinToFilter(query), QueryContexts.getJoinFilterRewriteMaxSize(query) ); } @@ -91,6 +100,11 @@ public boolean isEnableRewriteValueColumnFilters() return enableRewriteValueColumnFilters; } + public boolean isEnableRewriteJoinToFilter() + { + return enableRewriteJoinToFilter; + } + public long getFilterRewriteMaxSize() { return filterRewriteMaxSize; @@ -106,10 +120,11 @@ public boolean equals(Object o) return false; } JoinFilterRewriteConfig that = (JoinFilterRewriteConfig) o; - return enableFilterPushDown == that.enableFilterPushDown && - enableFilterRewrite == that.enableFilterRewrite && - enableRewriteValueColumnFilters == that.enableRewriteValueColumnFilters && - filterRewriteMaxSize == that.filterRewriteMaxSize; + return enableFilterPushDown == that.enableFilterPushDown + && enableFilterRewrite == that.enableFilterRewrite + && enableRewriteValueColumnFilters == that.enableRewriteValueColumnFilters + && enableRewriteJoinToFilter == that.enableRewriteJoinToFilter + && filterRewriteMaxSize == that.filterRewriteMaxSize; } @Override @@ -119,7 +134,20 @@ public int hashCode() enableFilterPushDown, enableFilterRewrite, enableRewriteValueColumnFilters, + enableRewriteJoinToFilter, filterRewriteMaxSize ); } + + @Override + public String toString() + { + return "JoinFilterRewriteConfig{" + + "enableFilterPushDown=" + enableFilterPushDown + + ", enableFilterRewrite=" + enableFilterRewrite + + ", enableRewriteValueColumnFilters=" + enableRewriteValueColumnFilters + + ", enableRewriteJoinToFilter=" + enableRewriteJoinToFilter + + ", filterRewriteMaxSize=" + filterRewriteMaxSize + + '}'; + } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java index 109da85ab460..2d3c43d76883 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java @@ -21,6 +21,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.query.lookup.LookupExtractor; import org.apache.druid.segment.ColumnSelectorFactory; @@ -34,6 +36,7 @@ import javax.annotation.Nullable; import java.io.Closeable; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; @@ -92,6 +95,39 @@ public JoinMatcher makeJoinMatcher( return LookupJoinMatcher.create(extractor, leftSelectorFactory, condition, remainderNeeded); } + @Override + public Optional> getNonNullColumnValuesIfAllUnique(String columnName, int maxNumValues) + { + if (LookupColumnSelectorFactory.KEY_COLUMN.equals(columnName) && extractor.canGetKeySet()) { + final Set keys = extractor.keySet(); + + final Set nullEquivalentValues = new HashSet<>(); + nullEquivalentValues.add(null); + if (NullHandling.replaceWithDefault()) { + nullEquivalentValues.add(NullHandling.defaultStringValue()); + } + + // size() of Sets.difference is slow; avoid it. + int nonNullKeys = keys.size(); + + for (String value : nullEquivalentValues) { + if (keys.contains(value)) { + nonNullKeys--; + } + } + + if (nonNullKeys > maxNumValues) { + return Optional.empty(); + } else if (nonNullKeys == keys.size()) { + return Optional.of(keys); + } else { + return Optional.of(Sets.difference(keys, nullEquivalentValues)); + } + } else { + return Optional.empty(); + } + } + @Override public Optional> getCorrelatedColumnValues( String searchColumnName, diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java index 4faaf549cd0b..e59b4fe999fe 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java @@ -20,8 +20,10 @@ package org.apache.druid.segment.join.table; import it.unimi.dsi.fastutil.ints.IntList; +import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.segment.join.JoinMatcher; @@ -35,6 +37,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.TreeSet; public class IndexedTableJoinable implements Joinable { @@ -88,6 +91,42 @@ public JoinMatcher makeJoinMatcher( ); } + @Override + public Optional> getNonNullColumnValuesIfAllUnique(final String columnName, final int maxNumValues) + { + final int columnPosition = table.rowSignature().indexOf(columnName); + + if (columnPosition < 0) { + return Optional.empty(); + } + + try (final IndexedTable.Reader reader = table.columnReader(columnPosition)) { + // Sorted set to encourage "in" filters that result from this method to do dictionary lookups in order. + // The hopes are that this will improve locality and therefore improve performance. + final Set allValues = new TreeSet<>(); + + for (int i = 0; i < table.numRows(); i++) { + final String s = DimensionHandlerUtils.convertObjectToString(reader.read(i)); + + if (!NullHandling.isNullOrEquivalent(s)) { + if (!allValues.add(s)) { + // Duplicate found. Since the values are not all unique, we must return an empty Optional. + return Optional.empty(); + } + + if (allValues.size() > maxNumValues) { + return Optional.empty(); + } + } + } + + return Optional.of(allValues); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + @Override public Optional> getCorrelatedColumnValues( String searchColumnName, @@ -112,7 +151,7 @@ public Optional> getCorrelatedColumnValues( IntList rowIndex = index.find(searchColumnValue); for (int i = 0; i < rowIndex.size(); i++) { int rowNum = rowIndex.getInt(i); - String correlatedDimVal = Objects.toString(reader.read(rowNum), null); + String correlatedDimVal = DimensionHandlerUtils.convertObjectToString(reader.read(rowNum)); correlatedValues.add(correlatedDimVal); if (correlatedValues.size() > maxCorrelationSetSize) { @@ -132,7 +171,7 @@ public Optional> getCorrelatedColumnValues( for (int i = 0; i < table.numRows(); i++) { String dimVal = Objects.toString(dimNameReader.read(i), null); if (searchColumnValue.equals(dimVal)) { - String correlatedDimVal = Objects.toString(correlatedColumnReader.read(i), null); + String correlatedDimVal = DimensionHandlerUtils.convertObjectToString(correlatedColumnReader.read(i)); correlatedValues.add(correlatedDimVal); if (correlatedValues.size() > maxCorrelationSetSize) { return Optional.empty(); diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java index b6f76b4be160..cec90ed2df28 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryTest.java @@ -21,11 +21,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Ordering; import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.Warning; import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.BaseQuery; import org.apache.druid.query.Query; import org.apache.druid.query.QueryRunnerTestHelper; @@ -40,6 +42,7 @@ import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.junit.Assert; import org.junit.Test; @@ -80,6 +83,33 @@ public void testQuerySerialization() throws IOException Assert.assertEquals(query, serdeQuery); } + @Test + public void testGetRequiredColumns() + { + final GroupByQuery query = GroupByQuery + .builder() + .setDataSource(QueryRunnerTestHelper.DATA_SOURCE) + .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD) + .setVirtualColumns(new ExpressionVirtualColumn("v", "\"other\"", ValueType.STRING, ExprMacroTable.nil())) + .setDimensions(new DefaultDimensionSpec("quality", "alias"), DefaultDimensionSpec.of("v")) + .setAggregatorSpecs(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .setGranularity(QueryRunnerTestHelper.DAY_GRAN) + .setPostAggregatorSpecs(ImmutableList.of(new FieldAccessPostAggregator("x", "idx"))) + .setLimitSpec( + new DefaultLimitSpec( + ImmutableList.of(new OrderByColumnSpec( + "alias", + OrderByColumnSpec.Direction.ASCENDING, + StringComparators.LEXICOGRAPHIC + )), + 100 + ) + ) + .build(); + + Assert.assertEquals(ImmutableSet.of("__time", "quality", "other", "index"), query.getRequiredColumns()); + } + @Test public void testRowOrderingMixTypes() { diff --git a/processing/src/test/java/org/apache/druid/query/scan/ScanQueryTest.java b/processing/src/test/java/org/apache/druid/query/scan/ScanQueryTest.java index 1854883ca5b9..7972725deae7 100644 --- a/processing/src/test/java/org/apache/druid/query/scan/ScanQueryTest.java +++ b/processing/src/test/java/org/apache/druid/query/scan/ScanQueryTest.java @@ -269,4 +269,47 @@ public void testTimeOrderingWithoutTimeColumn() // This should throw an ISE List res = borkedSequence.toList(); } + + @Test + public void testGetRequiredColumnsWithNoColumns() + { + final ScanQuery query = + Druids.newScanQueryBuilder() + .order(ScanQuery.Order.DESCENDING) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_LIST) + .dataSource("some src") + .intervals(intervalSpec) + .build(); + + Assert.assertNull(query.getRequiredColumns()); + } + + @Test + public void testGetRequiredColumnsWithEmptyColumns() + { + final ScanQuery query = + Druids.newScanQueryBuilder() + .order(ScanQuery.Order.DESCENDING) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_LIST) + .dataSource("some src") + .intervals(intervalSpec) + .columns(Collections.emptyList()) + .build(); + + Assert.assertNull(query.getRequiredColumns()); + } + + @Test + public void testGetRequiredColumnsWithColumns() + { + final ScanQuery query = + Druids.newScanQueryBuilder() + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_LIST) + .dataSource("some src") + .intervals(intervalSpec) + .columns("foo", "bar") + .build(); + + Assert.assertEquals(ImmutableSet.of("__time", "foo", "bar"), query.getRequiredColumns()); + } } diff --git a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryTest.java b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryTest.java index 310880244e29..54bebf8df406 100644 --- a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryTest.java +++ b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryTest.java @@ -20,10 +20,15 @@ package org.apache.druid.query.timeseries; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableSet; +import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.Druids; import org.apache.druid.query.Query; import org.apache.druid.query.QueryRunnerTestHelper; +import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -54,13 +59,13 @@ public TimeseriesQueryTest(boolean descending) public void testQuerySerialization() throws IOException { Query query = Druids.newTimeseriesQueryBuilder() - .dataSource(QueryRunnerTestHelper.DATA_SOURCE) - .granularity(QueryRunnerTestHelper.DAY_GRAN) - .intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC) - .aggregators(QueryRunnerTestHelper.ROWS_COUNT, QueryRunnerTestHelper.INDEX_DOUBLE_SUM) - .postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT) - .descending(descending) - .build(); + .dataSource(QueryRunnerTestHelper.DATA_SOURCE) + .granularity(QueryRunnerTestHelper.DAY_GRAN) + .intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC) + .aggregators(QueryRunnerTestHelper.ROWS_COUNT, QueryRunnerTestHelper.INDEX_DOUBLE_SUM) + .postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT) + .descending(descending) + .build(); String json = JSON_MAPPER.writeValueAsString(query); Query serdeQuery = JSON_MAPPER.readValue(json, Query.class); @@ -68,4 +73,32 @@ public void testQuerySerialization() throws IOException Assert.assertEquals(query, serdeQuery); } + @Test + public void testGetRequiredColumns() + { + final TimeseriesQuery query = + Druids.newTimeseriesQueryBuilder() + .dataSource(QueryRunnerTestHelper.DATA_SOURCE) + .granularity(QueryRunnerTestHelper.DAY_GRAN) + .virtualColumns( + new ExpressionVirtualColumn( + "index", + "\"fieldFromVirtualColumn\"", + ValueType.LONG, + ExprMacroTable.nil() + ) + ) + .intervals(QueryRunnerTestHelper.FULL_ON_INTERVAL_SPEC) + .aggregators( + QueryRunnerTestHelper.ROWS_COUNT, + QueryRunnerTestHelper.INDEX_DOUBLE_SUM, + QueryRunnerTestHelper.INDEX_LONG_MAX, + new LongSumAggregatorFactory("beep", "aField") + ) + .postAggregators(QueryRunnerTestHelper.ADD_ROWS_INDEX_CONSTANT) + .descending(descending) + .build(); + + Assert.assertEquals(ImmutableSet.of("__time", "fieldFromVirtualColumn", "aField"), query.getRequiredColumns()); + } } diff --git a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryTest.java b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryTest.java index 82c77b1ebd3e..5aede90d29fa 100644 --- a/processing/src/test/java/org/apache/druid/query/topn/TopNQueryTest.java +++ b/processing/src/test/java/org/apache/druid/query/topn/TopNQueryTest.java @@ -20,19 +20,27 @@ package org.apache.druid.query.topn; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.Query; import org.apache.druid.query.QueryRunnerTestHelper; import org.apache.druid.query.aggregation.DoubleMaxAggregatorFactory; import org.apache.druid.query.aggregation.DoubleMinAggregatorFactory; +import org.apache.druid.query.aggregation.LongSumAggregatorFactory; +import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; +import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.ExtractionDimensionSpec; import org.apache.druid.query.dimension.LegacyDimensionSpec; import org.apache.druid.query.extraction.MapLookupExtractor; import org.apache.druid.query.lookup.LookupExtractionFn; import org.apache.druid.query.ordering.StringComparators; import org.apache.druid.segment.TestHelper; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -240,4 +248,22 @@ public void testQueryNullMetric() throws IOException String json = JSON_MAPPER.writeValueAsString(query); JSON_MAPPER.readValue(json, Query.class); } + + @Test + public void testGetRequiredColumns() + { + final TopNQuery query = new TopNQueryBuilder() + .dataSource(QueryRunnerTestHelper.DATA_SOURCE) + .intervals(QueryRunnerTestHelper.FIRST_TO_THIRD) + .virtualColumns(new ExpressionVirtualColumn("v", "\"other\"", ValueType.STRING, ExprMacroTable.nil())) + .dimension(DefaultDimensionSpec.of("v")) + .aggregators(QueryRunnerTestHelper.ROWS_COUNT, new LongSumAggregatorFactory("idx", "index")) + .granularity(QueryRunnerTestHelper.DAY_GRAN) + .postAggregators(ImmutableList.of(new FieldAccessPostAggregator("x", "idx"))) + .metric(new NumericTopNMetricSpec("idx")) + .threshold(100) + .build(); + + Assert.assertEquals(ImmutableSet.of("__time", "other", "index"), query.getRequiredColumns()); + } } diff --git a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java index d5dc9a2f1f8f..26ba16119a0e 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java @@ -30,6 +30,7 @@ import org.apache.druid.segment.VirtualColumn; import org.apache.druid.segment.VirtualColumns; import org.apache.druid.segment.column.ValueType; +import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.join.filter.JoinFilterAnalyzer; import org.apache.druid.segment.join.filter.JoinFilterPreAnalysis; import org.apache.druid.segment.join.filter.JoinFilterPreAnalysisKey; @@ -48,6 +49,7 @@ import org.junit.rules.TemporaryFolder; import java.io.IOException; +import java.util.Collections; import java.util.List; public class BaseHashJoinSegmentStorageAdapterTest @@ -56,6 +58,7 @@ public class BaseHashJoinSegmentStorageAdapterTest true, true, true, + QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER, QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE ); @@ -235,12 +238,16 @@ protected static JoinFilterPreAnalysis makeDefaultConfigPreAnalysis( VirtualColumns virtualColumns ) { + // Seemingly-useless "Filter.maybeAnd" is here to dedupe filters, flatten stacks, etc, in the same way that + // JoinableFactoryWrapper's segmentMapFn would do. + final Filter filterToUse = Filters.maybeAnd(Collections.singletonList(originalFilter)).orElse(null); + return JoinFilterAnalyzer.computeJoinFilterPreAnalysis( new JoinFilterPreAnalysisKey( DEFAULT_JOIN_FILTER_REWRITE_CONFIG, joinableClauses, virtualColumns, - originalFilter + filterToUse ) ); } diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java index 68e6426551e3..10d048305f81 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java @@ -2024,20 +2024,22 @@ public void test_makeCursors_originalFilterDoesNotMatchPreAnalysis_shouldThrowIS @Test public void test_makeCursors_factToCountryLeftWithBaseFilter() { + final Filter baseFilter = Filters.or(Arrays.asList( + new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), + new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() + )); + List joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.LEFT)); JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( - null, + baseFilter, joinableClauses, VirtualColumns.EMPTY ); JoinTestHelper.verifyCursors( new HashJoinSegmentStorageAdapter( factSegment.asStorageAdapter(), - Filters.or(Arrays.asList( - new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), - new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() - )), + baseFilter, joinableClauses, joinFilterPreAnalysis ).makeCursors( @@ -2067,19 +2069,21 @@ public void test_makeCursors_factToCountryLeftWithBaseFilter() @Test public void test_makeCursors_factToCountryInnerWithBaseFilter() { + final Filter baseFilter = Filters.or(Arrays.asList( + new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), + new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() + )); + List joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.INNER)); JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( - null, + baseFilter, joinableClauses, VirtualColumns.EMPTY ); JoinTestHelper.verifyCursors( new HashJoinSegmentStorageAdapter( factSegment.asStorageAdapter(), - Filters.or(Arrays.asList( - new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), - new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() - )), + baseFilter, joinableClauses, joinFilterPreAnalysis ).makeCursors( @@ -2108,19 +2112,21 @@ public void test_makeCursors_factToCountryInnerWithBaseFilter() @Test public void test_makeCursors_factToCountryRightWithBaseFilter() { + final Filter baseFilter = Filters.or(Arrays.asList( + new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), + new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() + )); + List joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.RIGHT)); JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( - null, + baseFilter, joinableClauses, VirtualColumns.EMPTY ); JoinTestHelper.verifyCursors( new HashJoinSegmentStorageAdapter( factSegment.asStorageAdapter(), - Filters.or(Arrays.asList( - new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), - new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() - )), + baseFilter, joinableClauses, joinFilterPreAnalysis ).makeCursors( @@ -2166,19 +2172,21 @@ public void test_makeCursors_factToCountryRightWithBaseFilter() @Test public void test_makeCursors_factToCountryFullWithBaseFilter() { + final Filter baseFilter = Filters.or(Arrays.asList( + new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), + new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() + )); + List joinableClauses = ImmutableList.of(factToCountryOnIsoCode(JoinType.FULL)); JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( - null, + baseFilter, joinableClauses, VirtualColumns.EMPTY ); JoinTestHelper.verifyCursors( new HashJoinSegmentStorageAdapter( factSegment.asStorageAdapter(), - Filters.or(Arrays.asList( - new SelectorDimFilter("countryIsoCode", "CA", null).toFilter(), - new SelectorDimFilter("countryIsoCode", "MatchNothing", null).toFilter() - )), + baseFilter, joinableClauses, joinFilterPreAnalysis ).makeCursors( diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java index 9a56b3b6bdca..581c9a1c705f 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentTest.java @@ -22,13 +22,11 @@ import com.google.common.collect.ImmutableList; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.math.expr.ExprMacroTable; -import org.apache.druid.query.QueryContexts; import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.QueryableIndexSegment; import org.apache.druid.segment.ReferenceCountingSegment; import org.apache.druid.segment.SegmentReference; import org.apache.druid.segment.StorageAdapter; -import org.apache.druid.segment.join.filter.rewrite.JoinFilterRewriteConfig; import org.apache.druid.segment.join.table.IndexedTableJoinable; import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.timeline.SegmentId; @@ -49,14 +47,6 @@ public class HashJoinSegmentTest extends InitializedNullHandlingTest { - private static final JoinFilterRewriteConfig DEFAULT_JOIN_FILTER_REWRITE_CONFIG = - new JoinFilterRewriteConfig( - true, - true, - true, - QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE - ); - @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); @@ -205,7 +195,7 @@ public Optional acquireReferences() public void test_constructor_noClauses() { expectedException.expect(IllegalArgumentException.class); - expectedException.expectMessage("'clauses' is empty, no need to create HashJoinSegment"); + expectedException.expectMessage("'clauses' and 'baseFilter' are both empty, no need to create HashJoinSegment"); List joinableClauses = ImmutableList.of(); diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java index 875f686af577..1ab6c0922994 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java @@ -274,6 +274,15 @@ public void test_forExpression_mixedAndWithOr() Assert.assertEquals(analysis.getRightEquiConditionKeys(), ImmutableSet.of("y")); } + @Test + public void test_getRequiredColumns() + { + final String expression = "(x == \"j.y\") && ((x + y == \"j.z\") || (z == \"j.zz\"))"; + final JoinConditionAnalysis analysis = analyze(expression); + + Assert.assertEquals(ImmutableSet.of("x", "j.y", "y", "j.z", "z", "j.zz"), analysis.getRequiredColumns()); + } + @Test public void test_equals() { @@ -281,7 +290,7 @@ public void test_equals() .usingGetClass() .withIgnoredFields( // These fields are tightly coupled with originalExpression - "equiConditions", "nonEquiConditions", + "equiConditions", "nonEquiConditions", "requiredColumns", // These fields are calculated from other other fields in the class "isAlwaysTrue", "isAlwaysFalse", "canHashJoin", "rightKeyColumns") .verify(); diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java index e422423b9521..694edde1e9de 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java @@ -2092,6 +2092,7 @@ public void test_filterPushDown_factToRegionToCountryLeftFilterOnPageDisablePush false, true, true, + QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER, QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE ), joinableClauses.getJoinableClauses(), @@ -2171,6 +2172,7 @@ public void test_filterPushDown_factToRegionToCountryLeftEnablePushDownDisableRe true, false, true, + QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER, QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE ), joinableClauses.getJoinableClauses(), @@ -2591,6 +2593,7 @@ public void test_filterPushDown_factToRegionExprToCountryLeftFilterOnCountryName true, true, true, + QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER, QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE ), joinableClauses, diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java index 94067c34412f..70491adc573f 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java @@ -21,25 +21,30 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterators; +import com.google.common.collect.Sets; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.common.config.NullHandlingTest; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.DataSource; import org.apache.druid.query.GlobalTableDataSource; import org.apache.druid.query.LookupDataSource; -import org.apache.druid.query.QueryContexts; import org.apache.druid.query.TableDataSource; import org.apache.druid.query.TestQuery; import org.apache.druid.query.extraction.MapLookupExtractor; import org.apache.druid.query.filter.FalseDimFilter; +import org.apache.druid.query.filter.Filter; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.TrueDimFilter; import org.apache.druid.query.planning.DataSourceAnalysis; import org.apache.druid.query.planning.PreJoinableClause; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.segment.SegmentReference; -import org.apache.druid.segment.join.filter.rewrite.JoinFilterRewriteConfig; import org.apache.druid.segment.join.lookup.LookupJoinable; import org.easymock.EasyMock; import org.junit.Assert; @@ -51,22 +56,31 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; -public class JoinableFactoryWrapperTest +public class JoinableFactoryWrapperTest extends NullHandlingTest { - private static final JoinFilterRewriteConfig DEFAULT_JOIN_FILTER_REWRITE_CONFIG = new JoinFilterRewriteConfig( - QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_PUSH_DOWN, - QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE, - QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS, - QueryContexts.DEFAULT_ENABLE_JOIN_FILTER_REWRITE_MAX_SIZE - ); - private static final JoinableFactoryWrapper NOOP_JOINABLE_FACTORY_WRAPPER = new JoinableFactoryWrapper( NoopJoinableFactory.INSTANCE); + private static final Map TEST_LOOKUP = + ImmutableMap.builder() + .put("MX", "Mexico") + .put("NO", "Norway") + .put("SV", "El Salvador") + .put("US", "United States") + .put("", "Empty key") + .build(); + + private static final Set TEST_LOOKUP_KEYS = + NullHandling.sqlCompatible() + ? TEST_LOOKUP.keySet() + : Sets.difference(TEST_LOOKUP.keySet(), Collections.singleton("")); + @Rule public ExpectedException expectedException = ExpectedException.none(); @@ -428,6 +442,300 @@ public void test_checkClausePrefixesForDuplicatesAndShadowing_shadowing() JoinPrefixUtils.checkPrefixesForDuplicatesAndShadowing(prefixes); } + @Test + public void test_convertJoinsToFilters_convertInnerJoin() + { + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + ImmutableList.of( + new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil()) + ) + ), + ImmutableSet.of("x"), + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(new InDimFilter("x", TEST_LOOKUP_KEYS)), + ImmutableList.of() + ), + conversion + ); + } + + @Test + public void test_convertJoinsToFilters_convertTwoInnerJoins() + { + final ImmutableList clauses = ImmutableList.of( + new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil()) + ), + new JoinableClause( + "_j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x == \"_j.k\"", "_j.", ExprMacroTable.nil()) + ), + new JoinableClause( + "__j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.LEFT, + JoinConditionAnalysis.forExpression("x == \"__j.k\"", "__j.", ExprMacroTable.nil()) + ) + ); + + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + clauses, + ImmutableSet.of("x"), + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(new InDimFilter("x", TEST_LOOKUP_KEYS), new InDimFilter("x", TEST_LOOKUP_KEYS)), + ImmutableList.of(clauses.get(2)) + ), + conversion + ); + } + + @Test + public void test_convertJoinsToFilters_dontConvertTooManyValues() + { + final JoinableClause clause = new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil()) + ); + + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + ImmutableList.of( + clause + ), + ImmutableSet.of("x"), + 2 + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(), + ImmutableList.of(clause) + ), + conversion + ); + } + + @Test + public void test_convertJoinsToFilters_dontConvertLeftJoin() + { + final JoinableClause clause = new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.LEFT, + JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil()) + ); + + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + ImmutableList.of(clause), + ImmutableSet.of("x"), + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(), + ImmutableList.of(clause) + ), + conversion + ); + } + + @Test + public void test_convertJoinsToFilters_dontConvertWhenColumnIsUsed() + { + final JoinableClause clause = new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil()) + ); + + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + ImmutableList.of(clause), + ImmutableSet.of("x", "j.k"), + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(), + ImmutableList.of(clause) + ), + conversion + ); + } + + @Test + public void test_convertJoinsToFilters_dontConvertLhsFunctions() + { + final JoinableClause clause = new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("concat(x,'') == \"j.k\"", "j.", ExprMacroTable.nil()) + ); + + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + ImmutableList.of(clause), + ImmutableSet.of("x"), + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(), + ImmutableList.of(clause) + ), + conversion + ); + } + + @Test + public void test_convertJoinsToFilters_dontConvertRhsFunctions() + { + final JoinableClause clause = new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x == concat(\"j.k\",'')", "j.", ExprMacroTable.nil()) + ); + + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + ImmutableList.of(clause), + ImmutableSet.of("x"), + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(), + ImmutableList.of(clause) + ), + conversion + ); + } + + @Test + public void test_convertJoinsToFilters_dontConvertNonEquiJoin() + { + final JoinableClause clause = new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x != \"j.k\"", "j.", ExprMacroTable.nil()) + ); + + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + ImmutableList.of(clause), + ImmutableSet.of("x"), + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(), + ImmutableList.of(clause) + ), + conversion + ); + } + + @Test + public void test_convertJoinsToFilters_dontConvertJoinsDependedOnByLaterJoins() + { + final ImmutableList clauses = ImmutableList.of( + new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil()) + ), + new JoinableClause( + "_j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("\"j.k\" == \"_j.k\"", "_j.", ExprMacroTable.nil()) + ), + new JoinableClause( + "__j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.LEFT, + JoinConditionAnalysis.forExpression("x == \"__j.k\"", "__j.", ExprMacroTable.nil()) + ) + ); + + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + clauses, + ImmutableSet.of("x"), + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(), + clauses + ), + conversion + ); + } + + @Test + public void test_convertJoinsToFilters_dontConvertJoinsDependedOnByLaterJoins2() + { + final ImmutableList clauses = ImmutableList.of( + new JoinableClause( + "j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x == \"j.k\"", "j.", ExprMacroTable.nil()) + ), + new JoinableClause( + "_j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.INNER, + JoinConditionAnalysis.forExpression("x == \"_j.k\"", "_j.", ExprMacroTable.nil()) + ), + new JoinableClause( + "__j.", + LookupJoinable.wrap(new MapLookupExtractor(TEST_LOOKUP, false)), + JoinType.LEFT, + JoinConditionAnalysis.forExpression("\"_j.v\" == \"__j.k\"", "__j.", ExprMacroTable.nil()) + ) + ); + + final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters( + clauses, + ImmutableSet.of("x"), + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Pair.of( + ImmutableList.of(new InDimFilter("x", TEST_LOOKUP_KEYS)), + clauses.subList(1, clauses.size()) + ), + conversion + ); + } + private PreJoinableClause makeGlobalPreJoinableClause(String tableName, String expression, String prefix) { return makeGlobalPreJoinableClause(tableName, expression, prefix, JoinType.LEFT); diff --git a/processing/src/test/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfigTest.java b/processing/src/test/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfigTest.java new file mode 100644 index 000000000000..5d0b2f842507 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/segment/join/filter/rewrite/JoinFilterRewriteConfigTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.segment.join.filter.rewrite; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.Test; + +public class JoinFilterRewriteConfigTest +{ + @Test + public void testEquals() + { + EqualsVerifier.forClass(JoinFilterRewriteConfig.class).usingGetClass().verify(); + } +} diff --git a/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java b/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java index 2037f7763172..4b1dcdafa675 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java @@ -21,6 +21,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.common.config.NullHandlingTest; import org.apache.druid.query.lookup.LookupExtractor; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ValueType; @@ -35,12 +37,13 @@ import org.mockito.junit.MockitoJUnitRunner; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; @RunWith(MockitoJUnitRunner.class) -public class LookupJoinableTest +public class LookupJoinableTest extends NullHandlingTest { private static final String UNKNOWN_COLUMN = "UNKNOWN_COLUMN"; private static final String SEARCH_KEY_VALUE = "SEARCH_KEY_VALUE"; @@ -56,9 +59,17 @@ public class LookupJoinableTest @Before public void setUp() { + final Set keyValues = new HashSet<>(); + keyValues.add("foo"); + keyValues.add("bar"); + keyValues.add(""); + keyValues.add(null); + Mockito.doReturn(SEARCH_VALUE_VALUE).when(extractor).apply(SEARCH_KEY_VALUE); Mockito.doReturn(ImmutableList.of(SEARCH_KEY_VALUE)).when(extractor).unapply(SEARCH_VALUE_VALUE); Mockito.doReturn(ImmutableList.of()).when(extractor).unapply(SEARCH_VALUE_UNKNOWN); + Mockito.doReturn(true).when(extractor).canGetKeySet(); + Mockito.doReturn(keyValues).when(extractor).keySet(); target = LookupJoinable.wrap(extractor); } @@ -124,7 +135,8 @@ public void getCorrelatedColummnValuesMissingSearchColumnShouldReturnEmptySet() SEARCH_KEY_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, 0, - false); + false + ); Assert.assertFalse(correlatedValues.isPresent()); } @@ -138,10 +150,12 @@ public void getCorrelatedColummnValuesMissingRetrievalColumnShouldReturnEmptySet SEARCH_KEY_VALUE, UNKNOWN_COLUMN, 0, - false); + false + ); Assert.assertFalse(correlatedValues.isPresent()); } + @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnShouldReturnSearchValue() { @@ -150,7 +164,8 @@ public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnShouldRetur SEARCH_KEY_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, 0, - false); + false + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_KEY_VALUE)), correlatedValues); } @@ -162,7 +177,8 @@ public void getCorrelatedColumnValuesForSearchKeyAndRetrieveValueColumnShouldRet SEARCH_KEY_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, 0, - false); + false + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_VALUE_VALUE)), correlatedValues); } @@ -174,7 +190,8 @@ public void getCorrelatedColumnValuesForSearchKeyMissingAndRetrieveValueColumnSh SEARCH_KEY_NULL_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, 0, - false); + false + ); Assert.assertEquals(Optional.of(Collections.singleton(null)), correlatedValues); } @@ -186,14 +203,16 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnAndNonK SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, 10, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, 10, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @@ -205,7 +224,8 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnShouldR SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, 0, - true); + true + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_VALUE_VALUE)), correlatedValues); } @@ -217,7 +237,8 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnShouldRet SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, 10, - true); + true + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_KEY_VALUE)), correlatedValues); } @@ -234,7 +255,8 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnWithMaxLi SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, 0, - true); + true + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @@ -246,7 +268,46 @@ public void getCorrelatedColumnValuesForSearchUnknownValueAndRetrieveKeyColumnSh SEARCH_VALUE_UNKNOWN, LookupColumnSelectorFactory.KEY_COLUMN, 10, - true); + true + ); Assert.assertEquals(Optional.of(ImmutableSet.of()), correlatedValues); } + + @Test + public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnEmpty() + { + final Optional> values = target.getNonNullColumnValuesIfAllUnique( + LookupColumnSelectorFactory.VALUE_COLUMN, + Integer.MAX_VALUE + ); + + Assert.assertEquals(Optional.empty(), values); + } + + @Test + public void getNonNullColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() + { + final Optional> values = target.getNonNullColumnValuesIfAllUnique( + LookupColumnSelectorFactory.KEY_COLUMN, + Integer.MAX_VALUE + ); + + Assert.assertEquals( + Optional.of( + NullHandling.replaceWithDefault() ? ImmutableSet.of("foo", "bar") : ImmutableSet.of("foo", "bar", "") + ), + values + ); + } + + @Test + public void getNonNullColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() + { + final Optional> values = target.getNonNullColumnValuesIfAllUnique( + LookupColumnSelectorFactory.KEY_COLUMN, + 1 + ); + + Assert.assertEquals(Optional.empty(), values); + } } diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java index 5f54aa24e56c..a9b1ae599d88 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java @@ -50,6 +50,7 @@ public class IndexedTableJoinableTest private static final String PREFIX = "j."; private static final String KEY_COLUMN = "str"; private static final String VALUE_COLUMN = "long"; + private static final String ALL_SAME_COLUMN = "allsame"; private static final String UNKNOWN_COLUMN = "unknown"; private static final String SEARCH_KEY_NULL_VALUE = "baz"; private static final String SEARCH_KEY_VALUE = "foo"; @@ -84,13 +85,14 @@ public ColumnCapabilities getColumnCapabilities(String columnName) private final InlineDataSource inlineDataSource = InlineDataSource.fromIterable( ImmutableList.of( - new Object[]{"foo", 1L}, - new Object[]{"bar", 2L}, - new Object[]{"baz", null} + new Object[]{"foo", 1L, 1L}, + new Object[]{"bar", 2L, 1L}, + new Object[]{"baz", null, 1L} ), RowSignature.builder() - .add("str", ValueType.STRING) - .add("long", ValueType.LONG) + .add(KEY_COLUMN, ValueType.STRING) + .add(VALUE_COLUMN, ValueType.LONG) + .add(ALL_SAME_COLUMN, ValueType.LONG) .build() ); @@ -113,7 +115,7 @@ public void setUp() @Test public void getAvailableColumns() { - Assert.assertEquals(ImmutableList.of("str", "long"), target.getAvailableColumns()); + Assert.assertEquals(ImmutableList.of(KEY_COLUMN, VALUE_COLUMN, ALL_SAME_COLUMN), target.getAvailableColumns()); } @Test @@ -340,4 +342,50 @@ public void getCorrelatedColumnValuesForSearchUnknownValueAndRetrieveKeyColumnSh true); Assert.assertEquals(Optional.of(ImmutableSet.of()), correlatedValues); } + + @Test + public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnValues() + { + final Optional> values = target.getNonNullColumnValuesIfAllUnique(VALUE_COLUMN, Integer.MAX_VALUE); + + Assert.assertEquals(Optional.of(ImmutableSet.of("1", "2")), values); + } + + @Test + public void getNonNullColumnValuesIfAllUniqueForNonexistentColumnShouldReturnEmpty() + { + final Optional> values = target.getNonNullColumnValuesIfAllUnique("nonexistent", Integer.MAX_VALUE); + + Assert.assertEquals(Optional.empty(), values); + } + + @Test + public void getNonNullColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() + { + final Optional> values = target.getNonNullColumnValuesIfAllUnique(KEY_COLUMN, Integer.MAX_VALUE); + + Assert.assertEquals( + Optional.of(ImmutableSet.of("foo", "bar", "baz")), + values + ); + } + + @Test + public void getNonNullColumnValuesIfAllUniqueForAllSameColumnShouldReturnEmpty() + { + final Optional> values = target.getNonNullColumnValuesIfAllUnique(ALL_SAME_COLUMN, Integer.MAX_VALUE); + + Assert.assertEquals( + Optional.empty(), + values + ); + } + + @Test + public void getNonNullColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() + { + final Optional> values = target.getNonNullColumnValuesIfAllUnique(KEY_COLUMN, 1); + + Assert.assertEquals(Optional.empty(), values); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java index 8e039263b32e..3f61ebab2e86 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java @@ -849,6 +849,14 @@ protected void skipVectorize() skipVectorize = true; } + protected static boolean isRewriteJoinToFilter(final Map queryContext) + { + return (boolean) queryContext.getOrDefault( + QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, + QueryContexts.DEFAULT_ENABLE_REWRITE_JOIN_TO_FILTER + ); + } + /** * This is a provider of query contexts that should be used by join tests. * It tests various configs that can be passed to join queries. All the configs provided by this provider should @@ -862,23 +870,48 @@ public static Object[] provideQueryContexts() return new Object[]{ // default behavior QUERY_CONTEXT_DEFAULT, - // filter value re-writes enabled + // all rewrites enabled new ImmutableMap.Builder() .putAll(QUERY_CONTEXT_DEFAULT) .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true) .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true) + .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, true) .build(), - // rewrite values enabled but filter re-writes disabled. - // This should be drive the same behavior as the previous config + // filter-on-value-column rewrites disabled, everything else enabled + new ImmutableMap.Builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, false) + .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true) + .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, true) + .build(), + // filter rewrites fully disabled, join-to-filter enabled + new ImmutableMap.Builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, false) + .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false) + .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, true) + .build(), + // filter rewrites disabled, but value column filters still set to true (it should be ignored and this should + // behave the same as the previous context) new ImmutableMap.Builder() .putAll(QUERY_CONTEXT_DEFAULT) .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true) .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false) + .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, true) + .build(), + // filter rewrites fully enabled, join-to-filter disabled + new ImmutableMap.Builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, true) + .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, true) + .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, false) .build(), - // filter re-writes disabled + // all rewrites disabled new ImmutableMap.Builder() .putAll(QUERY_CONTEXT_DEFAULT) + .put(QueryContexts.JOIN_FILTER_REWRITE_VALUE_COLUMN_FILTERS_ENABLE_KEY, false) .put(QueryContexts.JOIN_FILTER_REWRITE_ENABLE_KEY, false) + .put(QueryContexts.REWRITE_JOIN_TO_FILTER_ENABLE_KEY, false) .build(), }; } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index c23b78a4b77d..155e39b74e1a 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -342,13 +342,17 @@ public void testJoinOuterGroupByAndSubqueryHasLimit() throws Exception } @Test - public void testJoinOuterGroupByAndSubqueryNoLimit() throws Exception + @Parameters(source = QueryContextForJoinProvider.class) + public void testJoinOuterGroupByAndSubqueryNoLimit(Map queryContext) throws Exception { - // Cannot vectorize JOIN operator. - cannotVectorize(); + // Fully removing the join allows this query to vectorize. + if (!isRewriteJoinToFilter(queryContext)) { + cannotVectorize(); + } testQuery( "SELECT dim2, AVG(m2) FROM (SELECT * FROM foo AS t1 INNER JOIN foo AS t2 ON t1.m1 = t2.m1) AS t3 GROUP BY dim2", + queryContext, ImmutableList.of( GroupByQuery.builder() .setDataSource( @@ -362,6 +366,7 @@ public void testJoinOuterGroupByAndSubqueryNoLimit() throws Exception .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .context(QUERY_CONTEXT_DEFAULT) .build() + .withOverriddenContext(queryContext) ), "j0.", equalsCondition( @@ -403,6 +408,7 @@ public void testJoinOuterGroupByAndSubqueryNoLimit() throws Exception ) .setContext(QUERY_CONTEXT_DEFAULT) .build() + .withOverriddenContext(queryContext) ), NullHandling.sqlCompatible() ? ImmutableList.of( @@ -4273,12 +4279,17 @@ public void testUnionAllSameTableThreeTimesWithSameMapping() throws Exception } @Test - public void testUnionAllTwoQueriesLeftQueryIsJoin() throws Exception + @Parameters(source = QueryContextForJoinProvider.class) + public void testUnionAllTwoQueriesLeftQueryIsJoin(Map queryContext) throws Exception { - cannotVectorize(); + // Fully removing the join allows this query to vectorize. + if (!isRewriteJoinToFilter(queryContext)) { + cannotVectorize(); + } testQuery( "(SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k) UNION ALL SELECT SUM(cnt) FROM foo", + queryContext, ImmutableList.of( Druids.newTimeseriesQueryBuilder() .dataSource( @@ -4293,7 +4304,8 @@ public void testUnionAllTwoQueriesLeftQueryIsJoin() throws Exception .granularity(Granularities.ALL) .aggregators(aggregators(new CountAggregatorFactory("a0"))) .context(TIMESERIES_CONTEXT_DEFAULT) - .build(), + .build() + .withOverriddenContext(queryContext), Druids.newTimeseriesQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) .intervals(querySegmentSpec(Filtration.eternity())) @@ -4301,18 +4313,24 @@ public void testUnionAllTwoQueriesLeftQueryIsJoin() throws Exception .aggregators(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) .context(TIMESERIES_CONTEXT_DEFAULT) .build() + .withOverriddenContext(queryContext) ), ImmutableList.of(new Object[]{1L}, new Object[]{6L}) ); } @Test - public void testUnionAllTwoQueriesRightQueryIsJoin() throws Exception + @Parameters(source = QueryContextForJoinProvider.class) + public void testUnionAllTwoQueriesRightQueryIsJoin(Map queryContext) throws Exception { - cannotVectorize(); + // Fully removing the join allows this query to vectorize. + if (!isRewriteJoinToFilter(queryContext)) { + cannotVectorize(); + } testQuery( "(SELECT SUM(cnt) FROM foo UNION ALL SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k) ", + queryContext, ImmutableList.of( Druids.newTimeseriesQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) @@ -4320,7 +4338,8 @@ public void testUnionAllTwoQueriesRightQueryIsJoin() throws Exception .granularity(Granularities.ALL) .aggregators(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) .context(TIMESERIES_CONTEXT_DEFAULT) - .build(), + .build() + .withOverriddenContext(queryContext), Druids.newTimeseriesQueryBuilder() .dataSource( join( @@ -4335,6 +4354,7 @@ public void testUnionAllTwoQueriesRightQueryIsJoin() throws Exception .aggregators(aggregators(new CountAggregatorFactory("a0"))) .context(TIMESERIES_CONTEXT_DEFAULT) .build() + .withOverriddenContext(queryContext) ), ImmutableList.of(new Object[]{6L}, new Object[]{1L}) ); @@ -8107,8 +8127,10 @@ public void testAvgDailyCountDistinct() throws Exception @Parameters(source = QueryContextForJoinProvider.class) public void testTopNFilterJoin(Map queryContext) throws Exception { - // Cannot vectorize JOIN operator. - cannotVectorize(); + // Fully removing the join allows this query to vectorize. + if (!isRewriteJoinToFilter(queryContext)) { + cannotVectorize(); + } // Filters on top N values of some dimension by using an inner join. testQuery( @@ -13173,8 +13195,10 @@ public void testGroupingSetsWithOrderByAggregatorWithLimit() throws Exception @Parameters(source = QueryContextForJoinProvider.class) public void testUsingSubqueryAsPartOfAndFilter(Map queryContext) throws Exception { - // Cannot vectorize JOIN operator. - cannotVectorize(); + // Fully removing the join allows this query to vectorize. + if (!isRewriteJoinToFilter(queryContext)) { + cannotVectorize(); + } testQuery( "SELECT dim1, dim2, COUNT(*) FROM druid.foo\n" @@ -13641,6 +13665,234 @@ public void testSemiJoinWithOuterTimeExtractScan() throws Exception ); } + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testTwoSemiJoinsSimultaneously(Map queryContext) throws Exception + { + // Fully removing the join allows this query to vectorize. + if (!isRewriteJoinToFilter(queryContext)) { + cannotVectorize(); + } + + testQuery( + "SELECT dim1, COUNT(*) FROM foo\n" + + "WHERE dim1 IN ('abc', 'def')" + + "AND __time IN (SELECT MAX(__time) FROM foo WHERE cnt = 1)\n" + + "AND __time IN (SELECT MAX(__time) FROM foo WHERE cnt <> 2)\n" + + "GROUP BY 1", + queryContext, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(selector("cnt", "1", null)) + .aggregators(new LongMaxAggregatorFactory("a0", "__time")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "j0.", + "(\"__time\" == \"j0.a0\")", + JoinType.INNER + ), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(not(selector("cnt", "2", null))) + .aggregators(new LongMaxAggregatorFactory("a0", "__time")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "_j0.", + "(\"__time\" == \"_j0.a0\")", + JoinType.INNER + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(in("dim1", ImmutableList.of("abc", "def"), null)) + .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0", ValueType.STRING))) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) + .setContext(queryContext) + .build() + ), + ImmutableList.of(new Object[]{"abc", 1L}) + ); + } + + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testSemiAndAntiJoinSimultaneouslyUsingWhereInSubquery(Map queryContext) throws Exception + { + cannotVectorize(); + + testQuery( + "SELECT dim1, COUNT(*) FROM foo\n" + + "WHERE dim1 IN ('abc', 'def')\n" + + "AND __time IN (SELECT MAX(__time) FROM foo)\n" + + "AND __time NOT IN (SELECT MIN(__time) FROM foo)\n" + + "GROUP BY 1", + queryContext, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + join( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(new LongMaxAggregatorFactory("a0", "__time")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "j0.", + "(\"__time\" == \"j0.a0\")", + JoinType.INNER + ), + new QueryDataSource( + GroupByQuery.builder() + .setDataSource( + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators( + new LongMinAggregatorFactory("a0", "__time") + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs( + new CountAggregatorFactory("_a0"), + NullHandling.sqlCompatible() + ? new FilteredAggregatorFactory( + new CountAggregatorFactory("_a1"), + not(selector("a0", null, null)) + ) + : new CountAggregatorFactory("_a1") + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + "_j0.", + "1", + JoinType.INNER + ), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(new LongMinAggregatorFactory("a0", "__time")) + .postAggregators(expressionPostAgg("p0", "1")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "__j0.", + "(\"__time\" == \"__j0.a0\")", + JoinType.LEFT + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter( + and( + in("dim1", ImmutableList.of("abc", "def"), null), + or( + selector("_j0._a0", "0", null), + and(selector("__j0.p0", null, null), expressionFilter("(\"_j0._a1\" >= \"_j0._a0\")")) + ) + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0", ValueType.STRING))) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) + .setContext(queryContext) + .build() + ), + ImmutableList.of(new Object[]{"abc", 1L}) + ); + } + + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testSemiAndAntiJoinSimultaneouslyUsingExplicitJoins(Map queryContext) throws Exception + { + cannotVectorize(); + + testQuery( + "SELECT dim1, COUNT(*) FROM\n" + + "foo\n" + + "INNER JOIN (SELECT MAX(__time) t FROM foo) t0 on t0.t = foo.__time\n" + + "LEFT JOIN (SELECT MIN(__time) t FROM foo) t1 on t1.t = foo.__time\n" + + "WHERE dim1 IN ('abc', 'def') AND t1.t is null\n" + + "GROUP BY 1", + queryContext, + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + join( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(new LongMaxAggregatorFactory("a0", "__time")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "j0.", + "(\"__time\" == \"j0.a0\")", + JoinType.INNER + ), + new QueryDataSource( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(new LongMinAggregatorFactory("a0", "__time")) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + "_j0.", + "(\"__time\" == \"_j0.a0\")", + JoinType.LEFT + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter( + and( + in("dim1", ImmutableList.of("abc", "def"), null), + selector("_j0.a0", null, null) + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0", ValueType.STRING))) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) + .setContext(queryContext) + .build() + ), + ImmutableList.of(new Object[]{"abc", 1L}) + ); + } + @Test public void testSemiJoinWithOuterTimeExtractAggregateWithOrderBy() throws Exception { @@ -13723,8 +13975,10 @@ public void testSemiJoinWithOuterTimeExtractAggregateWithOrderBy() throws Except @Parameters(source = QueryContextForJoinProvider.class) public void testInAggregationSubquery(Map queryContext) throws Exception { - // Cannot vectorize JOIN operator. - cannotVectorize(); + // Fully removing the join allows this query to vectorize. + if (!isRewriteJoinToFilter(queryContext)) { + cannotVectorize(); + } testQuery( "SELECT DISTINCT __time FROM druid.foo WHERE __time IN (SELECT MAX(__time) FROM druid.foo)", @@ -13742,6 +13996,7 @@ public void testInAggregationSubquery(Map queryContext) throws E .aggregators(new LongMaxAggregatorFactory("a0", "__time")) .context(TIMESERIES_CONTEXT_DEFAULT) .build() + .withOverriddenContext(queryContext) ), "j0.", equalsCondition( @@ -13754,8 +14009,9 @@ public void testInAggregationSubquery(Map queryContext) throws E .setInterval(querySegmentSpec(Filtration.eternity())) .setGranularity(Granularities.ALL) .setDimensions(dimensions(new DefaultDimensionSpec("__time", "d0", ValueType.LONG))) - .setContext(queryContext) + .setContext(QUERY_CONTEXT_DEFAULT) .build() + .withOverriddenContext(queryContext) ), ImmutableList.of( new Object[]{timestamp("2001-01-03")}