Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.commands.insert.InsertIntoTableCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.qe.ConnectContext;
Expand Down Expand Up @@ -102,6 +103,10 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
LOG.debug("add prepared statement {}, isBinaryProtocol {}",
name, ctx.getCommand() == MysqlCommand.COM_STMT_PREPARE);
}
if (logicalPlan instanceof InsertIntoTableCommand
&& ((InsertIntoTableCommand) logicalPlan).getLabelName().isPresent()) {
throw new org.apache.doris.common.UserException("Only support prepare InsertStmt without label now");
}
ctx.addPreparedStatementContext(name,
new PreparedStatementContext(this, ctx, ctx.getStatementContext(), name));
if (ctx.getCommand() == MysqlCommand.COM_STMT_PREPARE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.doris.catalog.Type;
import org.apache.doris.cluster.ClusterNamespace;
import org.apache.doris.common.Config;
import org.apache.doris.common.UserException;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.common.util.TimeUtils;
import org.apache.doris.datasource.CatalogIf;
Expand Down Expand Up @@ -394,12 +395,18 @@ public void addPreparedStmt(String stmtName, PrepareStmtContext ctx) {
this.preparedStmtCtxs.put(stmtName, ctx);
}

public void addPreparedStatementContext(String stmtName, PreparedStatementContext ctx) {
public void addPreparedStatementContext(String stmtName, PreparedStatementContext ctx) throws UserException {
if (this.preparedStatementContextMap.size() > sessionVariable.maxPreparedStmtCount) {
throw new UserException("Failed to create a server prepared statement"
+ "possibly because there are too many active prepared statements on server already."
+ "set max_prepared_stmt_count with larger number than " + sessionVariable.maxPreparedStmtCount);
}
this.preparedStatementContextMap.put(stmtName, ctx);
}

public void removePrepareStmt(String stmtName) {
this.preparedStmtCtxs.remove(stmtName);
this.preparedStatementContextMap.remove(stmtName);
}

public PrepareStmtContext getPreparedStmt(String stmtName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ public void executeQuery(MysqlCommand mysqlCommand, String originStmt) throws Ex
List<StatementBase> cachedStmts = null;
// Currently we add a config to decide whether using PREPARED/EXECUTE command for nereids
// TODO: after implemented full prepared, we could remove this flag
boolean nereidsUseServerPrep = sessionVariable.enableServeSidePreparedStatement
boolean nereidsUseServerPrep = (sessionVariable.enableServeSidePreparedStatement
&& !sessionVariable.isEnableInsertGroupCommit())
|| mysqlCommand == MysqlCommand.COM_QUERY;
if (nereidsUseServerPrep && sessionVariable.isEnableNereidsPlanner()) {
if (wantToParseSqlFromSqlCache) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.Config;
import org.apache.doris.common.Status;
Expand Down Expand Up @@ -149,7 +151,13 @@ void addKeyTuples(
KeyTuple.Builder kBuilder = KeyTuple.newBuilder();
for (Expr expr : shortCircuitQueryContext.scanNode.getConjuncts()) {
BinaryPredicate predicate = (BinaryPredicate) expr;
kBuilder.addKeyColumnRep(predicate.getChild(1).getStringValue());
Expr left = predicate.getChild(0);
Expr right = predicate.getChild(1);
// ignore delete sign conjuncts only collect key conjuncts
if (left instanceof SlotRef && ((SlotRef) left).getColumnName().equalsIgnoreCase(Column.DELETE_SIGN)) {
continue;
}
kBuilder.addKeyColumnRep(right.getStringValue());
}
requestBuilder.addKeyTuples(kBuilder);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ public class SessionVariable implements Serializable, Writable {
public static final String ENABLE_EXCHANGE_NODE_PARALLEL_MERGE = "enable_exchange_node_parallel_merge";

public static final String ENABLE_SERVER_SIDE_PREPARED_STATEMENT = "enable_server_side_prepared_statement";
public static final String MAX_PREPARED_STMT_COUNT = "max_prepared_stmt_count";
public static final String PREFER_JOIN_METHOD = "prefer_join_method";

public static final String ENABLE_FOLD_CONSTANT_BY_BE = "enable_fold_constant_by_be";
Expand Down Expand Up @@ -1361,7 +1362,12 @@ public void setEnableLeftZigZag(boolean enableLeftZigZag) {

@VariableMgr.VarAttr(name = ENABLE_SERVER_SIDE_PREPARED_STATEMENT, needForward = true, description = {
"是否启用开启服务端prepared statement", "Set whether to enable server side prepared statement."})
public boolean enableServeSidePreparedStatement = false;
public boolean enableServeSidePreparedStatement = true;

@VariableMgr.VarAttr(name = MAX_PREPARED_STMT_COUNT, flag = VariableMgr.GLOBAL,
needForward = true, description = {
"服务端prepared statement最大个数", "the maximum prepared statements server holds."})
public int maxPreparedStmtCount = 100000;

// Default value is false, which means the group by and having clause
// should first use column name not alias. According to mysql.
Expand Down
36 changes: 36 additions & 0 deletions regression-test/data/prepared_stmt_p0/prepared_stmt.out
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@
-- !select5 --
1

-- !select5 --
a

-- !select5 --
2

-- !select5 --
-2

-- !select5 --
6 MySQL Connector/J

-- !select5 --
0 0 0 0 0

-- !select6 --
2 1 user1 \N 1234.1111 xxxlalala

Expand All @@ -66,3 +81,24 @@
-- !select9 --
2

-- !select13 --
1

-- !select14 --
1

-- !select15 --
1

-- !sql --
1231 119291.110000000 ddd Will we ignore LIMIT ?,? 2021-01-01 2020-01-01T12:36:38 \N 1022-01-01 ["2022-01-01 11:30:38", "2022-01-01 11:30:38", "2022-01-01 11:30:38"]
1232 12222.991211350 xxx Will we ignore LIMIT ?,? 2021-01-01 2020-01-01T12:36:38 522.762 2022-01-01 ["2023-01-01 11:30:38", "2023-01-01 11:30:38"]
1233 1.392932911 yyy Will we ignore LIMIT ?,? 2021-01-01 2020-01-01T12:36:38 52.862 3022-01-01 ["2024-01-01 11:30:38", "2024-01-01 11:30:38", "2024-01-01 11:30:38"]
1234 12919291.129191137 xxddd Will we ignore LIMIT ?,? 2021-01-01 2020-01-01T12:36:38 552.872 4022-01-01 ["2025-01-01 11:30:38", "2025-01-01 11:30:38", "2025-01-01 11:30:38"]
1235 991129292901.111380000 dd Will we ignore LIMIT ?,? 2021-01-01 2020-01-01T12:36:38 652.692 5022-01-01 []
1236 100320.111390000 laa ddd Will we ignore LIMIT ?,? 2021-01-01 2020-01-01T12:36:38 2.7692 6022-01-01 [null]
1237 120939.111300000 a ddd Will we ignore LIMIT ?,? 2021-01-01 2020-01-01T12:36:38 22.822 7022-01-01 ["2025-01-01 11:30:38"]

-- !select16 --
mytable1 CREATE TABLE `mytable1` (\n `siteid` INT NULL DEFAULT "10",\n `citycode` SMALLINT NULL,\n `username` VARCHAR(32) NULL DEFAULT "",\n `pv` BIGINT SUM NULL DEFAULT "0"\n) ENGINE=OLAP\nAGGREGATE KEY(`siteid`, `citycode`, `username`)\nDISTRIBUTED BY HASH(`siteid`) BUCKETS 10\nPROPERTIES (\n"replication_allocation" = "tag.location.default: 1",\n"min_load_replica_num" = "-1",\n"is_being_synced" = "false",\n"storage_medium" = "hdd",\n"storage_format" = "V2",\n"inverted_index_storage_format" = "V2",\n"light_schema_change" = "true",\n"disable_auto_compaction" = "false",\n"enable_single_replica_compaction" = "false",\n"group_commit_interval_ms" = "10000",\n"group_commit_data_bytes" = "134217728"\n);

Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ suite("test_compaction_uniq_keys_row_store", "nonConcurrent") {
stmt.setInt(8, sex)
}

sql "set global enable_server_side_prepared_statement = true"

try {
String backend_id;
def backendId_to_backendIP = [:]
Expand Down Expand Up @@ -213,5 +211,4 @@ suite("test_compaction_uniq_keys_row_store", "nonConcurrent") {
} finally {
// try_sql("DROP TABLE IF EXISTS ${tableName}")
}
sql "set global enable_server_side_prepared_statement = false"
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ import java.sql.DriverManager
import java.sql.Statement
import java.sql.PreparedStatement

suite("insert_group_commit_with_exception") {
suite("insert_group_commit_with_exception", "nonConcurrent") {
def table = "insert_group_commit_with_exception"

def getRowCount = { expectedRowCount ->
def retry = 0
while (retry < 30) {
Expand Down Expand Up @@ -74,9 +73,10 @@ suite("insert_group_commit_with_exception") {
if (item == "nereids") {
sql """ set enable_nereids_dml = true; """
sql """ set enable_nereids_planner=true; """
//sql """ set enable_fallback_to_original_planner=false; """
sql "set global enable_server_side_prepared_statement = true"
} else {
sql """ set enable_nereids_dml = false; """
sql "set global enable_server_side_prepared_statement = false"
}

// insert into without column
Expand Down Expand Up @@ -161,9 +161,11 @@ suite("insert_group_commit_with_exception") {
if (item == "nereids") {
statement.execute("set enable_nereids_dml = true;");
statement.execute("set enable_nereids_planner=true;");
//statement.execute("set enable_fallback_to_original_planner=false;");
statement.execute("set enable_fallback_to_original_planner=false;");
sql "set global enable_server_side_prepared_statement = true"
} else {
statement.execute("set enable_nereids_dml = false;");
sql "set global enable_server_side_prepared_statement = false"
}
// without column
try (PreparedStatement ps = connection.prepareStatement("insert into ${table} values(?, ?, ?, ?)")) {
Expand Down Expand Up @@ -287,7 +289,13 @@ suite("insert_group_commit_with_exception") {
result = ps.executeBatch()
assertTrue(false)
} catch (Exception e) {
assertTrue(e.getMessage().contains("Column count doesn't match value count"))
logger.info("exception : " + e)
if (item == "legacy") {
assertTrue(e.getMessage().contains("Column count doesn't match value count"))
}
if (item == "nereids") {
assertTrue(e.getMessage().contains("insert into cols should be corresponding to the query output"))
}
}
}
getRowCount(14)
Expand Down Expand Up @@ -317,4 +325,5 @@ suite("insert_group_commit_with_exception") {
// try_sql("DROP TABLE ${table}")
}
}
sql "set global enable_server_side_prepared_statement = true"
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ suite("insert_group_commit_with_prepare_stmt") {
def table = realDb + ".insert_group_commit_with_prepare_stmt"

sql "CREATE DATABASE IF NOT EXISTS ${realDb}"

def getRowCount = { expectedRowCount ->
def retry = 0
while (retry < 30) {
Expand Down Expand Up @@ -89,7 +88,8 @@ suite("insert_group_commit_with_prepare_stmt") {
}
assertTrue(serverInfo.contains("'status':'PREPARE'"))
assertTrue(serverInfo.contains("'label':'group_commit_"))
assertEquals(reuse_plan, serverInfo.contains("reuse_group_commit_plan"))
// TODO: currently if enable_server_side_prepared_statement = true, will not reuse plan
// assertEquals(reuse_plan, serverInfo.contains("reuse_group_commit_plan"))
} else {
// for batch insert
ConnectionImpl connection = (ConnectionImpl) stmt.getConnection()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ suite("test_point_query", "nonConcurrent") {
// nereids do not support point query now
sql "set global enable_fallback_to_original_planner = false"
sql """set global enable_nereids_planner=true"""
sql "set global enable_server_side_prepared_statement = true"
def user = context.config.jdbcUser
def password = context.config.jdbcPassword
def realDb = "regression_test_serving_p0"
Expand Down Expand Up @@ -279,6 +278,5 @@ suite("test_point_query", "nonConcurrent") {
set_be_config.call("disable_storage_row_cache", "true")
sql """set global enable_nereids_planner=true"""
sql "set global enable_fallback_to_original_planner = true"
sql "set global enable_server_side_prepared_statement = false"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ suite("test_point_query_cluster_key") {
try {
set_be_config.call("disable_storage_row_cache", "false")
// nereids do not support point query now
sql """set enable_nereids_planner=false"""
sql """set enable_nereids_planner=true"""

def user = context.config.jdbcUser
def password = context.config.jdbcPassword
Expand Down Expand Up @@ -139,7 +139,7 @@ suite("test_point_query_cluster_key") {
sql """ INSERT INTO ${tableName} VALUES(298, 120939.11130, "${generateString(298)}", "laooq", "2030-01-02", "2020-01-01 12:36:38", 298, "7022-01-01 11:30:38", 1, 90696620686827832.374, [], []) """

def result1 = connect(user=user, password=password, url=prepare_url) {
def stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=false) */ * from ${tableName} where k1 = ? and k2 = ? and k3 = ?"
def stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from ${tableName} where k1 = ? and k2 = ? and k3 = ?"
assertEquals(stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
stmt.setInt(1, 1231)
stmt.setBigDecimal(2, new BigDecimal("119291.11"))
Expand Down Expand Up @@ -175,13 +175,14 @@ suite("test_point_query_cluster_key") {
qe_point_select stmt
stmt.close()

stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=false) */ * from ${tableName} where k1 = 1235 and k2 = ? and k3 = ?"
stmt = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ * from ${tableName} where k1 = ? and k2 = ? and k3 = ?"
assertEquals(stmt.class, com.mysql.cj.jdbc.ServerPreparedStatement);
stmt.setBigDecimal(1, new BigDecimal("991129292901.11138"))
stmt.setString(2, "dd")
stmt.setInt(1, 1235)
stmt.setBigDecimal(2, new BigDecimal("991129292901.11138"))
stmt.setString(3, "dd")
qe_point_select stmt

def stmt_fn = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=false) */ hex(k3), hex(k4) from ${tableName} where k1 = ? and k2 =? and k3 = ?"
def stmt_fn = prepareStatement "select /*+ SET_VAR(enable_nereids_planner=true) */ hex(k3), hex(k4) from ${tableName} where k1 = ? and k2 =? and k3 = ?"
assertEquals(stmt_fn.class, com.mysql.cj.jdbc.ServerPreparedStatement);
stmt_fn.setInt(1, 1231)
stmt_fn.setBigDecimal(2, new BigDecimal("119291.11"))
Expand All @@ -195,8 +196,8 @@ suite("test_point_query_cluster_key") {
"""
sleep(1);
nprep_sql """ INSERT INTO ${tableName} VALUES(1235, 120939.11130, "a ddd", "laooq", "2030-01-02", "2020-01-01 12:36:38", 22.822, "7022-01-01 11:30:38", 1, 1.1111299, [119291.19291], ["111", "222", "333"], 1) """
stmt.setBigDecimal(1, new BigDecimal("120939.11130"))
stmt.setString(2, "a ddd")
stmt.setBigDecimal(2, new BigDecimal("120939.11130"))
stmt.setString(3, "a ddd")
qe_point_select stmt
qe_point_select stmt
// invalidate cache
Expand All @@ -222,9 +223,9 @@ suite("test_point_query_cluster_key") {
}
// disable useServerPrepStmts
def result2 = connect(user=user, password=password, url=context.config.jdbcUrl) {
qt_sql """select /*+ SET_VAR(enable_nereids_planner=false) */ * from ${tableName} where k1 = 1231 and k2 = 119291.11 and k3 = 'ddd'"""
qt_sql """select /*+ SET_VAR(enable_nereids_planner=false) */ * from ${tableName} where k1 = 1237 and k2 = 120939.11130 and k3 = 'a ddd'"""
qt_sql """select /*+ SET_VAR(enable_nereids_planner=false) */ hex(k3), hex(k4), k7 + 10.1 from ${tableName} where k1 = 1237 and k2 = 120939.11130 and k3 = 'a ddd'"""
qt_sql """select /*+ SET_VAR(enable_nereids_planner=true) */ * from ${tableName} where k1 = 1231 and k2 = 119291.11 and k3 = 'ddd'"""
qt_sql """select /*+ SET_VAR(enable_nereids_planner=true) */ * from ${tableName} where k1 = 1237 and k2 = 120939.11130 and k3 = 'a ddd'"""
qt_sql """select /*+ SET_VAR(enable_nereids_planner=true) */ hex(k3), hex(k4), k7 + 10.1 from ${tableName} where k1 = 1237 and k2 = 120939.11130 and k3 = 'a ddd'"""
// prepared text
// sql """ prepare stmt1 from select * from ${tableName} where k1 = % and k2 = % and k3 = % """
// qt_sql """execute stmt1 using (1231, 119291.11, 'ddd')"""
Expand Down Expand Up @@ -254,7 +255,7 @@ suite("test_point_query_cluster_key") {
"disable_auto_compaction" = "false"
);"""
sql """insert into ${tableName} values (0, "1", "2", "3")"""
qt_sql """select /*+ SET_VAR(enable_nereids_planner=false) */ * from ${tableName} where customer_key = 0"""
qt_sql """select /*+ SET_VAR(enable_nereids_planner=true) */ * from ${tableName} where customer_key = 0"""
}
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

import java.math.BigDecimal;

suite("test_point_query_partition", "nonConcurrent") {
suite("test_point_query_partition") {
def user = context.config.jdbcUser
def password = context.config.jdbcPassword
def realDb = "regression_test_serving_p0"
def tableName = realDb + ".tbl_point_query_partition"
sql "CREATE DATABASE IF NOT EXISTS ${realDb}"
sql "set global enable_server_side_prepared_statement = true"
// Parse url
String jdbcUrl = context.config.jdbcUrl
String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3)
Expand Down Expand Up @@ -150,5 +149,4 @@ suite("test_point_query_partition", "nonConcurrent") {
qe_point_selectmmm stmt
qe_point_selecteee stmt
}
sql "set global enable_server_side_prepared_statement = false"
}
Loading