From 1dd50ab10c92c0fcb69b82ab688176124ebb4982 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 9 Jun 2025 12:01:19 -0700 Subject: [PATCH 1/2] SQL: Parse queries immediately for non-JDBC endpoints. SET statements (#17894) provide a new way to set query context. However, they are not available early enough to inform important operations like engine selection and planner rule construction. This patch moves parsing as early as possible: immediately on receiving the query from the user. As a consequence of moving parsing prior to statement creation, with this change, SQL queries that cannot be parsed are no longer logged. --- .../sql/resources/SqlStatementResource.java | 25 +- .../msq/sql/resources/SqlTaskResource.java | 11 +- .../msq/exec/ResultsContextSerdeTest.java | 5 +- .../apache/druid/msq/test/MSQTestBase.java | 12 +- .../druid/server/QueryResultPusher.java | 41 ++- .../org/apache/druid/sql/DirectStatement.java | 4 +- .../org/apache/druid/sql/HttpStatement.java | 11 +- .../apache/druid/sql/PreparedStatement.java | 10 +- .../org/apache/druid/sql/SqlQueryPlus.java | 112 +++++-- .../apache/druid/sql/SqlStatementFactory.java | 6 +- .../druid/sql/avatica/DruidJdbcStatement.java | 2 +- .../apache/druid/sql/avatica/DruidMeta.java | 13 +- .../sql/calcite/parser/DruidSqlParser.java | 305 ++++++++++++++++++ .../parser/StatementAndSetContext.java | 76 +++++ .../sql/calcite/planner/CalcitePlanner.java | 20 ++ .../druid/sql/calcite/planner/Calcites.java | 4 +- .../sql/calcite/planner/DruidPlanner.java | 220 +------------ .../sql/calcite/planner/PlannerContext.java | 11 + .../sql/calcite/planner/PlannerFactory.java | 44 +-- .../sql/calcite/view/DruidViewMacro.java | 5 +- .../druid/sql/http/SqlEngineRegistry.java | 4 +- .../apache/druid/sql/http/SqlResource.java | 89 ++++- .../apache/druid/sql/SqlQueryPlusTest.java | 72 +++++ .../apache/druid/sql/SqlStatementTest.java | 82 +---- .../druid/sql/avatica/DruidStatementTest.java | 143 ++++---- .../druid/sql/calcite/CalciteQueryTest.java | 48 +++ .../druid/sql/calcite/QueryTestRunner.java | 13 +- .../expression/ExpressionTestHelper.java | 1 + .../external/ExternalTableScanRuleTest.java | 4 +- .../calcite/parser/DruidSqlParserTest.java | 143 ++++++++ .../planner/CalcitePlannerModuleTest.java | 12 +- .../calcite/planner/DruidRexExecutorTest.java | 1 + .../sql/calcite/util/QueryFrameworkUtils.java | 8 +- .../druid/sql/http/SqlResourceTest.java | 52 ++- 34 files changed, 1082 insertions(+), 527 deletions(-) create mode 100644 sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParser.java create mode 100644 sql/src/main/java/org/apache/druid/sql/calcite/parser/StatementAndSetContext.java create mode 100644 sql/src/test/java/org/apache/druid/sql/SqlQueryPlusTest.java create mode 100644 sql/src/test/java/org/apache/druid/sql/calcite/parser/DruidSqlParserTest.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java index a45d77949228..99b16e0d2ecb 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java @@ -85,6 +85,7 @@ import org.apache.druid.server.security.ResourceAction; import org.apache.druid.sql.DirectStatement; import org.apache.druid.sql.HttpStatement; +import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlRowTransformer; import org.apache.druid.sql.SqlStatementFactory; import org.apache.druid.sql.http.ResultFormat; @@ -174,17 +175,29 @@ public Response doPost(@Context final HttpServletRequest req, } @VisibleForTesting - Response doPost(final SqlQuery sqlQuery, - final HttpServletRequest req) + Response doPost( + SqlQuery sqlQuery, // Not final: reassigned using createModifiedSqlQuery + final HttpServletRequest req + ) { - SqlQuery modifiedQuery = createModifiedSqlQuery(sqlQuery); + final SqlQueryPlus sqlQueryPlus; + final HttpStatement stmt; + final QueryContext queryContext; + + try { + sqlQuery = createModifiedSqlQuery(sqlQuery); + sqlQueryPlus = SqlResource.makeSqlQueryPlus(sqlQuery, req); + queryContext = QueryContext.of(sqlQueryPlus.context()); + stmt = msqSqlStatementFactory.httpStatement(SqlResource.makeSqlQueryPlus(sqlQuery, req), req); + } + catch (Exception e) { + return SqlResource.handleExceptionBeforeStatementCreated(e, sqlQuery.queryContext()); + } - final HttpStatement stmt = msqSqlStatementFactory.httpStatement(modifiedQuery, req); final String sqlQueryId = stmt.sqlQueryId(); final String currThreadName = Thread.currentThread().getName(); boolean isDebug = false; try { - QueryContext queryContext = QueryContext.of(modifiedQuery.getContext()); isDebug = queryContext.isDebug(); contextChecks(queryContext); @@ -202,7 +215,7 @@ Response doPost(final SqlQuery sqlQuery, return buildTaskResponse(sequence, stmt.query().authResult()); } else { // Used for EXPLAIN - return buildStandardResponse(sequence, modifiedQuery, sqlQueryId, rowTransformer); + return buildStandardResponse(sequence, sqlQuery, sqlQueryId, rowTransformer); } } catch (DruidException e) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java index 11e39bfa1267..6f2f6123804d 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java @@ -43,6 +43,7 @@ import org.apache.druid.server.security.ForbiddenException; import org.apache.druid.sql.DirectStatement; import org.apache.druid.sql.HttpStatement; +import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlRowTransformer; import org.apache.druid.sql.SqlStatementFactory; import org.apache.druid.sql.http.ResultFormat; @@ -127,7 +128,15 @@ public Response doPost( ) { // Queries run as MSQ tasks look like regular queries, but return the task ID as their only output. - final HttpStatement stmt = sqlStatementFactory.httpStatement(sqlQuery, req); + final SqlQueryPlus sqlQueryPlus; + final HttpStatement stmt; + try { + sqlQueryPlus = SqlResource.makeSqlQueryPlus(sqlQuery, req); + stmt = sqlStatementFactory.httpStatement(sqlQueryPlus, req); + } catch (Exception e) { + return SqlResource.handleExceptionBeforeStatementCreated(e, sqlQuery.queryContext()); + } + final String sqlQueryId = stmt.sqlQueryId(); try { final DirectStatement.ResultSet plan = stmt.plan(); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java index 084bf8e92fff..3a6b23fe8701 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java @@ -31,6 +31,7 @@ import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.server.security.AuthConfig; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import org.apache.druid.sql.calcite.planner.CalciteRulesManager; import org.apache.druid.sql.calcite.planner.CatalogResolver; import org.apache.druid.sql.calcite.planner.PlannerConfig; @@ -86,9 +87,11 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) EasyMock.createMock(QueryRunnerFactoryConglomerate.class) ); + final String sql = "SELECT 1"; PlannerContext plannerContext = PlannerContext.create( toolbox, - "DUMMY", + sql, + DruidSqlParser.parse(sql, false).getMainStatement(), engine, Collections.emptyMap(), null diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 2495251dec06..c05b6c384435 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -810,12 +810,12 @@ private String runMultiStageQuery( ) { final DirectStatement stmt = sqlStatementFactory.directStatement( - new SqlQueryPlus( - query, - context, - parameters, - authenticationResult - ) + SqlQueryPlus.builder() + .sql(query) + .context(context) + .parameters(parameters) + .auth(authenticationResult) + .build() ); final List sequence = stmt.execute().getResults().toList(); diff --git a/server/src/main/java/org/apache/druid/server/QueryResultPusher.java b/server/src/main/java/org/apache/druid/server/QueryResultPusher.java index b4a57bea8e1c..4208f984cee1 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResultPusher.java +++ b/server/src/main/java/org/apache/druid/server/QueryResultPusher.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import com.google.common.io.CountingOutputStream; import org.apache.druid.client.DirectDruidClient; import org.apache.druid.error.DruidException; @@ -270,17 +271,14 @@ private Response handleDruidException(ResultsWriter resultsWriter, DruidExceptio } if (response == null) { - final Response.ResponseBuilder bob = Response - .status(e.getStatusCode()) - .type(contentType) - .entity(new ErrorResponse(e)); - - bob.header(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId); - for (Map.Entry entry : extraHeaders.entrySet()) { - bob.header(entry.getKey(), entry.getValue()); - } - - return bob.build(); + return handleDruidExceptionBeforeResponseStarted( + e, + contentType, + ImmutableMap.builder() + .putAll(extraHeaders) + .put(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId) + .build() + ); } else { if (response.isCommitted()) { QueryResource.NO_STACK_LOGGER.warn(e, "Response was committed without the accumulator writing anything!?"); @@ -302,6 +300,27 @@ private Response handleDruidException(ResultsWriter resultsWriter, DruidExceptio } } + /** + * Generates a response for a {@link DruidException} that occurs prior to any query results being sent out. + */ + public static Response handleDruidExceptionBeforeResponseStarted( + final DruidException e, + final MediaType contentType, + final Map extraHeaders + ) + { + final Response.ResponseBuilder bob = Response + .status(e.getStatusCode()) + .type(contentType) + .entity(new ErrorResponse(e)); + + for (Map.Entry entry : extraHeaders.entrySet()) { + bob.header(entry.getKey(), entry.getValue()); + } + + return bob.build(); + } + public interface ResultsWriter extends Closeable { /** diff --git a/sql/src/main/java/org/apache/druid/sql/DirectStatement.java b/sql/src/main/java/org/apache/druid/sql/DirectStatement.java index f4f36c28b73f..010b326a94ea 100644 --- a/sql/src/main/java/org/apache/druid/sql/DirectStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/DirectStatement.java @@ -241,9 +241,9 @@ protected DruidPlanner createPlanner() return sqlToolbox.plannerFactory.createPlanner( sqlToolbox.engine, queryPlus.sql(), + queryPlus.sqlNode(), queryContext, - hook, - false + hook ); } diff --git a/sql/src/main/java/org/apache/druid/sql/HttpStatement.java b/sql/src/main/java/org/apache/druid/sql/HttpStatement.java index 8094f1f3890e..0a3011d4de15 100644 --- a/sql/src/main/java/org/apache/druid/sql/HttpStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/HttpStatement.java @@ -23,7 +23,6 @@ import org.apache.druid.server.security.AuthorizationUtils; import org.apache.druid.server.security.ResourceAction; import org.apache.druid.sql.calcite.planner.DruidPlanner; -import org.apache.druid.sql.http.SqlQuery; import javax.servlet.http.HttpServletRequest; import java.util.Set; @@ -45,15 +44,13 @@ public class HttpStatement extends DirectStatement public HttpStatement( final SqlToolbox lifecycleToolbox, - final SqlQuery sqlQuery, + final SqlQueryPlus sqlQueryPlus, final HttpServletRequest req ) { super( lifecycleToolbox, - SqlQueryPlus.builder(sqlQuery) - .auth(AuthorizationUtils.authenticationResultFromRequest(req)) - .build(), + sqlQueryPlus, req.getRemoteAddr() ); this.req = req; @@ -65,9 +62,9 @@ protected DruidPlanner createPlanner() return sqlToolbox.plannerFactory.createPlanner( sqlToolbox.engine, queryPlus.sql(), + queryPlus.sqlNode(), queryContext, - hook, - true + hook ); } diff --git a/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java b/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java index 7c8ab63d7c7f..505ec2f5f86a 100644 --- a/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java @@ -31,7 +31,6 @@ public class PreparedStatement extends AbstractStatement { private final SqlQueryPlus originalRequest; - private PrepareResult prepareResult; public PreparedStatement( final SqlToolbox lifecycleToolbox, @@ -70,8 +69,7 @@ public PrepareResult prepare() authorize(planner, authorizer()); // Do the prepare step. - this.prepareResult = planner.prepare(); - return prepareResult; + return planner.prepare(); } catch (RuntimeException e) { reporter.failed(e); @@ -92,7 +90,7 @@ public DirectStatement execute(List parameters) { return new DirectStatement( sqlToolbox, - originalRequest.withParameters(parameters) + originalRequest.freshCopy().withParameters(parameters) ); } @@ -101,9 +99,9 @@ protected DruidPlanner getPlanner() return sqlToolbox.plannerFactory.createPlanner( sqlToolbox.engine, queryPlus.sql(), + queryPlus.freshCopy().sqlNode(), queryContext, - hook, - false + hook ); } } diff --git a/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java b/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java index 8d06a65fa35c..38d8e64c8423 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java @@ -21,25 +21,37 @@ import com.google.common.base.Preconditions; import org.apache.calcite.avatica.remote.TypedValue; +import org.apache.calcite.sql.SqlNode; +import org.apache.druid.error.DruidException; +import org.apache.druid.query.QueryContexts; import org.apache.druid.server.security.AuthenticationResult; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; +import org.apache.druid.sql.calcite.parser.StatementAndSetContext; +import org.apache.druid.sql.calcite.planner.CalcitePlanner; import org.apache.druid.sql.http.SqlParameter; import org.apache.druid.sql.http.SqlQuery; +import javax.annotation.Nullable; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; /** - * Captures the inputs to a SQL execution request: the statement,the context, - * parameters, and the authorization result. Pass this around rather than the - * quad of items. The request can evolve: the context and parameters can be - * filled in later as needed. + * Captures the inputs to a SQL execution request: the statement (as a string), + * the parsed statement, the context, parameters, and the authorization result. + * The request can evolve: the context and parameters can be filled in later + * as needed. *

* SQL requests come from a variety of sources in a variety of formats. Use * the {@link Builder} class to create an instance from the information * available at each point in the code. *

+ * Each instance of SqlQueryPlus can only be used once, because {@link #sqlNode} + * is a mutable data structure, modified during {@link CalcitePlanner#validate}. + * If you need to use one again, call {@link #freshCopy()} to create a fresh + * copy with a new {@link SqlNode}. + *

* The query context has a complex lifecycle. The copy here is immutable: * it is the set of values which the user requested. Planning will * add (and sometimes remove) values: that work should be done on a copy of the @@ -50,24 +62,31 @@ public class SqlQueryPlus { private final String sql; + @Nullable + private final SqlNode sqlNode; + private boolean allowSetStatements; private final Map queryContext; private final List parameters; private final AuthenticationResult authResult; - public SqlQueryPlus( + private SqlQueryPlus( String sql, + SqlNode sqlNode, + boolean allowSetStatements, Map queryContext, List parameters, AuthenticationResult authResult ) { this.sql = Preconditions.checkNotNull(sql); + this.sqlNode = sqlNode; + this.allowSetStatements = allowSetStatements; this.queryContext = queryContext == null - ? Collections.emptyMap() - : Collections.unmodifiableMap(new HashMap<>(queryContext)); + ? Collections.emptyMap() + : Collections.unmodifiableMap(new HashMap<>(queryContext)); this.parameters = parameters == null - ? Collections.emptyList() - : parameters; + ? Collections.emptyList() + : parameters; this.authResult = Preconditions.checkNotNull(authResult); } @@ -81,14 +100,18 @@ public static Builder builder(String sql) return new Builder().sql(sql); } - public static Builder builder(SqlQuery sqlQuery) + public String sql() { - return new Builder().query(sqlQuery); + return sql; } - public String sql() + public SqlNode sqlNode() { - return sql; + if (sqlNode == null) { + throw DruidException.defensive("sqlNode not set"); + } + + return sqlNode; } public Map context() @@ -108,19 +131,40 @@ public AuthenticationResult authResult() public SqlQueryPlus withContext(Map context) { - return new SqlQueryPlus(sql, context, parameters, authResult); + return new SqlQueryPlus(sql, sqlNode, allowSetStatements, context, parameters, authResult); } public SqlQueryPlus withParameters(List parameters) { - return new SqlQueryPlus(sql, queryContext, parameters, authResult); + return new SqlQueryPlus(sql, sqlNode, allowSetStatements, queryContext, parameters, authResult); + } + + /** + * Returns a copy of this instance where everything is shared, except the {@link #sqlNode}, which is re-parsed from + * the SQL statement. + */ + public SqlQueryPlus freshCopy() + { + return new SqlQueryPlus( + sql, + DruidSqlParser.parse(sql, allowSetStatements).getMainStatement(), + allowSetStatements, + queryContext, + parameters, + authResult + ); } @Override public String toString() { - return "SqlQueryPlus {queryContext=" + queryContext + ", parameters=" + parameters - + ", authResult=" + authResult + ", sql=" + sql + " }"; + return "SqlQueryPlus{" + + "sql='" + sql + '\'' + + ", sqlNode=" + sqlNode + + ", queryContext=" + queryContext + + ", parameters=" + parameters + + ", authResult=" + authResult + + '}'; } public static class Builder @@ -136,14 +180,6 @@ public Builder sql(String sql) return this; } - public Builder query(SqlQuery sqlQuery) - { - this.sql = sqlQuery.getQuery(); - this.queryContext = sqlQuery.getContext(); - this.parameters = sqlQuery.getParameterList(); - return this; - } - public Builder context(Map queryContext) { this.queryContext = queryContext; @@ -168,14 +204,38 @@ public Builder auth(final AuthenticationResult authResult) return this; } + /** + * Parses the provided {@link #sql} and builds a {@link SqlQueryPlus} with SET statements folded into the + * context, and with the parsed SQL in {@link #sqlNode}. + * + * When using this method, the {@link #sqlNode()} must only be run through validation once. (The validator + * mutates the {@link SqlNode}). + */ public SqlQueryPlus build() { + final StatementAndSetContext statementAndSetContext = DruidSqlParser.parse(sql, true); return new SqlQueryPlus( sql, - queryContext, + statementAndSetContext.getMainStatement(), + true, + statementAndSetContext.getSetContext().isEmpty() + ? queryContext + : QueryContexts.override(queryContext, statementAndSetContext.getSetContext()), parameters, authResult ); } + + /** + * Builds a {@link SqlQueryPlus} with no {@link SqlNode} and with {@link #allowSetStatements} set to false. + * This is done for JDBC becauase it can runs each {@link SqlQueryPlus} multiple times, and it needs to keep + * re-parsing and re-validating the query on each run. + * + * When using this method, you must create a copy with {@link #freshCopy()} prior to calling {@link #sqlNode()}. + */ + public SqlQueryPlus buildJdbc() + { + return new SqlQueryPlus(sql, null, false, queryContext, parameters, authResult); + } } } diff --git a/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java b/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java index c8450d621661..9d31441befb1 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java @@ -19,8 +19,6 @@ package org.apache.druid.sql; -import org.apache.druid.sql.http.SqlQuery; - import javax.servlet.http.HttpServletRequest; /** @@ -44,11 +42,11 @@ public SqlStatementFactory(SqlToolbox lifecycleToolbox) } public HttpStatement httpStatement( - final SqlQuery sqlQuery, + final SqlQueryPlus sqlQueryPlus, final HttpServletRequest req ) { - return new HttpStatement(lifecycleToolbox, sqlQuery, req); + return new HttpStatement(lifecycleToolbox, sqlQueryPlus, req); } public DirectStatement directStatement(final SqlQueryPlus sqlRequest) diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java index 006a46a8f8ef..966067bb6741 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java @@ -57,7 +57,7 @@ public DruidJdbcStatement( public synchronized void execute(SqlQueryPlus queryPlus, long maxRowCount) { closeResultSet(); - this.sqlQuery = queryPlus.withContext(queryContext); + this.sqlQuery = queryPlus.withContext(queryContext).freshCopy(); DirectStatement stmt = lifecycleFactory.directStatement(this.sqlQuery); resultSet = new DruidJdbcResultSet(this, stmt, Long.MAX_VALUE, fetcherFactory); try { diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java index e271a6f60ab5..f60b7d4ef854 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java @@ -271,12 +271,11 @@ public StatementHandle prepare( { try { final DruidConnection druidConnection = getDruidConnection(ch.id); - final SqlQueryPlus sqlReq = new SqlQueryPlus( - sql, - druidConnection.sessionContext(), - null, // No parameters in this path - doAuthenticate(druidConnection) - ); + final SqlQueryPlus sqlReq = SqlQueryPlus.builder() + .sql(sql) + .context(druidConnection.sessionContext()) + .auth(doAuthenticate(druidConnection)) + .buildJdbc(); final DruidJdbcPreparedStatement stmt = getDruidConnection(ch.id).createPreparedStatement( sqlStatementFactory, sqlReq, @@ -351,7 +350,7 @@ public ExecuteResult prepareAndExecute( final AuthenticationResult authenticationResult = doAuthenticate(druidConnection); final SqlQueryPlus sqlRequest = SqlQueryPlus.builder(sql) .auth(authenticationResult) - .build(); + .buildJdbc(); druidStatement.execute(sqlRequest, maxRowCount); final ExecuteResult result = doFetch(druidStatement, maxRowsInFirstFrame); LOG.debug("Successfully prepared statement [%s] and started execution", druidStatement.getStatementId()); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParser.java b/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParser.java new file mode 100644 index 000000000000..fe0f0a681c4d --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParser.java @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.parser; + +import com.google.common.base.Joiner; +import org.apache.calcite.avatica.util.Casing; +import org.apache.calcite.avatica.util.Quoting; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlSetOption; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.dialect.CalciteSqlDialect; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.NlsString; +import org.apache.calcite.util.SourceStringReader; +import org.apache.druid.error.DruidException; +import org.apache.druid.error.InvalidSqlInput; +import org.apache.druid.query.QueryContext; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.DruidConformance; +import org.apache.druid.sql.calcite.planner.PlannerContext; + +import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Contains the utility function {@link #parse(String, boolean)}, for parsing Druid SQL statements. + */ +public class DruidSqlParser +{ + public static final SqlParser.Config PARSER_CONFIG = SqlParser + .config() + .withCaseSensitive(true) + .withUnquotedCasing(Casing.UNCHANGED) + .withQuotedCasing(Casing.UNCHANGED) + .withQuoting(Quoting.DOUBLE_QUOTE) + .withConformance(DruidConformance.instance()) + .withParserFactory(new DruidSqlParserImplFactory()); + + private static final Joiner SPACE_JOINER = Joiner.on(" "); + private static final Joiner COMMA_JOINER = Joiner.on(", "); + + private DruidSqlParser() + { + // No instantiation. + } + + public static StatementAndSetContext parse(final String sql, final boolean allowSetStatements) + { + try { + SqlParser parser = SqlParser.create(new SourceStringReader(sql), PARSER_CONFIG); + SqlNode sqlNode = parser.parseStmtList(); + return processStatementList(sqlNode, allowSetStatements); + } + catch (SqlParseException e) { + throw translateParseException(e); + } + } + + /** + * If an {@link SqlNode} is a {@link SqlNodeList}, it must consist of 0 or more {@link SqlSetOption} followed by a + * single {@link SqlNode} which is NOT a {@link SqlSetOption}. All {@link SqlSetOption} will be converted into a + * context parameters {@link Map} and added to the {@link PlannerContext} with + * {@link PlannerContext#addAllToQueryContext(Map)}. The final {@link SqlNode} of the {@link SqlNodeList} is returned + * by this method as the {@link SqlNode} which should actually be validated and executed, and will have access to the + * modified query context through the {@link PlannerContext}. {@link SqlSetOption} override any existing query + * context parameter values. + */ + private static StatementAndSetContext processStatementList( + SqlNode root, + final boolean allowSetStatements + ) + { + if (root instanceof SqlNodeList) { + final Map setContext = new LinkedHashMap<>(); + final SqlNodeList nodeList = (SqlNodeList) root; + if (!allowSetStatements && nodeList.size() > 1) { + throw InvalidSqlInput.exception("SQL query string must contain only a single statement"); + } + boolean isMissingDruidStatementNode = true; + // convert 0 or more SET statements into a Map of stuff to add to the query context + for (int i = 0; i < nodeList.size(); i++) { + SqlNode sqlNode = nodeList.get(i); + if (sqlNode instanceof SqlSetOption) { + final SqlSetOption sqlSetOption = (SqlSetOption) sqlNode; + if (!(sqlSetOption.getValue() instanceof SqlLiteral)) { + throw InvalidSqlInput.exception( + "Assigned value must be a literal for SET statement[%s]", + sqlSetOption.toSqlString(CalciteSqlDialect.DEFAULT) + ); + } + setContext.put( + sqlSetOption.getName().getSimple(), + sqlLiteralToContextValue((SqlLiteral) sqlSetOption.getValue()) + ); + } else if (i < nodeList.size() - 1) { + // only SET statements can appear before the last statement + throw InvalidSqlInput.exception( + "Only SET statements can appear before the final statement in a statement list, but found non-SET statement[%s]", + sqlNode.toSqlString(CalciteSqlDialect.DEFAULT) + ); + } else { + // last SqlNode + root = sqlNode; + isMissingDruidStatementNode = false; + } + } + if (isMissingDruidStatementNode) { + throw InvalidSqlInput.exception("Statement list is missing a non-SET statement to execute"); + } + + return new StatementAndSetContext(root, setContext); + } else { + return new StatementAndSetContext(root, Collections.emptyMap()); + } + } + + /** + * Coerces a SQL literal from a SET statement to a form acceptable for {@link QueryContext}. + */ + @Nullable + static Object sqlLiteralToContextValue(final SqlLiteral literal) + { + if (SqlUtil.isNullLiteral(literal, false)) { + return null; + } else if (SqlTypeName.CHAR_TYPES.contains(literal.getTypeName())) { + return ((NlsString) literal.getValue()).getValue(); + } else if (SqlTypeName.BOOLEAN_TYPES.contains(literal.getTypeName())) { + return literal.getValue(); + } else if (SqlTypeName.NUMERIC_TYPES.contains(literal.getTypeName())) { + final Number number = (Number) literal.getValue(); + if (number instanceof BigDecimal && number.equals(BigDecimal.valueOf(number.longValue()))) { + return number.longValue(); + } else if (number instanceof BigInteger && number.equals(BigInteger.valueOf(number.longValue()))) { + return number.longValue(); + } else if (number instanceof BigDecimal && number.equals(BigDecimal.valueOf(number.doubleValue()))) { + return number.doubleValue(); + } else { + return number.toString(); + } + } else if (literal.getTypeName() == SqlTypeName.DATE) { + return Calcites.CALCITE_DATE_PARSER.parse(literal.getValue().toString()).toString(); + } else if (literal.getTypeName() == SqlTypeName.TIMESTAMP) { + return Calcites.CALCITE_TIMESTAMP_PARSER.parse(literal.getValue().toString()).toString(); + } else { + throw InvalidSqlInput.exception("Unsupported type for SET[%s]", literal.getTypeName()); + } + } + + /** + * Constructs a user-friendly {@link DruidException} from a Calcite {@link SqlParseException}. + */ + private static DruidException translateParseException(SqlParseException e) + { + final Throwable cause = e.getCause(); + if (cause instanceof DruidException) { + return (DruidException) cause; + } + + if (cause instanceof ParseException) { + ParseException parseException = (ParseException) cause; + final SqlParserPos failurePosition = e.getPos(); + // When calcite catches a syntax error at the top level + // expected token sequences can be null. + // In such a case return the syntax error to the user + // wrapped in a DruidException with invalid input + if (parseException.expectedTokenSequences == null) { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.INVALID_INPUT) + .withErrorCode("invalidInput") + .build(e, "%s", e.getMessage()) + .withContext("sourceType", "sql"); + } else { + final String theUnexpectedToken = getUnexpectedTokenString(parseException); + + final String[] tokenDictionary = e.getTokenImages(); + final int[][] expectedTokenSequences = e.getExpectedTokenSequences(); + final ArrayList expectedTokens = new ArrayList<>(expectedTokenSequences.length); + for (int[] expectedTokenSequence : expectedTokenSequences) { + String[] strings = new String[expectedTokenSequence.length]; + for (int i = 0; i < expectedTokenSequence.length; ++i) { + strings[i] = tokenDictionary[expectedTokenSequence[i]]; + } + expectedTokens.add(SPACE_JOINER.join(strings)); + } + + return InvalidSqlInput + .exception( + e, + "Received an unexpected token [%s] (line [%s], column [%s]), acceptable options: [%s]", + theUnexpectedToken, + failurePosition.getLineNum(), + failurePosition.getColumnNum(), + COMMA_JOINER.join(expectedTokens) + ) + .withContext("line", failurePosition.getLineNum()) + .withContext("column", failurePosition.getColumnNum()) + .withContext("endLine", failurePosition.getEndLineNum()) + .withContext("endColumn", failurePosition.getEndColumnNum()) + .withContext("token", theUnexpectedToken) + .withContext("expected", expectedTokens); + + } + } + + return InvalidSqlInput.exception(e.getMessage()); + } + + /** + * Grabs the unexpected token string. This code is borrowed with minimal adjustments from + * {@link ParseException#getMessage()}. It is possible that if that code changes, we need to also + * change this code to match it. + * + * @param parseException the parse exception to extract from + * + * @return the String representation of the unexpected token string + */ + private static String getUnexpectedTokenString(ParseException parseException) + { + int maxSize = 0; + for (int[] ints : parseException.expectedTokenSequences) { + if (maxSize < ints.length) { + maxSize = ints.length; + } + } + + StringBuilder bob = new StringBuilder(); + Token tok = parseException.currentToken.next; + for (int i = 0; i < maxSize; i++) { + if (i != 0) { + bob.append(" "); + } + if (tok.kind == 0) { + bob.append(""); + break; + } + char ch; + for (int i1 = 0; i1 < tok.image.length(); i1++) { + switch (tok.image.charAt(i1)) { + case 0: + continue; + case '\b': + bob.append("\\b"); + continue; + case '\t': + bob.append("\\t"); + continue; + case '\n': + bob.append("\\n"); + continue; + case '\f': + bob.append("\\f"); + continue; + case '\r': + bob.append("\\r"); + continue; + case '\"': + bob.append("\\\""); + continue; + case '\'': + bob.append("\\\'"); + continue; + case '\\': + bob.append("\\\\"); + continue; + default: + if ((ch = tok.image.charAt(i1)) < 0x20 || ch > 0x7e) { + String s = "0000" + Integer.toString(ch, 16); + bob.append("\\u").append(s.substring(s.length() - 4, s.length())); + } else { + bob.append(ch); + } + } + } + tok = tok.next; + } + return bob.toString(); + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/parser/StatementAndSetContext.java b/sql/src/main/java/org/apache/druid/sql/calcite/parser/StatementAndSetContext.java new file mode 100644 index 000000000000..2f00cb70e0a2 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/parser/StatementAndSetContext.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.parser; + +import org.apache.calcite.sql.SqlNode; + +import java.util.Map; +import java.util.Objects; + +/** + * Represents a parsed "main" (not SET) SQL statement, plus context derived from SET statements. + */ +public class StatementAndSetContext +{ + private final SqlNode mainStatement; + private final Map setContext; + + public StatementAndSetContext(SqlNode mainStatement, Map setContext) + { + this.mainStatement = mainStatement; + this.setContext = setContext; + } + + public SqlNode getMainStatement() + { + return mainStatement; + } + + public Map getSetContext() + { + return setContext; + } + + @Override + public boolean equals(Object o) + { + if (o == null || getClass() != o.getClass()) { + return false; + } + StatementAndSetContext that = (StatementAndSetContext) o; + return Objects.equals(mainStatement, that.mainStatement) + && Objects.equals(setContext, that.setContext); + } + + @Override + public int hashCode() + { + return Objects.hash(mainStatement, setContext); + } + + @Override + public String toString() + { + return "StatementAndSetContext{" + + "mainStatement=" + mainStatement + + ", setContext=" + setContext + + '}'; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalcitePlanner.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalcitePlanner.java index baf257632573..fe8dfd509a30 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalcitePlanner.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalcitePlanner.java @@ -60,6 +60,7 @@ import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.ValidationException; import org.apache.calcite.util.Pair; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import javax.annotation.Nullable; import java.io.Reader; @@ -232,6 +233,25 @@ public SqlNode parse(final Reader reader) throws SqlParseException return sqlNode; } + /** + * Skip parsing, moving state along to a {@link State#STATE_3_PARSED}. We have this because we + * parse before the planner is even created, using {@link DruidSqlParser}, in order to capture + * SET statements as early as possible. + */ + public void skipParse() + { + switch (state) { + case STATE_0_CLOSED: + case STATE_1_RESET: + ready(); + break; + default: + break; + } + ensure(CalcitePlanner.State.STATE_2_READY); + state = CalcitePlanner.State.STATE_3_PARSED; + } + @Override public SqlNode validate(SqlNode sqlNode) throws ValidationException { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index 42d6af585738..7f90f0a9aa28 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -77,8 +77,8 @@ */ public class Calcites { - private static final DateTimes.UtcFormatter CALCITE_DATE_PARSER = DateTimes.wrapFormatter(ISODateTimeFormat.dateParser()); - private static final DateTimes.UtcFormatter CALCITE_TIMESTAMP_PARSER = DateTimes.wrapFormatter( + public static final DateTimes.UtcFormatter CALCITE_DATE_PARSER = DateTimes.wrapFormatter(ISODateTimeFormat.dateParser()); + public static final DateTimes.UtcFormatter CALCITE_TIMESTAMP_PARSER = DateTimes.wrapFormatter( new DateTimeFormatterBuilder() .append(ISODateTimeFormat.dateParser()) .appendLiteral(' ') diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java index 2842eb76e94a..f361e43797b4 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java @@ -20,20 +20,13 @@ package org.apache.druid.sql.calcite.planner; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import org.apache.calcite.plan.RelOptPlanner; import org.apache.calcite.runtime.CalciteContextException; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.sql.SqlExplain; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlNodeList; -import org.apache.calcite.sql.SqlSetOption; -import org.apache.calcite.sql.dialect.CalciteSqlDialect; -import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.ValidationException; import org.apache.druid.error.DruidException; @@ -44,16 +37,11 @@ import org.apache.druid.server.security.ResourceAction; import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.parser.DruidSqlReplace; -import org.apache.druid.sql.calcite.parser.ParseException; -import org.apache.druid.sql.calcite.parser.Token; import org.apache.druid.sql.calcite.run.SqlEngine; -import org.apache.druid.sql.calcite.run.SqlResults; import org.joda.time.DateTimeZone; import java.io.Closeable; -import java.util.ArrayList; import java.util.HashSet; -import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; import java.util.function.Function; @@ -70,9 +58,6 @@ */ public class DruidPlanner implements Closeable { - public static final Joiner SPACE_JOINER = Joiner.on(" "); - public static final Joiner COMMA_JOINER = Joiner.on(", "); - public enum State { START, VALIDATED, PREPARED, PLANNED @@ -112,7 +97,6 @@ public AuthResult( private final PlannerContext plannerContext; private final SqlEngine engine; private final PlannerHook hook; - private final boolean allowSetStatementsToBuildContext; private State state = State.START; private SqlStatementHandler handler; private boolean authorized; @@ -121,8 +105,7 @@ public AuthResult( final FrameworkConfig frameworkConfig, final PlannerContext plannerContext, final SqlEngine engine, - final PlannerHook hook, - final boolean allowSetStatementsToBuildContext + final PlannerHook hook ) { this.frameworkConfig = frameworkConfig; @@ -130,7 +113,6 @@ public AuthResult( this.plannerContext = plannerContext; this.engine = engine; this.hook = hook == null ? NoOpPlannerHook.INSTANCE : hook; - this.allowSetStatementsToBuildContext = allowSetStatementsToBuildContext; } /** @@ -145,19 +127,8 @@ public void validate() // Validate query context. engine.validateContext(plannerContext.queryContextMap()); - - // Parse the query string. - String sql = plannerContext.getSql(); - hook.captureSql(sql); - SqlNode root; - try { - root = planner.parse(sql); - } - catch (SqlParseException e1) { - throw translateException(e1); - } - root = processStatementList(root); - root = rewriteParameters(root); + planner.skipParse(); + final SqlNode root = rewriteParameters(plannerContext.getSqlNode()); hook.captureSqlNode(root); handler = createHandler(root); handler.validate(); @@ -293,66 +264,6 @@ public void close() planner.close(); } - /** - * If an {@link SqlNode} is a {@link SqlNodeList}, it must consist of 0 or more {@link SqlSetOption} followed by a - * single {@link SqlNode} which is NOT a {@link SqlSetOption}. All {@link SqlSetOption} will be converted into a - * context parameters {@link Map} and added to the {@link PlannerContext} with - * {@link PlannerContext#addAllToQueryContext(Map)}. The final {@link SqlNode} of the {@link SqlNodeList} is returned - * by this method as the {@link SqlNode} which should actually be validated and executed, and will have access to the - * modified query context through the {@link PlannerContext}. {@link SqlSetOption} override any existing query - * context parameter values. - */ - private SqlNode processStatementList(SqlNode root) - { - if (root instanceof SqlNodeList) { - final SqlNodeList nodeList = (SqlNodeList) root; - if (!allowSetStatementsToBuildContext && nodeList.size() > 1) { - throw InvalidSqlInput.exception("SQL query string must contain only a single statement"); - } - final Map contextMap = new LinkedHashMap<>(); - boolean isMissingDruidStatementNode = true; - // convert 0 or more SET statements into a Map of stuff to add to the query context - for (int i = 0; i < nodeList.size(); i++) { - SqlNode sqlNode = nodeList.get(i); - if (sqlNode instanceof SqlSetOption) { - final SqlSetOption sqlSetOption = (SqlSetOption) sqlNode; - if (!(sqlSetOption.getValue() instanceof SqlLiteral)) { - throw InvalidSqlInput.exception( - "Assigned value must be a literal for SET statement[%s]", - sqlSetOption.toSqlString(CalciteSqlDialect.DEFAULT) - ); - } - final SqlLiteral value = (SqlLiteral) sqlSetOption.getValue(); - contextMap.put( - sqlSetOption.getName().getSimple(), - SqlResults.coerce( - plannerContext.getJsonMapper(), - SqlResults.Context.fromPlannerContext(plannerContext), - value.getValue(), - value.getTypeName(), - "set" - ) - ); - } else if (i < nodeList.size() - 1) { - // only SET statements can appear before the last statement - throw InvalidSqlInput.exception( - "Only SET statements can appear before the final statement in a statement list, but found non-SET statement[%s]", - sqlNode.toSqlString(CalciteSqlDialect.DEFAULT) - ); - } else { - // last SqlNode - root = sqlNode; - isMissingDruidStatementNode = false; - } - } - if (isMissingDruidStatementNode) { - throw InvalidSqlInput.exception("Statement list is missing a non-SET statement to execute"); - } - plannerContext.addAllToQueryContext(contextMap); - } - return root; - } - protected class HandlerContextImpl implements SqlStatementHandler.HandlerContext { @Override @@ -421,59 +332,6 @@ public static DruidException translateException(Exception e) catch (ValidationException inner) { return parseValidationMessage(inner); } - catch (SqlParseException inner) { - final Throwable cause = inner.getCause(); - if (cause instanceof DruidException) { - return (DruidException) cause; - } - - if (cause instanceof ParseException) { - ParseException parseException = (ParseException) cause; - final SqlParserPos failurePosition = inner.getPos(); - // When calcite catches a syntax error at the top level - // expected token sequences can be null. - // In such a case return the syntax error to the user - // wrapped in a DruidException with invalid input - if (parseException.expectedTokenSequences == null) { - return DruidException.forPersona(DruidException.Persona.USER) - .ofCategory(DruidException.Category.INVALID_INPUT) - .withErrorCode("invalidInput") - .build(inner, "%s", inner.getMessage()).withContext("sourceType", "sql"); - } else { - final String theUnexpectedToken = getUnexpectedTokenString(parseException); - - final String[] tokenDictionary = inner.getTokenImages(); - final int[][] expectedTokenSequences = inner.getExpectedTokenSequences(); - final ArrayList expectedTokens = new ArrayList<>(expectedTokenSequences.length); - for (int[] expectedTokenSequence : expectedTokenSequences) { - String[] strings = new String[expectedTokenSequence.length]; - for (int i = 0; i < expectedTokenSequence.length; ++i) { - strings[i] = tokenDictionary[expectedTokenSequence[i]]; - } - expectedTokens.add(SPACE_JOINER.join(strings)); - } - - return InvalidSqlInput - .exception( - inner, - "Received an unexpected token [%s] (line [%s], column [%s]), acceptable options: [%s]", - theUnexpectedToken, - failurePosition.getLineNum(), - failurePosition.getColumnNum(), - COMMA_JOINER.join(expectedTokens) - ) - .withContext("line", failurePosition.getLineNum()) - .withContext("column", failurePosition.getColumnNum()) - .withContext("endLine", failurePosition.getEndLineNum()) - .withContext("endColumn", failurePosition.getEndColumnNum()) - .withContext("token", theUnexpectedToken) - .withContext("expected", expectedTokens); - - } - } - - return InvalidSqlInput.exception(inner.getMessage()); - } catch (RelOptPlanner.CannotPlanException inner) { return DruidException.forPersona(DruidException.Persona.USER) .ofCategory(DruidException.Category.INVALID_INPUT) @@ -524,76 +382,4 @@ private static DruidException parseValidationMessage(Exception e) .build(e, "Uncategorized calcite error message: [%s]", e.getMessage()); } } - - /** - * Grabs the unexpected token string. This code is borrowed with minimal adjustments from - * {@link ParseException#getMessage()}. It is possible that if that code changes, we need to also - * change this code to match it. - * - * @param parseException the parse exception to extract from - * @return the String representation of the unexpected token string - */ - private static String getUnexpectedTokenString(ParseException parseException) - { - int maxSize = 0; - for (int[] ints : parseException.expectedTokenSequences) { - if (maxSize < ints.length) { - maxSize = ints.length; - } - } - - StringBuilder bob = new StringBuilder(); - Token tok = parseException.currentToken.next; - for (int i = 0; i < maxSize; i++) { - if (i != 0) { - bob.append(" "); - } - if (tok.kind == 0) { - bob.append(""); - break; - } - char ch; - for (int i1 = 0; i1 < tok.image.length(); i1++) { - switch (tok.image.charAt(i1)) { - case 0: - continue; - case '\b': - bob.append("\\b"); - continue; - case '\t': - bob.append("\\t"); - continue; - case '\n': - bob.append("\\n"); - continue; - case '\f': - bob.append("\\f"); - continue; - case '\r': - bob.append("\\r"); - continue; - case '\"': - bob.append("\\\""); - continue; - case '\'': - bob.append("\\\'"); - continue; - case '\\': - bob.append("\\\\"); - continue; - default: - if ((ch = tok.image.charAt(i1)) < 0x20 || ch > 0x7e) { - String s = "0000" + Integer.toString(ch, 16); - bob.append("\\u").append(s.substring(s.length() - 4, s.length())); - } else { - bob.append(ch); - } - continue; - } - } - tok = tok.next; - } - return bob.toString(); - } - } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java index a2bf29868778..18c56dc3a91a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java @@ -28,6 +28,7 @@ import org.apache.calcite.avatica.remote.TypedValue; import org.apache.calcite.linq4j.QueryProvider; import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.SqlNode; import org.apache.druid.error.InvalidSqlInput; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; @@ -128,6 +129,7 @@ public class PlannerContext private final PlannerToolbox plannerToolbox; private final ExpressionParser expressionParser; private final String sql; + private final SqlNode sqlNode; private final SqlEngine engine; private final Map queryContext; private final CopyOnWriteArrayList nativeQueryIds = new CopyOnWriteArrayList<>(); @@ -163,6 +165,7 @@ public class PlannerContext private PlannerContext( final PlannerToolbox plannerToolbox, final String sql, + final SqlNode sqlNode, final SqlEngine engine, final Map queryContext, final PlannerHook hook @@ -171,6 +174,7 @@ private PlannerContext( this.plannerToolbox = plannerToolbox; this.expressionParser = new ExpressionParserImpl(plannerToolbox.exprMacroTable()); this.sql = sql; + this.sqlNode = sqlNode; this.engine = engine; this.queryContext = new LinkedHashMap<>(queryContext); this.hook = hook == null ? NoOpPlannerHook.INSTANCE : hook; @@ -180,6 +184,7 @@ private PlannerContext( public static PlannerContext create( final PlannerToolbox plannerToolbox, final String sql, + final SqlNode sqlNode, final SqlEngine engine, final Map queryContext, final PlannerHook hook @@ -188,6 +193,7 @@ public static PlannerContext create( return new PlannerContext( plannerToolbox, sql, + sqlNode, engine, queryContext, hook @@ -393,6 +399,11 @@ public String getSql() return sql; } + public SqlNode getSqlNode() + { + return sqlNode; + } + public PlannerHook getPlannerHook() { return hook; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java index 2f7dce4ce614..dcc245704423 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java @@ -23,8 +23,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; -import org.apache.calcite.avatica.util.Casing; -import org.apache.calcite.avatica.util.Quoting; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionConfigImpl; import org.apache.calcite.config.CalciteConnectionProperty; @@ -32,7 +30,7 @@ import org.apache.calcite.plan.ConventionTraitDef; import org.apache.calcite.plan.volcano.DruidVolcanoCost; import org.apache.calcite.rel.RelCollationTraitDef; -import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.validate.SqlConformance; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.calcite.tools.FrameworkConfig; @@ -46,7 +44,8 @@ import org.apache.druid.server.security.AuthorizationResult; import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.server.security.NoopEscalator; -import org.apache.druid.sql.calcite.parser.DruidSqlParserImplFactory; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; +import org.apache.druid.sql.calcite.parser.StatementAndSetContext; import org.apache.druid.sql.calcite.planner.convertlet.DruidConvertletTable; import org.apache.druid.sql.calcite.run.SqlEngine; import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; @@ -59,16 +58,6 @@ public class PlannerFactory extends PlannerToolbox { - static final SqlParser.Config PARSER_CONFIG = SqlParser - .configBuilder() - .setCaseSensitive(true) - .setUnquotedCasing(Casing.UNCHANGED) - .setQuotedCasing(Casing.UNCHANGED) - .setQuoting(Quoting.DOUBLE_QUOTE) - .setConformance(DruidConformance.instance()) - .setParserFactory(new DruidSqlParserImplFactory()) // Custom SQL parser factory - .build(); - @Inject public PlannerFactory( final DruidSchemaCatalog rootSchema, @@ -108,25 +97,33 @@ public PlannerFactory( * the parser is allowed to parse multi-part SQL statements where all statements in the list except the last one are * SET statements, for example 'SET x = 'y'; SET foo = 123; SELECT ...', where these values will be added to the * {@link org.apache.druid.query.QueryContext} of the final statement. + * + * @param engine current SQL engine + * @param sql sql query string + * @param sqlNode parsed sql query, from {@link DruidSqlParser#parse(String, boolean)}. This is the main + * statement from {@link StatementAndSetContext#getMainStatement()}. + * @param queryContext query context including {@link StatementAndSetContext#getSetContext()} + * @param hook calcite planner hook */ public DruidPlanner createPlanner( final SqlEngine engine, final String sql, + final SqlNode sqlNode, final Map queryContext, - final PlannerHook hook, - boolean allowSetStatementsToBuildContext + final PlannerHook hook ) { final PlannerContext context = PlannerContext.create( this, sql, + sqlNode, engine, queryContext, hook ); context.dispatchHook(DruidHook.SQL, sql); - return new DruidPlanner(buildFrameworkConfig(context), context, engine, hook, allowSetStatementsToBuildContext); + return new DruidPlanner(buildFrameworkConfig(context), context, engine, hook); } /** @@ -140,7 +137,16 @@ public DruidPlanner createPlannerForTesting( final Map queryContext ) { - final DruidPlanner thePlanner = createPlanner(engine, sql, queryContext, null, true); + final StatementAndSetContext statementAndSetContext = DruidSqlParser.parse(sql, true); + final DruidPlanner thePlanner = createPlanner( + engine, + sql, + statementAndSetContext.getMainStatement(), + statementAndSetContext.getSetContext().isEmpty() + ? queryContext + : QueryContexts.override(queryContext, statementAndSetContext.getSetContext()), + null + ); thePlanner.getPlannerContext() .setAuthenticationResult(NoopEscalator.getInstance().createEscalatedAuthenticationResult()); thePlanner.validate(); @@ -166,7 +172,7 @@ private FrameworkConfig buildFrameworkConfig(PlannerContext plannerContext) Frameworks.ConfigBuilder frameworkConfigBuilder = Frameworks .newConfigBuilder() - .parserConfig(PARSER_CONFIG) + .parserConfig(DruidSqlParser.PARSER_CONFIG) .traitDefs(ConventionTraitDef.INSTANCE, RelCollationTraitDef.INSTANCE) .convertletTable(new DruidConvertletTable(plannerContext)) .operatorTable(operatorTable) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java b/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java index 33787001dd71..d17c9c5765a5 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java @@ -28,6 +28,7 @@ import org.apache.calcite.schema.TableMacro; import org.apache.calcite.schema.TranslatableTable; import org.apache.calcite.schema.impl.ViewTable; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import org.apache.druid.sql.calcite.planner.DruidPlanner; import org.apache.druid.sql.calcite.planner.PlannerFactory; import org.apache.druid.sql.calcite.schema.DruidSchemaName; @@ -61,9 +62,9 @@ public TranslatableTable apply(final List arguments) plannerFactory.createPlanner( ViewSqlEngine.INSTANCE, viewSql, + DruidSqlParser.parse(viewSql, false).getMainStatement(), // views cannot embed SET Collections.emptyMap(), - null, - false + null ) ) { planner.validate(); diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlEngineRegistry.java b/sql/src/main/java/org/apache/druid/sql/http/SqlEngineRegistry.java index 1cf5b23ef24e..1e70c08c0ccb 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlEngineRegistry.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlEngineRegistry.java @@ -20,8 +20,8 @@ package org.apache.druid.sql.http; import com.google.inject.Inject; +import org.apache.druid.error.InvalidSqlInput; import org.apache.druid.query.QueryContexts; -import org.apache.druid.server.initialization.jetty.BadRequestException; import org.apache.druid.sql.calcite.run.SqlEngine; import javax.validation.constraints.NotNull; @@ -46,7 +46,7 @@ public SqlEngine getEngine(final String engineName) { SqlEngine engine = engines.getOrDefault(engineName == null ? QueryContexts.DEFAULT_ENGINE : engineName, null); if (engine == null) { - throw new BadRequestException("Unsupported engine"); + throw InvalidSqlInput.exception("Unsupported engine[%s]", engineName); } return engine; } diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java index 2c4ed3c38a0e..3640ce7b32e4 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java @@ -22,12 +22,16 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import com.sun.jersey.api.core.HttpContext; import org.apache.druid.common.exception.SanitizableException; +import org.apache.druid.error.DruidException; import org.apache.druid.guice.annotations.Self; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryContexts; import org.apache.druid.server.DruidNode; import org.apache.druid.server.QueryLifecycle; import org.apache.druid.server.QueryResource; @@ -46,6 +50,7 @@ import org.apache.druid.sql.HttpStatement; import org.apache.druid.sql.SqlLifecycleManager; import org.apache.druid.sql.SqlLifecycleManager.Cancelable; +import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlRowTransformer; import org.apache.druid.sql.calcite.run.SqlEngine; @@ -178,17 +183,25 @@ public Response doPost( final HttpServletRequest req ) { - final String engineName = sqlQuery.queryContext().getEngine(); - final SqlEngine engine = sqlEngineRegistry.getEngine(engineName); - final HttpStatement stmt = engine.getSqlStatementFactory().httpStatement(sqlQuery, req); - final String sqlQueryId = stmt.sqlQueryId(); - final String currThreadName = Thread.currentThread().getName(); + final HttpStatement stmt; + final QueryContext queryContext; try { - Thread.currentThread().setName(StringUtils.format("sql[%s]", sqlQueryId)); + final SqlQueryPlus sqlQueryPlus = makeSqlQueryPlus(sqlQuery, req); + queryContext = new QueryContext(sqlQueryPlus.context()); // Redefine queryContext to include SET parameters + final String engineName = queryContext.getEngine(); + final SqlEngine engine = sqlEngineRegistry.getEngine(engineName); + stmt = engine.getSqlStatementFactory().httpStatement(sqlQueryPlus, req); + } + catch (Exception e) { + // Can't use the queryContext with SETs since it might not have been created yet. Use the original one. + return handleExceptionBeforeStatementCreated(e, sqlQuery.queryContext()); + } - QueryResultPusher pusher = makePusher(req, stmt, sqlQuery); - return pusher.push(); + final String currThreadName = Thread.currentThread().getName(); + try { + Thread.currentThread().setName(StringUtils.format("sql[%s]", stmt.sqlQueryId())); + return makePusher(req, stmt, sqlQuery, queryContext).push(); } finally { Thread.currentThread().setName(currThreadName); @@ -249,7 +262,12 @@ public void incrementTimedOut() } } - private SqlResourceQueryResultPusher makePusher(HttpServletRequest req, HttpStatement stmt, SqlQuery sqlQuery) + private SqlResourceQueryResultPusher makePusher( + HttpServletRequest req, + HttpStatement stmt, + SqlQuery sqlQuery, + QueryContext queryContext + ) { final String sqlQueryId = stmt.sqlQueryId(); Map headers = new LinkedHashMap<>(); @@ -259,7 +277,7 @@ private SqlResourceQueryResultPusher makePusher(HttpServletRequest req, HttpStat headers.put(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE); } - return new SqlResourceQueryResultPusher(req, sqlQueryId, stmt, sqlQuery, headers); + return new SqlResourceQueryResultPusher(req, sqlQueryId, stmt, sqlQuery, queryContext, headers); } private class SqlResourceQueryResultPusher extends QueryResultPusher @@ -268,11 +286,17 @@ private class SqlResourceQueryResultPusher extends QueryResultPusher private final HttpStatement stmt; private final SqlQuery sqlQuery; + /** + * Context to use for pushing results. May be different from the context in SqlQuery due to SET statements. + */ + private final QueryContext queryContext; + public SqlResourceQueryResultPusher( HttpServletRequest req, String sqlQueryId, HttpStatement stmt, SqlQuery sqlQuery, + QueryContext queryContext, Map headers ) { @@ -288,6 +312,7 @@ public SqlResourceQueryResultPusher( ); this.sqlQueryId = sqlQueryId; this.stmt = stmt; + this.queryContext = queryContext; this.sqlQuery = sqlQuery; } @@ -374,7 +399,7 @@ public void recordSuccess(long numBytes) @Override public void recordFailure(Exception e) { - if (QueryLifecycle.shouldLogStackTrace(e, sqlQuery.queryContext())) { + if (QueryLifecycle.shouldLogStackTrace(e, queryContext)) { log.warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId); } else { log.noStackTrace().warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId); @@ -418,4 +443,46 @@ public AuthorizationResult authorizeCancellation(final HttpServletRequest req, f authorizerMapper ); } + + /** + * Create a {@link SqlQueryPlus}, which involves parsing the query from {@link SqlQuery#getQuery()} and + * extracing any SET parameters into the query context. + */ + public static SqlQueryPlus makeSqlQueryPlus(final SqlQuery sqlQuery, final HttpServletRequest req) + { + return SqlQueryPlus.builder() + .sql(sqlQuery.getQuery()) + .context(sqlQuery.getContext()) + .parameters(sqlQuery.getParameterList()) + .auth(AuthorizationUtils.authenticationResultFromRequest(req)) + .build(); + } + + /** + * Generates a response for a {@link DruidException} that occurs prior to the {@link HttpStatement} being created. + */ + public static Response handleExceptionBeforeStatementCreated(final Exception e, final QueryContext queryContext) + { + if (e instanceof DruidException) { + final String sqlQueryId = queryContext.getString(QueryContexts.CTX_SQL_QUERY_ID); + return QueryResultPusher.handleDruidExceptionBeforeResponseStarted( + (DruidException) e, + MediaType.APPLICATION_JSON_TYPE, + sqlQueryId != null + ? ImmutableMap.builder() + .put(QueryResource.QUERY_ID_RESPONSE_HEADER, sqlQueryId) + .put(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId) + .build() + : Collections.emptyMap() + ); + } else { + return QueryResultPusher.handleDruidExceptionBeforeResponseStarted( + DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build(e, "Cannot handle query"), + MediaType.APPLICATION_JSON_TYPE, + Collections.emptyMap() + ); + } + } } diff --git a/sql/src/test/java/org/apache/druid/sql/SqlQueryPlusTest.java b/sql/src/test/java/org/apache/druid/sql/SqlQueryPlusTest.java new file mode 100644 index 000000000000..688453de26c1 --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/SqlQueryPlusTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql; + +import org.apache.druid.error.DruidException; +import org.apache.druid.error.DruidExceptionMatcher; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.hamcrest.MatcherAssert; +import org.junit.Assert; +import org.junit.Test; + +public class SqlQueryPlusTest +{ + @Test + public void testSyntaxError() + { + // SqlQueryPlus throws parse errors on build() if the statement is invalid + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> SqlQueryPlus.builder("SELECT COUNT(*) AS cnt, 'foo' AS") + .auth(CalciteTests.REGULAR_USER_AUTH_RESULT) + .build() + ); + + MatcherAssert.assertThat( + e, + DruidExceptionMatcher + .invalidSqlInput() + .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") + ); + } + + @Test + public void testSyntaxErrorJdbc() + { + // SqlQueryPlus does not throw parse errors on buildJdbc(), because parsing is deferred + final SqlQueryPlus sqlQueryPlus = + SqlQueryPlus.builder("SELECT COUNT(*) AS cnt, 'foo' AS") + .auth(CalciteTests.REGULAR_USER_AUTH_RESULT) + .buildJdbc(); + + // It does throw exceptions on freshCopy(), though. + final DruidException e = Assert.assertThrows( + DruidException.class, + sqlQueryPlus::freshCopy + ); + + MatcherAssert.assertThat( + e, + DruidExceptionMatcher + .invalidSqlInput() + .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") + ); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java index dfbad004f3ca..ff21f9853d3b 100644 --- a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java @@ -63,7 +63,6 @@ import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; import org.apache.druid.sql.calcite.util.CalciteTests; import org.apache.druid.sql.hook.DruidHookDispatcher; -import org.apache.druid.sql.http.SqlQuery; import org.easymock.EasyMock; import org.hamcrest.MatcherAssert; import org.junit.After; @@ -280,28 +279,6 @@ public void testDirectPolicyEnforcerValidatesWithPolicy() assertResultsEquals("SELECT COUNT(*) AS cnt FROM druid.restrictedDatasource_m1_is_6", expectedResults, results); } - @Test - public void testDirectSyntaxError() - { - SqlQueryPlus sqlReq = queryPlus( - "SELECT COUNT(*) AS cnt, 'foo' AS", - CalciteTests.REGULAR_USER_AUTH_RESULT - ); - DirectStatement stmt = sqlStatementFactory.directStatement(sqlReq); - try { - stmt.execute(); - fail(); - } - catch (DruidException e) { - MatcherAssert.assertThat( - e, - DruidExceptionMatcher - .invalidSqlInput() - .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") - ); - } - } - @Test public void testDirectValidationError() { @@ -344,17 +321,13 @@ public void testDirectPermissionError() //----------------------------------------------------------------- // HTTP statements: using a servlet request for verification. - private SqlQuery makeQuery(String sql) + /** + * Creates a {@link SqlQueryPlus} with auth result {@link CalciteTests#REGULAR_USER_AUTH_RESULT}, which matches + * the auth result used by {@link #request(boolean)}. + */ + private SqlQueryPlus makeQuery(String sql) { - return new SqlQuery( - sql, - null, - false, - false, - false, - null, - null - ); + return SqlQueryPlus.builder(sql).auth(CalciteTests.REGULAR_USER_AUTH_RESULT).build(); } @Test @@ -370,27 +343,6 @@ public void testHttpHappyPath() assertEquals("foo", results.get(0)[1]); } - @Test - public void testHttpSyntaxError() - { - HttpStatement stmt = sqlStatementFactory.httpStatement( - makeQuery("SELECT COUNT(*) AS cnt, 'foo' AS"), - request(true) - ); - try { - stmt.execute(); - fail(); - } - catch (DruidException e) { - MatcherAssert.assertThat( - e, - DruidExceptionMatcher - .invalidSqlInput() - .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") - ); - } - } - @Test public void testHttpValidationError() { @@ -494,28 +446,6 @@ public void testPreparedHappyPath() } } - @Test - public void testPrepareSyntaxError() - { - SqlQueryPlus sqlReq = queryPlus( - "SELECT COUNT(*) AS cnt, 'foo' AS", - CalciteTests.REGULAR_USER_AUTH_RESULT - ); - PreparedStatement stmt = sqlStatementFactory.preparedStatement(sqlReq); - try { - stmt.prepare(); - fail(); - } - catch (DruidException e) { - MatcherAssert.assertThat( - e, - DruidExceptionMatcher - .invalidSqlInput() - .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") - ); - } - } - @Test public void testPrepareValidationError() { diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java index f971482ecc81..78bdbf4a15a0 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java @@ -152,12 +152,11 @@ private DruidJdbcStatement jdbcStatement() @Test public void testSubQueryWithOrderByDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SUB_QUERY_WITH_ORDER_BY, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SUB_QUERY_WITH_ORDER_BY) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for all rows. statement.execute(queryPlus, -1); @@ -173,12 +172,11 @@ public void testSubQueryWithOrderByDirect() @Test public void testFetchPastEOFDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SUB_QUERY_WITH_ORDER_BY, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SUB_QUERY_WITH_ORDER_BY) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for all rows. statement.execute(queryPlus, -1); @@ -221,12 +219,11 @@ public void testSkipExecuteDirect() @Test public void testFetchAfterResultCloseDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SUB_QUERY_WITH_ORDER_BY, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SUB_QUERY_WITH_ORDER_BY) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for all rows. statement.execute(queryPlus, -1); @@ -243,12 +240,11 @@ public void testFetchAfterResultCloseDirect() @Test public void testSubQueryWithOrderByDirectTwice() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SUB_QUERY_WITH_ORDER_BY, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SUB_QUERY_WITH_ORDER_BY) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { statement.execute(queryPlus, -1); Meta.Frame frame = statement.nextFrame(AbstractDruidJdbcStatement.START_OFFSET, 6); @@ -288,12 +284,11 @@ private Meta.Frame subQueryWithOrderByResults() @Test public void testSelectAllInFirstFrameDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for all rows. statement.execute(queryPlus, -1); @@ -330,12 +325,11 @@ public void testSelectAllInFirstFrameDirect() @Test public void testSelectSplitOverTwoFramesDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for 2 rows. @@ -367,12 +361,11 @@ public void testSelectSplitOverTwoFramesDirect() @Test public void testTwoFramesAutoCloseDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for 2 rows. statement.execute(queryPlus, -1); @@ -409,12 +402,11 @@ public void testTwoFramesAutoCloseDirect() @Test public void testTwoFramesCloseWithResultSetDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for 2 rows. statement.execute(queryPlus, -1); @@ -464,12 +456,11 @@ private Meta.Frame secondFrameResults() @Test public void testSignatureDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_STAR_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_STAR_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // Check signature. statement.execute(queryPlus, -1); @@ -533,12 +524,11 @@ private DruidJdbcPreparedStatement jdbcPreparedStatement(SqlQueryPlus queryPlus) public void testSubQueryWithOrderByPrepared() { final String sql = "select T20.F13 as F22 from (SELECT DISTINCT dim1 as F13 FROM druid.foo T10) T20 order by T20.F13 ASC"; - SqlQueryPlus queryPlus = new SqlQueryPlus( - sql, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(sql) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcPreparedStatement statement = jdbcPreparedStatement(queryPlus)) { statement.prepare(); // First frame, ask for all rows. @@ -556,12 +546,11 @@ public void testSubQueryWithOrderByPrepared() public void testSubQueryWithOrderByPreparedTwice() { final String sql = "select T20.F13 as F22 from (SELECT DISTINCT dim1 as F13 FROM druid.foo T10) T20 order by T20.F13 ASC"; - SqlQueryPlus queryPlus = new SqlQueryPlus( - sql, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(sql) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcPreparedStatement statement = jdbcPreparedStatement(queryPlus)) { statement.prepare(); statement.execute(Collections.emptyList()); @@ -586,12 +575,11 @@ public void testSubQueryWithOrderByPreparedTwice() @Test public void testSignaturePrepared() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_STAR_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_STAR_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcPreparedStatement statement = jdbcPreparedStatement(queryPlus)) { statement.prepare(); verifySignature(statement.getSignature()); @@ -601,12 +589,11 @@ public void testSignaturePrepared() @Test public void testParameters() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - "SELECT COUNT(*) AS cnt FROM sys.servers WHERE servers.host = ?", - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql("SELECT COUNT(*) AS cnt FROM sys.servers WHERE servers.host = ?") + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); Meta.Frame expected = Meta.Frame.create(0, true, Collections.singletonList(new Object[] {1L})); List matchingParams = Collections.singletonList(TypedValue.ofLocal(ColumnMetaData.Rep.STRING, "dummy")); try (final DruidJdbcPreparedStatement statement = jdbcPreparedStatement(queryPlus)) { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index d87d9e00cc24..65fb0d3e88bf 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -15976,4 +15976,52 @@ public void testMultiStatementSetsInvalidTooManyNonSetStatements() "Only SET statements can appear before the final statement in a statement list, but found non-SET statement[SELECT 1]" ); } + + @Test + public void testSetUseApproximateCountDistinctFalse() + { + testBuilder().sql( + "SET useApproximateCountDistinct = FALSE;\n" + + "SELECT COUNT(DISTINCT dim2) FROM druid.foo" + ).expectedQueries( + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0"))) + .setContext( + ImmutableMap.builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put("useApproximateCountDistinct", false) + .build() + ) + .build() + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + notNull("d0") + ) + )) + .setContext( + ImmutableMap.builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put("useApproximateCountDistinct", false) + .build() + ) + .build() + ) + ).expectedResults( + ImmutableList.of( + new Object[]{3L} + ) + ).run(); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestRunner.java b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestRunner.java index 27071dcf7b2c..bbc0c2987cac 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestRunner.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestRunner.java @@ -263,10 +263,6 @@ public void run() BaseCalciteQueryTest.log.info("SQL: %s", builder.sql); final SqlStatementFactory sqlStatementFactory = builder.statementFactory(); - final SqlQueryPlus sqlQuery = SqlQueryPlus.builder(builder.sql) - .sqlParameters(builder.parameters) - .auth(builder.authenticationResult) - .build(); final List vectorizeValues = new ArrayList<>(); vectorizeValues.add("false"); @@ -275,7 +271,14 @@ public void run() } for (final String vectorize : vectorizeValues) { - final Map theQueryContext = new HashMap<>(builder.queryContext); + // Need to create sqlQuery inside the loop, because SqlQueryPlus can only be used once + final SqlQueryPlus sqlQuery = SqlQueryPlus.builder(builder.sql) + .sqlParameters(builder.parameters) + .auth(builder.authenticationResult) + .context(builder.queryContext) + .build(); + + final Map theQueryContext = new HashMap<>(sqlQuery.context()); theQueryContext.put(QueryContexts.VECTORIZE_KEY, vectorize); theQueryContext.put(QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java index 20d7c0fa0786..5ec8f67bb424 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java @@ -106,6 +106,7 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) public static final PlannerContext PLANNER_CONTEXT = PlannerContext.create( PLANNER_TOOLBOX, "SELECT 1", // The actual query isn't important for this test + null, /* Don't need SQL node */ null, /* Don't need engine */ Collections.emptyMap(), null diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java index 4f30cc1feb11..cccaffe8ec4f 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java @@ -28,6 +28,7 @@ import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.server.security.AuthConfig; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import org.apache.druid.sql.calcite.planner.CalciteRulesManager; import org.apache.druid.sql.calcite.planner.CatalogResolver; import org.apache.druid.sql.calcite.planner.PlannerConfig; @@ -79,7 +80,8 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) ); final PlannerContext plannerContext = PlannerContext.create( toolbox, - "DUMMY", // The actual query isn't important for this test + "SELECT 1", // The actual query isn't important for this test + DruidSqlParser.parse("SELECT 1", false).getMainStatement(), engine, Collections.emptyMap(), null diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/parser/DruidSqlParserTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/parser/DruidSqlParserTest.java new file mode 100644 index 000000000000..1765ecc2f5cf --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/calcite/parser/DruidSqlParserTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.parser; + +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.DateString; +import org.apache.calcite.util.TimestampString; +import org.apache.druid.error.DruidException; +import org.junit.Assert; +import org.junit.Test; + +public class DruidSqlParserTest +{ + @Test + public void test_sqlLiteralToContextValue_null() + { + final SqlLiteral literal = SqlLiteral.createNull(SqlParserPos.ZERO); + Assert.assertNull(DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_string() + { + final SqlLiteral literal = SqlLiteral.createCharString("abc", SqlParserPos.ZERO); + Assert.assertEquals("abc", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_stringWithSpecialChars() + { + final SqlLiteral literal = SqlLiteral.createCharString("hello\nworld\t\"test\"", SqlParserPos.ZERO); + Assert.assertEquals("hello\nworld\t\"test\"", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_integer() + { + // Numbers within Long range are converted to Long. + final SqlLiteral literal = SqlLiteral.createExactNumeric("42", SqlParserPos.ZERO); + Assert.assertEquals(42L, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_negativeInteger() + { + final SqlLiteral literal = SqlLiteral.createExactNumeric("-123", SqlParserPos.ZERO); + Assert.assertEquals(-123L, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_decimal() + { + // Decimals are converted to Double. + final SqlLiteral literal = SqlLiteral.createExactNumeric("3.14159", SqlParserPos.ZERO); + Assert.assertEquals(3.14159, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_largeNumber() + { + // Integers outside Long range are retained as strings. + final SqlLiteral literal = SqlLiteral.createExactNumeric("123456789012345678901234567890", SqlParserPos.ZERO); + Assert.assertEquals("123456789012345678901234567890", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_approximateNumeric() + { + final SqlLiteral literal = SqlLiteral.createApproxNumeric("1.23E10", SqlParserPos.ZERO); + Assert.assertEquals(1.23E10, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_booleanTrue() + { + final SqlLiteral literal = SqlLiteral.createBoolean(true, SqlParserPos.ZERO); + Assert.assertEquals(true, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_booleanFalse() + { + final SqlLiteral literal = SqlLiteral.createBoolean(false, SqlParserPos.ZERO); + Assert.assertEquals(false, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_timestamp() + { + // Timestamps are returned as strings in ISO8601 format + final TimestampString timestampString = new TimestampString("2023-01-15 14:30:00"); + final SqlLiteral literal = + SqlLiteral.createTimestamp(SqlTypeName.TIMESTAMP, timestampString, 0, SqlParserPos.ZERO); + Assert.assertEquals("2023-01-15T14:30:00.000Z", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_timestampWithFractionalSeconds() + { + final TimestampString timestampString = new TimestampString("2023-01-15 14:30:00.123"); + final SqlLiteral literal = + SqlLiteral.createTimestamp(SqlTypeName.TIMESTAMP, timestampString, 3, SqlParserPos.ZERO); + Assert.assertEquals("2023-01-15T14:30:00.123Z", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_date() + { + final DateString dateString = new DateString("2023-01-15"); + final SqlLiteral literal = SqlLiteral.createDate(dateString, SqlParserPos.ZERO); + Assert.assertEquals("2023-01-15T00:00:00.000Z", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_unsupportedType() + { + final SqlLiteral literal = SqlLiteral.createSymbol(SqlTypeName.BINARY, SqlParserPos.ZERO); + final DruidException exception = Assert.assertThrows( + DruidException.class, + () -> DruidSqlParser.sqlLiteralToContextValue(literal) + ); + Assert.assertTrue(exception.getMessage().contains("Unsupported type for SET")); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java index eb803186d4c5..b1809ba87556 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java @@ -44,6 +44,7 @@ import org.apache.druid.sql.SqlStatementFactory; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import org.apache.druid.sql.calcite.rule.ExtensionCalciteRuleProvider; import org.apache.druid.sql.calcite.run.NativeSqlEngine; import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; @@ -184,9 +185,11 @@ public void testExtensionCalciteRule() ObjectMapper mapper = new DefaultObjectMapper(); PlannerToolbox toolbox = injector.getInstance(PlannerFactory.class); + final String sql = "SELECT 1"; PlannerContext context = PlannerContext.create( toolbox, - "SELECT 1", + sql, + DruidSqlParser.parse(sql, false).getMainStatement(), new NativeSqlEngine(queryLifecycleFactory, mapper, (SqlStatementFactory) null), Collections.emptyMap(), null @@ -204,9 +207,11 @@ public void testConfigurableBloat() ObjectMapper mapper = new DefaultObjectMapper(); PlannerToolbox toolbox = injector.getInstance(PlannerFactory.class); + final String sql = "SELECT 1"; PlannerContext contextWithBloat = PlannerContext.create( toolbox, - "SELECT 1", + sql, + DruidSqlParser.parse(sql, false).getMainStatement(), new NativeSqlEngine(queryLifecycleFactory, mapper, (SqlStatementFactory) null), Collections.singletonMap(BLOAT_PROPERTY, BLOAT), null @@ -214,7 +219,8 @@ public void testConfigurableBloat() PlannerContext contextWithoutBloat = PlannerContext.create( toolbox, - "SELECT 1", + sql, + DruidSqlParser.parse(sql, false).getMainStatement(), new NativeSqlEngine(queryLifecycleFactory, mapper, (SqlStatementFactory) null), Collections.emptyMap(), null diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java index 816ce3fd1a1d..fc0e7a43edc7 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java @@ -112,6 +112,7 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) private static final PlannerContext PLANNER_CONTEXT = PlannerContext.create( PLANNER_TOOLBOX, "SELECT 1", // The actual query isn't important for this test + null, /* Don't need a SQL node */ null, /* Don't need an engine */ Collections.emptyMap(), null diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java index a4ff866f0ddd..0d0deb14a5a6 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java @@ -368,9 +368,9 @@ protected DruidPlanner createPlanner() return plannerFactory.createPlanner( engine, queryPlus.sql(), + queryPlus.sqlNode(), queryContext, - hook, - true + hook ); } }; @@ -387,9 +387,9 @@ protected DruidPlanner getPlanner() return plannerFactory.createPlanner( engine, queryPlus.sql(), + queryPlus.sqlNode(), queryContext, - hook, - true + hook ); } diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java index b1aa19753088..ac0419485615 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java @@ -292,13 +292,13 @@ public void add(String sqlQueryId, Cancelable lifecycle) { @Override public HttpStatement httpStatement( - final SqlQuery sqlQuery, + final SqlQueryPlus sqlQueryPlus, final HttpServletRequest req ) { TestHttpStatement stmt = new TestHttpStatement( sqlToolbox.withEngine(engine), - sqlQuery, + sqlQueryPlus, req, validateAndAuthorizeLatchSupplier, planLatchSupplier, @@ -1481,7 +1481,7 @@ public void testCannotParse() throws Exception errorResponse, "Incorrect syntax near the keyword 'FROM' at line 1, column 1" ); - checkSqlRequestLog(false); + Assert.assertEquals(0, testRequestLogger.getSqlQueryLogs().size()); // Invalid queries are not logged Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @@ -1590,15 +1590,10 @@ public void testUnsupportedQueryThrowsException() throws Exception ImmutableMap.of(BaseQuery.SQL_QUERY_ID, "id"), null ), - 501 + DruidException.Category.INVALID_INPUT.getExpectedStatus() ); - validateLegacyQueryExceptionErrorResponse( - exception, - QueryException.QUERY_UNSUPPORTED_ERROR_CODE, - QueryUnsupportedException.class.getName(), - "" - ); + validateInvalidSqlError(exception, "Incorrect syntax near the keyword 'TO'"); Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @@ -1623,32 +1618,31 @@ public void testErrorResponseReturnSameQueryIdWhenSetInContext() // This is checked in the common method that returns the response, but checking it again just protects // from changes there breaking the checks, so doesn't hurt. - assertStatusAndCommonHeaders(response, 501); + assertStatusAndCommonHeaders(response, DruidException.Category.INVALID_INPUT.getExpectedStatus()); Assert.assertEquals(queryId, getHeader(response, QueryResource.QUERY_ID_RESPONSE_HEADER)); Assert.assertEquals(queryId, getHeader(response, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)); } @Test - public void testErrorResponseReturnNewQueryIdWhenNotSetInContext() + public void testErrorResponseReturnNoQueryIdWhenNotSetInContext() { String errorMessage = "This will be supported in Druid 9999"; failOnExecute(errorMessage); - final Response response = postForSyncResponse( - new SqlQuery( - "SELECT ANSWER TO LIFE", - ResultFormat.OBJECT, - false, - false, - false, - ImmutableMap.of(), - null - ), - req + final SqlQuery sqlQuery = new SqlQuery( + "SELECT ANSWER TO LIFE", + ResultFormat.OBJECT, + false, + false, + false, + ImmutableMap.of(), + null ); - // This is checked in the common method that returns the response, but checking it again just protects - // from changes there breaking the checks, so doesn't hurt. - assertStatusAndCommonHeaders(response, 501); + final Response response = resource.doPost(sqlQuery, req); + + // Query ID won't be set, but we can look for other aspects of the response that we expect. + Assert.assertEquals(DruidException.Category.INVALID_INPUT.getExpectedStatus(), response.getStatus()); + Assert.assertEquals("application/json", getContentType(response)); } @Test @@ -1681,7 +1675,7 @@ public ErrorResponseTransformStrategy getErrorResponseTransformStrategy() failOnExecute(errorMessage); ErrorResponse exception = postSyncForException( new SqlQuery( - "SELECT ANSWER TO LIFE", + "SELECT 1", ResultFormat.OBJECT, false, false, @@ -2297,7 +2291,7 @@ private static class TestHttpStatement extends HttpStatement private TestHttpStatement( final SqlToolbox lifecycleContext, - final SqlQuery sqlQuery, + final SqlQueryPlus sqlQueryPlus, final HttpServletRequest req, SettableSupplier> validateAndAuthorizeLatchSupplier, SettableSupplier> planLatchSupplier, @@ -2307,7 +2301,7 @@ private TestHttpStatement( final Consumer onAuthorize ) { - super(lifecycleContext, sqlQuery, req); + super(lifecycleContext, sqlQueryPlus, req); this.validateAndAuthorizeLatchSupplier = validateAndAuthorizeLatchSupplier; this.planLatchSupplier = planLatchSupplier; this.executeLatchSupplier = executeLatchSupplier; From e21bb3b3ba300ac56a9f7e39aa789ffbbecf1046 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 9 Jun 2025 14:16:25 -0700 Subject: [PATCH 2/2] Fix style --- .../org/apache/druid/msq/sql/resources/SqlTaskResource.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java index 6f2f6123804d..1dd84a83826b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java @@ -133,7 +133,8 @@ public Response doPost( try { sqlQueryPlus = SqlResource.makeSqlQueryPlus(sqlQuery, req); stmt = sqlStatementFactory.httpStatement(sqlQueryPlus, req); - } catch (Exception e) { + } + catch (Exception e) { return SqlResource.handleExceptionBeforeStatementCreated(e, sqlQuery.queryContext()); }