diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java index d3e73cf3d681..2a8b0c4791f0 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java @@ -62,11 +62,13 @@ public class ArrayOverlapOperatorConversion extends BaseExpressionDimFilterOpera "'ARRAY_OVERLAP(array, array)'", OperandTypes.or( OperandTypes.family(SqlTypeFamily.ARRAY), - OperandTypes.family(SqlTypeFamily.STRING) + OperandTypes.family(SqlTypeFamily.STRING), + OperandTypes.family(SqlTypeFamily.NUMERIC) ), OperandTypes.or( OperandTypes.family(SqlTypeFamily.ARRAY), - OperandTypes.family(SqlTypeFamily.STRING) + OperandTypes.family(SqlTypeFamily.STRING), + OperandTypes.family(SqlTypeFamily.NUMERIC) ) ) ) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlParameterizerShuttle.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlParameterizerShuttle.java index 29da45b085fd..326a535b8edd 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlParameterizerShuttle.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlParameterizerShuttle.java @@ -129,16 +129,17 @@ private SqlNode createArrayLiteral(Object value, int posn) List args = new ArrayList<>(list.size()); for (int i = 0, listSize = list.size(); i < listSize; i++) { Object element = list.get(i); - if (element == null) { - throw InvalidSqlInput.exception("parameter [%d] is an array, with an illegal null at index [%d]", posn + 1, i); - } SqlNode node; - if (element instanceof String) { + if (element == null) { + node = SqlLiteral.createNull(SqlParserPos.ZERO); + } else if (element instanceof String) { node = SqlLiteral.createCharString((String) element, SqlParserPos.ZERO); } else if (element instanceof Integer || element instanceof Long) { // No direct way to create a literal from an Integer or Long, have // to parse a string, sadly. node = SqlLiteral.createExactNumeric(element.toString(), SqlParserPos.ZERO); + } else if (element instanceof Double || element instanceof Float) { + node = SqlLiteral.createApproxNumeric(element.toString(), SqlParserPos.ZERO); } else if (element instanceof Boolean) { node = SqlLiteral.createBoolean((Boolean) value, SqlParserPos.ZERO); } else { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index b870e45967b2..958689997f74 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import org.apache.calcite.avatica.SqlType; import org.apache.druid.common.config.NullHandling; import org.apache.druid.guice.DruidInjectorBuilder; import org.apache.druid.guice.NestedDataModule; @@ -70,6 +71,7 @@ import org.apache.druid.segment.virtual.ExpressionVirtualColumn; import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.util.CalciteTests; +import org.apache.druid.sql.http.SqlParameter; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -933,6 +935,94 @@ public void testArrayOverlapFilterArrayDoubleColumns() ); } + @Test + public void testArrayOverlapFilterNumeric() + { + testQuery( + "SELECT dim3 FROM druid.numfoo WHERE ARRAY_OVERLAP(l1, ARRAY[1, 7]) LIMIT 5", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(new InDimFilter("l1", ImmutableList.of("1", "7"), null)) + .columns("dim3") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"} + ) + ); + } + + @Test + public void testArrayOverlapFilterWithDynamicParameter() + { + Druids.ScanQueryBuilder builder = newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(new InDimFilter("d1", Arrays.asList("1.0", "1.7", null), null)) + .columns("dim3") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT); + + testQuery( + PLANNER_CONFIG_DEFAULT, + QUERY_CONTEXT_DEFAULT, + ImmutableList.of( + new SqlParameter(SqlType.ARRAY, Arrays.asList(1.0, 1.7, null)) + ), + "SELECT dim3 FROM druid.numfoo WHERE ARRAY_OVERLAP(?, d1) LIMIT 5", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(builder.build()), + NullHandling.sqlCompatible() ? ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"b\",\"c\"]"}, + new Object[]{""}, + new Object[]{null}, + new Object[]{null} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{"[\"b\",\"c\"]"} + ) + ); + } + + @Test + public void testArrayOverlapFilterWithDynamicParameterLong() + { + Druids.ScanQueryBuilder builder = newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(new InDimFilter("l1", Arrays.asList("1", "7", null), null)) + .columns("dim3") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT); + + testQuery( + PLANNER_CONFIG_DEFAULT, + QUERY_CONTEXT_DEFAULT, + ImmutableList.of( + new SqlParameter(SqlType.ARRAY, Arrays.asList(1, 7, null)) + ), + "SELECT dim3 FROM druid.numfoo WHERE ARRAY_OVERLAP(?, l1) LIMIT 5", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(builder.build()), + NullHandling.sqlCompatible() ? ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"}, + new Object[]{""}, + new Object[]{null}, + new Object[]{null} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"} + ) + ); + } + @Test public void testArrayContainsFilter() { @@ -1231,6 +1321,48 @@ public void testArrayContainsFilterArrayDoubleColumns() ); } + @Test + public void testArrayContainsFilterWithDynamicParameter() + { + Druids.ScanQueryBuilder builder = newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("dim3") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(5) + .context(QUERY_CONTEXT_DEFAULT); + + if (NullHandling.sqlCompatible()) { + builder = builder.filters( + and( + equality("dim3", "a", ColumnType.STRING), + equality("dim3", "b", ColumnType.STRING) + ) + ); + } else { + builder = builder.filters( + and( + selector("dim3", "a"), + selector("dim3", "b") + ) + ); + } + + testQuery( + PLANNER_CONFIG_DEFAULT, + QUERY_CONTEXT_DEFAULT, + ImmutableList.of( + new SqlParameter(SqlType.ARRAY, Arrays.asList("a", "b")) + ), + "SELECT dim3 FROM druid.numfoo WHERE ARRAY_CONTAINS(dim3, ?) LIMIT 5", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(builder.build()), + ImmutableList.of( + new Object[]{"[\"a\",\"b\"]"} + ) + ); + } + @Test public void testArrayContainsConstantNull() {