diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBenchmark.java index c2b2987f6af4..5500f2c06201 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlBenchmark.java @@ -220,7 +220,7 @@ public void querySql(Blackhole blackhole) throws Exception final Map context = ImmutableMap.of("vectorize", vectorize); final AuthenticationResult authenticationResult = NoopEscalator.getInstance() .createEscalatedAuthenticationResult(); - try (final DruidPlanner planner = plannerFactory.createPlanner(context, authenticationResult)) { + try (final DruidPlanner planner = plannerFactory.createPlanner(context, ImmutableList.of(), authenticationResult)) { final PlannerResult plannerResult = planner.plan(QUERIES.get(Integer.parseInt(query))); final Sequence resultSequence = plannerResult.run(); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); diff --git a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java index deacbe5561a6..4690d02e926d 100644 --- a/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java +++ b/benchmarks/src/test/java/org/apache/druid/benchmark/query/SqlVsNativeBenchmark.java @@ -19,6 +19,7 @@ package org.apache.druid.benchmark.query; +import com.google.common.collect.ImmutableList; import org.apache.calcite.schema.SchemaPlus; import org.apache.druid.benchmark.datagen.BenchmarkSchemaInfo; import org.apache.druid.benchmark.datagen.BenchmarkSchemas; @@ -165,9 +166,9 @@ public void queryNative(Blackhole blackhole) @OutputTimeUnit(TimeUnit.MILLISECONDS) public void queryPlanner(Blackhole blackhole) throws Exception { - final AuthenticationResult authenticationResult = NoopEscalator.getInstance() - .createEscalatedAuthenticationResult(); - try (final DruidPlanner planner = plannerFactory.createPlanner(null, authenticationResult)) { + final AuthenticationResult authResult = NoopEscalator.getInstance() + .createEscalatedAuthenticationResult(); + try (final DruidPlanner planner = plannerFactory.createPlanner(null, ImmutableList.of(), authResult)) { final PlannerResult plannerResult = planner.plan(sqlQuery); final Sequence resultSequence = plannerResult.run(); final Object[] lastRow = resultSequence.accumulate(null, (accumulated, in) -> in); diff --git a/docs/querying/sql.md b/docs/querying/sql.md index 6e9a230c6283..269decfbcfe7 100644 --- a/docs/querying/sql.md +++ b/docs/querying/sql.md @@ -55,6 +55,8 @@ like `100` (denoting an integer), `100.0` (denoting a floating point value), or timestamps can be written like `TIMESTAMP '2000-01-01 00:00:00'`. Literal intervals, used for time arithmetic, can be written like `INTERVAL '1' HOUR`, `INTERVAL '1 02:03' DAY TO MINUTE`, `INTERVAL '1-2' YEAR TO MONTH`, and so on. +Druid SQL supports dynamic parameters in question mark (`?`) syntax, where parameters are bound to the `?` placeholders at execution time. To use dynamic parameters, replace any literal in the query with a `?` character and ensure that corresponding parameter values are provided at execution time. Parameters are bound to the placeholders in the order in which they are passed. + Druid SQL supports SELECT queries with the following structure: ``` @@ -518,6 +520,17 @@ of configuration. You can make Druid SQL queries using JSON over HTTP by posting to the endpoint `/druid/v2/sql/`. The request should be a JSON object with a "query" field, like `{"query" : "SELECT COUNT(*) FROM data_source WHERE foo = 'bar'"}`. +##### Request + +|Property|Type|Description|Required| +|--------|----|-----------|--------| +|`query`|`String`| SQL query to run| yes | +|`resultFormat`|`String` (`ResultFormat`)| Result format for output | no (default `"object"`)| +|`header`|`Boolean`| Write column name header for supporting formats| no (default `false`)| +|`context`|`Object`| Connection context map. see [connection context parameters](#connection-context)| no | +|`parameters`|`SqlParameter` list| List of query parameters for parameterized queries. | no | + + You can use _curl_ to send SQL queries from the command-line: ```bash @@ -540,7 +553,27 @@ like: } ``` -Metadata is available over the HTTP API by querying [system tables](#metadata-tables). +Parameterized SQL queries are also supported: + +```json +{ + "query" : "SELECT COUNT(*) FROM data_source WHERE foo = ? AND __time > ?", + "parameters": [ + { "type": "VARCHAR", "value": "bar"}, + { "type": "TIMESTAMP", "value": "2000-01-01 00:00:00" } + ] +} +``` + +##### SqlParameter + +|Property|Type|Description|Required| +|--------|----|-----------|--------| +|`type`|`String` (`SqlType`) | String value of `SqlType` of parameter. [`SqlType`](https://calcite.apache.org/avatica/javadocAggregate/org/apache/calcite/avatica/SqlType.html) is a friendly wrapper around [`java.sql.Types`](https://docs.oracle.com/javase/8/docs/api/java/sql/Types.html?is-external=true)|yes| +|`value`|`Object`| Value of the parameter|yes| + + +Metadata is also available over the HTTP API by querying [system tables](#metadata-tables). #### Responses @@ -617,8 +650,7 @@ try (Connection connection = DriverManager.getConnection(url, connectionProperti ``` Table metadata is available over JDBC using `connection.getMetaData()` or by querying the -["INFORMATION_SCHEMA" tables](#metadata-tables). Parameterized queries (using `?` or other placeholders) don't work properly, -so avoid those. +["INFORMATION_SCHEMA" tables](#metadata-tables). #### Connection stickiness @@ -630,6 +662,17 @@ the necessary stickiness even with a normal non-sticky load balancer. Please see Note that the non-JDBC [JSON over HTTP](#json-over-http) API is stateless and does not require stickiness. +### Dynamic Parameters + +You can also use parameterized queries in JDBC code, as in this example; + +```java +PreparedStatement statement = connection.prepareStatement("SELECT COUNT(*) AS cnt FROM druid.foo WHERE dim1 = ? OR dim1 = ?"); +statement.setString(1, "abc"); +statement.setString(2, "def"); +final ResultSet resultSet = statement.executeQuery(); +``` + ### Connection context Druid SQL supports setting connection parameters on the client. The parameters in the table below affect SQL planning. diff --git a/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java b/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java index ef39376e7751..a74c3e7b275f 100644 --- a/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java +++ b/extensions-contrib/tdigestsketch/src/test/java/org/apache/druid/query/aggregation/tdigestsketch/sql/TDigestSketchSqlAggregatorTest.java @@ -186,7 +186,12 @@ public void testComputingSketchOnNumericValues() throws Exception + "FROM foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new String[] { "\"AAAAAT/wAAAAAAAAQBgAAAAAAABAaQAAAAAAAAAAAAY/8AAAAAAAAD/wAAAAAAAAP/AAAAAAAABAAAAAAAAAAD/wAAAAAAAAQAgAAAAAAAA/8AAAAAAAAEAQAAAAAAAAP/AAAAAAAABAFAAAAAAAAD/wAAAAAAAAQBgAAAAAAAA=\"" @@ -219,7 +224,12 @@ public void testDefaultCompressionForTDigestGenerateSketchAgg() throws Exception + "FROM foo"; // Log query - sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); // Verify query Assert.assertEquals( @@ -248,7 +258,12 @@ public void testComputingQuantileOnPreAggregatedSketch() throws Exception + "FROM foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new double[] { 1.1, @@ -297,7 +312,12 @@ public void testGeneratingSketchAndComputingQuantileOnFly() throws Exception + "FROM (SELECT dim1, TDIGEST_GENERATE_SKETCH(m1, 200) AS x FROM foo group by dim1)"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new double[] { 1.0, @@ -363,7 +383,12 @@ public void testQuantileOnNumericValues() throws Exception + "FROM foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new double[] { 1.0, @@ -410,7 +435,12 @@ public void testCompressionParamForTDigestQuantileAgg() throws Exception + "FROM foo"; // Log query - sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); // Verify query Assert.assertEquals( diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java index 74376a22ac99..cd18115bca70 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/hll/sql/HllSketchSqlAggregatorTest.java @@ -221,7 +221,12 @@ public void testApproxCountDistinctHllSketch() throws Exception + "FROM druid.foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults; if (NullHandling.replaceWithDefault()) { @@ -334,7 +339,12 @@ public void testAvgDailyCountDistinctHllSketch() throws Exception + ")"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 1L @@ -430,7 +440,8 @@ public void testApproxCountDistinctHllSketchIsRounded() throws Exception + " HAVING APPROX_COUNT_DISTINCT_DS_HLL(m1) = 2"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = + sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, DEFAULT_PARAMETERS, authenticationResult).toList(); final int expected = NullHandling.replaceWithDefault() ? 1 : 2; Assert.assertEquals(expected, results.size()); } @@ -457,7 +468,12 @@ public void testHllSketchPostAggs() throws Exception + "FROM druid.foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ "\"AgEHDAMIAgDhUv8P63iABQ==\"", @@ -605,7 +621,12 @@ public void testtHllSketchPostAggsPostSort() throws Exception final String sql2 = StringUtils.format("SELECT HLL_SKETCH_ESTIMATE(y), HLL_SKETCH_TO_STRING(y) from (%s)", sql); // Verify results - final List results = sqlLifecycle.runSimple(sql2, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql2, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 2.000000004967054d, diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java index 0a0d2a6f5a52..f50c87425a9d 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/quantiles/sql/DoublesSketchSqlAggregatorTest.java @@ -223,7 +223,12 @@ public void testQuantileOnFloatAndLongs() throws Exception + "FROM foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 1.0, @@ -303,7 +308,12 @@ public void testQuantileOnComplexColumn() throws Exception + "FROM foo"; // Verify results - final List results = lifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = lifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 1.0, @@ -362,7 +372,12 @@ public void testQuantileOnInnerQuery() throws Exception + "FROM (SELECT dim2, SUM(m1) AS x FROM foo GROUP BY dim2)"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults; if (NullHandling.replaceWithDefault()) { expectedResults = ImmutableList.of(new Object[]{7.0, 11.0}); @@ -431,7 +446,12 @@ public void testQuantileOnInnerQuantileQuery() throws Exception + "FROM (SELECT dim1, dim2, APPROX_QUANTILE_DS(m1, 0.5) AS x FROM foo GROUP BY dim1, dim2) GROUP BY dim1"; - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); ImmutableList.Builder builder = ImmutableList.builder(); builder.add(new Object[]{"", 1.0}); @@ -512,7 +532,12 @@ public void testDoublesSketchPostAggs() throws Exception + "FROM foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 6L, @@ -679,7 +704,12 @@ public void testDoublesSketchPostAggsPostSort() throws Exception final String sql2 = StringUtils.format("SELECT DS_GET_QUANTILE(y, 0.5), DS_GET_QUANTILE(y, 0.98) from (%s)", sql); // Verify results - final List results = sqlLifecycle.runSimple(sql2, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql2, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 4.0d, diff --git a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java index 7484f41e3a98..ec87c57dcff6 100644 --- a/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java +++ b/extensions-core/datasketches/src/test/java/org/apache/druid/query/aggregation/datasketches/theta/sql/ThetaSketchSqlAggregatorTest.java @@ -218,7 +218,12 @@ public void testApproxCountDistinctThetaSketch() throws Exception + "FROM druid.foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults; if (NullHandling.replaceWithDefault()) { @@ -330,7 +335,12 @@ public void testAvgDailyCountDistinctThetaSketch() throws Exception + "FROM (SELECT FLOOR(__time TO DAY), APPROX_COUNT_DISTINCT_DS_THETA(cnt) AS u FROM druid.foo GROUP BY 1)"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 1L @@ -431,7 +441,12 @@ public void testThetaSketchPostAggs() throws Exception + "FROM druid.foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults; if (NullHandling.replaceWithDefault()) { @@ -604,7 +619,12 @@ public void testThetaSketchPostAggsPostSort() throws Exception final String sql2 = StringUtils.format("SELECT THETA_SKETCH_ESTIMATE(y) from (%s)", sql); // Verify results - final List results = sqlLifecycle.runSimple(sql2, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql2, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 2.0d diff --git a/extensions-core/druid-bloom-filter/pom.xml b/extensions-core/druid-bloom-filter/pom.xml index e579429f378e..696557ca20da 100644 --- a/extensions-core/druid-bloom-filter/pom.xml +++ b/extensions-core/druid-bloom-filter/pom.xml @@ -110,6 +110,11 @@ guava provided + + org.apache.calcite.avatica + avatica-core + provided + diff --git a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java index 6306078cdba8..1e875f2f0126 100644 --- a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java +++ b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/aggregation/bloom/sql/BloomFilterSqlAggregatorTest.java @@ -73,6 +73,7 @@ import org.apache.druid.sql.calcite.planner.DruidOperatorTable; import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.planner.PlannerFactory; +import org.apache.druid.sql.calcite.util.CalciteTestBase; import org.apache.druid.sql.calcite.util.CalciteTests; import org.apache.druid.sql.calcite.util.QueryLogHook; import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; @@ -229,7 +230,12 @@ public void testBloomFilterAgg() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); BloomKFilter expected1 = new BloomKFilter(TEST_NUM_ENTRIES); for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) { @@ -281,7 +287,12 @@ public void testBloomFilterTwoAggs() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); BloomKFilter expected1 = new BloomKFilter(TEST_NUM_ENTRIES); BloomKFilter expected2 = new BloomKFilter(TEST_NUM_ENTRIES); @@ -351,7 +362,12 @@ public void testBloomFilterAggExtractionFn() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); BloomKFilter expected1 = new BloomKFilter(TEST_NUM_ENTRIES); for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) { @@ -407,8 +423,12 @@ public void testBloomFilterAggLong() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); - + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); BloomKFilter expected3 = new BloomKFilter(TEST_NUM_ENTRIES); for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) { @@ -462,7 +482,12 @@ public void testBloomFilterAggLongVirtualColumn() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); BloomKFilter expected1 = new BloomKFilter(TEST_NUM_ENTRIES); for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) { @@ -524,7 +549,12 @@ public void testBloomFilterAggFloatVirtualColumn() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); BloomKFilter expected1 = new BloomKFilter(TEST_NUM_ENTRIES); for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) { @@ -587,7 +617,12 @@ public void testBloomFilterAggDoubleVirtualColumn() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); BloomKFilter expected1 = new BloomKFilter(TEST_NUM_ENTRIES); for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) { diff --git a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/filter/sql/BloomDimFilterSqlTest.java b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/filter/sql/BloomDimFilterSqlTest.java index 188dfdf148b3..b93b46bc382d 100644 --- a/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/filter/sql/BloomDimFilterSqlTest.java +++ b/extensions-core/druid-bloom-filter/src/test/java/org/apache/druid/query/filter/sql/BloomDimFilterSqlTest.java @@ -26,6 +26,7 @@ import com.google.inject.Guice; import com.google.inject.Injector; import com.google.inject.Key; +import org.apache.calcite.avatica.SqlType; import org.apache.druid.common.config.NullHandling; import org.apache.druid.guice.BloomFilterExtensionModule; import org.apache.druid.guice.BloomFilterSerializersModule; @@ -54,6 +55,8 @@ import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.util.CalciteTests; import org.apache.druid.sql.calcite.util.QueryLogHook; +import org.apache.druid.sql.http.SqlParameter; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -273,10 +276,97 @@ public void testBloomFilters() throws Exception ); } + @Ignore("this test is really slow and is intended to use for comparisons with testBloomFilterBigParameter") + @Test + public void testBloomFilterBigNoParam() throws Exception + { + BloomKFilter filter = new BloomKFilter(5_000_000); + filter.addString("def"); + byte[] bytes = BloomFilterSerializersModule.bloomKFilterToBytes(filter); + String base64 = StringUtils.encodeBase64String(bytes); + testQuery( + StringUtils.format("SELECT COUNT(*) FROM druid.foo WHERE bloom_filter_test(dim1, '%s')", base64), + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters( + new BloomDimFilter("dim1", BloomKFilterHolder.fromBloomKFilter(filter), null) + ) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{1L} + ) + ); + } + + @Ignore("this test is for comparison with testBloomFilterBigNoParam") + @Test + public void testBloomFilterBigParameter() throws Exception + { + BloomKFilter filter = new BloomKFilter(5_000_000); + filter.addString("def"); + byte[] bytes = BloomFilterSerializersModule.bloomKFilterToBytes(filter); + String base64 = StringUtils.encodeBase64String(bytes); + testQuery( + "SELECT COUNT(*) FROM druid.foo WHERE bloom_filter_test(dim1, ?)", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters( + new BloomDimFilter("dim1", BloomKFilterHolder.fromBloomKFilter(filter), null) + ) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{1L} + ), + ImmutableList.of(new SqlParameter(SqlType.VARCHAR, base64)) + ); + } + + @Test + public void testBloomFilterNullParameter() throws Exception + { + BloomKFilter filter = new BloomKFilter(1500); + filter.addBytes(null, 0, 0); + byte[] bytes = BloomFilterSerializersModule.bloomKFilterToBytes(filter); + String base64 = StringUtils.encodeBase64String(bytes); + + // bloom filter expression is evaluated and optimized out at planning time since parameter is null and null matches + // the supplied filter of the other parameter + testQuery( + "SELECT COUNT(*) FROM druid.foo WHERE bloom_filter_test(?, ?)", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{6L} + ), + // there are no empty strings in the druid expression language since empty is coerced into a null when parsed + ImmutableList.of(new SqlParameter(SqlType.VARCHAR, NullHandling.defaultStringValue()), new SqlParameter(SqlType.VARCHAR, base64)) + ); + } + @Override public List getResults( final PlannerConfig plannerConfig, final Map queryContext, + final List parameters, final String sql, final AuthenticationResult authenticationResult ) throws Exception @@ -288,6 +378,7 @@ public List getResults( return getResults( plannerConfig, queryContext, + parameters, sql, authenticationResult, operatorTable, diff --git a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java index cbfc61b92e98..24651405bcbd 100644 --- a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java +++ b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/FixedBucketsHistogramQuantileSqlAggregatorTest.java @@ -206,7 +206,12 @@ public void testQuantileOnFloatAndLongs() throws Exception + "FROM foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 1.0299999713897705, @@ -327,7 +332,12 @@ public void testQuantileOnComplexColumn() throws Exception + "FROM foo"; // Verify results - final List results = lifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = lifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 1.0299999713897705, @@ -417,7 +427,12 @@ public void testQuantileOnInnerQuery() throws Exception + "FROM (SELECT dim2, SUM(m1) AS x FROM foo GROUP BY dim2)"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults; if (NullHandling.replaceWithDefault()) { expectedResults = ImmutableList.of(new Object[]{7.0, 11.940000534057617}); diff --git a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java index 1d80893c8986..0ded89c43608 100644 --- a/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java +++ b/extensions-core/histogram/src/test/java/org/apache/druid/query/aggregation/histogram/sql/QuantileSqlAggregatorTest.java @@ -206,7 +206,12 @@ public void testQuantileOnFloatAndLongs() throws Exception + "FROM foo"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{ 1.0, @@ -285,7 +290,12 @@ public void testQuantileOnComplexColumn() throws Exception + "FROM foo"; // Verify results - final List results = lifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = lifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults = ImmutableList.of( new Object[]{1.0, 3.0, 5.880000114440918, 5.940000057220459, 6.0, 4.994999885559082, 6.0} ); @@ -335,7 +345,12 @@ public void testQuantileOnInnerQuery() throws Exception + "FROM (SELECT dim2, SUM(m1) AS x FROM foo GROUP BY dim2)"; // Verify results - final List results = sqlLifecycle.runSimple(sql, QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + final List results = sqlLifecycle.runSimple( + sql, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + authenticationResult + ).toList(); final List expectedResults; if (NullHandling.replaceWithDefault()) { expectedResults = ImmutableList.of(new Object[]{7.0, 8.26386833190918}); diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java index 192788fed357..56866d4662b0 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java @@ -59,6 +59,7 @@ import org.apache.druid.sql.calcite.planner.DruidOperatorTable; import org.apache.druid.sql.calcite.planner.PlannerConfig; import org.apache.druid.sql.calcite.planner.PlannerFactory; +import org.apache.druid.sql.calcite.util.CalciteTestBase; import org.apache.druid.sql.calcite.util.CalciteTests; import org.apache.druid.sql.calcite.util.QueryLogHook; import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; @@ -230,7 +231,12 @@ public void testVarPop() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector(); VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector(); @@ -285,7 +291,12 @@ public void testVarSamp() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector(); VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector(); @@ -340,7 +351,12 @@ public void testStdDevPop() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector(); VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector(); @@ -402,7 +418,12 @@ public void testStdDevSamp() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector(); VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector(); @@ -463,7 +484,12 @@ public void testStdDevWithVirtualColumns() throws Exception + "FROM numfoo"; final List results = - sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList(); + sqlLifecycle.runSimple( + sql, + BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, + CalciteTestBase.DEFAULT_PARAMETERS, + authenticationResult + ).toList(); VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector(); VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector(); diff --git a/sql/pom.xml b/sql/pom.xml index 8caa6ef53f33..36c08a0f0314 100644 --- a/sql/pom.xml +++ b/sql/pom.xml @@ -221,6 +221,11 @@ hamcrest-core test + + nl.jqno.equalsverifier + equalsverifier + test + diff --git a/sql/src/main/java/org/apache/druid/sql/SqlLifecycle.java b/sql/src/main/java/org/apache/druid/sql/SqlLifecycle.java index 700211dfc318..4c9135941ca0 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlLifecycle.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlLifecycle.java @@ -19,8 +19,9 @@ package org.apache.druid.sql; -import com.google.common.base.Preconditions; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Iterables; +import org.apache.calcite.avatica.remote.TypedValue; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.tools.RelConversionException; @@ -46,11 +47,16 @@ import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerFactory; import org.apache.druid.sql.calcite.planner.PlannerResult; +import org.apache.druid.sql.calcite.planner.PrepareResult; +import org.apache.druid.sql.http.SqlParameter; +import org.apache.druid.sql.http.SqlQuery; import javax.annotation.Nullable; import javax.servlet.http.HttpServletRequest; +import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -86,9 +92,11 @@ public class SqlLifecycle // init during intialize private String sql; private Map queryContext; + private List parameters; // init during plan @Nullable private HttpServletRequest req; private PlannerContext plannerContext; + private PrepareResult prepareResult; private PlannerResult plannerResult; public SqlLifecycle( @@ -104,6 +112,7 @@ public SqlLifecycle( this.requestLogger = requestLogger; this.startMs = startMs; this.startNs = startNs; + this.parameters = Collections.emptyList(); } public String initialize(String sql, Map queryContext) @@ -131,12 +140,30 @@ private String sqlQueryId() return (String) this.queryContext.get(PlannerContext.CTX_SQL_QUERY_ID); } + public void setParameters(List parameters) + { + this.parameters = parameters; + } + + public PrepareResult prepare(AuthenticationResult authenticationResult) + throws ValidationException, RelConversionException, SqlParseException + { + synchronized (lock) { + try (DruidPlanner planner = plannerFactory.createPlanner(queryContext, parameters, authenticationResult)) { + // set planner context for logs/metrics in case something explodes early + this.plannerContext = planner.getPlannerContext(); + this.prepareResult = planner.prepare(sql); + return prepareResult; + } + } + } + public PlannerContext plan(AuthenticationResult authenticationResult) throws ValidationException, RelConversionException, SqlParseException { synchronized (lock) { transition(State.INITIALIZED, State.PLANNED); - try (DruidPlanner planner = plannerFactory.createPlanner(queryContext, authenticationResult)) { + try (DruidPlanner planner = plannerFactory.createPlanner(queryContext, parameters, authenticationResult)) { this.plannerContext = planner.getPlannerContext(); this.plannerResult = planner.plan(sql); } @@ -156,9 +183,7 @@ public PlannerContext plan(HttpServletRequest req) public RelDataType rowType() { synchronized (lock) { - Preconditions.checkState(plannerResult != null, - "must be called after SQL has been planned"); - return plannerResult.rowType(); + return plannerResult != null ? plannerResult.rowType() : prepareResult.getRowType(); } } @@ -171,10 +196,7 @@ public Access authorize() return doAuthorize( AuthorizationUtils.authorizeAllResourceActions( req, - Iterables.transform( - plannerResult.datasourceNames(), - AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR - ), + Iterables.transform(plannerResult.datasourceNames(), AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR), plannerFactory.getAuthorizerMapper() ) ); @@ -231,9 +253,11 @@ public Sequence execute() } } + @VisibleForTesting public Sequence runSimple( String sql, Map queryContext, + List parameters, AuthenticationResult authenticationResult ) throws ValidationException, RelConversionException, SqlParseException { @@ -241,6 +265,7 @@ public Sequence runSimple( initialize(sql, queryContext); try { + setParameters(SqlQuery.getParameterList(parameters)); planAndAuthorize(authenticationResult); result = execute(); } diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java index 030755b16118..d5b27aad6489 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java @@ -49,6 +49,7 @@ import javax.annotation.Nonnull; import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -187,13 +188,13 @@ public ExecuteResult prepareAndExecute( if (authenticationResult == null) { throw new ForbiddenException("Authentication failed."); } - final Signature signature = druidStatement.prepare(sql, maxRowCount, authenticationResult).getSignature(); - final Frame firstFrame = druidStatement.execute() + druidStatement.prepare(sql, maxRowCount, authenticationResult); + final Frame firstFrame = druidStatement.execute(Collections.emptyList()) .nextFrame( DruidStatement.START_OFFSET, getEffectiveMaxRowsPerFrame(maxRowsInFirstFrame) ); - + final Signature signature = druidStatement.getSignature(); return new ExecuteResult( ImmutableList.of( MetaResultSet.create( @@ -256,16 +257,14 @@ public ExecuteResult execute( final int maxRowsInFirstFrame ) throws NoSuchStatementException { - Preconditions.checkArgument(parameterValues.isEmpty(), "Expected parameterValues to be empty"); - final DruidStatement druidStatement = getDruidStatement(statement); - final Signature signature = druidStatement.getSignature(); - final Frame firstFrame = druidStatement.execute() + final Frame firstFrame = druidStatement.execute(parameterValues) .nextFrame( DruidStatement.START_OFFSET, getEffectiveMaxRowsPerFrame(maxRowsInFirstFrame) ); + final Signature signature = druidStatement.getSignature(); return new ExecuteResult( ImmutableList.of( MetaResultSet.create( diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java index 4c9ac0391906..2b64c3039e6a 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java @@ -22,8 +22,10 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.calcite.avatica.AvaticaParameter; import org.apache.calcite.avatica.ColumnMetaData; import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.remote.TypedValue; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.druid.java.util.common.ISE; @@ -35,6 +37,8 @@ import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.server.security.ForbiddenException; import org.apache.druid.sql.SqlLifecycle; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PrepareResult; import org.apache.druid.sql.calcite.rel.QueryMaker; import java.io.Closeable; @@ -79,6 +83,7 @@ public class DruidStatement implements Closeable private Yielder yielder; private int offset = 0; private Throwable throwable; + private AuthenticationResult authenticationResult; public DruidStatement( final String connectionId, @@ -152,42 +157,43 @@ public DruidStatement prepare( try { ensure(State.NEW); sqlLifecycle.initialize(query, queryContext); - sqlLifecycle.planAndAuthorize(authenticationResult); + + this.authenticationResult = authenticationResult; + PrepareResult prepareResult = sqlLifecycle.prepare(authenticationResult); this.maxRowCount = maxRowCount; this.query = query; + List params = new ArrayList<>(); + final RelDataType parameterRowType = prepareResult.getParameterRowType(); + for (RelDataTypeField field : parameterRowType.getFieldList()) { + RelDataType type = field.getType(); + params.add(createParameter(field, type)); + } this.signature = Meta.Signature.create( - createColumnMetaData(sqlLifecycle.rowType()), + createColumnMetaData(prepareResult.getRowType()), query, - new ArrayList<>(), + params, Meta.CursorFactory.ARRAY, Meta.StatementType.SELECT // We only support SELECT ); this.state = State.PREPARED; } catch (Throwable t) { - this.throwable = t; - try { - close(); - } - catch (Throwable t1) { - t.addSuppressed(t1); - } - throw new RuntimeException(t); + return closeAndPropagateThrowable(t); } return this; } } - public DruidStatement execute() + + public DruidStatement execute(List parameters) { synchronized (lock) { ensure(State.PREPARED); - try { - final Sequence baseSequence = yielderOpenCloseExecutor.submit( - sqlLifecycle::execute - ).get(); + sqlLifecycle.setParameters(parameters); + sqlLifecycle.planAndAuthorize(authenticationResult); + final Sequence baseSequence = yielderOpenCloseExecutor.submit(sqlLifecycle::execute).get(); // We can't apply limits greater than Integer.MAX_VALUE, ignore them. final Sequence retSequence = @@ -199,14 +205,7 @@ public DruidStatement execute() state = State.RUNNING; } catch (Throwable t) { - this.throwable = t; - try { - close(); - } - catch (Throwable t1) { - t.addSuppressed(t1); - } - throw new RuntimeException(t); + closeAndPropagateThrowable(t); } return this; @@ -350,6 +349,34 @@ public void close() } } + private AvaticaParameter createParameter(RelDataTypeField field, RelDataType type) + { + // signed is always false because no way to extract from RelDataType, and the only usage of this AvaticaParameter + // constructor I can find, in CalcitePrepareImpl, does it this way with hard coded false + return new AvaticaParameter( + false, + type.getPrecision(), + type.getScale(), + type.getSqlTypeName().getJdbcOrdinal(), + type.getSqlTypeName().getName(), + Calcites.sqlTypeNameJdbcToJavaClass(type.getSqlTypeName()).getName(), + field.getName()); + } + + + + private DruidStatement closeAndPropagateThrowable(Throwable t) + { + this.throwable = t; + try { + close(); + } + catch (Throwable t1) { + t.addSuppressed(t1); + } + throw new RuntimeException(t); + } + @GuardedBy("lock") private void ensure(final State... desiredStates) { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java index dd7ecf5e063d..b719cad892a6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/OperatorConversions.java @@ -20,6 +20,7 @@ package org.apache.druid.sql.calcite.expression; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntSet; @@ -36,6 +37,7 @@ import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.type.BasicSqlType; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlOperandCountRanges; import org.apache.calcite.sql.type.SqlOperandTypeChecker; @@ -49,6 +51,7 @@ import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.DruidTypeSystem; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.table.RowSignature; @@ -235,6 +238,7 @@ public static class OperatorBuilder private SqlOperandTypeChecker operandTypeChecker; private List operandTypes; private Integer requiredOperands = null; + private SqlOperandTypeInference operandTypeInference; private OperatorBuilder(final String name) { @@ -287,6 +291,12 @@ public OperatorBuilder operandTypes(final SqlTypeFamily... operandTypes) return this; } + public OperatorBuilder operandTypeInference(final SqlOperandTypeInference operandTypeInference) + { + this.operandTypeInference = operandTypeInference; + return this; + } + public OperatorBuilder requiredOperands(final int requiredOperands) { this.requiredOperands = requiredOperands; @@ -317,11 +327,27 @@ public SqlFunction build() ); } + if (operandTypeInference == null) { + SqlOperandTypeInference defaultInference = new DefaultOperandTypeInference(operandTypes, nullableOperands); + operandTypeInference = (callBinding, returnType, types) -> { + for (int i = 0; i < types.length; i++) { + // calcite sql validate tries to do bad things to dynamic parameters if the type is inferred to be a string + if (callBinding.operand(i).isA(ImmutableSet.of(SqlKind.DYNAMIC_PARAM))) { + types[i] = new BasicSqlType( + DruidTypeSystem.INSTANCE, + SqlTypeName.ANY + ); + } else { + defaultInference.inferOperandTypes(callBinding, returnType, types); + } + } + }; + } return new SqlFunction( name, kind, Preconditions.checkNotNull(returnTypeInference, "returnTypeInference"), - new DefaultOperandTypeInference(operandTypes, nullableOperands), + operandTypeInference, theOperandTypeChecker, functionCategory ); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index c7ea655088bf..9ba5214446ff 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -51,7 +51,12 @@ import org.joda.time.format.ISODateTimeFormat; import javax.annotation.Nullable; +import java.math.BigDecimal; import java.nio.charset.Charset; +import java.sql.Date; +import java.sql.JDBCType; +import java.sql.Time; +import java.sql.Timestamp; import java.util.NavigableSet; import java.util.TreeSet; import java.util.regex.Pattern; @@ -397,4 +402,45 @@ public static int collapseFetch(int innerFetch, int outerFetch, int outerOffset) } return fetch; } + + public static Class sqlTypeNameJdbcToJavaClass(SqlTypeName typeName) + { + // reference: https://docs.oracle.com/javase/1.5.0/docs/guide/jdbc/getstart/mapping.html + JDBCType jdbcType = JDBCType.valueOf(typeName.getJdbcOrdinal()); + switch (jdbcType) { + case CHAR: + case VARCHAR: + case LONGVARCHAR: + return String.class; + case NUMERIC: + case DECIMAL: + return BigDecimal.class; + case BIT: + return Boolean.class; + case TINYINT: + return Byte.class; + case SMALLINT: + return Short.class; + case INTEGER: + return Integer.class; + case BIGINT: + return Long.class; + case REAL: + return Float.class; + case FLOAT: + case DOUBLE: + return Double.class; + case BINARY: + case VARBINARY: + return Byte[].class; + case DATE: + return Date.class; + case TIME: + return Time.class; + case TIMESTAMP: + return Timestamp.class; + default: + return Object.class; + } + } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidConvertletTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidConvertletTable.java index 84b26cdb4bda..c8cfb045f6d1 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidConvertletTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidConvertletTable.java @@ -25,7 +25,7 @@ import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.fun.OracleSqlOperatorTable; +import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql2rel.SqlRexContext; import org.apache.calcite.sql2rel.SqlRexConvertlet; @@ -70,7 +70,7 @@ public class DruidConvertletTable implements SqlRexConvertletTable .add(SqlStdOperatorTable.UNION_ALL) .add(SqlStdOperatorTable.NULLIF) .add(SqlStdOperatorTable.COALESCE) - .add(OracleSqlOperatorTable.NVL) + .add(SqlLibraryOperators.NVL) .build(); private final Map table; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java index 88622e8154a7..795bacaa6f00 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java @@ -19,6 +19,7 @@ package org.apache.druid.sql.calcite.planner; +import com.google.common.base.Preconditions; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; @@ -26,17 +27,24 @@ import com.google.common.primitives.Ints; import org.apache.calcite.DataContext; import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionConfigImpl; +import org.apache.calcite.config.CalciteConnectionProperty; import org.apache.calcite.interpreter.BindableConvention; import org.apache.calcite.interpreter.BindableRel; import org.apache.calcite.interpreter.Bindables; +import org.apache.calcite.jdbc.CalciteSchema; import org.apache.calcite.linq4j.Enumerable; import org.apache.calcite.linq4j.Enumerator; import org.apache.calcite.plan.RelOptPlanner; import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.logical.LogicalSort; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; @@ -46,6 +54,10 @@ import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.type.BasicSqlType; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorUtil; +import org.apache.calcite.tools.FrameworkConfig; +import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.Planner; import org.apache.calcite.tools.RelConversionException; import org.apache.calcite.tools.ValidationException; @@ -62,23 +74,50 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Properties; import java.util.Set; public class DruidPlanner implements Closeable { + private final FrameworkConfig frameworkConfig; private final Planner planner; private final PlannerContext plannerContext; private RexBuilder rexBuilder; public DruidPlanner( - final Planner planner, + final FrameworkConfig frameworkConfig, final PlannerContext plannerContext ) { - this.planner = planner; + this.frameworkConfig = frameworkConfig; + this.planner = Frameworks.getPlanner(frameworkConfig); this.plannerContext = plannerContext; } + public PrepareResult prepare(final String sql) throws SqlParseException, ValidationException, RelConversionException + { + SqlNode parsed = planner.parse(sql); + SqlExplain explain = null; + if (parsed.getKind() == SqlKind.EXPLAIN) { + explain = (SqlExplain) parsed; + parsed = explain.getExplicandum(); + } + final SqlNode validated = planner.validate(parsed); + RelRoot root = planner.rel(validated); + RelDataType rowType = root.validatedRowType; + + // this is sort of lame, planner won't cough up its validator, it is private and has no accessors, so make another + // one so we can get the parameter types... but i suppose beats creating our own Prepare and Planner implementations + SqlValidator validator = getValidator(); + RelDataType parameterTypes = validator.getParameterRowType(validator.validate(parsed)); + + if (explain != null) { + final RelDataTypeFactory typeFactory = root.rel.getCluster().getTypeFactory(); + return new PrepareResult(getExplainStructType(typeFactory), parameterTypes); + } + return new PrepareResult(rowType, parameterTypes); + } + public PlannerResult plan(final String sql) throws SqlParseException, ValidationException, RelConversionException { @@ -91,7 +130,9 @@ public PlannerResult plan(final String sql) // the planner's type factory is not available until after parsing this.rexBuilder = new RexBuilder(planner.getTypeFactory()); - final SqlNode validated = planner.validate(parsed); + SqlParameterizerShuttle sshuttle = new SqlParameterizerShuttle(plannerContext); + SqlNode parametized = parsed.accept(sshuttle); + final SqlNode validated = planner.validate(parametized); final RelRoot root = planner.rel(validated); try { @@ -120,6 +161,38 @@ public void close() planner.close(); } + private SqlValidator getValidator() + { + Preconditions.checkNotNull(planner.getTypeFactory()); + + final CalciteConnectionConfig connectionConfig; + + if (frameworkConfig.getContext() != null) { + connectionConfig = frameworkConfig.getContext().unwrap(CalciteConnectionConfig.class); + } else { + Properties properties = new Properties(); + properties.setProperty( + CalciteConnectionProperty.CASE_SENSITIVE.camelName(), + String.valueOf(PlannerFactory.PARSER_CONFIG.caseSensitive()) + ); + connectionConfig = new CalciteConnectionConfigImpl(properties); + } + + Prepare.CatalogReader catalogReader = new CalciteCatalogReader( + CalciteSchema.from(frameworkConfig.getDefaultSchema().getParentSchema()), + CalciteSchema.from(frameworkConfig.getDefaultSchema()).path(null), + planner.getTypeFactory(), + connectionConfig + ); + + return SqlValidatorUtil.newValidator( + frameworkConfig.getOperatorTable(), + catalogReader, + planner.getTypeFactory(), + DruidConformance.instance() + ); + } + private PlannerResult planWithDruidConvention( final SqlExplain explain, final RelRoot root @@ -127,12 +200,14 @@ private PlannerResult planWithDruidConvention( { final RelNode possiblyWrappedRootRel = possiblyWrapRootWithOuterLimitFromContext(root); + RelParameterizerShuttle parametizer = new RelParameterizerShuttle(plannerContext); + RelNode parametized = possiblyWrappedRootRel.accept(parametizer); final DruidRel druidRel = (DruidRel) planner.transform( Rules.DRUID_CONVENTION_RULES, planner.getEmptyTraitSet() .replace(DruidConvention.instance()) .plus(root.collation), - possiblyWrappedRootRel + parametized ); final Set dataSourceNames = ImmutableSet.copyOf(druidRel.getDataSourceNames()); @@ -195,7 +270,7 @@ private PlannerResult planWithBindableConvention( return planExplanation(bindableRel, explain, ImmutableSet.of()); } else { final BindableRel theRel = bindableRel; - final DataContext dataContext = plannerContext.createDataContext((JavaTypeFactory) planner.getTypeFactory()); + final DataContext dataContext = plannerContext.createDataContext((JavaTypeFactory) planner.getTypeFactory(), plannerContext.getParameters()); final Supplier> resultsSupplier = () -> { final Enumerable enumerable = theRel.bind(dataContext); final Enumerator enumerator = enumerable.enumerator(); @@ -294,6 +369,26 @@ private RexNode makeBigIntLiteral(long value) ); } + private PlannerResult planExplanation( + final RelNode rel, + final SqlExplain explain, + final Set datasourceNames + ) + { + final String explanation = RelOptUtil.dumpPlan("", rel, explain.getFormat(), explain.getDetailLevel()); + final Supplier> resultsSupplier = Suppliers.ofInstance( + Sequences.simple(ImmutableList.of(new Object[]{explanation}))); + return new PlannerResult(resultsSupplier, getExplainStructType(rel.getCluster().getTypeFactory()), datasourceNames); + } + + private static RelDataType getExplainStructType(RelDataTypeFactory typeFactory) + { + return typeFactory.createStructType( + ImmutableList.of(Calcites.createSqlType(typeFactory, SqlTypeName.VARCHAR)), + ImmutableList.of("PLAN") + ); + } + private static class EnumeratorIterator implements Iterator { private final Iterator it; @@ -315,24 +410,4 @@ public T next() return it.next(); } } - - private PlannerResult planExplanation( - final RelNode rel, - final SqlExplain explain, - final Set datasourceNames - ) - { - final String explanation = RelOptUtil.dumpPlan("", rel, explain.getFormat(), explain.getDetailLevel()); - final Supplier> resultsSupplier = Suppliers.ofInstance( - Sequences.simple(ImmutableList.of(new Object[]{explanation}))); - final RelDataTypeFactory typeFactory = rel.getCluster().getTypeFactory(); - return new PlannerResult( - resultsSupplier, - typeFactory.createStructType( - ImmutableList.of(Calcites.createSqlType(typeFactory, SqlTypeName.VARCHAR)), - ImmutableList.of("PLAN") - ), - datasourceNames - ); - } } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java index 2dafa252d55c..48c13c08938b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidRexExecutor.java @@ -116,13 +116,17 @@ public void reduce( } else if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) { final BigDecimal bigDecimal; - if (exprResult.type() == ExprType.LONG) { - bigDecimal = BigDecimal.valueOf(exprResult.asLong()); + if (exprResult.isNumericNull()) { + literal = rexBuilder.makeNullLiteral(constExp.getType()); } else { - bigDecimal = BigDecimal.valueOf(exprResult.asDouble()); - } + if (exprResult.type() == ExprType.LONG) { + bigDecimal = BigDecimal.valueOf(exprResult.asLong()); + } else { + bigDecimal = BigDecimal.valueOf(exprResult.asDouble()); + } - literal = rexBuilder.makeLiteral(bigDecimal, constExp.getType(), true); + literal = rexBuilder.makeLiteral(bigDecimal, constExp.getType(), true); + } } else if (sqlTypeName == SqlTypeName.ARRAY) { assert exprResult.isArray(); literal = rexBuilder.makeLiteral(Arrays.asList(exprResult.asArray()), constExp.getType(), true); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java index 03dc196c6ad2..db8ff97deadd 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.calcite.DataContext; import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.avatica.remote.TypedValue; import org.apache.calcite.linq4j.QueryProvider; import org.apache.calcite.schema.SchemaPlus; import org.apache.druid.java.util.common.DateTimes; @@ -62,16 +63,19 @@ public class PlannerContext private final PlannerConfig plannerConfig; private final DateTime localNow; private final Map queryContext; + private final List parameters; private final AuthenticationResult authenticationResult; private final String sqlQueryId; private final List nativeQueryIds = new CopyOnWriteArrayList<>(); + private PlannerContext( final DruidOperatorTable operatorTable, final ExprMacroTable macroTable, final PlannerConfig plannerConfig, final DateTime localNow, final Map queryContext, + final List parameters, final AuthenticationResult authenticationResult ) { @@ -79,6 +83,7 @@ private PlannerContext( this.macroTable = macroTable; this.plannerConfig = Preconditions.checkNotNull(plannerConfig, "plannerConfig"); this.queryContext = queryContext != null ? new HashMap<>(queryContext) : new HashMap<>(); + this.parameters = Preconditions.checkNotNull(parameters); this.localNow = Preconditions.checkNotNull(localNow, "localNow"); this.authenticationResult = Preconditions.checkNotNull(authenticationResult, "authenticationResult"); @@ -95,6 +100,7 @@ public static PlannerContext create( final ExprMacroTable macroTable, final PlannerConfig plannerConfig, final Map queryContext, + final List parameters, final AuthenticationResult authenticationResult ) { @@ -127,6 +133,7 @@ public static PlannerContext create( plannerConfig.withOverrides(queryContext), utcNow.withZone(timeZone), queryContext, + parameters, authenticationResult ); } @@ -161,6 +168,11 @@ public Map getQueryContext() return queryContext; } + public List getParameters() + { + return parameters; + } + public AuthenticationResult getAuthenticationResult() { return authenticationResult; @@ -181,11 +193,11 @@ public void addNativeQueryId(String queryId) this.nativeQueryIds.add(queryId); } - public DataContext createDataContext(final JavaTypeFactory typeFactory) + public DataContext createDataContext(final JavaTypeFactory typeFactory, List parameters) { class DruidDataContext implements DataContext { - private final Map context = ImmutableMap.of( + private final Map base_context = ImmutableMap.of( DataContext.Variable.UTC_TIMESTAMP.camelName, localNow.getMillis(), DataContext.Variable.CURRENT_TIMESTAMP.camelName, localNow.getMillis(), DataContext.Variable.LOCAL_TIMESTAMP.camelName, new Interval( @@ -195,6 +207,19 @@ DataContext.Variable.LOCAL_TIMESTAMP.camelName, new Interval( DataContext.Variable.TIME_ZONE.camelName, localNow.getZone().toTimeZone().clone(), DATA_CTX_AUTHENTICATION_RESULT, authenticationResult ); + private final Map context; + + DruidDataContext() + { + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.putAll(base_context); + int i = 0; + for (TypedValue parameter : parameters) { + builder.put("?" + i, parameter.value); + i++; + } + context = builder.build(); + } @Override public SchemaPlus getRootSchema() diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java index fb14444d3ea1..ffa325273ae9 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Inject; +import org.apache.calcite.avatica.remote.TypedValue; import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.avatica.util.Quoting; import org.apache.calcite.config.CalciteConnectionConfig; @@ -43,12 +44,13 @@ import org.apache.druid.sql.calcite.rel.QueryMaker; import org.apache.druid.sql.calcite.schema.DruidSchemaName; +import java.util.List; import java.util.Map; import java.util.Properties; public class PlannerFactory { - private static final SqlParser.Config PARSER_CONFIG = SqlParser + static final SqlParser.Config PARSER_CONFIG = SqlParser .configBuilder() .setCaseSensitive(true) .setUnquotedCasing(Casing.UNCHANGED) @@ -90,6 +92,7 @@ public PlannerFactory( public DruidPlanner createPlanner( final Map queryContext, + final List parameters, final AuthenticationResult authenticationResult ) { @@ -98,6 +101,7 @@ public DruidPlanner createPlanner( macroTable, plannerConfig, queryContext, + parameters, authenticationResult ); final QueryMaker queryMaker = new QueryMaker(queryLifecycleFactory, plannerContext, jsonMapper); @@ -152,7 +156,7 @@ public SqlConformance conformance() .build(); return new DruidPlanner( - Frameworks.getPlanner(frameworkConfig), + frameworkConfig, plannerContext ); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PrepareResult.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PrepareResult.java new file mode 100644 index 000000000000..9e6b27b6cf56 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PrepareResult.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.planner; + +import org.apache.calcite.rel.type.RelDataType; + +public class PrepareResult +{ + private final RelDataType rowType; + private final RelDataType parameterRowType; + + public PrepareResult(final RelDataType rowType, final RelDataType parameterRowType) + { + this.rowType = rowType; + this.parameterRowType = parameterRowType; + } + + public RelDataType getRowType() + { + return rowType; + } + + public RelDataType getParameterRowType() + { + return parameterRowType; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/RelParameterizerShuttle.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/RelParameterizerShuttle.java new file mode 100644 index 000000000000..7607f1d3e45c --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/RelParameterizerShuttle.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.planner; + +import org.apache.calcite.avatica.remote.TypedValue; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelShuttle; +import org.apache.calcite.rel.RelVisitor; +import org.apache.calcite.rel.core.TableFunctionScan; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.logical.LogicalCorrelate; +import org.apache.calcite.rel.logical.LogicalExchange; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalIntersect; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalMatch; +import org.apache.calcite.rel.logical.LogicalMinus; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalSort; +import org.apache.calcite.rel.logical.LogicalUnion; +import org.apache.calcite.rel.logical.LogicalValues; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexDynamicParam; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.java.util.common.ISE; + +/** + * Traverse {@link RelNode} tree and replaces all {@link RexDynamicParam} with {@link org.apache.calcite.rex.RexLiteral} + * using {@link RexBuilder} if a value binding exists for the parameter. All parameters must have a value by the time + * {@link RelParameterizerShuttle} is run, or else it will throw an exception. + * + * Note: none of the tests currently hit this anymore since {@link SqlParameterizerShuttle} has been modified to handle + * most common jdbc types, but leaving this here provides a safety net to try again to convert parameters + * to literal values in case {@link SqlParameterizerShuttle} fails. + */ +public class RelParameterizerShuttle implements RelShuttle +{ + private final PlannerContext plannerContext; + + public RelParameterizerShuttle(PlannerContext plannerContext) + { + this.plannerContext = plannerContext; + } + + @Override + public RelNode visit(TableScan scan) + { + return bindRel(scan); + } + + @Override + public RelNode visit(TableFunctionScan scan) + { + return bindRel(scan); + } + + @Override + public RelNode visit(LogicalValues values) + { + return bindRel(values); + } + + @Override + public RelNode visit(LogicalFilter filter) + { + return bindRel(filter); + } + + @Override + public RelNode visit(LogicalProject project) + { + return bindRel(project); + } + + @Override + public RelNode visit(LogicalJoin join) + { + return bindRel(join); + } + + @Override + public RelNode visit(LogicalCorrelate correlate) + { + return bindRel(correlate); + } + + @Override + public RelNode visit(LogicalUnion union) + { + return bindRel(union); + } + + @Override + public RelNode visit(LogicalIntersect intersect) + { + return bindRel(intersect); + } + + @Override + public RelNode visit(LogicalMinus minus) + { + return bindRel(minus); + } + + @Override + public RelNode visit(LogicalAggregate aggregate) + { + return bindRel(aggregate); + } + + @Override + public RelNode visit(LogicalMatch match) + { + return bindRel(match); + } + + @Override + public RelNode visit(LogicalSort sort) + { + final RexBuilder builder = sort.getCluster().getRexBuilder(); + final RelDataTypeFactory typeFactory = sort.getCluster().getTypeFactory(); + RexNode newFetch = bind(sort.fetch, builder, typeFactory); + RexNode newOffset = bind(sort.offset, builder, typeFactory); + sort = (LogicalSort) sort.copy(sort.getTraitSet(), sort.getInput(), sort.getCollation(), newOffset, newFetch); + return bindRel(sort, builder, typeFactory); + } + + @Override + public RelNode visit(LogicalExchange exchange) + { + return bindRel(exchange); + } + + @Override + public RelNode visit(RelNode other) + { + return bindRel(other); + } + + private RelNode bindRel(RelNode node) + { + final RexBuilder builder = node.getCluster().getRexBuilder(); + final RelDataTypeFactory typeFactory = node.getCluster().getTypeFactory(); + return bindRel(node, builder, typeFactory); + } + + private RelNode bindRel(RelNode node, RexBuilder builder, RelDataTypeFactory typeFactory) + { + final RexShuttle binder = new RexShuttle() + { + @Override + public RexNode visitDynamicParam(RexDynamicParam dynamicParam) + { + return bind(dynamicParam, builder, typeFactory); + } + }; + node = node.accept(binder); + node.childrenAccept(new RelVisitor() + { + @Override + public void visit(RelNode node, int ordinal, RelNode parent) + { + super.visit(node, ordinal, parent); + RelNode transformed = node.accept(binder); + if (!node.equals(transformed)) { + parent.replaceInput(ordinal, transformed); + } + } + }); + return node; + } + + private RexNode bind(RexNode node, RexBuilder builder, RelDataTypeFactory typeFactory) + { + if (node instanceof RexDynamicParam) { + RexDynamicParam dynamicParam = (RexDynamicParam) node; + // if we have a value for dynamic parameter, replace with a literal, else add to list of unbound parameters + if (plannerContext.getParameters().size() > dynamicParam.getIndex()) { + TypedValue param = plannerContext.getParameters().get(dynamicParam.getIndex()); + if (param.value == null) { + return builder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.NULL)); + } + SqlTypeName typeName = SqlTypeName.getNameForJdbcType(param.type.typeId); + return builder.makeLiteral( + param.value, + typeFactory.createSqlType(typeName), + true + ); + } else { + throw new ISE("Parameter: [%s] is not bound", dynamicParam.getName()); + } + } + return node; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlParameterizerShuttle.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlParameterizerShuttle.java new file mode 100644 index 000000000000..52c486cf18b9 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/SqlParameterizerShuttle.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.planner; + +import org.apache.calcite.avatica.remote.TypedValue; +import org.apache.calcite.sql.SqlDynamicParam; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.util.SqlShuttle; +import org.apache.calcite.util.TimestampString; + +/** + * Replaces all {@link SqlDynamicParam} encountered in an {@link SqlNode} tree with a {@link SqlLiteral} if a value + * binding exists for the parameter, if possible. This is used in tandem with {@link RelParameterizerShuttle}. + * + * It is preferable that all parameters are placed here to pick up as many optimizations as possible, but the facilities + * to convert jdbc types to {@link SqlLiteral} are a bit less rich here than exist for converting a + * {@link org.apache.calcite.rex.RexDynamicParam} to {@link org.apache.calcite.rex.RexLiteral}, which is why + * {@link SqlParameterizerShuttle} and {@link RelParameterizerShuttle} both exist. + */ +public class SqlParameterizerShuttle extends SqlShuttle +{ + private final PlannerContext plannerContext; + + public SqlParameterizerShuttle(PlannerContext plannerContext) + { + this.plannerContext = plannerContext; + } + + @Override + public SqlNode visit(SqlDynamicParam param) + { + try { + if (plannerContext.getParameters().size() > param.getIndex()) { + TypedValue paramBinding = plannerContext.getParameters().get(param.getIndex()); + if (paramBinding.value == null) { + return SqlLiteral.createNull(param.getParserPosition()); + } + SqlTypeName typeName = SqlTypeName.getNameForJdbcType(paramBinding.type.typeId); + if (SqlTypeName.APPROX_TYPES.contains(typeName)) { + return SqlLiteral.createApproxNumeric(paramBinding.value.toString(), param.getParserPosition()); + } + if (SqlTypeName.TIMESTAMP.equals(typeName) && paramBinding.value instanceof Long) { + return SqlLiteral.createTimestamp( + TimestampString.fromMillisSinceEpoch((Long) paramBinding.value), + 0, + param.getParserPosition() + ); + } + + return typeName.createLiteral(paramBinding.value, param.getParserPosition()); + } + } + catch (ClassCastException ignored) { + // suppress + } + return param; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java b/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java index 96e3eb13f91a..f9c1ff98998f 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java @@ -63,8 +63,8 @@ public TranslatableTable apply(final List arguments) final RelDataType rowType; // Using an escalator here is a hack, but it's currently needed to get the row type. Ideally, some // later refactoring would make this unnecessary, since there is no actual query going out herem. - final AuthenticationResult authenticationResult = escalator.createEscalatedAuthenticationResult(); - try (final DruidPlanner planner = plannerFactory.createPlanner(null, authenticationResult)) { + final AuthenticationResult authResult = escalator.createEscalatedAuthenticationResult(); + try (final DruidPlanner planner = plannerFactory.createPlanner(null, ImmutableList.of(), authResult)) { rowType = planner.plan(viewSql).rowType(); } diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlParameter.java b/sql/src/main/java/org/apache/druid/sql/http/SqlParameter.java new file mode 100644 index 000000000000..7e8e190d3efa --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlParameter.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.http; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.SqlType; +import org.apache.calcite.avatica.remote.TypedValue; +import org.apache.calcite.runtime.SqlFunctions; +import org.apache.calcite.util.TimestampString; +import org.apache.druid.java.util.common.DateTimes; + +import javax.annotation.Nullable; +import java.sql.Date; +import java.util.Objects; + +public class SqlParameter +{ + private final SqlType type; + private final Object value; + + @JsonCreator + public SqlParameter( + @JsonProperty("type") SqlType type, + @JsonProperty("value") @Nullable Object value + ) + { + this.type = Preconditions.checkNotNull(type); + this.value = value; + } + + @JsonProperty + public Object getValue() + { + return value; + } + + @JsonProperty + public SqlType getType() + { + return type; + } + + @JsonIgnore + public TypedValue getTypedValue() + { + Object adjustedValue = value; + + // perhaps there is a better way to do this? + if (type == SqlType.TIMESTAMP) { + // TypedValue.create for TIMESTAMP expects a long... + // but be lenient try to accept iso format and sql 'timestamp' format\ + if (value instanceof String) { + try { + adjustedValue = DateTimes.of((String) value).getMillis(); + } + catch (IllegalArgumentException ignore) { + } + try { + adjustedValue = new TimestampString((String) value).getMillisSinceEpoch(); + } + catch (IllegalArgumentException ignore) { + } + } + } else if (type == SqlType.DATE) { + // TypedValue.create for DATE expects calcites internal int representation of sql dates + // but be lenient try to accept sql date 'yyyy-MM-dd' format and convert to internal calcite int representation + if (value instanceof String) { + try { + adjustedValue = SqlFunctions.toInt(Date.valueOf((String) value)); + } + catch (IllegalArgumentException ignore) { + } + } + } + return TypedValue.create(ColumnMetaData.Rep.nonPrimitiveRepOf(type).name(), adjustedValue); + } + + @Override + public String toString() + { + return "SqlParameter{" + + ", value={" + type.name() + ',' + value + '}' + + '}'; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SqlParameter that = (SqlParameter) o; + return Objects.equals(type, that.type) && + Objects.equals(value, that.value); + } + + @Override + public int hashCode() + { + return Objects.hash(type, value); + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlQuery.java b/sql/src/main/java/org/apache/druid/sql/http/SqlQuery.java index 4e2c8739a42f..1df21a652de3 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlQuery.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlQuery.java @@ -22,30 +22,44 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import org.apache.calcite.avatica.remote.TypedValue; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; public class SqlQuery { + public static List getParameterList(List parameters) + { + return parameters.stream() + .map(SqlParameter::getTypedValue) + .collect(Collectors.toList()); + } + private final String query; private final ResultFormat resultFormat; private final boolean header; private final Map context; + private final List parameters; @JsonCreator public SqlQuery( @JsonProperty("query") final String query, @JsonProperty("resultFormat") final ResultFormat resultFormat, @JsonProperty("header") final boolean header, - @JsonProperty("context") final Map context + @JsonProperty("context") final Map context, + @JsonProperty("parameters") final List parameters ) { this.query = Preconditions.checkNotNull(query, "query"); this.resultFormat = resultFormat == null ? ResultFormat.OBJECT : resultFormat; this.header = header; this.context = context == null ? ImmutableMap.of() : context; + this.parameters = parameters == null ? ImmutableList.of() : parameters; } @JsonProperty @@ -72,6 +86,17 @@ public Map getContext() return context; } + @JsonProperty + public List getParameters() + { + return parameters; + } + + public List getParameterList() + { + return getParameterList(parameters); + } + @Override public boolean equals(final Object o) { @@ -85,13 +110,14 @@ public boolean equals(final Object o) return header == sqlQuery.header && Objects.equals(query, sqlQuery.query) && resultFormat == sqlQuery.resultFormat && - Objects.equals(context, sqlQuery.context); + Objects.equals(context, sqlQuery.context) && + Objects.equals(parameters, sqlQuery.parameters); } @Override public int hashCode() { - return Objects.hash(query, resultFormat, header, context); + return Objects.hash(query, resultFormat, header, context, parameters); } @Override @@ -102,6 +128,7 @@ public String toString() ", resultFormat=" + resultFormat + ", header=" + header + ", context=" + context + + ", parameters=" + parameters + '}'; } } diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java index c457066ca6a3..cc7e6cdf2371 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java @@ -88,6 +88,8 @@ public Response doPost( try { Thread.currentThread().setName(StringUtils.format("sql[%s]", sqlQueryId)); + lifecycle.setParameters(sqlQuery.getParameterList()); + final PlannerContext plannerContext = lifecycle.planAndAuthorize(req); final DateTimeZone timeZone = plannerContext.getTimeZone(); diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java index 9a613b8ce78d..5e1d43844651 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java @@ -84,6 +84,7 @@ import java.sql.DatabaseMetaData; import java.sql.Date; import java.sql.DriverManager; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; @@ -926,6 +927,37 @@ public void testSqlRequestLog() throws Exception Assert.assertEquals(0, testRequestLogger.getSqlQueryLogs().size()); } + @Test + public void testParameterBinding() throws Exception + { + PreparedStatement statement = client.prepareStatement("SELECT COUNT(*) AS cnt FROM druid.foo WHERE dim1 = ? OR dim1 = ?"); + statement.setString(1, "abc"); + statement.setString(2, "def"); + final ResultSet resultSet = statement.executeQuery(); + final List> rows = getRows(resultSet); + Assert.assertEquals( + ImmutableList.of( + ImmutableMap.of("cnt", 2L) + ), + rows + ); + } + + @Test + public void testSysTableParameterBinding() throws Exception + { + PreparedStatement statement = client.prepareStatement("SELECT COUNT(*) AS cnt FROM sys.servers WHERE servers.host = ?"); + statement.setString(1, "dummy"); + final ResultSet resultSet = statement.executeQuery(); + final List> rows = getRows(resultSet); + Assert.assertEquals( + ImmutableList.of( + ImmutableMap.of("cnt", 1L) + ), + rows + ); + } + private static List> getRows(final ResultSet resultSet) throws SQLException { return getRows(resultSet, null); diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java index 61980bf396b9..93d5bf4dae7c 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java @@ -50,6 +50,7 @@ import org.junit.rules.TemporaryFolder; import java.io.IOException; +import java.util.Collections; import java.util.List; public class DruidStatementTest extends CalciteTestBase @@ -159,7 +160,7 @@ public void testSelectAllInFirstFrame() }).prepare(sql, -1, AllowAllAuthenticator.ALLOW_ALL_RESULT); // First frame, ask for all rows. - Meta.Frame frame = statement.execute().nextFrame(DruidStatement.START_OFFSET, 6); + Meta.Frame frame = statement.execute(Collections.emptyList()).nextFrame(DruidStatement.START_OFFSET, 6); Assert.assertEquals( Meta.Frame.create( 0, @@ -192,7 +193,7 @@ public void testSelectSplitOverTwoFrames() }).prepare(sql, -1, AllowAllAuthenticator.ALLOW_ALL_RESULT); // First frame, ask for 2 rows. - Meta.Frame frame = statement.execute().nextFrame(DruidStatement.START_OFFSET, 2); + Meta.Frame frame = statement.execute(Collections.emptyList()).nextFrame(DruidStatement.START_OFFSET, 2); Assert.assertEquals( Meta.Frame.create( 0, diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java index 91727eb9f376..67b0aeecb88e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java @@ -78,6 +78,7 @@ import org.apache.druid.sql.calcite.util.QueryLogHook; import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; import org.apache.druid.sql.calcite.view.InProcessViewManager; +import org.apache.druid.sql.http.SqlParameter; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.joda.time.Interval; @@ -495,6 +496,7 @@ public void testQuery( testQuery( PLANNER_CONFIG_DEFAULT, QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, sql, CalciteTests.REGULAR_USER_AUTH_RESULT, expectedQueries, @@ -504,14 +506,15 @@ public void testQuery( public void testQuery( final String sql, - final Map queryContext, + final Map context, final List expectedQueries, final List expectedResults ) throws Exception { testQuery( PLANNER_CONFIG_DEFAULT, - queryContext, + context, + DEFAULT_PARAMETERS, sql, CalciteTests.REGULAR_USER_AUTH_RESULT, expectedQueries, @@ -519,15 +522,57 @@ public void testQuery( ); } + public void testQuery( + final String sql, + final List expectedQueries, + final List expectedResults, + final List parameters + ) throws Exception + { + testQuery( + PLANNER_CONFIG_DEFAULT, + QUERY_CONTEXT_DEFAULT, + parameters, + sql, + CalciteTests.REGULAR_USER_AUTH_RESULT, + expectedQueries, + expectedResults + ); + } + + public void testQuery( + final PlannerConfig plannerConfig, + final String sql, + final AuthenticationResult authenticationResult, + final List expectedQueries, + final List expectedResults + ) throws Exception + { + testQuery( + plannerConfig, + QUERY_CONTEXT_DEFAULT, + DEFAULT_PARAMETERS, + sql, + authenticationResult, + expectedQueries, + expectedResults + ); + } + public void testQuery( final PlannerConfig plannerConfig, + final Map queryContext, final String sql, final AuthenticationResult authenticationResult, final List expectedQueries, final List expectedResults ) throws Exception { - testQuery(plannerConfig, QUERY_CONTEXT_DEFAULT, sql, authenticationResult, expectedQueries, expectedResults); + log.info("SQL: %s", sql); + queryLogHook.clearRecordedQueries(); + final List plannerResults = + getResults(plannerConfig, queryContext, DEFAULT_PARAMETERS, sql, authenticationResult); + verifyResults(sql, expectedQueries, expectedResults, plannerResults); } /** @@ -560,6 +605,7 @@ private DataSource recursivelyOverrideContext(final DataSource dataSource, final public void testQuery( final PlannerConfig plannerConfig, final Map queryContext, + final List parameters, final String sql, final AuthenticationResult authenticationResult, final List expectedQueries, @@ -596,7 +642,7 @@ public void testQuery( expectedException.expectMessage("Cannot vectorize"); } - final List plannerResults = getResults(plannerConfig, theQueryContext, sql, authenticationResult); + final List plannerResults = getResults(plannerConfig, theQueryContext, parameters, sql, authenticationResult); verifyResults(sql, theQueries, expectedResults, plannerResults); } } @@ -604,6 +650,7 @@ public void testQuery( public List getResults( final PlannerConfig plannerConfig, final Map queryContext, + final List parameters, final String sql, final AuthenticationResult authenticationResult ) throws Exception @@ -611,6 +658,7 @@ public List getResults( return getResults( plannerConfig, queryContext, + parameters, sql, authenticationResult, CalciteTests.createOperatorTable(), @@ -623,6 +671,7 @@ public List getResults( public List getResults( final PlannerConfig plannerConfig, final Map queryContext, + final List parameters, final String sql, final AuthenticationResult authenticationResult, final DruidOperatorTable operatorTable, @@ -666,7 +715,7 @@ public List getResults( + "WHERE __time >= CURRENT_TIMESTAMP + INTERVAL '1' DAY AND __time < TIMESTAMP '2002-01-01 00:00:00'" ); - return sqlLifecycleFactory.factorize().runSimple(sql, queryContext, authenticationResult).toList(); + return sqlLifecycleFactory.factorize().runSimple(sql, queryContext, parameters, authenticationResult).toList(); } public void verifyResults( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java new file mode 100644 index 000000000000..98263cb2dea4 --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteParameterQueryTest.java @@ -0,0 +1,694 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.avatica.SqlType; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.query.Druids; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.FilteredAggregatorFactory; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.groupby.GroupByQuery; +import org.apache.druid.query.ordering.StringComparators; +import org.apache.druid.query.scan.ScanQuery; +import org.apache.druid.segment.column.ValueType; +import org.apache.druid.sql.calcite.filtration.Filtration; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.apache.druid.sql.http.SqlParameter; +import org.junit.Test; + +/** + * This class has copied a subset of the tests in {@link CalciteQueryTest} and replaced various parts of queries with + * dynamic parameters. It is NOT important that this file remains in sync with {@link CalciteQueryTest}, the tests + * were merely chosen to produce a selection of parameter types and positions within query expressions and have been + * renamed to reflect this + */ +public class CalciteParameterQueryTest extends BaseCalciteQueryTest +{ + private final boolean useDefault = NullHandling.replaceWithDefault(); + + @Test + public void testSelectConstantParamGetsConstant() throws Exception + { + testQuery( + "SELECT 1 + ?", + ImmutableList.of(), + ImmutableList.of( + new Object[]{2} + ), + ImmutableList.of(new SqlParameter(SqlType.INTEGER, 1)) + ); + } + + @Test + public void testParamsGetOptimizedIntoConstant() throws Exception + { + testQuery( + "SELECT 1 + ?, dim1 FROM foo LIMIT ?", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(expressionVirtualColumn("v0", "2", ValueType.LONG)) + .columns("dim1", "v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .limit(1) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{2, ""} + ), + ImmutableList.of( + new SqlParameter(SqlType.INTEGER, 1), + new SqlParameter(SqlType.INTEGER, 1) + ) + ); + } + + @Test + public void testParametersInSelectAndFilter() throws Exception + { + testQuery( + PLANNER_CONFIG_DEFAULT, + QUERY_CONTEXT_DONT_SKIP_EMPTY_BUCKETS, + ImmutableList.of( + new SqlParameter(SqlType.INTEGER, 10), + new SqlParameter(SqlType.INTEGER, 0) + ), + "SELECT exp(count(*)) + ?, sum(m2) FROM druid.foo WHERE dim2 = ?", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .filters(numericSelector("dim2", "0", null)) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new CountAggregatorFactory("a0"), + new DoubleSumAggregatorFactory("a1", "m2") + )) + .postAggregators( + expressionPostAgg("p0", "(exp(\"a0\") + 10)") + ) + .context(QUERY_CONTEXT_DONT_SKIP_EMPTY_BUCKETS) + .build()), + ImmutableList.of( + new Object[]{11.0, NullHandling.defaultDoubleValue()} + ) + ); + } + + @Test + public void testSelectTrimFamilyWithParameters() throws Exception + { + // TRIM has some whacky parsing. Abuse this to test a bunch of parameters + + testQuery( + "SELECT\n" + + "TRIM(BOTH ? FROM ?),\n" + + "TRIM(TRAILING ? FROM ?),\n" + + "TRIM(? FROM ?),\n" + + "TRIM(TRAILING FROM ?),\n" + + "TRIM(?),\n" + + "BTRIM(?),\n" + + "BTRIM(?, ?),\n" + + "LTRIM(?),\n" + + "LTRIM(?, ?),\n" + + "RTRIM(?),\n" + + "RTRIM(?, ?),\n" + + "COUNT(*)\n" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .postAggregators( + expressionPostAgg("p0", "'foo'"), + expressionPostAgg("p1", "'xfoo'"), + expressionPostAgg("p2", "'foo'"), + expressionPostAgg("p3", "' foo'"), + expressionPostAgg("p4", "'foo'"), + expressionPostAgg("p5", "'foo'"), + expressionPostAgg("p6", "'foo'"), + expressionPostAgg("p7", "'foo '"), + expressionPostAgg("p8", "'foox'"), + expressionPostAgg("p9", "' foo'"), + expressionPostAgg("p10", "'xfoo'") + ) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"foo", "xfoo", "foo", " foo", "foo", "foo", "foo", "foo ", "foox", " foo", "xfoo", 6L} + ), + ImmutableList.of( + new SqlParameter(SqlType.VARCHAR, "x"), + new SqlParameter(SqlType.VARCHAR, "xfoox"), + new SqlParameter(SqlType.VARCHAR, "x"), + new SqlParameter(SqlType.VARCHAR, "xfoox"), + new SqlParameter(SqlType.VARCHAR, " "), + new SqlParameter(SqlType.VARCHAR, " foo "), + new SqlParameter(SqlType.VARCHAR, " foo "), + new SqlParameter(SqlType.VARCHAR, " foo "), + new SqlParameter(SqlType.VARCHAR, " foo "), + new SqlParameter(SqlType.VARCHAR, "xfoox"), + new SqlParameter(SqlType.VARCHAR, "x"), + new SqlParameter(SqlType.VARCHAR, " foo "), + new SqlParameter(SqlType.VARCHAR, "xfoox"), + new SqlParameter(SqlType.VARCHAR, "x"), + new SqlParameter(SqlType.VARCHAR, " foo "), + new SqlParameter(SqlType.VARCHAR, "xfoox"), + new SqlParameter(SqlType.VARCHAR, "x") + ) + ); + } + + @Test + public void testParamsInInformationSchema() throws Exception + { + // Not including COUNT DISTINCT, since it isn't supported by BindableAggregate, and so it can't work. + testQuery( + "SELECT\n" + + " COUNT(JDBC_TYPE),\n" + + " SUM(JDBC_TYPE),\n" + + " AVG(JDBC_TYPE),\n" + + " MIN(JDBC_TYPE),\n" + + " MAX(JDBC_TYPE)\n" + + "FROM INFORMATION_SCHEMA.COLUMNS\n" + + "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", + ImmutableList.of(), + ImmutableList.of( + new Object[]{8L, 1249L, 156L, -5L, 1111L} + ), + ImmutableList.of( + new SqlParameter(SqlType.VARCHAR, "druid"), + new SqlParameter(SqlType.VARCHAR, "foo") + ) + ); + } + + @Test + public void testParamsInSelectExpressionAndLimit() throws Exception + { + testQuery( + "SELECT SUBSTRING(dim2, ?, ?) FROM druid.foo LIMIT ?", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + expressionVirtualColumn("v0", "substring(\"dim2\", 0, 1)", ValueType.STRING) + ) + .columns("v0") + .limit(2) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"a"}, + new Object[]{NULL_STRING} + ), + ImmutableList.of( + new SqlParameter(SqlType.INTEGER, 1), + new SqlParameter(SqlType.INTEGER, 1), + new SqlParameter(SqlType.INTEGER, 2) + ) + ); + } + + @Test + public void testParamsTuckedInACast() throws Exception + { + cannotVectorize(); + testQuery( + "SELECT dim1, m1, COUNT(*) FROM druid.foo WHERE m1 - CAST(? as INT) = dim1 GROUP BY dim1, m1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(expressionFilter("((\"m1\" - 1) == CAST(\"dim1\", 'DOUBLE'))")) + .setDimensions(dimensions( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("m1", "d1", ValueType.FLOAT) + )) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.replaceWithDefault() ? + ImmutableList.of( + new Object[]{"", 1.0f, 1L}, + new Object[]{"2", 3.0f, 1L} + ) : + ImmutableList.of( + new Object[]{"2", 3.0f, 1L} + ), + ImmutableList.of( + new SqlParameter(SqlType.INTEGER, 1) + ) + ); + } + + @Test + public void testParametersInStrangePlaces() throws Exception + { + testQuery( + "SELECT\n" + + " dim1,\n" + + " COUNT(*) FILTER(WHERE dim2 <> ?)/COUNT(*) as ratio\n" + + "FROM druid.foo\n" + + "GROUP BY dim1\n" + + "HAVING COUNT(*) FILTER(WHERE dim2 <> ?)/COUNT(*) = ?", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0"))) + .setAggregatorSpecs(aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + not(selector("dim2", "a", null)) + ), + new CountAggregatorFactory("a1") + )) + .setPostAggregatorSpecs(ImmutableList.of( + expressionPostAgg("p0", "(\"a0\" / \"a1\")") + )) + .setHavingSpec(having(expressionFilter("((\"a0\" / \"a1\") == 1)"))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"10.1", 1L}, + new Object[]{"2", 1L}, + new Object[]{"abc", 1L}, + new Object[]{"def", 1L} + ), + ImmutableList.of( + new SqlParameter(SqlType.VARCHAR, "a"), + new SqlParameter(SqlType.VARCHAR, "a"), + new SqlParameter(SqlType.INTEGER, 1) + ) + ); + } + + @Test + public void testParametersInCases() throws Exception + { + testQuery( + "SELECT\n" + + " CASE 'foo'\n" + + " WHEN ? THEN SUM(cnt) / CAST(? as INT)\n" + + " WHEN ? THEN SUM(m1) / CAST(? as INT)\n" + + " WHEN ? THEN SUM(m2) / CAST(? as INT)\n" + + " END\n" + + "FROM foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators(new DoubleSumAggregatorFactory("a0", "m1"))) + .postAggregators(ImmutableList.of(expressionPostAgg("p0", "(\"a0\" / 10)"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{2.1}), + ImmutableList.of( + new SqlParameter(SqlType.VARCHAR, "bar"), + new SqlParameter(SqlType.INTEGER, 10), + new SqlParameter(SqlType.VARCHAR, "foo"), + new SqlParameter(SqlType.INTEGER, 10), + new SqlParameter(SqlType.VARCHAR, "baz"), + new SqlParameter(SqlType.INTEGER, 10) + ) + ); + } + + + @Test + public void testTimestamp() throws Exception + { + // with millis + testQuery( + PLANNER_CONFIG_DEFAULT, + QUERY_CONTEXT_DONT_SKIP_EMPTY_BUCKETS, + ImmutableList.of( + new SqlParameter(SqlType.INTEGER, 10), + new SqlParameter( + SqlType.TIMESTAMP, + DateTimes.of("2999-01-01T00:00:00Z").getMillis() + ) + ), + "SELECT exp(count(*)) + ?, sum(m2) FROM druid.foo WHERE __time >= ?", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Intervals.of( + "2999-01-01T00:00:00.000Z/146140482-04-24T15:36:27.903Z"))) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new CountAggregatorFactory("a0"), + new DoubleSumAggregatorFactory("a1", "m2") + )) + .postAggregators( + expressionPostAgg("p0", "(exp(\"a0\") + 10)") + ) + .context(QUERY_CONTEXT_DONT_SKIP_EMPTY_BUCKETS) + .build()), + ImmutableList.of( + new Object[]{11.0, NullHandling.defaultDoubleValue()} + ) + ); + + } + + @Test + public void testTimestampString() throws Exception + { + // with timestampstring + testQuery( + PLANNER_CONFIG_DEFAULT, + QUERY_CONTEXT_DONT_SKIP_EMPTY_BUCKETS, + ImmutableList.of( + new SqlParameter(SqlType.INTEGER, 10), + new SqlParameter( + SqlType.TIMESTAMP, + "2999-01-01 00:00:00" + ) + ), + "SELECT exp(count(*)) + ?, sum(m2) FROM druid.foo WHERE __time >= ?", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Intervals.of( + "2999-01-01T00:00:00.000Z/146140482-04-24T15:36:27.903Z"))) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new CountAggregatorFactory("a0"), + new DoubleSumAggregatorFactory("a1", "m2") + )) + .postAggregators( + expressionPostAgg("p0", "(exp(\"a0\") + 10)") + ) + .context(QUERY_CONTEXT_DONT_SKIP_EMPTY_BUCKETS) + .build()), + ImmutableList.of( + new Object[]{11.0, NullHandling.defaultDoubleValue()} + ) + ); + } + + @Test + public void testDate() throws Exception + { + // with date from millis + + testQuery( + PLANNER_CONFIG_DEFAULT, + QUERY_CONTEXT_DONT_SKIP_EMPTY_BUCKETS, + ImmutableList.of( + new SqlParameter(SqlType.INTEGER, 10), + new SqlParameter( + SqlType.DATE, + "2999-01-01" + ) + ), + "SELECT exp(count(*)) + ?, sum(m2) FROM druid.foo WHERE __time >= ?", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Intervals.of( + "2999-01-01T00:00:00.000Z/146140482-04-24T15:36:27.903Z"))) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new CountAggregatorFactory("a0"), + new DoubleSumAggregatorFactory("a1", "m2") + )) + .postAggregators( + expressionPostAgg("p0", "(exp(\"a0\") + 10)") + ) + .context(QUERY_CONTEXT_DONT_SKIP_EMPTY_BUCKETS) + .build()), + ImmutableList.of( + new Object[]{11.0, NullHandling.defaultDoubleValue()} + ) + ); + } + + @Test + public void testDoubles() throws Exception + { + testQuery( + "SELECT COUNT(*) FROM druid.foo WHERE cnt > ? and cnt < ?", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters( + bound("cnt", "1.1", "100000001", true, true, null, StringComparators.NUMERIC) + ) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(), + ImmutableList.of( + new SqlParameter(SqlType.DOUBLE, 1.1), + new SqlParameter(SqlType.FLOAT, 100000001.0) + ) + ); + + + testQuery( + "SELECT COUNT(*) FROM druid.foo WHERE cnt = ? or cnt = ?", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters( + in("cnt", ImmutableList.of("1.0", "100000001"), null) + ) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{6L} + ), + ImmutableList.of( + new SqlParameter(SqlType.DOUBLE, 1.0), + new SqlParameter(SqlType.FLOAT, 100000001.0) + ) + ); + } + + @Test + public void testFloats() throws Exception + { + testQuery( + "SELECT COUNT(*) FROM druid.foo WHERE cnt = ?", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters( + selector("cnt", "1.0", null) + ) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{6L}), + ImmutableList.of(new SqlParameter(SqlType.REAL, 1.0f)) + ); + } + + @Test + public void testLongs() throws Exception + { + testQuery( + "SELECT COUNT(*)\n" + + "FROM druid.numfoo\n" + + "WHERE l1 > ?", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(bound("l1", "3", null, true, false, null, StringComparators.NUMERIC)) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of(new Object[]{2L}), + ImmutableList.of(new SqlParameter(SqlType.BIGINT, 3L)) + ); + } + + @Test + public void testMissingParameter() throws Exception + { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Parameter: [?0] is not bound"); + testQuery( + "SELECT COUNT(*)\n" + + "FROM druid.numfoo\n" + + "WHERE l1 > ?", + ImmutableList.of(), + ImmutableList.of(new Object[]{3L}), + ImmutableList.of() + ); + } + + @Test + public void testPartiallyMissingParameter() throws Exception + { + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("Parameter: [?1] is not bound"); + testQuery( + "SELECT COUNT(*)\n" + + "FROM druid.numfoo\n" + + "WHERE l1 > ? AND f1 = ?", + ImmutableList.of(), + ImmutableList.of(new Object[]{3L}), + ImmutableList.of(new SqlParameter(SqlType.BIGINT, 3L)) + ); + } + + @Test + public void testWrongTypeParameter() throws Exception + { + testQuery( + "SELECT COUNT(*)\n" + + "FROM druid.numfoo\n" + + "WHERE l1 > ? AND f1 = ?", + useDefault ? ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters( + and( + bound("l1", "3", null, true, false, null, StringComparators.NUMERIC), + selector("f1", useDefault ? "0.0" : null, null) + + ) + ) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(TIMESERIES_CONTEXT_DEFAULT) + .build() + ) : ImmutableList.of(), + useDefault ? ImmutableList.of() : ImmutableList.of(new Object[]{0L}), + ImmutableList.of(new SqlParameter(SqlType.BIGINT, 3L), new SqlParameter(SqlType.VARCHAR, "wat")) + ); + } + + @Test + public void testNullParameter() throws Exception + { + // contrived example of using null as an sql parameter to at least test the codepath because lots of things dont + // actually work as null and things like 'IS NULL' fail to parse in calcite if expressed as 'IS ?' + cannotVectorize(); + + // this will optimize out the 3rd argument because 2nd argument will be constant and not null + testQuery( + "SELECT COALESCE(dim2, ?, ?), COUNT(*) FROM druid.foo GROUP BY 1\n", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "case_searched(notnull(\"dim2\"),\"dim2\",'parameter')", + ValueType.STRING + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "v0", ValueType.STRING))) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.replaceWithDefault() ? + ImmutableList.of( + new Object[]{"a", 2L}, + new Object[]{"abc", 1L}, + new Object[]{"parameter", 3L} + ) : + ImmutableList.of( + new Object[]{"", 1L}, + new Object[]{"a", 2L}, + new Object[]{"abc", 1L}, + new Object[]{"parameter", 2L} + ), + ImmutableList.of(new SqlParameter(SqlType.VARCHAR, "parameter"), new SqlParameter(SqlType.VARCHAR, null)) + ); + + // when converting to rel expression, this will optimize out 2nd argument to coalesce which is null + testQuery( + "SELECT COALESCE(dim2, ?, ?), COUNT(*) FROM druid.foo GROUP BY 1\n", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn( + "v0", + "case_searched(notnull(\"dim2\"),\"dim2\",'parameter')", + ValueType.STRING + ) + ) + .setDimensions(dimensions(new DefaultDimensionSpec("v0", "v0", ValueType.STRING))) + .setAggregatorSpecs(aggregators(new CountAggregatorFactory("a0"))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.replaceWithDefault() ? + ImmutableList.of( + new Object[]{"a", 2L}, + new Object[]{"abc", 1L}, + new Object[]{"parameter", 3L} + ) : + ImmutableList.of( + new Object[]{"", 1L}, + new Object[]{"a", 2L}, + new Object[]{"abc", 1L}, + new Object[]{"parameter", 2L} + ), + ImmutableList.of(new SqlParameter(SqlType.VARCHAR, null), new SqlParameter(SqlType.VARCHAR, "parameter")) + ); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java index 254a7b2ee7cf..6fffe8028d80 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java @@ -19,6 +19,7 @@ package org.apache.druid.sql.calcite.expression; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.rel.type.RelDataType; @@ -55,6 +56,7 @@ class ExpressionTestHelper CalciteTests.createExprMacroTable(), new PlannerConfig(), ImmutableMap.of(), + ImmutableList.of(), CalciteTests.REGULAR_USER_AUTH_RESULT ); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/http/SqlQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/http/SqlQueryTest.java index aa85c70bb6e8..65275f0076a5 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/http/SqlQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/http/SqlQueryTest.java @@ -20,10 +20,14 @@ package org.apache.druid.sql.calcite.http; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.calcite.avatica.SqlType; import org.apache.druid.segment.TestHelper; import org.apache.druid.sql.calcite.util.CalciteTestBase; import org.apache.druid.sql.http.ResultFormat; +import org.apache.druid.sql.http.SqlParameter; import org.apache.druid.sql.http.SqlQuery; import org.junit.Assert; import org.junit.Test; @@ -34,7 +38,20 @@ public class SqlQueryTest extends CalciteTestBase public void testSerde() throws Exception { final ObjectMapper jsonMapper = TestHelper.makeJsonMapper(); - final SqlQuery query = new SqlQuery("SELECT 1", ResultFormat.ARRAY, true, ImmutableMap.of("useCache", false)); + final SqlQuery query = new SqlQuery( + "SELECT ?", + ResultFormat.ARRAY, + true, + ImmutableMap.of("useCache", false), + ImmutableList.of(new SqlParameter(SqlType.INTEGER, 1)) + ); Assert.assertEquals(query, jsonMapper.readValue(jsonMapper.writeValueAsString(query), SqlQuery.class)); } + + @Test + public void testEquals() + { + EqualsVerifier.forClass(SqlQuery.class).withNonnullFields("query").usingGetClass().verify(); + EqualsVerifier.forClass(SqlParameter.class).withNonnullFields("type").usingGetClass().verify(); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTestBase.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTestBase.java index 8c99ee52ee37..f8f530726620 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTestBase.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTestBase.java @@ -19,12 +19,18 @@ package org.apache.druid.sql.calcite.util; +import com.google.common.collect.ImmutableList; import org.apache.druid.common.config.NullHandling; import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.http.SqlParameter; import org.junit.BeforeClass; +import java.util.List; + public abstract class CalciteTestBase { + public static final List DEFAULT_PARAMETERS = ImmutableList.of(); + @BeforeClass public static void setupCalciteProperties() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java index 3113a11a5dce..132df255d689 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import com.google.common.util.concurrent.ListenableFuture; import com.google.inject.Guice; import com.google.inject.Injector; @@ -44,6 +45,7 @@ import org.apache.druid.data.input.impl.MapInputRowParser; import org.apache.druid.data.input.impl.TimeAndDimsParseSpec; import org.apache.druid.data.input.impl.TimestampSpec; +import org.apache.druid.discovery.DiscoveryDruidNode; import org.apache.druid.discovery.DruidLeaderClient; import org.apache.druid.discovery.DruidNodeDiscovery; import org.apache.druid.discovery.DruidNodeDiscoveryProvider; @@ -865,9 +867,17 @@ public static SystemSchema createMockSystemSchema( final AuthorizerMapper authorizerMapper ) { + + final DruidNode coordinatorNode = new DruidNode("test", "dummy", false, 8080, null, true, false); + FakeDruidNodeDiscoveryProvider provider = new FakeDruidNodeDiscoveryProvider( + ImmutableMap.of( + NodeRole.COORDINATOR, new FakeDruidNodeDiscovery(ImmutableMap.of(NodeRole.COORDINATOR, coordinatorNode)) + ) + ); + final DruidLeaderClient druidLeaderClient = new DruidLeaderClient( new FakeHttpClient(), - new FakeDruidNodeDiscoveryProvider(), + provider, NodeRole.COORDINATOR, "/simple/leader", () -> { @@ -889,7 +899,7 @@ public static SystemSchema createMockSystemSchema( authorizerMapper, druidLeaderClient, druidLeaderClient, - new FakeDruidNodeDiscoveryProvider(), + provider, getJsonMapper() ); } @@ -1019,19 +1029,67 @@ public ListenableFuture go( */ private static class FakeDruidNodeDiscoveryProvider extends DruidNodeDiscoveryProvider { + private final Map nodeDiscoveries; + + public FakeDruidNodeDiscoveryProvider(Map nodeDiscoveries) + { + this.nodeDiscoveries = nodeDiscoveries; + } + @Override public BooleanSupplier getForNode(DruidNode node, NodeRole nodeRole) { - throw new UnsupportedOperationException(); + boolean get = nodeDiscoveries.getOrDefault(nodeRole, new FakeDruidNodeDiscovery()) + .getAllNodes() + .stream() + .anyMatch(x -> x.getDruidNode().equals(node)); + return () -> get; } @Override public DruidNodeDiscovery getForNodeRole(NodeRole nodeRole) { - throw new UnsupportedOperationException(); + return nodeDiscoveries.getOrDefault(nodeRole, new FakeDruidNodeDiscovery()); } } + private static class FakeDruidNodeDiscovery implements DruidNodeDiscovery + { + private final Set nodes; + + FakeDruidNodeDiscovery() + { + this.nodes = new HashSet<>(); + } + + FakeDruidNodeDiscovery(Map nodes) + { + this.nodes = Sets.newHashSetWithExpectedSize(nodes.size()); + nodes.forEach((k, v) -> { + addNode(v, k); + }); + } + + @Override + public Collection getAllNodes() + { + return nodes; + } + + void addNode(DruidNode node, NodeRole role) + { + final DiscoveryDruidNode discoveryNode = new DiscoveryDruidNode(node, role, ImmutableMap.of()); + this.nodes.add(discoveryNode); + } + + @Override + public void registerListener(Listener listener) + { + + } + } + + /** * A fake {@link ServerInventoryView} for {@link #createMockSystemSchema}. */ diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java index a9201eb52121..ce925902a575 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import org.apache.calcite.avatica.SqlType; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.tools.ValidationException; import org.apache.druid.common.config.NullHandling; @@ -189,7 +190,7 @@ public void testUnauthorized() throws Exception try { resource.doPost( - new SqlQuery("select count(*) from forbiddenDatasource", null, false, null), + new SqlQuery("select count(*) from forbiddenDatasource", null, false, null, null), testRequest ); Assert.fail("doPost did not throw ForbiddenException for an unauthorized query"); @@ -204,7 +205,7 @@ public void testUnauthorized() throws Exception public void testCountStar() throws Exception { final List> rows = doPost( - new SqlQuery("SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo", null, false, null) + new SqlQuery("SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo", null, false, null, null) ).rhs; Assert.assertEquals( @@ -224,6 +225,7 @@ public void testTimestampsInResponse() throws Exception "SELECT __time, CAST(__time AS DATE) AS t2 FROM druid.foo LIMIT 1", ResultFormat.OBJECT, false, + null, null ) ).rhs; @@ -236,6 +238,27 @@ public void testTimestampsInResponse() throws Exception ); } + @Test + public void testTimestampsInResponseWithParameterizedLimit() throws Exception + { + final List> rows = doPost( + new SqlQuery( + "SELECT __time, CAST(__time AS DATE) AS t2 FROM druid.foo LIMIT ?", + ResultFormat.OBJECT, + false, + null, + ImmutableList.of(new SqlParameter(SqlType.INTEGER, 1)) + ) + ).rhs; + + Assert.assertEquals( + ImmutableList.of( + ImmutableMap.of("__time", "2000-01-01T00:00:00.000Z", "t2", "2000-01-01T00:00:00.000Z") + ), + rows + ); + } + @Test public void testTimestampsInResponseLosAngelesTimeZone() throws Exception { @@ -244,7 +267,8 @@ public void testTimestampsInResponseLosAngelesTimeZone() throws Exception "SELECT __time, CAST(__time AS DATE) AS t2 FROM druid.foo LIMIT 1", ResultFormat.OBJECT, false, - ImmutableMap.of(PlannerContext.CTX_SQL_TIME_ZONE, "America/Los_Angeles") + ImmutableMap.of(PlannerContext.CTX_SQL_TIME_ZONE, "America/Los_Angeles"), + null ) ).rhs; @@ -260,7 +284,7 @@ public void testTimestampsInResponseLosAngelesTimeZone() throws Exception public void testFieldAliasingSelect() throws Exception { final List> rows = doPost( - new SqlQuery("SELECT dim2 \"x\", dim2 \"y\" FROM druid.foo LIMIT 1", ResultFormat.OBJECT, false, null) + new SqlQuery("SELECT dim2 \"x\", dim2 \"y\" FROM druid.foo LIMIT 1", ResultFormat.OBJECT, false, null, null) ).rhs; Assert.assertEquals( @@ -275,7 +299,7 @@ public void testFieldAliasingSelect() throws Exception public void testFieldAliasingGroupBy() throws Exception { final List> rows = doPost( - new SqlQuery("SELECT dim2 \"x\", dim2 \"y\" FROM druid.foo GROUP BY dim2", ResultFormat.OBJECT, false, null) + new SqlQuery("SELECT dim2 \"x\", dim2 \"y\" FROM druid.foo GROUP BY dim2", ResultFormat.OBJECT, false, null, null) ).rhs; Assert.assertEquals( @@ -327,7 +351,7 @@ public void testArrayResultFormat() throws Exception nullStr ) ), - doPost(new SqlQuery(query, ResultFormat.ARRAY, false, null), new TypeReference>>() {}).rhs + doPost(new SqlQuery(query, ResultFormat.ARRAY, false, null, null), new TypeReference>>() {}).rhs ); } @@ -363,7 +387,7 @@ public void testArrayResultFormatWithHeader() throws Exception nullStr ) ), - doPost(new SqlQuery(query, ResultFormat.ARRAY, true, null), new TypeReference>>() {}).rhs + doPost(new SqlQuery(query, ResultFormat.ARRAY, true, null, null), new TypeReference>>() {}).rhs ); } @@ -371,7 +395,7 @@ public void testArrayResultFormatWithHeader() throws Exception public void testArrayLinesResultFormat() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.ARRAYLINES, false, null)).rhs; + final String response = doPostRaw(new SqlQuery(query, ResultFormat.ARRAYLINES, false, null, null)).rhs; final String nullStr = NullHandling.replaceWithDefault() ? "" : null; final List lines = Splitter.on('\n').splitToList(response); @@ -412,7 +436,7 @@ public void testArrayLinesResultFormat() throws Exception public void testArrayLinesResultFormatWithHeader() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.ARRAYLINES, true, null)).rhs; + final String response = doPostRaw(new SqlQuery(query, ResultFormat.ARRAYLINES, true, null, null)).rhs; final String nullStr = NullHandling.replaceWithDefault() ? "" : null; final List lines = Splitter.on('\n').splitToList(response); @@ -493,7 +517,7 @@ public void testObjectResultFormat() throws Exception .build() ).stream().map(transformer).collect(Collectors.toList()), doPost( - new SqlQuery(query, ResultFormat.OBJECT, false, null), + new SqlQuery(query, ResultFormat.OBJECT, false, null, null), new TypeReference>>() {} ).rhs ); @@ -503,7 +527,7 @@ public void testObjectResultFormat() throws Exception public void testObjectLinesResultFormat() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.OBJECTLINES, false, null)).rhs; + final String response = doPostRaw(new SqlQuery(query, ResultFormat.OBJECTLINES, false, null, null)).rhs; final String nullStr = NullHandling.replaceWithDefault() ? "" : null; final Function, Map> transformer = m -> { return Maps.transformEntries( @@ -556,7 +580,7 @@ public void testObjectLinesResultFormat() throws Exception public void testCsvResultFormat() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.CSV, false, null)).rhs; + final String response = doPostRaw(new SqlQuery(query, ResultFormat.CSV, false, null, null)).rhs; final List lines = Splitter.on('\n').splitToList(response); Assert.assertEquals( @@ -574,7 +598,7 @@ public void testCsvResultFormat() throws Exception public void testCsvResultFormatWithHeaders() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.CSV, true, null)).rhs; + final String response = doPostRaw(new SqlQuery(query, ResultFormat.CSV, true, null, null)).rhs; final List lines = Splitter.on('\n').splitToList(response); Assert.assertEquals( @@ -594,7 +618,7 @@ public void testExplainCountStar() throws Exception { Map queryContext = ImmutableMap.of(PlannerContext.CTX_SQL_QUERY_ID, DUMMY_SQL_QUERY_ID); final List> rows = doPost( - new SqlQuery("EXPLAIN PLAN FOR SELECT COUNT(*) AS cnt FROM druid.foo", ResultFormat.OBJECT, false, queryContext) + new SqlQuery("EXPLAIN PLAN FOR SELECT COUNT(*) AS cnt FROM druid.foo", ResultFormat.OBJECT, false, queryContext, null) ).rhs; Assert.assertEquals( @@ -619,6 +643,7 @@ public void testCannotValidate() throws Exception "SELECT dim4 FROM druid.foo", ResultFormat.OBJECT, false, + null, null ) ).lhs; @@ -635,7 +660,7 @@ public void testCannotConvert() throws Exception { // SELECT + ORDER unsupported final QueryInterruptedException exception = doPost( - new SqlQuery("SELECT dim1 FROM druid.foo ORDER BY dim1", ResultFormat.OBJECT, false, null) + new SqlQuery("SELECT dim1 FROM druid.foo ORDER BY dim1", ResultFormat.OBJECT, false, null, null) ).lhs; Assert.assertNotNull(exception); @@ -656,7 +681,8 @@ public void testResourceLimitExceeded() throws Exception "SELECT DISTINCT dim1 FROM foo", ResultFormat.OBJECT, false, - ImmutableMap.of("maxMergingDictionarySize", 1) + ImmutableMap.of("maxMergingDictionarySize", 1), + null ) ).lhs; diff --git a/website/.spelling b/website/.spelling index cb09e12d5460..6b5824d79356 100644 --- a/website/.spelling +++ b/website/.spelling @@ -156,6 +156,7 @@ SSL Samza Splunk SqlFirehose +SqlParameter StatsD TCP TGT