diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java index a45d77949228..99b16e0d2ecb 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java @@ -85,6 +85,7 @@ import org.apache.druid.server.security.ResourceAction; import org.apache.druid.sql.DirectStatement; import org.apache.druid.sql.HttpStatement; +import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlRowTransformer; import org.apache.druid.sql.SqlStatementFactory; import org.apache.druid.sql.http.ResultFormat; @@ -174,17 +175,29 @@ public Response doPost(@Context final HttpServletRequest req, } @VisibleForTesting - Response doPost(final SqlQuery sqlQuery, - final HttpServletRequest req) + Response doPost( + SqlQuery sqlQuery, // Not final: reassigned using createModifiedSqlQuery + final HttpServletRequest req + ) { - SqlQuery modifiedQuery = createModifiedSqlQuery(sqlQuery); + final SqlQueryPlus sqlQueryPlus; + final HttpStatement stmt; + final QueryContext queryContext; + + try { + sqlQuery = createModifiedSqlQuery(sqlQuery); + sqlQueryPlus = SqlResource.makeSqlQueryPlus(sqlQuery, req); + queryContext = QueryContext.of(sqlQueryPlus.context()); + stmt = msqSqlStatementFactory.httpStatement(SqlResource.makeSqlQueryPlus(sqlQuery, req), req); + } + catch (Exception e) { + return SqlResource.handleExceptionBeforeStatementCreated(e, sqlQuery.queryContext()); + } - final HttpStatement stmt = msqSqlStatementFactory.httpStatement(modifiedQuery, req); final String sqlQueryId = stmt.sqlQueryId(); final String currThreadName = Thread.currentThread().getName(); boolean isDebug = false; try { - QueryContext queryContext = QueryContext.of(modifiedQuery.getContext()); isDebug = queryContext.isDebug(); contextChecks(queryContext); @@ -202,7 +215,7 @@ Response doPost(final SqlQuery sqlQuery, return buildTaskResponse(sequence, stmt.query().authResult()); } else { // Used for EXPLAIN - return buildStandardResponse(sequence, modifiedQuery, sqlQueryId, rowTransformer); + return buildStandardResponse(sequence, sqlQuery, sqlQueryId, rowTransformer); } } catch (DruidException e) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java index 11e39bfa1267..1dd84a83826b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java @@ -43,6 +43,7 @@ import org.apache.druid.server.security.ForbiddenException; import org.apache.druid.sql.DirectStatement; import org.apache.druid.sql.HttpStatement; +import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlRowTransformer; import org.apache.druid.sql.SqlStatementFactory; import org.apache.druid.sql.http.ResultFormat; @@ -127,7 +128,16 @@ public Response doPost( ) { // Queries run as MSQ tasks look like regular queries, but return the task ID as their only output. - final HttpStatement stmt = sqlStatementFactory.httpStatement(sqlQuery, req); + final SqlQueryPlus sqlQueryPlus; + final HttpStatement stmt; + try { + sqlQueryPlus = SqlResource.makeSqlQueryPlus(sqlQuery, req); + stmt = sqlStatementFactory.httpStatement(sqlQueryPlus, req); + } + catch (Exception e) { + return SqlResource.handleExceptionBeforeStatementCreated(e, sqlQuery.queryContext()); + } + final String sqlQueryId = stmt.sqlQueryId(); try { final DirectStatement.ResultSet plan = stmt.plan(); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java index 084bf8e92fff..3a6b23fe8701 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/ResultsContextSerdeTest.java @@ -31,6 +31,7 @@ import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.server.security.AuthConfig; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import org.apache.druid.sql.calcite.planner.CalciteRulesManager; import org.apache.druid.sql.calcite.planner.CatalogResolver; import org.apache.druid.sql.calcite.planner.PlannerConfig; @@ -86,9 +87,11 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) EasyMock.createMock(QueryRunnerFactoryConglomerate.class) ); + final String sql = "SELECT 1"; PlannerContext plannerContext = PlannerContext.create( toolbox, - "DUMMY", + sql, + DruidSqlParser.parse(sql, false).getMainStatement(), engine, Collections.emptyMap(), null diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 2495251dec06..c05b6c384435 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -810,12 +810,12 @@ private String runMultiStageQuery( ) { final DirectStatement stmt = sqlStatementFactory.directStatement( - new SqlQueryPlus( - query, - context, - parameters, - authenticationResult - ) + SqlQueryPlus.builder() + .sql(query) + .context(context) + .parameters(parameters) + .auth(authenticationResult) + .build() ); final List sequence = stmt.execute().getResults().toList(); diff --git a/server/src/main/java/org/apache/druid/server/QueryResultPusher.java b/server/src/main/java/org/apache/druid/server/QueryResultPusher.java index b4a57bea8e1c..4208f984cee1 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResultPusher.java +++ b/server/src/main/java/org/apache/druid/server/QueryResultPusher.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import com.google.common.io.CountingOutputStream; import org.apache.druid.client.DirectDruidClient; import org.apache.druid.error.DruidException; @@ -270,17 +271,14 @@ private Response handleDruidException(ResultsWriter resultsWriter, DruidExceptio } if (response == null) { - final Response.ResponseBuilder bob = Response - .status(e.getStatusCode()) - .type(contentType) - .entity(new ErrorResponse(e)); - - bob.header(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId); - for (Map.Entry entry : extraHeaders.entrySet()) { - bob.header(entry.getKey(), entry.getValue()); - } - - return bob.build(); + return handleDruidExceptionBeforeResponseStarted( + e, + contentType, + ImmutableMap.builder() + .putAll(extraHeaders) + .put(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId) + .build() + ); } else { if (response.isCommitted()) { QueryResource.NO_STACK_LOGGER.warn(e, "Response was committed without the accumulator writing anything!?"); @@ -302,6 +300,27 @@ private Response handleDruidException(ResultsWriter resultsWriter, DruidExceptio } } + /** + * Generates a response for a {@link DruidException} that occurs prior to any query results being sent out. + */ + public static Response handleDruidExceptionBeforeResponseStarted( + final DruidException e, + final MediaType contentType, + final Map extraHeaders + ) + { + final Response.ResponseBuilder bob = Response + .status(e.getStatusCode()) + .type(contentType) + .entity(new ErrorResponse(e)); + + for (Map.Entry entry : extraHeaders.entrySet()) { + bob.header(entry.getKey(), entry.getValue()); + } + + return bob.build(); + } + public interface ResultsWriter extends Closeable { /** diff --git a/sql/src/main/java/org/apache/druid/sql/DirectStatement.java b/sql/src/main/java/org/apache/druid/sql/DirectStatement.java index f4f36c28b73f..010b326a94ea 100644 --- a/sql/src/main/java/org/apache/druid/sql/DirectStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/DirectStatement.java @@ -241,9 +241,9 @@ protected DruidPlanner createPlanner() return sqlToolbox.plannerFactory.createPlanner( sqlToolbox.engine, queryPlus.sql(), + queryPlus.sqlNode(), queryContext, - hook, - false + hook ); } diff --git a/sql/src/main/java/org/apache/druid/sql/HttpStatement.java b/sql/src/main/java/org/apache/druid/sql/HttpStatement.java index 8094f1f3890e..0a3011d4de15 100644 --- a/sql/src/main/java/org/apache/druid/sql/HttpStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/HttpStatement.java @@ -23,7 +23,6 @@ import org.apache.druid.server.security.AuthorizationUtils; import org.apache.druid.server.security.ResourceAction; import org.apache.druid.sql.calcite.planner.DruidPlanner; -import org.apache.druid.sql.http.SqlQuery; import javax.servlet.http.HttpServletRequest; import java.util.Set; @@ -45,15 +44,13 @@ public class HttpStatement extends DirectStatement public HttpStatement( final SqlToolbox lifecycleToolbox, - final SqlQuery sqlQuery, + final SqlQueryPlus sqlQueryPlus, final HttpServletRequest req ) { super( lifecycleToolbox, - SqlQueryPlus.builder(sqlQuery) - .auth(AuthorizationUtils.authenticationResultFromRequest(req)) - .build(), + sqlQueryPlus, req.getRemoteAddr() ); this.req = req; @@ -65,9 +62,9 @@ protected DruidPlanner createPlanner() return sqlToolbox.plannerFactory.createPlanner( sqlToolbox.engine, queryPlus.sql(), + queryPlus.sqlNode(), queryContext, - hook, - true + hook ); } diff --git a/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java b/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java index 7c8ab63d7c7f..505ec2f5f86a 100644 --- a/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java @@ -31,7 +31,6 @@ public class PreparedStatement extends AbstractStatement { private final SqlQueryPlus originalRequest; - private PrepareResult prepareResult; public PreparedStatement( final SqlToolbox lifecycleToolbox, @@ -70,8 +69,7 @@ public PrepareResult prepare() authorize(planner, authorizer()); // Do the prepare step. - this.prepareResult = planner.prepare(); - return prepareResult; + return planner.prepare(); } catch (RuntimeException e) { reporter.failed(e); @@ -92,7 +90,7 @@ public DirectStatement execute(List parameters) { return new DirectStatement( sqlToolbox, - originalRequest.withParameters(parameters) + originalRequest.freshCopy().withParameters(parameters) ); } @@ -101,9 +99,9 @@ protected DruidPlanner getPlanner() return sqlToolbox.plannerFactory.createPlanner( sqlToolbox.engine, queryPlus.sql(), + queryPlus.freshCopy().sqlNode(), queryContext, - hook, - false + hook ); } } diff --git a/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java b/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java index 8d06a65fa35c..38d8e64c8423 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlQueryPlus.java @@ -21,25 +21,37 @@ import com.google.common.base.Preconditions; import org.apache.calcite.avatica.remote.TypedValue; +import org.apache.calcite.sql.SqlNode; +import org.apache.druid.error.DruidException; +import org.apache.druid.query.QueryContexts; import org.apache.druid.server.security.AuthenticationResult; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; +import org.apache.druid.sql.calcite.parser.StatementAndSetContext; +import org.apache.druid.sql.calcite.planner.CalcitePlanner; import org.apache.druid.sql.http.SqlParameter; import org.apache.druid.sql.http.SqlQuery; +import javax.annotation.Nullable; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; /** - * Captures the inputs to a SQL execution request: the statement,the context, - * parameters, and the authorization result. Pass this around rather than the - * quad of items. The request can evolve: the context and parameters can be - * filled in later as needed. + * Captures the inputs to a SQL execution request: the statement (as a string), + * the parsed statement, the context, parameters, and the authorization result. + * The request can evolve: the context and parameters can be filled in later + * as needed. *

* SQL requests come from a variety of sources in a variety of formats. Use * the {@link Builder} class to create an instance from the information * available at each point in the code. *

+ * Each instance of SqlQueryPlus can only be used once, because {@link #sqlNode} + * is a mutable data structure, modified during {@link CalcitePlanner#validate}. + * If you need to use one again, call {@link #freshCopy()} to create a fresh + * copy with a new {@link SqlNode}. + *

* The query context has a complex lifecycle. The copy here is immutable: * it is the set of values which the user requested. Planning will * add (and sometimes remove) values: that work should be done on a copy of the @@ -50,24 +62,31 @@ public class SqlQueryPlus { private final String sql; + @Nullable + private final SqlNode sqlNode; + private boolean allowSetStatements; private final Map queryContext; private final List parameters; private final AuthenticationResult authResult; - public SqlQueryPlus( + private SqlQueryPlus( String sql, + SqlNode sqlNode, + boolean allowSetStatements, Map queryContext, List parameters, AuthenticationResult authResult ) { this.sql = Preconditions.checkNotNull(sql); + this.sqlNode = sqlNode; + this.allowSetStatements = allowSetStatements; this.queryContext = queryContext == null - ? Collections.emptyMap() - : Collections.unmodifiableMap(new HashMap<>(queryContext)); + ? Collections.emptyMap() + : Collections.unmodifiableMap(new HashMap<>(queryContext)); this.parameters = parameters == null - ? Collections.emptyList() - : parameters; + ? Collections.emptyList() + : parameters; this.authResult = Preconditions.checkNotNull(authResult); } @@ -81,14 +100,18 @@ public static Builder builder(String sql) return new Builder().sql(sql); } - public static Builder builder(SqlQuery sqlQuery) + public String sql() { - return new Builder().query(sqlQuery); + return sql; } - public String sql() + public SqlNode sqlNode() { - return sql; + if (sqlNode == null) { + throw DruidException.defensive("sqlNode not set"); + } + + return sqlNode; } public Map context() @@ -108,19 +131,40 @@ public AuthenticationResult authResult() public SqlQueryPlus withContext(Map context) { - return new SqlQueryPlus(sql, context, parameters, authResult); + return new SqlQueryPlus(sql, sqlNode, allowSetStatements, context, parameters, authResult); } public SqlQueryPlus withParameters(List parameters) { - return new SqlQueryPlus(sql, queryContext, parameters, authResult); + return new SqlQueryPlus(sql, sqlNode, allowSetStatements, queryContext, parameters, authResult); + } + + /** + * Returns a copy of this instance where everything is shared, except the {@link #sqlNode}, which is re-parsed from + * the SQL statement. + */ + public SqlQueryPlus freshCopy() + { + return new SqlQueryPlus( + sql, + DruidSqlParser.parse(sql, allowSetStatements).getMainStatement(), + allowSetStatements, + queryContext, + parameters, + authResult + ); } @Override public String toString() { - return "SqlQueryPlus {queryContext=" + queryContext + ", parameters=" + parameters - + ", authResult=" + authResult + ", sql=" + sql + " }"; + return "SqlQueryPlus{" + + "sql='" + sql + '\'' + + ", sqlNode=" + sqlNode + + ", queryContext=" + queryContext + + ", parameters=" + parameters + + ", authResult=" + authResult + + '}'; } public static class Builder @@ -136,14 +180,6 @@ public Builder sql(String sql) return this; } - public Builder query(SqlQuery sqlQuery) - { - this.sql = sqlQuery.getQuery(); - this.queryContext = sqlQuery.getContext(); - this.parameters = sqlQuery.getParameterList(); - return this; - } - public Builder context(Map queryContext) { this.queryContext = queryContext; @@ -168,14 +204,38 @@ public Builder auth(final AuthenticationResult authResult) return this; } + /** + * Parses the provided {@link #sql} and builds a {@link SqlQueryPlus} with SET statements folded into the + * context, and with the parsed SQL in {@link #sqlNode}. + * + * When using this method, the {@link #sqlNode()} must only be run through validation once. (The validator + * mutates the {@link SqlNode}). + */ public SqlQueryPlus build() { + final StatementAndSetContext statementAndSetContext = DruidSqlParser.parse(sql, true); return new SqlQueryPlus( sql, - queryContext, + statementAndSetContext.getMainStatement(), + true, + statementAndSetContext.getSetContext().isEmpty() + ? queryContext + : QueryContexts.override(queryContext, statementAndSetContext.getSetContext()), parameters, authResult ); } + + /** + * Builds a {@link SqlQueryPlus} with no {@link SqlNode} and with {@link #allowSetStatements} set to false. + * This is done for JDBC becauase it can runs each {@link SqlQueryPlus} multiple times, and it needs to keep + * re-parsing and re-validating the query on each run. + * + * When using this method, you must create a copy with {@link #freshCopy()} prior to calling {@link #sqlNode()}. + */ + public SqlQueryPlus buildJdbc() + { + return new SqlQueryPlus(sql, null, false, queryContext, parameters, authResult); + } } } diff --git a/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java b/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java index c8450d621661..9d31441befb1 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java @@ -19,8 +19,6 @@ package org.apache.druid.sql; -import org.apache.druid.sql.http.SqlQuery; - import javax.servlet.http.HttpServletRequest; /** @@ -44,11 +42,11 @@ public SqlStatementFactory(SqlToolbox lifecycleToolbox) } public HttpStatement httpStatement( - final SqlQuery sqlQuery, + final SqlQueryPlus sqlQueryPlus, final HttpServletRequest req ) { - return new HttpStatement(lifecycleToolbox, sqlQuery, req); + return new HttpStatement(lifecycleToolbox, sqlQueryPlus, req); } public DirectStatement directStatement(final SqlQueryPlus sqlRequest) diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java index 006a46a8f8ef..966067bb6741 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidJdbcStatement.java @@ -57,7 +57,7 @@ public DruidJdbcStatement( public synchronized void execute(SqlQueryPlus queryPlus, long maxRowCount) { closeResultSet(); - this.sqlQuery = queryPlus.withContext(queryContext); + this.sqlQuery = queryPlus.withContext(queryContext).freshCopy(); DirectStatement stmt = lifecycleFactory.directStatement(this.sqlQuery); resultSet = new DruidJdbcResultSet(this, stmt, Long.MAX_VALUE, fetcherFactory); try { diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java index e271a6f60ab5..f60b7d4ef854 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java @@ -271,12 +271,11 @@ public StatementHandle prepare( { try { final DruidConnection druidConnection = getDruidConnection(ch.id); - final SqlQueryPlus sqlReq = new SqlQueryPlus( - sql, - druidConnection.sessionContext(), - null, // No parameters in this path - doAuthenticate(druidConnection) - ); + final SqlQueryPlus sqlReq = SqlQueryPlus.builder() + .sql(sql) + .context(druidConnection.sessionContext()) + .auth(doAuthenticate(druidConnection)) + .buildJdbc(); final DruidJdbcPreparedStatement stmt = getDruidConnection(ch.id).createPreparedStatement( sqlStatementFactory, sqlReq, @@ -351,7 +350,7 @@ public ExecuteResult prepareAndExecute( final AuthenticationResult authenticationResult = doAuthenticate(druidConnection); final SqlQueryPlus sqlRequest = SqlQueryPlus.builder(sql) .auth(authenticationResult) - .build(); + .buildJdbc(); druidStatement.execute(sqlRequest, maxRowCount); final ExecuteResult result = doFetch(druidStatement, maxRowsInFirstFrame); LOG.debug("Successfully prepared statement [%s] and started execution", druidStatement.getStatementId()); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParser.java b/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParser.java new file mode 100644 index 000000000000..fe0f0a681c4d --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/parser/DruidSqlParser.java @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.parser; + +import com.google.common.base.Joiner; +import org.apache.calcite.avatica.util.Casing; +import org.apache.calcite.avatica.util.Quoting; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlSetOption; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.dialect.CalciteSqlDialect; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.NlsString; +import org.apache.calcite.util.SourceStringReader; +import org.apache.druid.error.DruidException; +import org.apache.druid.error.InvalidSqlInput; +import org.apache.druid.query.QueryContext; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.DruidConformance; +import org.apache.druid.sql.calcite.planner.PlannerContext; + +import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Contains the utility function {@link #parse(String, boolean)}, for parsing Druid SQL statements. + */ +public class DruidSqlParser +{ + public static final SqlParser.Config PARSER_CONFIG = SqlParser + .config() + .withCaseSensitive(true) + .withUnquotedCasing(Casing.UNCHANGED) + .withQuotedCasing(Casing.UNCHANGED) + .withQuoting(Quoting.DOUBLE_QUOTE) + .withConformance(DruidConformance.instance()) + .withParserFactory(new DruidSqlParserImplFactory()); + + private static final Joiner SPACE_JOINER = Joiner.on(" "); + private static final Joiner COMMA_JOINER = Joiner.on(", "); + + private DruidSqlParser() + { + // No instantiation. + } + + public static StatementAndSetContext parse(final String sql, final boolean allowSetStatements) + { + try { + SqlParser parser = SqlParser.create(new SourceStringReader(sql), PARSER_CONFIG); + SqlNode sqlNode = parser.parseStmtList(); + return processStatementList(sqlNode, allowSetStatements); + } + catch (SqlParseException e) { + throw translateParseException(e); + } + } + + /** + * If an {@link SqlNode} is a {@link SqlNodeList}, it must consist of 0 or more {@link SqlSetOption} followed by a + * single {@link SqlNode} which is NOT a {@link SqlSetOption}. All {@link SqlSetOption} will be converted into a + * context parameters {@link Map} and added to the {@link PlannerContext} with + * {@link PlannerContext#addAllToQueryContext(Map)}. The final {@link SqlNode} of the {@link SqlNodeList} is returned + * by this method as the {@link SqlNode} which should actually be validated and executed, and will have access to the + * modified query context through the {@link PlannerContext}. {@link SqlSetOption} override any existing query + * context parameter values. + */ + private static StatementAndSetContext processStatementList( + SqlNode root, + final boolean allowSetStatements + ) + { + if (root instanceof SqlNodeList) { + final Map setContext = new LinkedHashMap<>(); + final SqlNodeList nodeList = (SqlNodeList) root; + if (!allowSetStatements && nodeList.size() > 1) { + throw InvalidSqlInput.exception("SQL query string must contain only a single statement"); + } + boolean isMissingDruidStatementNode = true; + // convert 0 or more SET statements into a Map of stuff to add to the query context + for (int i = 0; i < nodeList.size(); i++) { + SqlNode sqlNode = nodeList.get(i); + if (sqlNode instanceof SqlSetOption) { + final SqlSetOption sqlSetOption = (SqlSetOption) sqlNode; + if (!(sqlSetOption.getValue() instanceof SqlLiteral)) { + throw InvalidSqlInput.exception( + "Assigned value must be a literal for SET statement[%s]", + sqlSetOption.toSqlString(CalciteSqlDialect.DEFAULT) + ); + } + setContext.put( + sqlSetOption.getName().getSimple(), + sqlLiteralToContextValue((SqlLiteral) sqlSetOption.getValue()) + ); + } else if (i < nodeList.size() - 1) { + // only SET statements can appear before the last statement + throw InvalidSqlInput.exception( + "Only SET statements can appear before the final statement in a statement list, but found non-SET statement[%s]", + sqlNode.toSqlString(CalciteSqlDialect.DEFAULT) + ); + } else { + // last SqlNode + root = sqlNode; + isMissingDruidStatementNode = false; + } + } + if (isMissingDruidStatementNode) { + throw InvalidSqlInput.exception("Statement list is missing a non-SET statement to execute"); + } + + return new StatementAndSetContext(root, setContext); + } else { + return new StatementAndSetContext(root, Collections.emptyMap()); + } + } + + /** + * Coerces a SQL literal from a SET statement to a form acceptable for {@link QueryContext}. + */ + @Nullable + static Object sqlLiteralToContextValue(final SqlLiteral literal) + { + if (SqlUtil.isNullLiteral(literal, false)) { + return null; + } else if (SqlTypeName.CHAR_TYPES.contains(literal.getTypeName())) { + return ((NlsString) literal.getValue()).getValue(); + } else if (SqlTypeName.BOOLEAN_TYPES.contains(literal.getTypeName())) { + return literal.getValue(); + } else if (SqlTypeName.NUMERIC_TYPES.contains(literal.getTypeName())) { + final Number number = (Number) literal.getValue(); + if (number instanceof BigDecimal && number.equals(BigDecimal.valueOf(number.longValue()))) { + return number.longValue(); + } else if (number instanceof BigInteger && number.equals(BigInteger.valueOf(number.longValue()))) { + return number.longValue(); + } else if (number instanceof BigDecimal && number.equals(BigDecimal.valueOf(number.doubleValue()))) { + return number.doubleValue(); + } else { + return number.toString(); + } + } else if (literal.getTypeName() == SqlTypeName.DATE) { + return Calcites.CALCITE_DATE_PARSER.parse(literal.getValue().toString()).toString(); + } else if (literal.getTypeName() == SqlTypeName.TIMESTAMP) { + return Calcites.CALCITE_TIMESTAMP_PARSER.parse(literal.getValue().toString()).toString(); + } else { + throw InvalidSqlInput.exception("Unsupported type for SET[%s]", literal.getTypeName()); + } + } + + /** + * Constructs a user-friendly {@link DruidException} from a Calcite {@link SqlParseException}. + */ + private static DruidException translateParseException(SqlParseException e) + { + final Throwable cause = e.getCause(); + if (cause instanceof DruidException) { + return (DruidException) cause; + } + + if (cause instanceof ParseException) { + ParseException parseException = (ParseException) cause; + final SqlParserPos failurePosition = e.getPos(); + // When calcite catches a syntax error at the top level + // expected token sequences can be null. + // In such a case return the syntax error to the user + // wrapped in a DruidException with invalid input + if (parseException.expectedTokenSequences == null) { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.INVALID_INPUT) + .withErrorCode("invalidInput") + .build(e, "%s", e.getMessage()) + .withContext("sourceType", "sql"); + } else { + final String theUnexpectedToken = getUnexpectedTokenString(parseException); + + final String[] tokenDictionary = e.getTokenImages(); + final int[][] expectedTokenSequences = e.getExpectedTokenSequences(); + final ArrayList expectedTokens = new ArrayList<>(expectedTokenSequences.length); + for (int[] expectedTokenSequence : expectedTokenSequences) { + String[] strings = new String[expectedTokenSequence.length]; + for (int i = 0; i < expectedTokenSequence.length; ++i) { + strings[i] = tokenDictionary[expectedTokenSequence[i]]; + } + expectedTokens.add(SPACE_JOINER.join(strings)); + } + + return InvalidSqlInput + .exception( + e, + "Received an unexpected token [%s] (line [%s], column [%s]), acceptable options: [%s]", + theUnexpectedToken, + failurePosition.getLineNum(), + failurePosition.getColumnNum(), + COMMA_JOINER.join(expectedTokens) + ) + .withContext("line", failurePosition.getLineNum()) + .withContext("column", failurePosition.getColumnNum()) + .withContext("endLine", failurePosition.getEndLineNum()) + .withContext("endColumn", failurePosition.getEndColumnNum()) + .withContext("token", theUnexpectedToken) + .withContext("expected", expectedTokens); + + } + } + + return InvalidSqlInput.exception(e.getMessage()); + } + + /** + * Grabs the unexpected token string. This code is borrowed with minimal adjustments from + * {@link ParseException#getMessage()}. It is possible that if that code changes, we need to also + * change this code to match it. + * + * @param parseException the parse exception to extract from + * + * @return the String representation of the unexpected token string + */ + private static String getUnexpectedTokenString(ParseException parseException) + { + int maxSize = 0; + for (int[] ints : parseException.expectedTokenSequences) { + if (maxSize < ints.length) { + maxSize = ints.length; + } + } + + StringBuilder bob = new StringBuilder(); + Token tok = parseException.currentToken.next; + for (int i = 0; i < maxSize; i++) { + if (i != 0) { + bob.append(" "); + } + if (tok.kind == 0) { + bob.append(""); + break; + } + char ch; + for (int i1 = 0; i1 < tok.image.length(); i1++) { + switch (tok.image.charAt(i1)) { + case 0: + continue; + case '\b': + bob.append("\\b"); + continue; + case '\t': + bob.append("\\t"); + continue; + case '\n': + bob.append("\\n"); + continue; + case '\f': + bob.append("\\f"); + continue; + case '\r': + bob.append("\\r"); + continue; + case '\"': + bob.append("\\\""); + continue; + case '\'': + bob.append("\\\'"); + continue; + case '\\': + bob.append("\\\\"); + continue; + default: + if ((ch = tok.image.charAt(i1)) < 0x20 || ch > 0x7e) { + String s = "0000" + Integer.toString(ch, 16); + bob.append("\\u").append(s.substring(s.length() - 4, s.length())); + } else { + bob.append(ch); + } + } + } + tok = tok.next; + } + return bob.toString(); + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/parser/StatementAndSetContext.java b/sql/src/main/java/org/apache/druid/sql/calcite/parser/StatementAndSetContext.java new file mode 100644 index 000000000000..2f00cb70e0a2 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/parser/StatementAndSetContext.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.parser; + +import org.apache.calcite.sql.SqlNode; + +import java.util.Map; +import java.util.Objects; + +/** + * Represents a parsed "main" (not SET) SQL statement, plus context derived from SET statements. + */ +public class StatementAndSetContext +{ + private final SqlNode mainStatement; + private final Map setContext; + + public StatementAndSetContext(SqlNode mainStatement, Map setContext) + { + this.mainStatement = mainStatement; + this.setContext = setContext; + } + + public SqlNode getMainStatement() + { + return mainStatement; + } + + public Map getSetContext() + { + return setContext; + } + + @Override + public boolean equals(Object o) + { + if (o == null || getClass() != o.getClass()) { + return false; + } + StatementAndSetContext that = (StatementAndSetContext) o; + return Objects.equals(mainStatement, that.mainStatement) + && Objects.equals(setContext, that.setContext); + } + + @Override + public int hashCode() + { + return Objects.hash(mainStatement, setContext); + } + + @Override + public String toString() + { + return "StatementAndSetContext{" + + "mainStatement=" + mainStatement + + ", setContext=" + setContext + + '}'; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalcitePlanner.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalcitePlanner.java index baf257632573..fe8dfd509a30 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalcitePlanner.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalcitePlanner.java @@ -60,6 +60,7 @@ import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.ValidationException; import org.apache.calcite.util.Pair; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import javax.annotation.Nullable; import java.io.Reader; @@ -232,6 +233,25 @@ public SqlNode parse(final Reader reader) throws SqlParseException return sqlNode; } + /** + * Skip parsing, moving state along to a {@link State#STATE_3_PARSED}. We have this because we + * parse before the planner is even created, using {@link DruidSqlParser}, in order to capture + * SET statements as early as possible. + */ + public void skipParse() + { + switch (state) { + case STATE_0_CLOSED: + case STATE_1_RESET: + ready(); + break; + default: + break; + } + ensure(CalcitePlanner.State.STATE_2_READY); + state = CalcitePlanner.State.STATE_3_PARSED; + } + @Override public SqlNode validate(SqlNode sqlNode) throws ValidationException { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java index 42d6af585738..7f90f0a9aa28 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/Calcites.java @@ -77,8 +77,8 @@ */ public class Calcites { - private static final DateTimes.UtcFormatter CALCITE_DATE_PARSER = DateTimes.wrapFormatter(ISODateTimeFormat.dateParser()); - private static final DateTimes.UtcFormatter CALCITE_TIMESTAMP_PARSER = DateTimes.wrapFormatter( + public static final DateTimes.UtcFormatter CALCITE_DATE_PARSER = DateTimes.wrapFormatter(ISODateTimeFormat.dateParser()); + public static final DateTimes.UtcFormatter CALCITE_TIMESTAMP_PARSER = DateTimes.wrapFormatter( new DateTimeFormatterBuilder() .append(ISODateTimeFormat.dateParser()) .appendLiteral(' ') diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java index 2842eb76e94a..f361e43797b4 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidPlanner.java @@ -20,20 +20,13 @@ package org.apache.druid.sql.calcite.planner; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import org.apache.calcite.plan.RelOptPlanner; import org.apache.calcite.runtime.CalciteContextException; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.sql.SqlExplain; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.SqlNodeList; -import org.apache.calcite.sql.SqlSetOption; -import org.apache.calcite.sql.dialect.CalciteSqlDialect; -import org.apache.calcite.sql.parser.SqlParseException; -import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.ValidationException; import org.apache.druid.error.DruidException; @@ -44,16 +37,11 @@ import org.apache.druid.server.security.ResourceAction; import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.parser.DruidSqlReplace; -import org.apache.druid.sql.calcite.parser.ParseException; -import org.apache.druid.sql.calcite.parser.Token; import org.apache.druid.sql.calcite.run.SqlEngine; -import org.apache.druid.sql.calcite.run.SqlResults; import org.joda.time.DateTimeZone; import java.io.Closeable; -import java.util.ArrayList; import java.util.HashSet; -import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; import java.util.function.Function; @@ -70,9 +58,6 @@ */ public class DruidPlanner implements Closeable { - public static final Joiner SPACE_JOINER = Joiner.on(" "); - public static final Joiner COMMA_JOINER = Joiner.on(", "); - public enum State { START, VALIDATED, PREPARED, PLANNED @@ -112,7 +97,6 @@ public AuthResult( private final PlannerContext plannerContext; private final SqlEngine engine; private final PlannerHook hook; - private final boolean allowSetStatementsToBuildContext; private State state = State.START; private SqlStatementHandler handler; private boolean authorized; @@ -121,8 +105,7 @@ public AuthResult( final FrameworkConfig frameworkConfig, final PlannerContext plannerContext, final SqlEngine engine, - final PlannerHook hook, - final boolean allowSetStatementsToBuildContext + final PlannerHook hook ) { this.frameworkConfig = frameworkConfig; @@ -130,7 +113,6 @@ public AuthResult( this.plannerContext = plannerContext; this.engine = engine; this.hook = hook == null ? NoOpPlannerHook.INSTANCE : hook; - this.allowSetStatementsToBuildContext = allowSetStatementsToBuildContext; } /** @@ -145,19 +127,8 @@ public void validate() // Validate query context. engine.validateContext(plannerContext.queryContextMap()); - - // Parse the query string. - String sql = plannerContext.getSql(); - hook.captureSql(sql); - SqlNode root; - try { - root = planner.parse(sql); - } - catch (SqlParseException e1) { - throw translateException(e1); - } - root = processStatementList(root); - root = rewriteParameters(root); + planner.skipParse(); + final SqlNode root = rewriteParameters(plannerContext.getSqlNode()); hook.captureSqlNode(root); handler = createHandler(root); handler.validate(); @@ -293,66 +264,6 @@ public void close() planner.close(); } - /** - * If an {@link SqlNode} is a {@link SqlNodeList}, it must consist of 0 or more {@link SqlSetOption} followed by a - * single {@link SqlNode} which is NOT a {@link SqlSetOption}. All {@link SqlSetOption} will be converted into a - * context parameters {@link Map} and added to the {@link PlannerContext} with - * {@link PlannerContext#addAllToQueryContext(Map)}. The final {@link SqlNode} of the {@link SqlNodeList} is returned - * by this method as the {@link SqlNode} which should actually be validated and executed, and will have access to the - * modified query context through the {@link PlannerContext}. {@link SqlSetOption} override any existing query - * context parameter values. - */ - private SqlNode processStatementList(SqlNode root) - { - if (root instanceof SqlNodeList) { - final SqlNodeList nodeList = (SqlNodeList) root; - if (!allowSetStatementsToBuildContext && nodeList.size() > 1) { - throw InvalidSqlInput.exception("SQL query string must contain only a single statement"); - } - final Map contextMap = new LinkedHashMap<>(); - boolean isMissingDruidStatementNode = true; - // convert 0 or more SET statements into a Map of stuff to add to the query context - for (int i = 0; i < nodeList.size(); i++) { - SqlNode sqlNode = nodeList.get(i); - if (sqlNode instanceof SqlSetOption) { - final SqlSetOption sqlSetOption = (SqlSetOption) sqlNode; - if (!(sqlSetOption.getValue() instanceof SqlLiteral)) { - throw InvalidSqlInput.exception( - "Assigned value must be a literal for SET statement[%s]", - sqlSetOption.toSqlString(CalciteSqlDialect.DEFAULT) - ); - } - final SqlLiteral value = (SqlLiteral) sqlSetOption.getValue(); - contextMap.put( - sqlSetOption.getName().getSimple(), - SqlResults.coerce( - plannerContext.getJsonMapper(), - SqlResults.Context.fromPlannerContext(plannerContext), - value.getValue(), - value.getTypeName(), - "set" - ) - ); - } else if (i < nodeList.size() - 1) { - // only SET statements can appear before the last statement - throw InvalidSqlInput.exception( - "Only SET statements can appear before the final statement in a statement list, but found non-SET statement[%s]", - sqlNode.toSqlString(CalciteSqlDialect.DEFAULT) - ); - } else { - // last SqlNode - root = sqlNode; - isMissingDruidStatementNode = false; - } - } - if (isMissingDruidStatementNode) { - throw InvalidSqlInput.exception("Statement list is missing a non-SET statement to execute"); - } - plannerContext.addAllToQueryContext(contextMap); - } - return root; - } - protected class HandlerContextImpl implements SqlStatementHandler.HandlerContext { @Override @@ -421,59 +332,6 @@ public static DruidException translateException(Exception e) catch (ValidationException inner) { return parseValidationMessage(inner); } - catch (SqlParseException inner) { - final Throwable cause = inner.getCause(); - if (cause instanceof DruidException) { - return (DruidException) cause; - } - - if (cause instanceof ParseException) { - ParseException parseException = (ParseException) cause; - final SqlParserPos failurePosition = inner.getPos(); - // When calcite catches a syntax error at the top level - // expected token sequences can be null. - // In such a case return the syntax error to the user - // wrapped in a DruidException with invalid input - if (parseException.expectedTokenSequences == null) { - return DruidException.forPersona(DruidException.Persona.USER) - .ofCategory(DruidException.Category.INVALID_INPUT) - .withErrorCode("invalidInput") - .build(inner, "%s", inner.getMessage()).withContext("sourceType", "sql"); - } else { - final String theUnexpectedToken = getUnexpectedTokenString(parseException); - - final String[] tokenDictionary = inner.getTokenImages(); - final int[][] expectedTokenSequences = inner.getExpectedTokenSequences(); - final ArrayList expectedTokens = new ArrayList<>(expectedTokenSequences.length); - for (int[] expectedTokenSequence : expectedTokenSequences) { - String[] strings = new String[expectedTokenSequence.length]; - for (int i = 0; i < expectedTokenSequence.length; ++i) { - strings[i] = tokenDictionary[expectedTokenSequence[i]]; - } - expectedTokens.add(SPACE_JOINER.join(strings)); - } - - return InvalidSqlInput - .exception( - inner, - "Received an unexpected token [%s] (line [%s], column [%s]), acceptable options: [%s]", - theUnexpectedToken, - failurePosition.getLineNum(), - failurePosition.getColumnNum(), - COMMA_JOINER.join(expectedTokens) - ) - .withContext("line", failurePosition.getLineNum()) - .withContext("column", failurePosition.getColumnNum()) - .withContext("endLine", failurePosition.getEndLineNum()) - .withContext("endColumn", failurePosition.getEndColumnNum()) - .withContext("token", theUnexpectedToken) - .withContext("expected", expectedTokens); - - } - } - - return InvalidSqlInput.exception(inner.getMessage()); - } catch (RelOptPlanner.CannotPlanException inner) { return DruidException.forPersona(DruidException.Persona.USER) .ofCategory(DruidException.Category.INVALID_INPUT) @@ -524,76 +382,4 @@ private static DruidException parseValidationMessage(Exception e) .build(e, "Uncategorized calcite error message: [%s]", e.getMessage()); } } - - /** - * Grabs the unexpected token string. This code is borrowed with minimal adjustments from - * {@link ParseException#getMessage()}. It is possible that if that code changes, we need to also - * change this code to match it. - * - * @param parseException the parse exception to extract from - * @return the String representation of the unexpected token string - */ - private static String getUnexpectedTokenString(ParseException parseException) - { - int maxSize = 0; - for (int[] ints : parseException.expectedTokenSequences) { - if (maxSize < ints.length) { - maxSize = ints.length; - } - } - - StringBuilder bob = new StringBuilder(); - Token tok = parseException.currentToken.next; - for (int i = 0; i < maxSize; i++) { - if (i != 0) { - bob.append(" "); - } - if (tok.kind == 0) { - bob.append(""); - break; - } - char ch; - for (int i1 = 0; i1 < tok.image.length(); i1++) { - switch (tok.image.charAt(i1)) { - case 0: - continue; - case '\b': - bob.append("\\b"); - continue; - case '\t': - bob.append("\\t"); - continue; - case '\n': - bob.append("\\n"); - continue; - case '\f': - bob.append("\\f"); - continue; - case '\r': - bob.append("\\r"); - continue; - case '\"': - bob.append("\\\""); - continue; - case '\'': - bob.append("\\\'"); - continue; - case '\\': - bob.append("\\\\"); - continue; - default: - if ((ch = tok.image.charAt(i1)) < 0x20 || ch > 0x7e) { - String s = "0000" + Integer.toString(ch, 16); - bob.append("\\u").append(s.substring(s.length() - 4, s.length())); - } else { - bob.append(ch); - } - continue; - } - } - tok = tok.next; - } - return bob.toString(); - } - } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java index a2bf29868778..18c56dc3a91a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java @@ -28,6 +28,7 @@ import org.apache.calcite.avatica.remote.TypedValue; import org.apache.calcite.linq4j.QueryProvider; import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.sql.SqlNode; import org.apache.druid.error.InvalidSqlInput; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.ISE; @@ -128,6 +129,7 @@ public class PlannerContext private final PlannerToolbox plannerToolbox; private final ExpressionParser expressionParser; private final String sql; + private final SqlNode sqlNode; private final SqlEngine engine; private final Map queryContext; private final CopyOnWriteArrayList nativeQueryIds = new CopyOnWriteArrayList<>(); @@ -163,6 +165,7 @@ public class PlannerContext private PlannerContext( final PlannerToolbox plannerToolbox, final String sql, + final SqlNode sqlNode, final SqlEngine engine, final Map queryContext, final PlannerHook hook @@ -171,6 +174,7 @@ private PlannerContext( this.plannerToolbox = plannerToolbox; this.expressionParser = new ExpressionParserImpl(plannerToolbox.exprMacroTable()); this.sql = sql; + this.sqlNode = sqlNode; this.engine = engine; this.queryContext = new LinkedHashMap<>(queryContext); this.hook = hook == null ? NoOpPlannerHook.INSTANCE : hook; @@ -180,6 +184,7 @@ private PlannerContext( public static PlannerContext create( final PlannerToolbox plannerToolbox, final String sql, + final SqlNode sqlNode, final SqlEngine engine, final Map queryContext, final PlannerHook hook @@ -188,6 +193,7 @@ public static PlannerContext create( return new PlannerContext( plannerToolbox, sql, + sqlNode, engine, queryContext, hook @@ -393,6 +399,11 @@ public String getSql() return sql; } + public SqlNode getSqlNode() + { + return sqlNode; + } + public PlannerHook getPlannerHook() { return hook; diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java index 2f7dce4ce614..dcc245704423 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerFactory.java @@ -23,8 +23,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; -import org.apache.calcite.avatica.util.Casing; -import org.apache.calcite.avatica.util.Quoting; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionConfigImpl; import org.apache.calcite.config.CalciteConnectionProperty; @@ -32,7 +30,7 @@ import org.apache.calcite.plan.ConventionTraitDef; import org.apache.calcite.plan.volcano.DruidVolcanoCost; import org.apache.calcite.rel.RelCollationTraitDef; -import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.validate.SqlConformance; import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.calcite.tools.FrameworkConfig; @@ -46,7 +44,8 @@ import org.apache.druid.server.security.AuthorizationResult; import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.server.security.NoopEscalator; -import org.apache.druid.sql.calcite.parser.DruidSqlParserImplFactory; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; +import org.apache.druid.sql.calcite.parser.StatementAndSetContext; import org.apache.druid.sql.calcite.planner.convertlet.DruidConvertletTable; import org.apache.druid.sql.calcite.run.SqlEngine; import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; @@ -59,16 +58,6 @@ public class PlannerFactory extends PlannerToolbox { - static final SqlParser.Config PARSER_CONFIG = SqlParser - .configBuilder() - .setCaseSensitive(true) - .setUnquotedCasing(Casing.UNCHANGED) - .setQuotedCasing(Casing.UNCHANGED) - .setQuoting(Quoting.DOUBLE_QUOTE) - .setConformance(DruidConformance.instance()) - .setParserFactory(new DruidSqlParserImplFactory()) // Custom SQL parser factory - .build(); - @Inject public PlannerFactory( final DruidSchemaCatalog rootSchema, @@ -108,25 +97,33 @@ public PlannerFactory( * the parser is allowed to parse multi-part SQL statements where all statements in the list except the last one are * SET statements, for example 'SET x = 'y'; SET foo = 123; SELECT ...', where these values will be added to the * {@link org.apache.druid.query.QueryContext} of the final statement. + * + * @param engine current SQL engine + * @param sql sql query string + * @param sqlNode parsed sql query, from {@link DruidSqlParser#parse(String, boolean)}. This is the main + * statement from {@link StatementAndSetContext#getMainStatement()}. + * @param queryContext query context including {@link StatementAndSetContext#getSetContext()} + * @param hook calcite planner hook */ public DruidPlanner createPlanner( final SqlEngine engine, final String sql, + final SqlNode sqlNode, final Map queryContext, - final PlannerHook hook, - boolean allowSetStatementsToBuildContext + final PlannerHook hook ) { final PlannerContext context = PlannerContext.create( this, sql, + sqlNode, engine, queryContext, hook ); context.dispatchHook(DruidHook.SQL, sql); - return new DruidPlanner(buildFrameworkConfig(context), context, engine, hook, allowSetStatementsToBuildContext); + return new DruidPlanner(buildFrameworkConfig(context), context, engine, hook); } /** @@ -140,7 +137,16 @@ public DruidPlanner createPlannerForTesting( final Map queryContext ) { - final DruidPlanner thePlanner = createPlanner(engine, sql, queryContext, null, true); + final StatementAndSetContext statementAndSetContext = DruidSqlParser.parse(sql, true); + final DruidPlanner thePlanner = createPlanner( + engine, + sql, + statementAndSetContext.getMainStatement(), + statementAndSetContext.getSetContext().isEmpty() + ? queryContext + : QueryContexts.override(queryContext, statementAndSetContext.getSetContext()), + null + ); thePlanner.getPlannerContext() .setAuthenticationResult(NoopEscalator.getInstance().createEscalatedAuthenticationResult()); thePlanner.validate(); @@ -166,7 +172,7 @@ private FrameworkConfig buildFrameworkConfig(PlannerContext plannerContext) Frameworks.ConfigBuilder frameworkConfigBuilder = Frameworks .newConfigBuilder() - .parserConfig(PARSER_CONFIG) + .parserConfig(DruidSqlParser.PARSER_CONFIG) .traitDefs(ConventionTraitDef.INSTANCE, RelCollationTraitDef.INSTANCE) .convertletTable(new DruidConvertletTable(plannerContext)) .operatorTable(operatorTable) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java b/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java index 33787001dd71..d17c9c5765a5 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/view/DruidViewMacro.java @@ -28,6 +28,7 @@ import org.apache.calcite.schema.TableMacro; import org.apache.calcite.schema.TranslatableTable; import org.apache.calcite.schema.impl.ViewTable; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import org.apache.druid.sql.calcite.planner.DruidPlanner; import org.apache.druid.sql.calcite.planner.PlannerFactory; import org.apache.druid.sql.calcite.schema.DruidSchemaName; @@ -61,9 +62,9 @@ public TranslatableTable apply(final List arguments) plannerFactory.createPlanner( ViewSqlEngine.INSTANCE, viewSql, + DruidSqlParser.parse(viewSql, false).getMainStatement(), // views cannot embed SET Collections.emptyMap(), - null, - false + null ) ) { planner.validate(); diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlEngineRegistry.java b/sql/src/main/java/org/apache/druid/sql/http/SqlEngineRegistry.java index 1cf5b23ef24e..1e70c08c0ccb 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlEngineRegistry.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlEngineRegistry.java @@ -20,8 +20,8 @@ package org.apache.druid.sql.http; import com.google.inject.Inject; +import org.apache.druid.error.InvalidSqlInput; import org.apache.druid.query.QueryContexts; -import org.apache.druid.server.initialization.jetty.BadRequestException; import org.apache.druid.sql.calcite.run.SqlEngine; import javax.validation.constraints.NotNull; @@ -46,7 +46,7 @@ public SqlEngine getEngine(final String engineName) { SqlEngine engine = engines.getOrDefault(engineName == null ? QueryContexts.DEFAULT_ENGINE : engineName, null); if (engine == null) { - throw new BadRequestException("Unsupported engine"); + throw InvalidSqlInput.exception("Unsupported engine[%s]", engineName); } return engine; } diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java index 2c4ed3c38a0e..3640ce7b32e4 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java @@ -22,12 +22,16 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import com.sun.jersey.api.core.HttpContext; import org.apache.druid.common.exception.SanitizableException; +import org.apache.druid.error.DruidException; import org.apache.druid.guice.annotations.Self; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryContexts; import org.apache.druid.server.DruidNode; import org.apache.druid.server.QueryLifecycle; import org.apache.druid.server.QueryResource; @@ -46,6 +50,7 @@ import org.apache.druid.sql.HttpStatement; import org.apache.druid.sql.SqlLifecycleManager; import org.apache.druid.sql.SqlLifecycleManager.Cancelable; +import org.apache.druid.sql.SqlQueryPlus; import org.apache.druid.sql.SqlRowTransformer; import org.apache.druid.sql.calcite.run.SqlEngine; @@ -178,17 +183,25 @@ public Response doPost( final HttpServletRequest req ) { - final String engineName = sqlQuery.queryContext().getEngine(); - final SqlEngine engine = sqlEngineRegistry.getEngine(engineName); - final HttpStatement stmt = engine.getSqlStatementFactory().httpStatement(sqlQuery, req); - final String sqlQueryId = stmt.sqlQueryId(); - final String currThreadName = Thread.currentThread().getName(); + final HttpStatement stmt; + final QueryContext queryContext; try { - Thread.currentThread().setName(StringUtils.format("sql[%s]", sqlQueryId)); + final SqlQueryPlus sqlQueryPlus = makeSqlQueryPlus(sqlQuery, req); + queryContext = new QueryContext(sqlQueryPlus.context()); // Redefine queryContext to include SET parameters + final String engineName = queryContext.getEngine(); + final SqlEngine engine = sqlEngineRegistry.getEngine(engineName); + stmt = engine.getSqlStatementFactory().httpStatement(sqlQueryPlus, req); + } + catch (Exception e) { + // Can't use the queryContext with SETs since it might not have been created yet. Use the original one. + return handleExceptionBeforeStatementCreated(e, sqlQuery.queryContext()); + } - QueryResultPusher pusher = makePusher(req, stmt, sqlQuery); - return pusher.push(); + final String currThreadName = Thread.currentThread().getName(); + try { + Thread.currentThread().setName(StringUtils.format("sql[%s]", stmt.sqlQueryId())); + return makePusher(req, stmt, sqlQuery, queryContext).push(); } finally { Thread.currentThread().setName(currThreadName); @@ -249,7 +262,12 @@ public void incrementTimedOut() } } - private SqlResourceQueryResultPusher makePusher(HttpServletRequest req, HttpStatement stmt, SqlQuery sqlQuery) + private SqlResourceQueryResultPusher makePusher( + HttpServletRequest req, + HttpStatement stmt, + SqlQuery sqlQuery, + QueryContext queryContext + ) { final String sqlQueryId = stmt.sqlQueryId(); Map headers = new LinkedHashMap<>(); @@ -259,7 +277,7 @@ private SqlResourceQueryResultPusher makePusher(HttpServletRequest req, HttpStat headers.put(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE); } - return new SqlResourceQueryResultPusher(req, sqlQueryId, stmt, sqlQuery, headers); + return new SqlResourceQueryResultPusher(req, sqlQueryId, stmt, sqlQuery, queryContext, headers); } private class SqlResourceQueryResultPusher extends QueryResultPusher @@ -268,11 +286,17 @@ private class SqlResourceQueryResultPusher extends QueryResultPusher private final HttpStatement stmt; private final SqlQuery sqlQuery; + /** + * Context to use for pushing results. May be different from the context in SqlQuery due to SET statements. + */ + private final QueryContext queryContext; + public SqlResourceQueryResultPusher( HttpServletRequest req, String sqlQueryId, HttpStatement stmt, SqlQuery sqlQuery, + QueryContext queryContext, Map headers ) { @@ -288,6 +312,7 @@ public SqlResourceQueryResultPusher( ); this.sqlQueryId = sqlQueryId; this.stmt = stmt; + this.queryContext = queryContext; this.sqlQuery = sqlQuery; } @@ -374,7 +399,7 @@ public void recordSuccess(long numBytes) @Override public void recordFailure(Exception e) { - if (QueryLifecycle.shouldLogStackTrace(e, sqlQuery.queryContext())) { + if (QueryLifecycle.shouldLogStackTrace(e, queryContext)) { log.warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId); } else { log.noStackTrace().warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId); @@ -418,4 +443,46 @@ public AuthorizationResult authorizeCancellation(final HttpServletRequest req, f authorizerMapper ); } + + /** + * Create a {@link SqlQueryPlus}, which involves parsing the query from {@link SqlQuery#getQuery()} and + * extracing any SET parameters into the query context. + */ + public static SqlQueryPlus makeSqlQueryPlus(final SqlQuery sqlQuery, final HttpServletRequest req) + { + return SqlQueryPlus.builder() + .sql(sqlQuery.getQuery()) + .context(sqlQuery.getContext()) + .parameters(sqlQuery.getParameterList()) + .auth(AuthorizationUtils.authenticationResultFromRequest(req)) + .build(); + } + + /** + * Generates a response for a {@link DruidException} that occurs prior to the {@link HttpStatement} being created. + */ + public static Response handleExceptionBeforeStatementCreated(final Exception e, final QueryContext queryContext) + { + if (e instanceof DruidException) { + final String sqlQueryId = queryContext.getString(QueryContexts.CTX_SQL_QUERY_ID); + return QueryResultPusher.handleDruidExceptionBeforeResponseStarted( + (DruidException) e, + MediaType.APPLICATION_JSON_TYPE, + sqlQueryId != null + ? ImmutableMap.builder() + .put(QueryResource.QUERY_ID_RESPONSE_HEADER, sqlQueryId) + .put(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId) + .build() + : Collections.emptyMap() + ); + } else { + return QueryResultPusher.handleDruidExceptionBeforeResponseStarted( + DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build(e, "Cannot handle query"), + MediaType.APPLICATION_JSON_TYPE, + Collections.emptyMap() + ); + } + } } diff --git a/sql/src/test/java/org/apache/druid/sql/SqlQueryPlusTest.java b/sql/src/test/java/org/apache/druid/sql/SqlQueryPlusTest.java new file mode 100644 index 000000000000..688453de26c1 --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/SqlQueryPlusTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql; + +import org.apache.druid.error.DruidException; +import org.apache.druid.error.DruidExceptionMatcher; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.hamcrest.MatcherAssert; +import org.junit.Assert; +import org.junit.Test; + +public class SqlQueryPlusTest +{ + @Test + public void testSyntaxError() + { + // SqlQueryPlus throws parse errors on build() if the statement is invalid + final DruidException e = Assert.assertThrows( + DruidException.class, + () -> SqlQueryPlus.builder("SELECT COUNT(*) AS cnt, 'foo' AS") + .auth(CalciteTests.REGULAR_USER_AUTH_RESULT) + .build() + ); + + MatcherAssert.assertThat( + e, + DruidExceptionMatcher + .invalidSqlInput() + .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") + ); + } + + @Test + public void testSyntaxErrorJdbc() + { + // SqlQueryPlus does not throw parse errors on buildJdbc(), because parsing is deferred + final SqlQueryPlus sqlQueryPlus = + SqlQueryPlus.builder("SELECT COUNT(*) AS cnt, 'foo' AS") + .auth(CalciteTests.REGULAR_USER_AUTH_RESULT) + .buildJdbc(); + + // It does throw exceptions on freshCopy(), though. + final DruidException e = Assert.assertThrows( + DruidException.class, + sqlQueryPlus::freshCopy + ); + + MatcherAssert.assertThat( + e, + DruidExceptionMatcher + .invalidSqlInput() + .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") + ); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java index dfbad004f3ca..ff21f9853d3b 100644 --- a/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/SqlStatementTest.java @@ -63,7 +63,6 @@ import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; import org.apache.druid.sql.calcite.util.CalciteTests; import org.apache.druid.sql.hook.DruidHookDispatcher; -import org.apache.druid.sql.http.SqlQuery; import org.easymock.EasyMock; import org.hamcrest.MatcherAssert; import org.junit.After; @@ -280,28 +279,6 @@ public void testDirectPolicyEnforcerValidatesWithPolicy() assertResultsEquals("SELECT COUNT(*) AS cnt FROM druid.restrictedDatasource_m1_is_6", expectedResults, results); } - @Test - public void testDirectSyntaxError() - { - SqlQueryPlus sqlReq = queryPlus( - "SELECT COUNT(*) AS cnt, 'foo' AS", - CalciteTests.REGULAR_USER_AUTH_RESULT - ); - DirectStatement stmt = sqlStatementFactory.directStatement(sqlReq); - try { - stmt.execute(); - fail(); - } - catch (DruidException e) { - MatcherAssert.assertThat( - e, - DruidExceptionMatcher - .invalidSqlInput() - .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") - ); - } - } - @Test public void testDirectValidationError() { @@ -344,17 +321,13 @@ public void testDirectPermissionError() //----------------------------------------------------------------- // HTTP statements: using a servlet request for verification. - private SqlQuery makeQuery(String sql) + /** + * Creates a {@link SqlQueryPlus} with auth result {@link CalciteTests#REGULAR_USER_AUTH_RESULT}, which matches + * the auth result used by {@link #request(boolean)}. + */ + private SqlQueryPlus makeQuery(String sql) { - return new SqlQuery( - sql, - null, - false, - false, - false, - null, - null - ); + return SqlQueryPlus.builder(sql).auth(CalciteTests.REGULAR_USER_AUTH_RESULT).build(); } @Test @@ -370,27 +343,6 @@ public void testHttpHappyPath() assertEquals("foo", results.get(0)[1]); } - @Test - public void testHttpSyntaxError() - { - HttpStatement stmt = sqlStatementFactory.httpStatement( - makeQuery("SELECT COUNT(*) AS cnt, 'foo' AS"), - request(true) - ); - try { - stmt.execute(); - fail(); - } - catch (DruidException e) { - MatcherAssert.assertThat( - e, - DruidExceptionMatcher - .invalidSqlInput() - .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") - ); - } - } - @Test public void testHttpValidationError() { @@ -494,28 +446,6 @@ public void testPreparedHappyPath() } } - @Test - public void testPrepareSyntaxError() - { - SqlQueryPlus sqlReq = queryPlus( - "SELECT COUNT(*) AS cnt, 'foo' AS", - CalciteTests.REGULAR_USER_AUTH_RESULT - ); - PreparedStatement stmt = sqlStatementFactory.preparedStatement(sqlReq); - try { - stmt.prepare(); - fail(); - } - catch (DruidException e) { - MatcherAssert.assertThat( - e, - DruidExceptionMatcher - .invalidSqlInput() - .expectMessageContains("Incorrect syntax near the keyword 'AS' at line 1, column 31") - ); - } - } - @Test public void testPrepareValidationError() { diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java index f971482ecc81..78bdbf4a15a0 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidStatementTest.java @@ -152,12 +152,11 @@ private DruidJdbcStatement jdbcStatement() @Test public void testSubQueryWithOrderByDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SUB_QUERY_WITH_ORDER_BY, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SUB_QUERY_WITH_ORDER_BY) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for all rows. statement.execute(queryPlus, -1); @@ -173,12 +172,11 @@ public void testSubQueryWithOrderByDirect() @Test public void testFetchPastEOFDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SUB_QUERY_WITH_ORDER_BY, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SUB_QUERY_WITH_ORDER_BY) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for all rows. statement.execute(queryPlus, -1); @@ -221,12 +219,11 @@ public void testSkipExecuteDirect() @Test public void testFetchAfterResultCloseDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SUB_QUERY_WITH_ORDER_BY, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SUB_QUERY_WITH_ORDER_BY) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for all rows. statement.execute(queryPlus, -1); @@ -243,12 +240,11 @@ public void testFetchAfterResultCloseDirect() @Test public void testSubQueryWithOrderByDirectTwice() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SUB_QUERY_WITH_ORDER_BY, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SUB_QUERY_WITH_ORDER_BY) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { statement.execute(queryPlus, -1); Meta.Frame frame = statement.nextFrame(AbstractDruidJdbcStatement.START_OFFSET, 6); @@ -288,12 +284,11 @@ private Meta.Frame subQueryWithOrderByResults() @Test public void testSelectAllInFirstFrameDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for all rows. statement.execute(queryPlus, -1); @@ -330,12 +325,11 @@ public void testSelectAllInFirstFrameDirect() @Test public void testSelectSplitOverTwoFramesDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for 2 rows. @@ -367,12 +361,11 @@ public void testSelectSplitOverTwoFramesDirect() @Test public void testTwoFramesAutoCloseDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for 2 rows. statement.execute(queryPlus, -1); @@ -409,12 +402,11 @@ public void testTwoFramesAutoCloseDirect() @Test public void testTwoFramesCloseWithResultSetDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // First frame, ask for 2 rows. statement.execute(queryPlus, -1); @@ -464,12 +456,11 @@ private Meta.Frame secondFrameResults() @Test public void testSignatureDirect() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_STAR_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_STAR_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcStatement statement = jdbcStatement()) { // Check signature. statement.execute(queryPlus, -1); @@ -533,12 +524,11 @@ private DruidJdbcPreparedStatement jdbcPreparedStatement(SqlQueryPlus queryPlus) public void testSubQueryWithOrderByPrepared() { final String sql = "select T20.F13 as F22 from (SELECT DISTINCT dim1 as F13 FROM druid.foo T10) T20 order by T20.F13 ASC"; - SqlQueryPlus queryPlus = new SqlQueryPlus( - sql, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(sql) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcPreparedStatement statement = jdbcPreparedStatement(queryPlus)) { statement.prepare(); // First frame, ask for all rows. @@ -556,12 +546,11 @@ public void testSubQueryWithOrderByPrepared() public void testSubQueryWithOrderByPreparedTwice() { final String sql = "select T20.F13 as F22 from (SELECT DISTINCT dim1 as F13 FROM druid.foo T10) T20 order by T20.F13 ASC"; - SqlQueryPlus queryPlus = new SqlQueryPlus( - sql, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(sql) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcPreparedStatement statement = jdbcPreparedStatement(queryPlus)) { statement.prepare(); statement.execute(Collections.emptyList()); @@ -586,12 +575,11 @@ public void testSubQueryWithOrderByPreparedTwice() @Test public void testSignaturePrepared() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - SELECT_STAR_FROM_FOO, - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql(SELECT_STAR_FROM_FOO) + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); try (final DruidJdbcPreparedStatement statement = jdbcPreparedStatement(queryPlus)) { statement.prepare(); verifySignature(statement.getSignature()); @@ -601,12 +589,11 @@ public void testSignaturePrepared() @Test public void testParameters() { - SqlQueryPlus queryPlus = new SqlQueryPlus( - "SELECT COUNT(*) AS cnt FROM sys.servers WHERE servers.host = ?", - null, - null, - AllowAllAuthenticator.ALLOW_ALL_RESULT - ); + SqlQueryPlus queryPlus = + SqlQueryPlus.builder() + .sql("SELECT COUNT(*) AS cnt FROM sys.servers WHERE servers.host = ?") + .auth(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .buildJdbc(); Meta.Frame expected = Meta.Frame.create(0, true, Collections.singletonList(new Object[] {1L})); List matchingParams = Collections.singletonList(TypedValue.ofLocal(ColumnMetaData.Rep.STRING, "dummy")); try (final DruidJdbcPreparedStatement statement = jdbcPreparedStatement(queryPlus)) { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index d87d9e00cc24..65fb0d3e88bf 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -15976,4 +15976,52 @@ public void testMultiStatementSetsInvalidTooManyNonSetStatements() "Only SET statements can appear before the final statement in a statement list, but found non-SET statement[SELECT 1]" ); } + + @Test + public void testSetUseApproximateCountDistinctFalse() + { + testBuilder().sql( + "SET useApproximateCountDistinct = FALSE;\n" + + "SELECT COUNT(DISTINCT dim2) FROM druid.foo" + ).expectedQueries( + ImmutableList.of( + GroupByQuery.builder() + .setDataSource( + new QueryDataSource( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions(new DefaultDimensionSpec("dim2", "d0"))) + .setContext( + ImmutableMap.builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put("useApproximateCountDistinct", false) + .build() + ) + .build() + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setAggregatorSpecs(aggregators( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + notNull("d0") + ) + )) + .setContext( + ImmutableMap.builder() + .putAll(QUERY_CONTEXT_DEFAULT) + .put("useApproximateCountDistinct", false) + .build() + ) + .build() + ) + ).expectedResults( + ImmutableList.of( + new Object[]{3L} + ) + ).run(); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestRunner.java b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestRunner.java index 27071dcf7b2c..bbc0c2987cac 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestRunner.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/QueryTestRunner.java @@ -263,10 +263,6 @@ public void run() BaseCalciteQueryTest.log.info("SQL: %s", builder.sql); final SqlStatementFactory sqlStatementFactory = builder.statementFactory(); - final SqlQueryPlus sqlQuery = SqlQueryPlus.builder(builder.sql) - .sqlParameters(builder.parameters) - .auth(builder.authenticationResult) - .build(); final List vectorizeValues = new ArrayList<>(); vectorizeValues.add("false"); @@ -275,7 +271,14 @@ public void run() } for (final String vectorize : vectorizeValues) { - final Map theQueryContext = new HashMap<>(builder.queryContext); + // Need to create sqlQuery inside the loop, because SqlQueryPlus can only be used once + final SqlQueryPlus sqlQuery = SqlQueryPlus.builder(builder.sql) + .sqlParameters(builder.parameters) + .auth(builder.authenticationResult) + .context(builder.queryContext) + .build(); + + final Map theQueryContext = new HashMap<>(sqlQuery.context()); theQueryContext.put(QueryContexts.VECTORIZE_KEY, vectorize); theQueryContext.put(QueryContexts.VECTORIZE_VIRTUAL_COLUMNS_KEY, vectorize); diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java index 20d7c0fa0786..5ec8f67bb424 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/expression/ExpressionTestHelper.java @@ -106,6 +106,7 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) public static final PlannerContext PLANNER_CONTEXT = PlannerContext.create( PLANNER_TOOLBOX, "SELECT 1", // The actual query isn't important for this test + null, /* Don't need SQL node */ null, /* Don't need engine */ Collections.emptyMap(), null diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java index 4f30cc1feb11..cccaffe8ec4f 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/external/ExternalTableScanRuleTest.java @@ -28,6 +28,7 @@ import org.apache.druid.query.QuerySegmentWalker; import org.apache.druid.query.policy.NoopPolicyEnforcer; import org.apache.druid.server.security.AuthConfig; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import org.apache.druid.sql.calcite.planner.CalciteRulesManager; import org.apache.druid.sql.calcite.planner.CatalogResolver; import org.apache.druid.sql.calcite.planner.PlannerConfig; @@ -79,7 +80,8 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) ); final PlannerContext plannerContext = PlannerContext.create( toolbox, - "DUMMY", // The actual query isn't important for this test + "SELECT 1", // The actual query isn't important for this test + DruidSqlParser.parse("SELECT 1", false).getMainStatement(), engine, Collections.emptyMap(), null diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/parser/DruidSqlParserTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/parser/DruidSqlParserTest.java new file mode 100644 index 000000000000..1765ecc2f5cf --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/calcite/parser/DruidSqlParserTest.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.parser; + +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.DateString; +import org.apache.calcite.util.TimestampString; +import org.apache.druid.error.DruidException; +import org.junit.Assert; +import org.junit.Test; + +public class DruidSqlParserTest +{ + @Test + public void test_sqlLiteralToContextValue_null() + { + final SqlLiteral literal = SqlLiteral.createNull(SqlParserPos.ZERO); + Assert.assertNull(DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_string() + { + final SqlLiteral literal = SqlLiteral.createCharString("abc", SqlParserPos.ZERO); + Assert.assertEquals("abc", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_stringWithSpecialChars() + { + final SqlLiteral literal = SqlLiteral.createCharString("hello\nworld\t\"test\"", SqlParserPos.ZERO); + Assert.assertEquals("hello\nworld\t\"test\"", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_integer() + { + // Numbers within Long range are converted to Long. + final SqlLiteral literal = SqlLiteral.createExactNumeric("42", SqlParserPos.ZERO); + Assert.assertEquals(42L, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_negativeInteger() + { + final SqlLiteral literal = SqlLiteral.createExactNumeric("-123", SqlParserPos.ZERO); + Assert.assertEquals(-123L, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_decimal() + { + // Decimals are converted to Double. + final SqlLiteral literal = SqlLiteral.createExactNumeric("3.14159", SqlParserPos.ZERO); + Assert.assertEquals(3.14159, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_largeNumber() + { + // Integers outside Long range are retained as strings. + final SqlLiteral literal = SqlLiteral.createExactNumeric("123456789012345678901234567890", SqlParserPos.ZERO); + Assert.assertEquals("123456789012345678901234567890", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_approximateNumeric() + { + final SqlLiteral literal = SqlLiteral.createApproxNumeric("1.23E10", SqlParserPos.ZERO); + Assert.assertEquals(1.23E10, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_booleanTrue() + { + final SqlLiteral literal = SqlLiteral.createBoolean(true, SqlParserPos.ZERO); + Assert.assertEquals(true, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_booleanFalse() + { + final SqlLiteral literal = SqlLiteral.createBoolean(false, SqlParserPos.ZERO); + Assert.assertEquals(false, DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_timestamp() + { + // Timestamps are returned as strings in ISO8601 format + final TimestampString timestampString = new TimestampString("2023-01-15 14:30:00"); + final SqlLiteral literal = + SqlLiteral.createTimestamp(SqlTypeName.TIMESTAMP, timestampString, 0, SqlParserPos.ZERO); + Assert.assertEquals("2023-01-15T14:30:00.000Z", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_timestampWithFractionalSeconds() + { + final TimestampString timestampString = new TimestampString("2023-01-15 14:30:00.123"); + final SqlLiteral literal = + SqlLiteral.createTimestamp(SqlTypeName.TIMESTAMP, timestampString, 3, SqlParserPos.ZERO); + Assert.assertEquals("2023-01-15T14:30:00.123Z", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_date() + { + final DateString dateString = new DateString("2023-01-15"); + final SqlLiteral literal = SqlLiteral.createDate(dateString, SqlParserPos.ZERO); + Assert.assertEquals("2023-01-15T00:00:00.000Z", DruidSqlParser.sqlLiteralToContextValue(literal)); + } + + @Test + public void test_sqlLiteralToContextValue_unsupportedType() + { + final SqlLiteral literal = SqlLiteral.createSymbol(SqlTypeName.BINARY, SqlParserPos.ZERO); + final DruidException exception = Assert.assertThrows( + DruidException.class, + () -> DruidSqlParser.sqlLiteralToContextValue(literal) + ); + Assert.assertTrue(exception.getMessage().contains("Unsupported type for SET")); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java index eb803186d4c5..b1809ba87556 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/CalcitePlannerModuleTest.java @@ -44,6 +44,7 @@ import org.apache.druid.sql.SqlStatementFactory; import org.apache.druid.sql.calcite.aggregation.SqlAggregator; import org.apache.druid.sql.calcite.expression.SqlOperatorConversion; +import org.apache.druid.sql.calcite.parser.DruidSqlParser; import org.apache.druid.sql.calcite.rule.ExtensionCalciteRuleProvider; import org.apache.druid.sql.calcite.run.NativeSqlEngine; import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; @@ -184,9 +185,11 @@ public void testExtensionCalciteRule() ObjectMapper mapper = new DefaultObjectMapper(); PlannerToolbox toolbox = injector.getInstance(PlannerFactory.class); + final String sql = "SELECT 1"; PlannerContext context = PlannerContext.create( toolbox, - "SELECT 1", + sql, + DruidSqlParser.parse(sql, false).getMainStatement(), new NativeSqlEngine(queryLifecycleFactory, mapper, (SqlStatementFactory) null), Collections.emptyMap(), null @@ -204,9 +207,11 @@ public void testConfigurableBloat() ObjectMapper mapper = new DefaultObjectMapper(); PlannerToolbox toolbox = injector.getInstance(PlannerFactory.class); + final String sql = "SELECT 1"; PlannerContext contextWithBloat = PlannerContext.create( toolbox, - "SELECT 1", + sql, + DruidSqlParser.parse(sql, false).getMainStatement(), new NativeSqlEngine(queryLifecycleFactory, mapper, (SqlStatementFactory) null), Collections.singletonMap(BLOAT_PROPERTY, BLOAT), null @@ -214,7 +219,8 @@ public void testConfigurableBloat() PlannerContext contextWithoutBloat = PlannerContext.create( toolbox, - "SELECT 1", + sql, + DruidSqlParser.parse(sql, false).getMainStatement(), new NativeSqlEngine(queryLifecycleFactory, mapper, (SqlStatementFactory) null), Collections.emptyMap(), null diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java index 816ce3fd1a1d..fc0e7a43edc7 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/planner/DruidRexExecutorTest.java @@ -112,6 +112,7 @@ NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class)) private static final PlannerContext PLANNER_CONTEXT = PlannerContext.create( PLANNER_TOOLBOX, "SELECT 1", // The actual query isn't important for this test + null, /* Don't need a SQL node */ null, /* Don't need an engine */ Collections.emptyMap(), null diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java index a4ff866f0ddd..0d0deb14a5a6 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/QueryFrameworkUtils.java @@ -368,9 +368,9 @@ protected DruidPlanner createPlanner() return plannerFactory.createPlanner( engine, queryPlus.sql(), + queryPlus.sqlNode(), queryContext, - hook, - true + hook ); } }; @@ -387,9 +387,9 @@ protected DruidPlanner getPlanner() return plannerFactory.createPlanner( engine, queryPlus.sql(), + queryPlus.sqlNode(), queryContext, - hook, - true + hook ); } diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java index b1aa19753088..ac0419485615 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java @@ -292,13 +292,13 @@ public void add(String sqlQueryId, Cancelable lifecycle) { @Override public HttpStatement httpStatement( - final SqlQuery sqlQuery, + final SqlQueryPlus sqlQueryPlus, final HttpServletRequest req ) { TestHttpStatement stmt = new TestHttpStatement( sqlToolbox.withEngine(engine), - sqlQuery, + sqlQueryPlus, req, validateAndAuthorizeLatchSupplier, planLatchSupplier, @@ -1481,7 +1481,7 @@ public void testCannotParse() throws Exception errorResponse, "Incorrect syntax near the keyword 'FROM' at line 1, column 1" ); - checkSqlRequestLog(false); + Assert.assertEquals(0, testRequestLogger.getSqlQueryLogs().size()); // Invalid queries are not logged Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @@ -1590,15 +1590,10 @@ public void testUnsupportedQueryThrowsException() throws Exception ImmutableMap.of(BaseQuery.SQL_QUERY_ID, "id"), null ), - 501 + DruidException.Category.INVALID_INPUT.getExpectedStatus() ); - validateLegacyQueryExceptionErrorResponse( - exception, - QueryException.QUERY_UNSUPPORTED_ERROR_CODE, - QueryUnsupportedException.class.getName(), - "" - ); + validateInvalidSqlError(exception, "Incorrect syntax near the keyword 'TO'"); Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @@ -1623,32 +1618,31 @@ public void testErrorResponseReturnSameQueryIdWhenSetInContext() // This is checked in the common method that returns the response, but checking it again just protects // from changes there breaking the checks, so doesn't hurt. - assertStatusAndCommonHeaders(response, 501); + assertStatusAndCommonHeaders(response, DruidException.Category.INVALID_INPUT.getExpectedStatus()); Assert.assertEquals(queryId, getHeader(response, QueryResource.QUERY_ID_RESPONSE_HEADER)); Assert.assertEquals(queryId, getHeader(response, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)); } @Test - public void testErrorResponseReturnNewQueryIdWhenNotSetInContext() + public void testErrorResponseReturnNoQueryIdWhenNotSetInContext() { String errorMessage = "This will be supported in Druid 9999"; failOnExecute(errorMessage); - final Response response = postForSyncResponse( - new SqlQuery( - "SELECT ANSWER TO LIFE", - ResultFormat.OBJECT, - false, - false, - false, - ImmutableMap.of(), - null - ), - req + final SqlQuery sqlQuery = new SqlQuery( + "SELECT ANSWER TO LIFE", + ResultFormat.OBJECT, + false, + false, + false, + ImmutableMap.of(), + null ); - // This is checked in the common method that returns the response, but checking it again just protects - // from changes there breaking the checks, so doesn't hurt. - assertStatusAndCommonHeaders(response, 501); + final Response response = resource.doPost(sqlQuery, req); + + // Query ID won't be set, but we can look for other aspects of the response that we expect. + Assert.assertEquals(DruidException.Category.INVALID_INPUT.getExpectedStatus(), response.getStatus()); + Assert.assertEquals("application/json", getContentType(response)); } @Test @@ -1681,7 +1675,7 @@ public ErrorResponseTransformStrategy getErrorResponseTransformStrategy() failOnExecute(errorMessage); ErrorResponse exception = postSyncForException( new SqlQuery( - "SELECT ANSWER TO LIFE", + "SELECT 1", ResultFormat.OBJECT, false, false, @@ -2297,7 +2291,7 @@ private static class TestHttpStatement extends HttpStatement private TestHttpStatement( final SqlToolbox lifecycleContext, - final SqlQuery sqlQuery, + final SqlQueryPlus sqlQueryPlus, final HttpServletRequest req, SettableSupplier> validateAndAuthorizeLatchSupplier, SettableSupplier> planLatchSupplier, @@ -2307,7 +2301,7 @@ private TestHttpStatement( final Consumer onAuthorize ) { - super(lifecycleContext, sqlQuery, req); + super(lifecycleContext, sqlQueryPlus, req); this.validateAndAuthorizeLatchSupplier = validateAndAuthorizeLatchSupplier; this.planLatchSupplier = planLatchSupplier; this.executeLatchSupplier = executeLatchSupplier;