diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java index a9e9f21bfae0..6f2225d8bcfc 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java +++ b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java @@ -32,6 +32,8 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; /** * Represents analysis of a join condition. @@ -55,6 +57,7 @@ public class JoinConditionAnalysis private final boolean isAlwaysFalse; private final boolean isAlwaysTrue; private final boolean canHashJoin; + private final Set rightKeyColumns; private JoinConditionAnalysis( final String originalExpression, @@ -76,6 +79,7 @@ private JoinConditionAnalysis( .allMatch(expr -> expr.isLiteral() && expr.eval( ExprUtils.nilBindings()).asBoolean()); canHashJoin = nonEquiConditions.stream().allMatch(Expr::isLiteral); + rightKeyColumns = getEquiConditions().stream().map(Equality::getRightColumn).distinct().collect(Collectors.toSet()); } /** @@ -176,6 +180,14 @@ public boolean canHashJoin() return canHashJoin; } + /** + * Returns the distinct column keys from the RHS required to evaluate the equi conditions. + */ + public Set getRightEquiConditionKeys() + { + return rightKeyColumns; + } + @Override public boolean equals(Object o) { diff --git a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinMatcher.java b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinMatcher.java index 9af6df93fdc6..cbf99f1475a4 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinMatcher.java +++ b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinMatcher.java @@ -191,9 +191,9 @@ public static LookupJoinMatcher create( keyExprs = null; } else if (!condition.getNonEquiConditions().isEmpty()) { throw new IAE("Cannot join lookup with non-equi condition: %s", condition); - } else if (!condition.getEquiConditions() + } else if (!condition.getRightEquiConditionKeys() .stream() - .allMatch(eq -> eq.getRightColumn().equals(LookupColumnSelectorFactory.KEY_COLUMN))) { + .allMatch(LookupColumnSelectorFactory.KEY_COLUMN::equals)) { throw new IAE("Cannot join lookup with condition referring to non-key column: %s", condition); } else { keyExprs = condition.getEquiConditions().stream().map(Equality::getLeftExpr).collect(Collectors.toList()); diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java index a7b6b748b948..dfd64b1e4eaf 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java @@ -25,6 +25,7 @@ import javax.annotation.Nullable; import java.util.List; import java.util.Map; +import java.util.Set; /** * An interface to a table where some columns (the 'key columns') have indexes that enable fast lookups. @@ -36,7 +37,7 @@ public interface IndexedTable /** * Returns the columns of this table that have indexes. */ - List keyColumns(); + Set keyColumns(); /** * Returns all columns of this table, including the key and non-key columns. diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexedTable.java b/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexedTable.java index 698382387004..37e9cb59e8a5 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexedTable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexedTable.java @@ -33,6 +33,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -48,13 +49,13 @@ public class RowBasedIndexedTable implements IndexedTable private final List columns; private final List columnTypes; private final List> columnFunctions; - private final List keyColumns; + private final Set keyColumns; public RowBasedIndexedTable( final List table, final RowAdapter rowAdapter, final Map rowSignature, - final List keyColumns + final Set keyColumns ) { this.table = table; @@ -107,7 +108,7 @@ public RowBasedIndexedTable( } @Override - public List keyColumns() + public Set keyColumns() { return keyColumns; } diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java index 87ad4516adfd..875f686af577 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/JoinConditionAnalysisTest.java @@ -20,6 +20,7 @@ package org.apache.druid.segment.join; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.Pair; @@ -60,6 +61,7 @@ public void test_forExpression_simple() ImmutableList.of(), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertEquals(analysis.getRightEquiConditionKeys(), ImmutableSet.of("y")); } @Test @@ -80,6 +82,7 @@ public void test_forExpression_simpleFlipped() ImmutableList.of(), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertEquals(analysis.getRightEquiConditionKeys(), ImmutableSet.of("y")); } @Test @@ -100,6 +103,7 @@ public void test_forExpression_leftFunction() ImmutableList.of(), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertEquals(analysis.getRightEquiConditionKeys(), ImmutableSet.of("z")); } @Test @@ -120,6 +124,7 @@ public void test_forExpression_rightFunction() ImmutableList.of("(== (+ j.x j.y) z)"), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertTrue(analysis.getRightEquiConditionKeys().isEmpty()); } @Test @@ -140,6 +145,7 @@ public void test_forExpression_mixedFunction() ImmutableList.of("(== (+ x j.y) j.z)"), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertTrue(analysis.getRightEquiConditionKeys().isEmpty()); } @Test @@ -160,6 +166,7 @@ public void test_forExpression_trueConstant() ImmutableList.of("2"), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertTrue(analysis.getRightEquiConditionKeys().isEmpty()); } @Test @@ -180,6 +187,7 @@ public void test_forExpression_falseConstant() ImmutableList.of("0"), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertTrue(analysis.getRightEquiConditionKeys().isEmpty()); } @Test @@ -200,6 +208,7 @@ public void test_forExpression_onlyLeft() ImmutableList.of("(== x 1)"), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertTrue(analysis.getRightEquiConditionKeys().isEmpty()); } @Test @@ -220,6 +229,7 @@ public void test_forExpression_onlyRight() ImmutableList.of(), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertEquals(analysis.getRightEquiConditionKeys(), ImmutableSet.of("x")); } @Test @@ -240,6 +250,7 @@ public void test_forExpression_andOfThreeConditions() ImmutableList.of(), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertEquals(analysis.getRightEquiConditionKeys(), ImmutableSet.of("y", "z", "zz")); } @Test @@ -260,6 +271,7 @@ public void test_forExpression_mixedAndWithOr() ImmutableList.of("(|| (== (+ x y) j.z) (== z j.zz))"), exprsToStrings(analysis.getNonEquiConditions()) ); + Assert.assertEquals(analysis.getRightEquiConditionKeys(), ImmutableSet.of("y")); } @Test @@ -270,8 +282,8 @@ public void test_equals() .withIgnoredFields( // These fields are tightly coupled with originalExpression "equiConditions", "nonEquiConditions", - // These fields are calculated from nonEquiConditions - "isAlwaysTrue", "isAlwaysFalse", "canHashJoin") + // These fields are calculated from other other fields in the class + "isAlwaysTrue", "isAlwaysFalse", "canHashJoin", "rightKeyColumns") .verify(); } diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinTestHelper.java b/processing/src/test/java/org/apache/druid/segment/join/JoinTestHelper.java index 88d36e61e2c4..ac0022dc1ff1 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/JoinTestHelper.java +++ b/processing/src/test/java/org/apache/druid/segment/join/JoinTestHelper.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import org.apache.druid.common.config.NullHandling; import org.apache.druid.data.input.InputRow; @@ -252,7 +253,7 @@ public static RowBasedIndexedTable> createCountriesIndexedTa rows, createMapRowAdapter(COUNTRIES_SIGNATURE), COUNTRIES_SIGNATURE, - ImmutableList.of("countryNumber", "countryIsoCode") + ImmutableSet.of("countryNumber", "countryIsoCode") ) ); } @@ -265,7 +266,7 @@ public static RowBasedIndexedTable> createRegionsIndexedTabl rows, createMapRowAdapter(REGIONS_SIGNATURE), REGIONS_SIGNATURE, - ImmutableList.of("regionIsoCode", "countryIsoCode") + ImmutableSet.of("regionIsoCode", "countryIsoCode") ) ); } diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java index ad785c1f08ee..30beafcfc98b 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java @@ -20,6 +20,7 @@ package org.apache.druid.segment.join.table; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.dimension.DefaultDimensionSpec; @@ -72,7 +73,7 @@ public ColumnCapabilities getColumnCapabilities(String columnName) inlineDataSource.getRowsAsList(), inlineDataSource.rowAdapter(), inlineDataSource.getRowSignature(), - ImmutableList.of("str") + ImmutableSet.of("str") ); @Test diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java index d57f23a0fa4e..fed0a7019853 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.apache.druid.common.config.NullHandling; import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.join.JoinTestHelper; @@ -64,7 +65,7 @@ public void setUp() throws IOException @Test public void test_keyColumns_countries() { - Assert.assertEquals(ImmutableList.of("countryNumber", "countryIsoCode"), countriesTable.keyColumns()); + Assert.assertEquals(ImmutableSet.of("countryNumber", "countryIsoCode"), countriesTable.keyColumns()); } @Test diff --git a/server/src/main/java/org/apache/druid/segment/join/InlineJoinableFactory.java b/server/src/main/java/org/apache/druid/segment/join/InlineJoinableFactory.java index d4b9937165bb..69ee6cc16c74 100644 --- a/server/src/main/java/org/apache/druid/segment/join/InlineJoinableFactory.java +++ b/server/src/main/java/org/apache/druid/segment/join/InlineJoinableFactory.java @@ -25,9 +25,8 @@ import org.apache.druid.segment.join.table.IndexedTableJoinable; import org.apache.druid.segment.join.table.RowBasedIndexedTable; -import java.util.List; import java.util.Optional; -import java.util.stream.Collectors; +import java.util.Set; /** * A {@link JoinableFactory} for {@link InlineDataSource}. It works by building an {@link IndexedTable}. @@ -39,8 +38,7 @@ public Optional build(final DataSource dataSource, final JoinCondition { if (condition.canHashJoin() && dataSource instanceof InlineDataSource) { final InlineDataSource inlineDataSource = (InlineDataSource) dataSource; - final List rightKeyColumns = - condition.getEquiConditions().stream().map(Equality::getRightColumn).distinct().collect(Collectors.toList()); + final Set rightKeyColumns = condition.getRightEquiConditionKeys(); return Optional.of( new IndexedTableJoinable(