From 751c539a6ce8f2012f6089b77782e4e89471e9cf Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Mon, 1 Feb 2021 13:12:10 -0700 Subject: [PATCH 1/6] session, util: update session to use new APIs --- session/bench_test.go | 100 ++++++++++++------------ session/bootstrap.go | 144 +++++++++++++++++----------------- session/bootstrap_test.go | 13 ++-- session/session.go | 145 +++++++++++++++++------------------ session/session_fail_test.go | 2 +- session/session_test.go | 36 ++++----- session/tidb_test.go | 21 +---- util/testkit/testkit.go | 2 +- 8 files changed, 217 insertions(+), 246 deletions(-) diff --git a/session/bench_test.go b/session/bench_test.go index 02d1889f9b73d..2e4e29b5f8f14 100644 --- a/session/bench_test.go +++ b/session/bench_test.go @@ -55,17 +55,17 @@ func prepareBenchSession() (Session, *domain.Domain, kv.Storage) { func prepareBenchData(se Session, colType string, valueFormat string, valueCount int) { mustExecute(se, "drop table if exists t") - mustExecute(se, fmt.Sprintf("create table t (pk int primary key auto_increment, col %s, index idx (col))", colType)) + mustExecute(se, "create table t (pk int primary key auto_increment, col %n, index idx (col))", colType) mustExecute(se, "begin") for i := 0; i < valueCount; i++ { - mustExecute(se, "insert t (col) values ("+fmt.Sprintf(valueFormat, i)+")") + mustExecute(se, "insert t (col) values (%?)", i) } mustExecute(se, "commit") } func prepareSortBenchData(se Session, colType string, valueFormat string, valueCount int) { mustExecute(se, "drop table if exists t") - mustExecute(se, fmt.Sprintf("create table t (pk int primary key auto_increment, col %s)", colType)) + mustExecute(se, "create table t (pk int primary key auto_increment, col %n)", colType) mustExecute(se, "begin") r := rand.New(rand.NewSource(time.Now().UnixNano())) for i := 0; i < valueCount; i++ { @@ -73,17 +73,17 @@ func prepareSortBenchData(se Session, colType string, valueFormat string, valueC mustExecute(se, "commit") mustExecute(se, "begin") } - mustExecute(se, "insert t (col) values ("+fmt.Sprintf(valueFormat, r.Intn(valueCount))+")") + mustExecute(se, "insert t (col) values (%?)", r.Intn(valueCount)) } mustExecute(se, "commit") } func prepareJoinBenchData(se Session, colType string, valueFormat string, valueCount int) { mustExecute(se, "drop table if exists t") - mustExecute(se, fmt.Sprintf("create table t (pk int primary key auto_increment, col %s)", colType)) + mustExecute(se, "create table t (pk int primary key auto_increment, col %n)", colType) mustExecute(se, "begin") for i := 0; i < valueCount; i++ { - mustExecute(se, "insert t (col) values ("+fmt.Sprintf(valueFormat, i)+")") + mustExecute(se, "insert t (col) values (%?)", i) } mustExecute(se, "commit") } @@ -113,11 +113,11 @@ func BenchmarkBasic(b *testing.B) { }() b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select 1") + rs, err := se.ExecuteInternal(ctx, "select 1") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 1) + readResult(ctx, rs, 1) } b.StopTimer() } @@ -133,11 +133,11 @@ func BenchmarkTableScan(b *testing.B) { prepareBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t") + rs, err := se.ExecuteInternal(ctx, "select * from t") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], smallCount) + readResult(ctx, rs, smallCount) } b.StopTimer() } @@ -153,11 +153,11 @@ func BenchmarkExplainTableScan(b *testing.B) { prepareBenchData(se, "int", "%v", 0) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "explain select * from t") + rs, err := se.ExecuteInternal(ctx, "explain select * from t") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 1) + readResult(ctx, rs, 1) } b.StopTimer() } @@ -173,11 +173,11 @@ func BenchmarkTableLookup(b *testing.B) { prepareBenchData(se, "int", "%d", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where pk = 64") + rs, err := se.ExecuteInternal(ctx, "select * from t where pk = 64") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 1) + readResult(ctx, rs, 1) } b.StopTimer() } @@ -193,11 +193,11 @@ func BenchmarkExplainTableLookup(b *testing.B) { prepareBenchData(se, "int", "%d", 0) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "explain select * from t where pk = 64") + rs, err := se.ExecuteInternal(ctx, "explain select * from t where pk = 64") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 1) + readResult(ctx, rs, 1) } b.StopTimer() } @@ -213,11 +213,11 @@ func BenchmarkStringIndexScan(b *testing.B) { prepareBenchData(se, "varchar(255)", "'hello %d'", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where col > 'hello'") + rs, err := se.ExecuteInternal(ctx, "select * from t where col > 'hello'") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], smallCount) + readResult(ctx, rs, smallCount) } b.StopTimer() } @@ -233,11 +233,11 @@ func BenchmarkExplainStringIndexScan(b *testing.B) { prepareBenchData(se, "varchar(255)", "'hello %d'", 0) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "explain select * from t where col > 'hello'") + rs, err := se.ExecuteInternal(ctx, "explain select * from t where col > 'hello'") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 1) + readResult(ctx, rs, 1) } b.StopTimer() } @@ -253,11 +253,11 @@ func BenchmarkStringIndexLookup(b *testing.B) { prepareBenchData(se, "varchar(255)", "'hello %d'", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where col = 'hello 64'") + rs, err := se.ExecuteInternal(ctx, "select * from t where col = 'hello 64'") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 1) + readResult(ctx, rs, 1) } b.StopTimer() } @@ -273,11 +273,11 @@ func BenchmarkIntegerIndexScan(b *testing.B) { prepareBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where col >= 0") + rs, err := se.ExecuteInternal(ctx, "select * from t where col >= 0") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], smallCount) + readResult(ctx, rs, smallCount) } b.StopTimer() } @@ -293,11 +293,11 @@ func BenchmarkIntegerIndexLookup(b *testing.B) { prepareBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where col = 64") + rs, err := se.ExecuteInternal(ctx, "select * from t where col = 64") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 1) + readResult(ctx, rs, 1) } b.StopTimer() } @@ -313,11 +313,11 @@ func BenchmarkDecimalIndexScan(b *testing.B) { prepareBenchData(se, "decimal(32,6)", "%v.1234", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where col >= 0") + rs, err := se.ExecuteInternal(ctx, "select * from t where col >= 0") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], smallCount) + readResult(ctx, rs, smallCount) } b.StopTimer() } @@ -333,11 +333,11 @@ func BenchmarkDecimalIndexLookup(b *testing.B) { prepareBenchData(se, "decimal(32,6)", "%v.1234", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where col = 64.1234") + rs, err := se.ExecuteInternal(ctx, "select * from t where col = 64.1234") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 1) + readResult(ctx, rs, 1) } b.StopTimer() } @@ -353,7 +353,7 @@ func BenchmarkInsertWithIndex(b *testing.B) { mustExecute(se, "create table t (pk int primary key, col int, index idx (col))") b.ResetTimer() for i := 0; i < b.N; i++ { - mustExecute(se, fmt.Sprintf("insert t values (%d, %d)", i, i)) + mustExecute(se, "insert t values (%d, %d)", i, i) } b.StopTimer() } @@ -369,7 +369,7 @@ func BenchmarkInsertNoIndex(b *testing.B) { mustExecute(se, "create table t (pk int primary key, col int)") b.ResetTimer() for i := 0; i < b.N; i++ { - mustExecute(se, fmt.Sprintf("insert t values (%d, %d)", i, i)) + mustExecute(se, "insert t values (%d, %d)", i, i) } b.StopTimer() } @@ -385,11 +385,11 @@ func BenchmarkSort(b *testing.B) { prepareSortBenchData(se, "int", "%v", bigCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t order by col limit 50") + rs, err := se.ExecuteInternal(ctx, "select * from t order by col limit 50") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 50) + readResult(ctx, rs, 50) } b.StopTimer() } @@ -405,11 +405,11 @@ func BenchmarkJoin(b *testing.B) { prepareJoinBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t a join t b on a.col = b.col") + rs, err := se.ExecuteInternal(ctx, "select * from t a join t b on a.col = b.col") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], smallCount) + readResult(ctx, rs, smallCount) } b.StopTimer() } @@ -425,11 +425,11 @@ func BenchmarkJoinLimit(b *testing.B) { prepareJoinBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t a join t b on a.col = b.col limit 1") + rs, err := se.ExecuteInternal(ctx, "select * from t a join t b on a.col = b.col limit 1") if err != nil { b.Fatal(err) } - readResult(ctx, rs[0], 1) + readResult(ctx, rs, 1) } b.StopTimer() } @@ -1472,11 +1472,11 @@ partition p1023 values less than (738538) )`) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where dt > to_days('2019-04-01 21:00:00') and dt < to_days('2019-04-07 23:59:59')") + rs, err := se.ExecuteInternal(ctx, "select * from t where dt > to_days('2019-04-01 21:00:00') and dt < to_days('2019-04-07 23:59:59')") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs[0]) + _, err = drainRecordSet(ctx, se.(*session), rs) if err != nil { b.Fatal(err) } @@ -1504,11 +1504,11 @@ func BenchmarkRangeColumnPartitionPruning(b *testing.B) { mustExecute(se, build.String()) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where dt > '2020-05-01' and dt < '2020-06-07'") + rs, err := se.ExecuteInternal(ctx, "select * from t where dt > '2020-05-01' and dt < '2020-06-07'") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs[0]) + _, err = drainRecordSet(ctx, se.(*session), rs) if err != nil { b.Fatal(err) } @@ -1528,11 +1528,11 @@ func BenchmarkHashPartitionPruningPointSelect(b *testing.B) { mustExecute(se, `create table t (id int, dt datetime) partition by hash(id) partitions 1024;`) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where id = 2330") + rs, err := se.ExecuteInternal(ctx, "select * from t where id = 2330") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs[0]) + _, err = drainRecordSet(ctx, se.(*session), rs) if err != nil { b.Fatal(err) } @@ -1552,27 +1552,27 @@ func BenchmarkHashPartitionPruningMultiSelect(b *testing.B) { mustExecute(se, `create table t (id int, dt datetime) partition by hash(id) partitions 1024;`) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.Execute(ctx, "select * from t where id = 2330") + rs, err := se.ExecuteInternal(ctx, "select * from t where id = 2330") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs[0]) + _, err = drainRecordSet(ctx, se.(*session), rs) if err != nil { b.Fatal(err) } - rs, err = se.Execute(ctx, "select * from t where id = 1233 or id = 1512") + rs, err = se.ExecuteInternal(ctx, "select * from t where id = 1233 or id = 1512") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs[0]) + _, err = drainRecordSet(ctx, se.(*session), rs) if err != nil { b.Fatal(err) } - rs, err = se.Execute(ctx, "select * from t where id in (117, 1233, 15678)") + rs, err = se.ExecuteInternal(ctx, "select * from t where id in (117, 1233, 15678)") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs[0]) + _, err = drainRecordSet(ctx, se.(*session), rs) if err != nil { b.Fatal(err) } diff --git a/session/bootstrap.go b/session/bootstrap.go index 14f1e21c4904a..4a0df857025d4 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -539,7 +539,7 @@ var ( func checkBootstrapped(s Session) (bool, error) { // Check if system db exists. - _, err := s.Execute(context.Background(), fmt.Sprintf("USE %s;", mysql.SystemDB)) + _, err := s.ExecuteInternal(context.Background(), "USE %n", mysql.SystemDB) if err != nil && infoschema.ErrDatabaseNotExists.NotEqual(err) { logutil.BgLogger().Fatal("check bootstrap error", zap.Error(err)) @@ -565,20 +565,21 @@ func checkBootstrapped(s Session) (bool, error) { // getTiDBVar gets variable value from mysql.tidb table. // Those variables are used by TiDB server. func getTiDBVar(s Session, name string) (sVal string, isNull bool, e error) { - sql := fmt.Sprintf(`SELECT HIGH_PRIORITY VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s"`, - mysql.SystemDB, mysql.TiDBTable, name) ctx := context.Background() - rs, err := s.Execute(ctx, sql) + rs, err := s.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME= %?`, + mysql.SystemDB, + mysql.TiDBTable, + name, + ) if err != nil { return "", true, errors.Trace(err) } - if len(rs) != 1 { + if rs == nil { return "", true, errors.New("Wrong number of Recordset") } - r := rs[0] - defer terror.Call(r.Close) - req := r.NewChunk() - err = r.Next(ctx, req) + defer terror.Call(rs.Close) + req := rs.NewChunk() + err = rs.Next(ctx, req) if err != nil || req.NumRows() == 0 { return "", true, errors.Trace(err) } @@ -604,7 +605,7 @@ func upgrade(s Session) { } updateBootstrapVer(s) - _, err = s.Execute(context.Background(), "COMMIT") + _, err = s.ExecuteInternal(context.Background(), "COMMIT") if err != nil { sleepTime := 1 * time.Second @@ -651,9 +652,7 @@ func upgradeToVer3(s Session, ver int64) { return } // Version 3 fix tx_read_only variable value. - sql := fmt.Sprintf("UPDATE HIGH_PRIORITY %s.%s SET variable_value = '0' WHERE variable_name = 'tx_read_only';", - mysql.SystemDB, mysql.GlobalVariablesTable) - mustExecute(s, sql) + mustExecute(s, "UPDATE HIGH_PRIORITY %n.%n SET variable_value = '0' WHERE variable_name = 'tx_read_only';", mysql.SystemDB, mysql.GlobalVariablesTable) } // upgradeToVer4 updates to version 4. @@ -661,8 +660,7 @@ func upgradeToVer4(s Session, ver int64) { if ver >= version4 { return } - sql := CreateStatsMetaTable - mustExecute(s, sql) + mustExecute(s, CreateStatsMetaTable) } func upgradeToVer5(s Session, ver int64) { @@ -696,7 +694,7 @@ func upgradeToVer8(s Session, ver int64) { return } // This is a dummy upgrade, it checks whether upgradeToVer7 success, if not, do it again. - if _, err := s.Execute(context.Background(), "SELECT HIGH_PRIORITY `Process_priv` FROM mysql.user LIMIT 0"); err == nil { + if _, err := s.ExecuteInternal(context.Background(), "SELECT HIGH_PRIORITY `Process_priv` FROM mysql.user LIMIT 0"); err == nil { return } upgradeToVer7(s, ver) @@ -712,7 +710,7 @@ func upgradeToVer9(s Session, ver int64) { } func doReentrantDDL(s Session, sql string, ignorableErrs ...error) { - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), sql) for _, ignorableErr := range ignorableErrs { if terror.ErrorEqual(err, ignorableErr) { return @@ -738,7 +736,7 @@ func upgradeToVer11(s Session, ver int64) { if ver >= version11 { return } - _, err := s.Execute(context.Background(), "ALTER TABLE mysql.user ADD COLUMN `References_priv` ENUM('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N' AFTER `Grant_priv`") + _, err := s.ExecuteInternal(context.Background(), "ALTER TABLE mysql.user ADD COLUMN `References_priv` ENUM('N','Y') CHARACTER SET utf8 NOT NULL DEFAULT 'N' AFTER `Grant_priv`") if err != nil { if terror.ErrorEqual(err, infoschema.ErrColumnExists) { return @@ -753,21 +751,20 @@ func upgradeToVer12(s Session, ver int64) { return } ctx := context.Background() - _, err := s.Execute(ctx, "BEGIN") + _, err := s.ExecuteInternal(ctx, "BEGIN") terror.MustNil(err) sql := "SELECT HIGH_PRIORITY user, host, password FROM mysql.user WHERE password != ''" - rs, err := s.Execute(ctx, sql) + rs, err := s.ExecuteInternal(ctx, sql) if terror.ErrorEqual(err, core.ErrUnknownColumn) { sql := "SELECT HIGH_PRIORITY user, host, authentication_string FROM mysql.user WHERE authentication_string != ''" - rs, err = s.Execute(ctx, sql) + rs, err = s.ExecuteInternal(ctx, sql) } terror.MustNil(err) - r := rs[0] sqls := make([]string, 0, 1) - defer terror.Call(r.Close) - req := r.NewChunk() + defer terror.Call(rs.Close) + req := rs.NewChunk() it := chunk.NewIterator4Chunk(req) - err = r.Next(ctx, req) + err = rs.Next(ctx, req) for err == nil && req.NumRows() != 0 { for row := it.Begin(); row != it.End(); row = it.Next() { user := row.GetString(0) @@ -779,7 +776,7 @@ func upgradeToVer12(s Session, ver int64) { updateSQL := fmt.Sprintf(`UPDATE HIGH_PRIORITY mysql.user SET password = "%s" WHERE user="%s" AND host="%s"`, newPass, user, host) sqls = append(sqls, updateSQL) } - err = r.Next(ctx, req) + err = rs.Next(ctx, req) } terror.MustNil(err) @@ -809,7 +806,7 @@ func upgradeToVer13(s Session, ver int64) { } ctx := context.Background() for _, sql := range sqls { - _, err := s.Execute(ctx, sql) + _, err := s.ExecuteInternal(ctx, sql) if err != nil { if terror.ErrorEqual(err, infoschema.ErrColumnExists) { continue @@ -838,7 +835,7 @@ func upgradeToVer14(s Session, ver int64) { } ctx := context.Background() for _, sql := range sqls { - _, err := s.Execute(ctx, sql) + _, err := s.ExecuteInternal(ctx, sql) if err != nil { if terror.ErrorEqual(err, infoschema.ErrColumnExists) { continue @@ -853,7 +850,7 @@ func upgradeToVer15(s Session, ver int64) { return } var err error - _, err = s.Execute(context.Background(), CreateGCDeleteRangeTable) + _, err = s.ExecuteInternal(context.Background(), CreateGCDeleteRangeTable) if err != nil { logutil.BgLogger().Fatal("upgradeToVer15 error", zap.Error(err)) } @@ -923,9 +920,13 @@ func upgradeToVer23(s Session, ver int64) { // writeSystemTZ writes system timezone info into mysql.tidb func writeSystemTZ(s Session) { - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", "%s", "TiDB Global System Timezone.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%s"`, - mysql.SystemDB, mysql.TiDBTable, tidbSystemTZ, timeutil.InferSystemTZ(), timeutil.InferSystemTZ()) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, "TiDB Global System Timezone.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE= %?`, + mysql.SystemDB, + mysql.TiDBTable, + tidbSystemTZ, + timeutil.InferSystemTZ(), + timeutil.InferSystemTZ(), + ) } // upgradeToVer24 initializes `System` timezone according to docs/design/2018-09-10-adding-tz-env.md @@ -1054,7 +1055,7 @@ func upgradeToVer38(s Session, ver int64) { return } var err error - _, err = s.Execute(context.Background(), CreateGlobalPrivTable) + _, err = s.ExecuteInternal(context.Background(), CreateGlobalPrivTable) if err != nil { logutil.BgLogger().Fatal("upgradeToVer38 error", zap.Error(err)) } @@ -1066,9 +1067,9 @@ func writeNewCollationParameter(s Session, flag bool) { if flag { b = varTrue } - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", '%s', '%s') ON DUPLICATE KEY UPDATE VARIABLE_VALUE='%s'`, - mysql.SystemDB, mysql.TiDBTable, tidbNewCollationEnabled, b, comment, b) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, %?) ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, tidbNewCollationEnabled, b, comment, b, + ) } func upgradeToVer40(s Session, ver int64) { @@ -1104,14 +1105,14 @@ func upgradeToVer42(s Session, ver int64) { // Convert statement summary global variables to non-empty values. func writeStmtSummaryVars(s Session) { - sql := fmt.Sprintf("UPDATE %s.%s SET variable_value='%%s' WHERE variable_name='%%s' AND variable_value=''", mysql.SystemDB, mysql.GlobalVariablesTable) + sql := "UPDATE %n.%n SET variable_value= %? WHERE variable_name= %? AND variable_value=''" stmtSummaryConfig := config.GetGlobalConfig().StmtSummary - mustExecute(s, fmt.Sprintf(sql, variable.BoolToOnOff(stmtSummaryConfig.Enable), variable.TiDBEnableStmtSummary)) - mustExecute(s, fmt.Sprintf(sql, variable.BoolToOnOff(stmtSummaryConfig.EnableInternalQuery), variable.TiDBStmtSummaryInternalQuery)) - mustExecute(s, fmt.Sprintf(sql, strconv.Itoa(stmtSummaryConfig.RefreshInterval), variable.TiDBStmtSummaryRefreshInterval)) - mustExecute(s, fmt.Sprintf(sql, strconv.Itoa(stmtSummaryConfig.HistorySize), variable.TiDBStmtSummaryHistorySize)) - mustExecute(s, fmt.Sprintf(sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxStmtCount), 10), variable.TiDBStmtSummaryMaxStmtCount)) - mustExecute(s, fmt.Sprintf(sql, strconv.FormatUint(uint64(stmtSummaryConfig.MaxSQLLength), 10), variable.TiDBStmtSummaryMaxSQLLength)) + mustExecute(s, sql, mysql.SystemDB, mysql.GlobalVariablesTable, variable.BoolToOnOff(stmtSummaryConfig.Enable), variable.TiDBEnableStmtSummary) + mustExecute(s, sql, mysql.SystemDB, mysql.GlobalVariablesTable, variable.BoolToOnOff(stmtSummaryConfig.EnableInternalQuery), variable.TiDBStmtSummaryInternalQuery) + mustExecute(s, sql, mysql.SystemDB, mysql.GlobalVariablesTable, strconv.Itoa(stmtSummaryConfig.RefreshInterval), variable.TiDBStmtSummaryRefreshInterval) + mustExecute(s, sql, mysql.SystemDB, mysql.GlobalVariablesTable, strconv.Itoa(stmtSummaryConfig.HistorySize), variable.TiDBStmtSummaryHistorySize) + mustExecute(s, sql, mysql.SystemDB, mysql.GlobalVariablesTable, strconv.FormatUint(uint64(stmtSummaryConfig.MaxStmtCount), 10), variable.TiDBStmtSummaryMaxStmtCount) + mustExecute(s, sql, mysql.SystemDB, mysql.GlobalVariablesTable, strconv.FormatUint(uint64(stmtSummaryConfig.MaxSQLLength), 10), variable.TiDBStmtSummaryMaxSQLLength) } func upgradeToVer43(s Session, ver int64) { @@ -1219,13 +1220,12 @@ func upgradeToVer55(s Session, ver int64) { selectSQL := "select HIGH_PRIORITY * from mysql.global_variables where variable_name in ('" + strings.Join(names, quoteCommaQuote) + "')" ctx := context.Background() - rs, err := s.Execute(ctx, selectSQL) + rs, err := s.ExecuteInternal(ctx, selectSQL) terror.MustNil(err) - r := rs[0] - defer terror.Call(r.Close) - req := r.NewChunk() + defer terror.Call(rs.Close) + req := rs.NewChunk() it := chunk.NewIterator4Chunk(req) - err = r.Next(ctx, req) + err = rs.Next(ctx, req) for err == nil && req.NumRows() != 0 { for row := it.Begin(); row != it.End(); row = it.Next() { n := strings.ToLower(row.GetString(0)) @@ -1234,7 +1234,7 @@ func upgradeToVer55(s Session, ver int64) { return } } - err = r.Next(ctx, req) + err = rs.Next(ctx, req) } terror.MustNil(err) @@ -1270,9 +1270,9 @@ func initBindInfoTable(s Session) { } func insertBuiltinBindInfoRow(s Session) { - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO mysql.bind_info VALUES ("%s", "%s", "mysql", "%s", "0000-00-00 00:00:00", "0000-00-00 00:00:00", "", "", "%s")`, - bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.Builtin, bindinfo.Builtin) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO mysql.bind_info VALUES (%?, %?, "mysql", %?, "0000-00-00 00:00:00", "0000-00-00 00:00:00", "", "", %?)`, + bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.BuiltinPseudoSQL4BindLock, bindinfo.Builtin, bindinfo.Builtin, + ) } func upgradeToVer59(s Session, ver int64) { @@ -1400,9 +1400,9 @@ func updateBindInfo(iter *chunk.Iterator4Chunk, p *parser.Parser, bindMap map[st func writeMemoryQuotaQuery(s Session) { comment := "memory_quota_query is 32GB by default in v3.0.x, 1GB by default in v4.0.x+" - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", '%d', '%s') ON DUPLICATE KEY UPDATE VARIABLE_VALUE='%d'`, - mysql.SystemDB, mysql.TiDBTable, tidbDefMemoryQuotaQuery, 32<<30, comment, 32<<30) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, %?) ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, tidbDefMemoryQuotaQuery, 32<<30, comment, 32<<30, + ) } func upgradeToVer62(s Session, ver int64) { @@ -1431,17 +1431,17 @@ func upgradeToVer64(s Session, ver int64) { func writeOOMAction(s Session) { comment := "oom-action is `log` by default in v3.0.x, `cancel` by default in v4.0.11+" - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", '%s', '%s') ON DUPLICATE KEY UPDATE VARIABLE_VALUE='%s'`, - mysql.SystemDB, mysql.TiDBTable, tidbDefOOMAction, config.OOMActionLog, comment, config.OOMActionLog) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, %?) ON DUPLICATE KEY UPDATE VARIABLE_VALUE= %?`, + mysql.SystemDB, mysql.TiDBTable, tidbDefOOMAction, config.OOMActionLog, comment, config.OOMActionLog, + ) } // updateBootstrapVer updates bootstrap version variable in mysql.TiDB table. func updateBootstrapVer(s Session) { // Update bootstrap version. - sql := fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES ("%s", "%d", "TiDB bootstrap version.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%d"`, - mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion, currentBootstrapVersion) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES (%?, %?, "TiDB bootstrap version.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion, currentBootstrapVersion, + ) } // getBootstrapVersion gets bootstrap version from mysql.tidb table; @@ -1461,7 +1461,7 @@ func doDDLWorks(s Session) { // Create a test database. mustExecute(s, "CREATE DATABASE IF NOT EXISTS test") // Create system db. - mustExecute(s, fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", mysql.SystemDB)) + mustExecute(s, "CREATE DATABASE IF NOT EXISTS %n", mysql.SystemDB) // Create user table. mustExecute(s, CreateUserTable) // Create privilege tables. @@ -1507,6 +1507,7 @@ func doDDLWorks(s Session) { // doDMLWorks executes DML statements in bootstrap stage. // All the statements run in a single transaction. +// TODO: sanitize. func doDMLWorks(s Session) { mustExecute(s, "BEGIN") @@ -1548,14 +1549,13 @@ func doDMLWorks(s Session) { strings.Join(values, ", ")) mustExecute(s, sql) - sql = fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES("%s", "%s", "Bootstrap flag. Do not delete.") - ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%s"`, - mysql.SystemDB, mysql.TiDBTable, bootstrappedVar, varTrue, varTrue) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES(%?, %?, "Bootstrap flag. Do not delete.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE=%?`, + mysql.SystemDB, mysql.TiDBTable, bootstrappedVar, varTrue, varTrue, + ) - sql = fmt.Sprintf(`INSERT HIGH_PRIORITY INTO %s.%s VALUES("%s", "%d", "Bootstrap version. Do not delete.")`, - mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion) - mustExecute(s, sql) + mustExecute(s, `INSERT HIGH_PRIORITY INTO %n.%n VALUES(%?, %?, "Bootstrap version. Do not delete.")`, + mysql.SystemDB, mysql.TiDBTable, tidbServerVersionVar, currentBootstrapVersion, + ) writeSystemTZ(s) @@ -1565,7 +1565,7 @@ func doDMLWorks(s Session) { writeStmtSummaryVars(s) - _, err := s.Execute(context.Background(), "COMMIT") + _, err := s.ExecuteInternal(context.Background(), "COMMIT") if err != nil { sleepTime := 1 * time.Second logutil.BgLogger().Info("doDMLWorks failed", zap.Error(err), zap.Duration("sleeping time", sleepTime)) @@ -1582,8 +1582,8 @@ func doDMLWorks(s Session) { } } -func mustExecute(s Session, sql string) { - _, err := s.ExecuteInternal(context.Background(), sql) +func mustExecute(s Session, sql string, args ...interface{}) { + _, err := s.ExecuteInternal(context.Background(), sql, args...) if err != nil { debug.PrintStack() logutil.BgLogger().Fatal("mustExecute error", zap.Error(err)) diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index 5470082eff101..bec967310e048 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -81,7 +81,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { se, err = CreateSession4Test(store) c.Assert(err, IsNil) mustExecSQL(c, se, "USE test;") - mustExecSQL(c, se, "insert t values (?)", 3) + mustExecSQL(c, se, "insert t values (%?)", 3) se, err = CreateSession4Test(store) c.Assert(err, IsNil) mustExecSQL(c, se, "USE test;") @@ -227,8 +227,7 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { err = txn.Commit(context.Background()) c.Assert(err, IsNil) mustExecSQL(c, se1, `delete from mysql.TiDB where VARIABLE_NAME="tidb_server_version";`) - mustExecSQL(c, se1, fmt.Sprintf(`delete from mysql.global_variables where VARIABLE_NAME="%s";`, - variable.TiDBDistSQLScanConcurrency)) + mustExecSQL(c, se1, `delete from mysql.global_variables where VARIABLE_NAME= %?`, variable.TiDBDistSQLScanConcurrency) mustExecSQL(c, se1, `commit;`) unsetStoreBootstrapped(store.UUID()) // Make sure the version is downgraded. @@ -263,7 +262,7 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { c.Assert(ver, Equals, int64(currentBootstrapVersion)) // Verify that 'new_collation_enabled' is false. - r = mustExecSQL(c, se2, fmt.Sprintf(`SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME='%s';`, tidbNewCollationEnabled)) + r = mustExecSQL(c, se2, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME= %?;`, tidbNewCollationEnabled) req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) @@ -541,13 +540,11 @@ func (s *testBootstrapSuite) TestUpdateBindInfo(c *C) { defer dom.Close() se := newSession(c, store, s.dbName) for _, bindCase := range bindCases { - sql := fmt.Sprintf("insert into mysql.bind_info values('%s', '%s', '%s', 'using', '2021-01-04 14:50:58.257', '2021-01-04 14:50:58.257', 'utf8', 'utf8_general_ci', 'manual')", + mustExecSQL(c, se, "insert into mysql.bind_info values(%?, %?, %?, 'using', '2021-01-04 14:50:58.257', '2021-01-04 14:50:58.257', 'utf8', 'utf8_general_ci', 'manual')", bindCase.originText, bindCase.bindText, bindCase.db, ) - mustExecSQL(c, se, sql) - upgradeToVer61(se, version60) r := mustExecSQL(c, se, `select original_sql, bind_sql, default_db, status from mysql.bind_info where source != 'builtin'`) req := r.NewChunk() @@ -558,7 +555,7 @@ func (s *testBootstrapSuite) TestUpdateBindInfo(c *C) { c.Assert(row.GetString(2), Equals, "") c.Assert(row.GetString(3), Equals, "using") c.Assert(r.Close(), IsNil) - sql = fmt.Sprintf("drop global binding for %s", bindCase.deleteText) + sql := fmt.Sprintf("drop global binding for %s", bindCase.deleteText) mustExecSQL(c, se, sql) r = mustExecSQL(c, se, `select original_sql, bind_sql, status from mysql.bind_info where source != 'builtin'`) c.Assert(r.Next(ctx, req), IsNil) diff --git a/session/session.go b/session/session.go index fc4794fcb1a15..46efa01d59d3f 100644 --- a/session/session.go +++ b/session/session.go @@ -99,15 +99,6 @@ var ( tiKVGCAutoConcurrency = "tikv_gc_auto_concurrency" ) -var gcVariableComments = map[string]string{ - variable.TiDBGCRunInterval: "GC run interval, at least 10m, in Go format.", - variable.TiDBGCLifetime: "All versions within life time will not be collected by GC, at least 10m, in Go format.", - variable.TiDBGCConcurrency: "How many goroutines used to do GC parallel, [1, 128], default 2", - variable.TiDBGCEnable: "Current GC enable status", - tiKVGCAutoConcurrency: "Let TiDB pick the concurrency automatically. If set false, tikv_gc_concurrency will be used", - variable.TiDBGCScanLockMode: "Mode of scanning locks, \"physical\" or \"legacy\"", -} - var gcVariableMap = map[string]string{ variable.TiDBGCRunInterval: "tikv_gc_run_interval", variable.TiDBGCLifetime: "tikv_gc_life_time", @@ -978,15 +969,19 @@ func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet) ([]c } } -// getExecRet executes restricted sql and the result is one column. +// getTableValue executes restricted sql and the result is one column. // It returns a string value. -func (s *session) getExecRet(ctx sessionctx.Context, sql string) (string, error) { - rows, fields, err := s.ExecRestrictedSQL(sql) +func (s *session) getTableValue(ctx context.Context, tblName string, varName string) (string, error) { + stmt, err := s.ParseWithParams(ctx, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) + if err != nil { + return "", err + } + rows, fields, err := s.ExecRestrictedStmt(ctx, stmt) if err != nil { return "", err } if len(rows) == 0 { - return "", executor.ErrResultIsEmpty + return "", errResultIsEmpty } d := rows[0].GetDatum(0, &fields[0].Column.FieldType) value, err := d.ToString() @@ -996,6 +991,34 @@ func (s *session) getExecRet(ctx sessionctx.Context, sql string) (string, error) return value, nil } +var gcVariableComments = map[string]string{ + variable.TiDBGCRunInterval: "GC run interval, at least 10m, in Go format.", + variable.TiDBGCLifetime: "All versions within life time will not be collected by GC, at least 10m, in Go format.", + variable.TiDBGCConcurrency: "How many goroutines used to do GC parallel, [1, 128], default 2", + variable.TiDBGCEnable: "Current GC enable status", + tiKVGCAutoConcurrency: "Let TiDB pick the concurrency automatically. If set false, tikv_gc_concurrency will be used", + variable.TiDBGCScanLockMode: "Mode of scanning locks, \"physical\" or \"legacy\"", +} + +// replaceTableValue executes restricted sql updates the variable value +func (s *session) replaceTableValue(ctx context.Context, tblName string, varName, val string) error { + if tblName == mysql.TiDBTable { // maintain comment metadata + comment := gcVariableComments[varName] + stmt, err := s.ParseWithParams(ctx, `REPLACE INTO %n.%n (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, mysql.SystemDB, tblName, varName, val, comment) + if err != nil { + return err + } + _, _, err = s.ExecRestrictedStmt(ctx, stmt) + return err + } + stmt, err := s.ParseWithParams(ctx, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, tblName, varName, val) + if err != nil { + return err + } + _, _, err = s.ExecRestrictedStmt(ctx, stmt) + return err +} + func (s *session) varFromTiDBTable(name string) bool { switch name { case variable.TiDBGCConcurrency, variable.TiDBGCEnable, variable.TiDBGCRunInterval, variable.TiDBGCLifetime, variable.TiDBGCScanLockMode: @@ -1013,11 +1036,9 @@ func (s *session) GetGlobalSysVar(name string) (string, error) { // When running bootstrap or upgrade, we should not access global storage. return "", nil } - sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s";`, - mysql.SystemDB, mysql.GlobalVariablesTable, name) - sysVar, err := s.getExecRet(s, sql) + sysVar, err := s.getTableValue(context.TODO(), mysql.GlobalVariablesTable, name) if err != nil { - if executor.ErrResultIsEmpty.Equal(err) { + if errResultIsEmpty.Equal(err) { sv := variable.GetSysVar(name) if sv != nil { return sv.Value, nil @@ -1064,18 +1085,14 @@ func (s *session) SetGlobalSysVar(name, value string) error { } } variable.CheckDeprecationSetSystemVar(s.sessionVars, name) - sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`, - mysql.SystemDB, mysql.GlobalVariablesTable, name, escapeUserString(sVal)) - _, _, err = s.ExecRestrictedSQL(sql) + stmt, err := s.ParseWithParams(context.TODO(), "REPLACE %n.%n VALUES (%?, %?)", mysql.SystemDB, mysql.GlobalVariablesTable, name, sVal) + if err != nil { + return err + } + _, _, err = s.ExecRestrictedStmt(context.TODO(), stmt) return err } -// escape user supplied string for internal SQL. Not safe for all cases, since it doesn't -// handle quote-type, sql-mode, character set breakout. -func escapeUserString(str string) string { - return strings.ReplaceAll(str, `'`, `\'`) -} - // setTiDBTableValue handles tikv_* sysvars which need to update mysql.tidb // for backwards compatibility. Validation has already been performed. func (s *session) setTiDBTableValue(name, val string) error { @@ -1084,17 +1101,13 @@ func (s *session) setTiDBTableValue(name, val string) error { if val == "-1" { autoConcurrency = "true" } - sql := fmt.Sprintf(`INSERT INTO mysql.tidb (variable_name, variable_value, comment) VALUES ('%[1]s', '%[2]s', '%[3]s') - ON DUPLICATE KEY UPDATE variable_value = '%[2]s'`, tiKVGCAutoConcurrency, autoConcurrency, gcVariableComments[name]) - _, _, err := s.ExecRestrictedSQL(sql) + err := s.replaceTableValue(context.TODO(), mysql.TiDBTable, tiKVGCAutoConcurrency, autoConcurrency) if err != nil { return err } } val = onOffToTrueFalse(val) - sql := fmt.Sprintf(`INSERT INTO mysql.tidb (variable_name, variable_value, comment) VALUES ('%[1]s', '%[2]s', '%[3]s') - ON DUPLICATE KEY UPDATE variable_value = '%[2]s'`, gcVariableMap[name], escapeUserString(val), gcVariableComments[name]) - _, _, err := s.ExecRestrictedSQL(sql) + err := s.replaceTableValue(context.TODO(), mysql.TiDBTable, gcVariableMap[name], val) return err } @@ -1125,14 +1138,12 @@ func onOffToTrueFalse(str string) string { func (s *session) getTiDBTableValue(name, val string) (string, error) { if name == variable.TiDBGCConcurrency { // Check if autoconcurrency is set - sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM mysql.tidb WHERE VARIABLE_NAME='%s';`, tiKVGCAutoConcurrency) - autoConcurrencyVal, err := s.getExecRet(s, sql) + autoConcurrencyVal, err := s.getTableValue(context.TODO(), mysql.TiDBTable, tiKVGCAutoConcurrency) if err == nil && strings.EqualFold(autoConcurrencyVal, "true") { return "-1", nil // convention for "AUTO" } } - sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM mysql.tidb WHERE VARIABLE_NAME='%s';`, gcVariableMap[name]) - tblValue, err := s.getExecRet(s, sql) + tblValue, err := s.getTableValue(context.TODO(), mysql.TiDBTable, gcVariableMap[name]) if err != nil { return val, nil // mysql.tidb value does not exist. } @@ -1147,24 +1158,24 @@ func (s *session) getTiDBTableValue(name, val string) (string, error) { zap.String("tblName", gcVariableMap[name]), zap.String("tblValue", tblValue), zap.String("restoredValue", val)) - sql := fmt.Sprintf(`REPLACE INTO mysql.tidb (variable_name, variable_value, comment) - VALUES ('%s', '%s', '%s')`, gcVariableMap[name], escapeUserString(val), gcVariableComments[name]) - _, _, err = s.ExecRestrictedSQL(sql) + err = s.replaceTableValue(context.TODO(), mysql.TiDBTable, gcVariableMap[name], val) return val, err } if validatedVal != val { // The sysvar value is out of sync. - sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`, - mysql.SystemDB, mysql.GlobalVariablesTable, gcVariableMap[name], escapeUserString(validatedVal)) - _, _, err = s.ExecRestrictedSQL(sql) + err = s.replaceTableValue(context.TODO(), mysql.GlobalVariablesTable, gcVariableMap[name], validatedVal) return validatedVal, err } return validatedVal, nil } func (s *session) ensureFullGlobalStats() error { - rows, _, err := s.ExecRestrictedSQL(`select count(1) from information_schema.tables t where t.create_options = 'partitioned' - and not exists (select 1 from mysql.stats_meta m where m.table_id = t.tidb_table_id)`) + stmt, err := s.ParseWithParams(context.TODO(), `select count(1) from information_schema.tables t where t.create_options = 'partitioned' + and not exists (select 1 from mysql.stats_meta m where m.table_id = t.tidb_table_id)`) + if err != nil { + return err + } + rows, _, err := s.ExecRestrictedStmt(context.TODO(), stmt) if err != nil { return err } @@ -2150,11 +2161,6 @@ func CreateSessionWithOpt(store kv.Storage, opt *Opt) (Session, error) { return s, nil } -// loadSystemTZ loads systemTZ from mysql.tidb -func loadSystemTZ(se *session) (string, error) { - return loadParameter(se, "system_tz") -} - // loadCollationParameter loads collation parameter from mysql.tidb func loadCollationParameter(se *session) (bool, error) { para, err := loadParameter(se, tidbNewCollationEnabled) @@ -2210,25 +2216,7 @@ var ( // loadParameter loads read-only parameter from mysql.tidb func loadParameter(se *session, name string) (string, error) { - sql := "select variable_value from mysql.tidb where variable_name = '" + name + "'" - rs, errLoad := se.ExecuteInternal(context.Background(), sql) - if errLoad != nil { - return "", errLoad - } - // the record of mysql.tidb under where condition: variable_name = $name should shall only be one. - defer func() { - if err := rs.Close(); err != nil { - logutil.BgLogger().Error("close result set error", zap.Error(err)) - } - }() - req := rs.NewChunk() - if err := rs.Next(context.Background(), req); err != nil { - return "", err - } - if req.NumRows() == 0 { - return "", errResultIsEmpty - } - return req.GetRow(0).GetString(0), nil + return se.getTableValue(context.TODO(), mysql.TiDBTable, name) } // BootstrapSession runs the first time when the TiDB server start. @@ -2245,8 +2233,6 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { } } - initLoadCommonGlobalVarsSQL() - ver := getStoreBootstrapVersion(store) if ver == notBootstrapped { runInBootstrapSession(store, bootstrap) @@ -2258,8 +2244,11 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { if err != nil { return nil, err } + + se.initLoadCommonGlobalVarsSQL() + // get system tz from mysql.tidb - tz, err := loadSystemTZ(se) + tz, err := se.getTableValue(context.TODO(), mysql.TiDBTable, "system_tz") if err != nil { return nil, err } @@ -2624,22 +2613,26 @@ var builtinGlobalVariable = []string{ var ( loadCommonGlobalVarsSQLOnce sync.Once - loadCommonGlobalVarsSQL string + loadCommonGlobalVarsStmt ast.StmtNode ) -func initLoadCommonGlobalVarsSQL() { +func (s *session) initLoadCommonGlobalVarsSQL() { loadCommonGlobalVarsSQLOnce.Do(func() { vars := append(make([]string, 0, len(builtinGlobalVariable)+len(variable.PluginVarNames)), builtinGlobalVariable...) if len(variable.PluginVarNames) > 0 { vars = append(vars, variable.PluginVarNames...) } - loadCommonGlobalVarsSQL = "select HIGH_PRIORITY * from mysql.global_variables where variable_name in ('" + strings.Join(vars, quoteCommaQuote) + "')" + var err error + loadCommonGlobalVarsStmt, err = s.ParseWithParams(context.TODO(), "SELECT HIGH_PRIORITY * from mysql.global_variables where variable_name in (%?)", vars) + if err != nil { + loadCommonGlobalVarsStmt = nil + } }) } // loadCommonGlobalVariablesIfNeeded loads and applies commonly used global variables for the session. func (s *session) loadCommonGlobalVariablesIfNeeded() error { - initLoadCommonGlobalVarsSQL() + s.initLoadCommonGlobalVarsSQL() vars := s.sessionVars if vars.CommonGlobalLoaded { return nil @@ -2654,7 +2647,7 @@ func (s *session) loadCommonGlobalVariablesIfNeeded() error { // When a lot of connections connect to TiDB simultaneously, it can protect TiKV meta region from overload. gvc := domain.GetDomain(s).GetGlobalVarsCache() loadFunc := func() ([]chunk.Row, []*ast.ResultField, error) { - return s.ExecRestrictedSQL(loadCommonGlobalVarsSQL) + return s.ExecRestrictedStmt(context.TODO(), loadCommonGlobalVarsStmt) } rows, fields, err := gvc.LoadGlobalVariables(loadFunc) if err != nil { diff --git a/session/session_fail_test.go b/session/session_fail_test.go index 661dabae718a4..2d53625a31b29 100644 --- a/session/session_fail_test.go +++ b/session/session_fail_test.go @@ -49,7 +49,7 @@ func (s *testSessionSerialSuite) TestGetTSFailDirtyState(c *C) { ctx := failpoint.WithHook(context.Background(), func(ctx context.Context, fpname string) bool { return fpname == "github.com/pingcap/tidb/session/mockGetTSFail" }) - _, err := tk.Se.Execute(ctx, "select * from t") + _, err := tk.Se.ExecuteInternal(ctx, "select * from t") c.Assert(err, NotNil) // Fix a bug that active txn fail set TxnState.fail to error, and then the following write diff --git a/session/session_test.go b/session/session_test.go index 9d356dddf0e0e..ee1f24561f2d2 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -319,11 +319,11 @@ func (s *testSessionSuite) TestQueryString(c *C) { c.Assert(queryStr, Equals, "create table multi2 (a int)") // Test execution of DDL through the "ExecutePreparedStmt" interface. - _, err := tk.Se.Execute(context.Background(), "use test;") + _, err := tk.Se.ExecuteInternal(context.Background(), "use test;") c.Assert(err, IsNil) - _, err = tk.Se.Execute(context.Background(), "CREATE TABLE t (id bigint PRIMARY KEY, age int)") + _, err = tk.Se.ExecuteInternal(context.Background(), "CREATE TABLE t (id bigint PRIMARY KEY, age int)") c.Assert(err, IsNil) - _, err = tk.Se.Execute(context.Background(), "show create table t") + _, err = tk.Se.ExecuteInternal(context.Background(), "show create table t") c.Assert(err, IsNil) id, _, _, err := tk.Se.PrepareStmt("CREATE TABLE t2(id bigint PRIMARY KEY, age int)") c.Assert(err, IsNil) @@ -334,13 +334,13 @@ func (s *testSessionSuite) TestQueryString(c *C) { c.Assert(qs.(string), Equals, "CREATE TABLE t2(id bigint PRIMARY KEY, age int)") // Test execution of DDL through the "Execute" interface. - _, err = tk.Se.Execute(context.Background(), "use test;") + _, err = tk.Se.ExecuteInternal(context.Background(), "use test;") c.Assert(err, IsNil) - _, err = tk.Se.Execute(context.Background(), "drop table t2") + _, err = tk.Se.ExecuteInternal(context.Background(), "drop table t2") c.Assert(err, IsNil) - _, err = tk.Se.Execute(context.Background(), "prepare stmt from 'CREATE TABLE t2(id bigint PRIMARY KEY, age int)'") + _, err = tk.Se.ExecuteInternal(context.Background(), "prepare stmt from 'CREATE TABLE t2(id bigint PRIMARY KEY, age int)'") c.Assert(err, IsNil) - _, err = tk.Se.Execute(context.Background(), "execute stmt") + _, err = tk.Se.ExecuteInternal(context.Background(), "execute stmt") c.Assert(err, IsNil) qs = tk.Se.Value(sessionctx.QueryString) c.Assert(qs.(string), Equals, "CREATE TABLE t2(id bigint PRIMARY KEY, age int)") @@ -2624,17 +2624,17 @@ func (s *testSessionSuite3) TestSetTransactionIsolationOneShot(c *C) { ctx := context.WithValue(context.Background(), "CheckSelectRequestHook", func(req *kv.Request) { c.Assert(req.IsolationLevel, Equals, kv.SI) }) - tk.Se.Execute(ctx, "select * from t where k = 1") + tk.Se.ExecuteInternal(ctx, "select * from t where k = 1") // Check it just take effect for one time. ctx = context.WithValue(context.Background(), "CheckSelectRequestHook", func(req *kv.Request) { c.Assert(req.IsolationLevel, Equals, kv.SI) }) - tk.Se.Execute(ctx, "select * from t where k = 1") + tk.Se.ExecuteInternal(ctx, "select * from t where k = 1") // Can't change isolation level when it's inside a transaction. tk.MustExec("begin") - _, err := tk.Se.Execute(ctx, "set transaction isolation level read committed") + _, err := tk.Se.ExecuteInternal(ctx, "set transaction isolation level read committed") c.Assert(err, NotNil) } @@ -2692,7 +2692,7 @@ func (s *testSessionSuite2) TestCommitRetryCount(c *C) { tk2.MustExec("commit") // No auto retry because retry limit is set to 0. - _, err := tk1.Se.Execute(context.Background(), "commit") + _, err := tk1.Se.ExecuteInternal(context.Background(), "commit") c.Assert(err, NotNil) } @@ -2723,7 +2723,7 @@ func (s *testSessionSerialSuite) TestTxnRetryErrMsg(c *C) { tk2.MustExec("update no_retry set id = id + 1") tk1.MustExec("update no_retry set id = id + 1") c.Assert(tikv.MockRetryableErrorResp.Enable(`return(true)`), IsNil) - _, err := tk1.Se.Execute(context.Background(), "commit") + _, err := tk1.Se.ExecuteInternal(context.Background(), "commit") tikv.MockRetryableErrorResp.Disable() c.Assert(err, NotNil) c.Assert(kv.ErrTxnRetryable.Equal(err), IsTrue, Commentf("error: %s", err)) @@ -2746,7 +2746,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk2.MustExec("commit") // No auto retry because tidb_disable_txn_auto_retry is set to 1. - _, err := tk1.Se.Execute(context.Background(), "commit") + _, err := tk1.Se.ExecuteInternal(context.Background(), "commit") c.Assert(err, NotNil) // session 1 starts a transaction early. @@ -2778,7 +2778,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk2.MustExec("update no_retry set id = 8") - _, err = tk1.Se.Execute(context.Background(), "commit") + _, err = tk1.Se.ExecuteInternal(context.Background(), "commit") c.Assert(err, NotNil) c.Assert(kv.ErrWriteConflict.Equal(err), IsTrue, Commentf("error: %s", err)) c.Assert(strings.Contains(err.Error(), kv.TxnRetryableMark), IsTrue, Commentf("error: %s", err)) @@ -2791,7 +2791,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk2.MustExec("alter table no_retry add index idx(id)") tk2.MustQuery("select * from no_retry").Check(testkit.Rows("8")) tk1.MustExec("update no_retry set id = 10") - _, err = tk1.Se.Execute(context.Background(), "commit") + _, err = tk1.Se.ExecuteInternal(context.Background(), "commit") c.Assert(err, NotNil) // set autocommit to begin and commit @@ -2799,7 +2799,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk1.MustQuery("select * from no_retry").Check(testkit.Rows("8")) tk2.MustExec("update no_retry set id = 11") tk1.MustExec("update no_retry set id = 12") - _, err = tk1.Se.Execute(context.Background(), "set autocommit = 1") + _, err = tk1.Se.ExecuteInternal(context.Background(), "set autocommit = 1") c.Assert(err, NotNil) c.Assert(kv.ErrWriteConflict.Equal(err), IsTrue, Commentf("error: %s", err)) c.Assert(strings.Contains(err.Error(), kv.TxnRetryableMark), IsTrue, Commentf("error: %s", err)) @@ -2810,7 +2810,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk1.MustQuery("select * from no_retry").Check(testkit.Rows("11")) tk2.MustExec("update no_retry set id = 13") tk1.MustExec("update no_retry set id = 14") - _, err = tk1.Se.Execute(context.Background(), "commit") + _, err = tk1.Se.ExecuteInternal(context.Background(), "commit") c.Assert(err, NotNil) c.Assert(kv.ErrWriteConflict.Equal(err), IsTrue, Commentf("error: %s", err)) c.Assert(strings.Contains(err.Error(), kv.TxnRetryableMark), IsTrue, Commentf("error: %s", err)) @@ -4102,7 +4102,7 @@ func (s *testSessionSerialSuite) TestTiKVSystemVars(c *C) { result.Check(testkit.Rows("true")) tk.MustExec("UPDATE mysql.tidb SET variable_value = 'false' WHERE variable_name='tikv_gc_enable'") - result = tk.MustQuery("SELECT @@tidb_gc_enable;") + result = tk.MustQuery("SELECT @@tidb_gc_enable") result.Check(testkit.Rows("0")) // reads from mysql.tidb value and changes to false tk.MustExec("SET GLOBAL tidb_gc_concurrency = -1") // sets auto concurrency and concurrency diff --git a/session/tidb_test.go b/session/tidb_test.go index eff7406332925..7ffe682f094ba 100644 --- a/session/tidb_test.go +++ b/session/tidb_test.go @@ -141,26 +141,7 @@ func removeStore(c *C, dbPath string) { func exec(se Session, sql string, args ...interface{}) (sqlexec.RecordSet, error) { ctx := context.Background() - if len(args) == 0 { - rs, err := se.Execute(ctx, sql) - if err == nil && len(rs) > 0 { - return rs[0], nil - } - return nil, err - } - stmtID, _, _, err := se.PrepareStmt(sql) - if err != nil { - return nil, err - } - params := make([]types.Datum, len(args)) - for i := 0; i < len(params); i++ { - params[i] = types.NewDatum(args[i]) - } - rs, err := se.ExecutePreparedStmt(ctx, stmtID, params) - if err != nil { - return nil, err - } - return rs, nil + return se.ExecuteInternal(ctx, sql, args...) } func mustExecSQL(c *C, se Session, sql string, args ...interface{}) sqlexec.RecordSet { diff --git a/util/testkit/testkit.go b/util/testkit/testkit.go index 1b52f78549678..db16779fdbb8a 100644 --- a/util/testkit/testkit.go +++ b/util/testkit/testkit.go @@ -147,7 +147,7 @@ func (tk *TestKit) GetConnectionID() { } } -// Exec executes a sql statement. +// Exec executes a sql statement using the prepared stmt API func (tk *TestKit) Exec(sql string, args ...interface{}) (sqlexec.RecordSet, error) { var err error if tk.Se == nil { From 08d46311364d0bc6228dca717f35de06e62207fd Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Tue, 2 Feb 2021 11:15:55 -0700 Subject: [PATCH 2/6] Address PR review comments --- session/bench_test.go | 6 +++--- session/tidb_test.go | 12 ++++-------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/session/bench_test.go b/session/bench_test.go index 2e4e29b5f8f14..5078fdd80bdf1 100644 --- a/session/bench_test.go +++ b/session/bench_test.go @@ -53,7 +53,7 @@ func prepareBenchSession() (Session, *domain.Domain, kv.Storage) { return se, domain, store } -func prepareBenchData(se Session, colType string, valueFormat string, valueCount int) { +func prepareBenchData(se Session, colType string, valueCount int) { mustExecute(se, "drop table if exists t") mustExecute(se, "create table t (pk int primary key auto_increment, col %n, index idx (col))", colType) mustExecute(se, "begin") @@ -63,7 +63,7 @@ func prepareBenchData(se Session, colType string, valueFormat string, valueCount mustExecute(se, "commit") } -func prepareSortBenchData(se Session, colType string, valueFormat string, valueCount int) { +func prepareSortBenchData(se Session, colType string, valueCount int) { mustExecute(se, "drop table if exists t") mustExecute(se, "create table t (pk int primary key auto_increment, col %n)", colType) mustExecute(se, "begin") @@ -382,7 +382,7 @@ func BenchmarkSort(b *testing.B) { st.Close() do.Close() }() - prepareSortBenchData(se, "int", "%v", bigCount) + prepareSortBenchData(se, "int", bigCount) b.ResetTimer() for i := 0; i < b.N; i++ { rs, err := se.ExecuteInternal(ctx, "select * from t order by col limit 50") diff --git a/session/tidb_test.go b/session/tidb_test.go index 7ffe682f094ba..f057e88984bba 100644 --- a/session/tidb_test.go +++ b/session/tidb_test.go @@ -130,8 +130,8 @@ func newSession(c *C, store kv.Storage, dbName string) Session { se.SetConnectionID(id) c.Assert(err, IsNil) se.Auth(&auth.UserIdentity{Username: "root", Hostname: `%`}, nil, []byte("012345678901234567890")) - mustExecSQL(c, se, "create database if not exists "+dbName) - mustExecSQL(c, se, "use "+dbName) + mustExecSQL(c, se, "create database if not exists %n", dbName) + mustExecSQL(c, se, "use %n", dbName) return se } @@ -139,13 +139,9 @@ func removeStore(c *C, dbPath string) { os.RemoveAll(dbPath) } -func exec(se Session, sql string, args ...interface{}) (sqlexec.RecordSet, error) { - ctx := context.Background() - return se.ExecuteInternal(ctx, sql, args...) -} - func mustExecSQL(c *C, se Session, sql string, args ...interface{}) sqlexec.RecordSet { - rs, err := exec(se, sql, args...) + ctx := context.Background() + rs, err := se.ExecuteInternal(ctx, sql, args...) c.Assert(err, IsNil) return rs } From d7f9473c7ea5ff5dac726f83276ab3b4b6cd5638 Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Tue, 2 Feb 2021 19:38:07 -0700 Subject: [PATCH 3/6] restore valueFormat --- session/bench_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/session/bench_test.go b/session/bench_test.go index 5078fdd80bdf1..67f84e194e714 100644 --- a/session/bench_test.go +++ b/session/bench_test.go @@ -53,12 +53,12 @@ func prepareBenchSession() (Session, *domain.Domain, kv.Storage) { return se, domain, store } -func prepareBenchData(se Session, colType string, valueCount int) { +func prepareBenchData(se Session, colType string, valueFormat string, valueCount int) { mustExecute(se, "drop table if exists t") mustExecute(se, "create table t (pk int primary key auto_increment, col %n, index idx (col))", colType) mustExecute(se, "begin") for i := 0; i < valueCount; i++ { - mustExecute(se, "insert t (col) values (%?)", i) + mustExecute(se, "insert t (col) values (%?)", fmt.Sprintf(valueFormat, i)) } mustExecute(se, "commit") } From d9931dde937d482585aac8f8016231c7a18989be Mon Sep 17 00:00:00 2001 From: Morgan Tocker Date: Wed, 3 Feb 2021 16:02:32 -0700 Subject: [PATCH 4/6] revert ExecuteInternal from tests --- session/bench_test.go | 104 +++++++++++++++++------------------ session/bootstrap_test.go | 13 +++-- session/session_fail_test.go | 2 +- session/session_test.go | 43 ++++++++------- session/tidb_test.go | 31 +++++++++-- 5 files changed, 110 insertions(+), 83 deletions(-) diff --git a/session/bench_test.go b/session/bench_test.go index 67f84e194e714..02d1889f9b73d 100644 --- a/session/bench_test.go +++ b/session/bench_test.go @@ -55,17 +55,17 @@ func prepareBenchSession() (Session, *domain.Domain, kv.Storage) { func prepareBenchData(se Session, colType string, valueFormat string, valueCount int) { mustExecute(se, "drop table if exists t") - mustExecute(se, "create table t (pk int primary key auto_increment, col %n, index idx (col))", colType) + mustExecute(se, fmt.Sprintf("create table t (pk int primary key auto_increment, col %s, index idx (col))", colType)) mustExecute(se, "begin") for i := 0; i < valueCount; i++ { - mustExecute(se, "insert t (col) values (%?)", fmt.Sprintf(valueFormat, i)) + mustExecute(se, "insert t (col) values ("+fmt.Sprintf(valueFormat, i)+")") } mustExecute(se, "commit") } -func prepareSortBenchData(se Session, colType string, valueCount int) { +func prepareSortBenchData(se Session, colType string, valueFormat string, valueCount int) { mustExecute(se, "drop table if exists t") - mustExecute(se, "create table t (pk int primary key auto_increment, col %n)", colType) + mustExecute(se, fmt.Sprintf("create table t (pk int primary key auto_increment, col %s)", colType)) mustExecute(se, "begin") r := rand.New(rand.NewSource(time.Now().UnixNano())) for i := 0; i < valueCount; i++ { @@ -73,17 +73,17 @@ func prepareSortBenchData(se Session, colType string, valueCount int) { mustExecute(se, "commit") mustExecute(se, "begin") } - mustExecute(se, "insert t (col) values (%?)", r.Intn(valueCount)) + mustExecute(se, "insert t (col) values ("+fmt.Sprintf(valueFormat, r.Intn(valueCount))+")") } mustExecute(se, "commit") } func prepareJoinBenchData(se Session, colType string, valueFormat string, valueCount int) { mustExecute(se, "drop table if exists t") - mustExecute(se, "create table t (pk int primary key auto_increment, col %n)", colType) + mustExecute(se, fmt.Sprintf("create table t (pk int primary key auto_increment, col %s)", colType)) mustExecute(se, "begin") for i := 0; i < valueCount; i++ { - mustExecute(se, "insert t (col) values (%?)", i) + mustExecute(se, "insert t (col) values ("+fmt.Sprintf(valueFormat, i)+")") } mustExecute(se, "commit") } @@ -113,11 +113,11 @@ func BenchmarkBasic(b *testing.B) { }() b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select 1") + rs, err := se.Execute(ctx, "select 1") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 1) + readResult(ctx, rs[0], 1) } b.StopTimer() } @@ -133,11 +133,11 @@ func BenchmarkTableScan(b *testing.B) { prepareBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t") + rs, err := se.Execute(ctx, "select * from t") if err != nil { b.Fatal(err) } - readResult(ctx, rs, smallCount) + readResult(ctx, rs[0], smallCount) } b.StopTimer() } @@ -153,11 +153,11 @@ func BenchmarkExplainTableScan(b *testing.B) { prepareBenchData(se, "int", "%v", 0) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "explain select * from t") + rs, err := se.Execute(ctx, "explain select * from t") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 1) + readResult(ctx, rs[0], 1) } b.StopTimer() } @@ -173,11 +173,11 @@ func BenchmarkTableLookup(b *testing.B) { prepareBenchData(se, "int", "%d", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where pk = 64") + rs, err := se.Execute(ctx, "select * from t where pk = 64") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 1) + readResult(ctx, rs[0], 1) } b.StopTimer() } @@ -193,11 +193,11 @@ func BenchmarkExplainTableLookup(b *testing.B) { prepareBenchData(se, "int", "%d", 0) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "explain select * from t where pk = 64") + rs, err := se.Execute(ctx, "explain select * from t where pk = 64") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 1) + readResult(ctx, rs[0], 1) } b.StopTimer() } @@ -213,11 +213,11 @@ func BenchmarkStringIndexScan(b *testing.B) { prepareBenchData(se, "varchar(255)", "'hello %d'", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where col > 'hello'") + rs, err := se.Execute(ctx, "select * from t where col > 'hello'") if err != nil { b.Fatal(err) } - readResult(ctx, rs, smallCount) + readResult(ctx, rs[0], smallCount) } b.StopTimer() } @@ -233,11 +233,11 @@ func BenchmarkExplainStringIndexScan(b *testing.B) { prepareBenchData(se, "varchar(255)", "'hello %d'", 0) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "explain select * from t where col > 'hello'") + rs, err := se.Execute(ctx, "explain select * from t where col > 'hello'") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 1) + readResult(ctx, rs[0], 1) } b.StopTimer() } @@ -253,11 +253,11 @@ func BenchmarkStringIndexLookup(b *testing.B) { prepareBenchData(se, "varchar(255)", "'hello %d'", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where col = 'hello 64'") + rs, err := se.Execute(ctx, "select * from t where col = 'hello 64'") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 1) + readResult(ctx, rs[0], 1) } b.StopTimer() } @@ -273,11 +273,11 @@ func BenchmarkIntegerIndexScan(b *testing.B) { prepareBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where col >= 0") + rs, err := se.Execute(ctx, "select * from t where col >= 0") if err != nil { b.Fatal(err) } - readResult(ctx, rs, smallCount) + readResult(ctx, rs[0], smallCount) } b.StopTimer() } @@ -293,11 +293,11 @@ func BenchmarkIntegerIndexLookup(b *testing.B) { prepareBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where col = 64") + rs, err := se.Execute(ctx, "select * from t where col = 64") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 1) + readResult(ctx, rs[0], 1) } b.StopTimer() } @@ -313,11 +313,11 @@ func BenchmarkDecimalIndexScan(b *testing.B) { prepareBenchData(se, "decimal(32,6)", "%v.1234", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where col >= 0") + rs, err := se.Execute(ctx, "select * from t where col >= 0") if err != nil { b.Fatal(err) } - readResult(ctx, rs, smallCount) + readResult(ctx, rs[0], smallCount) } b.StopTimer() } @@ -333,11 +333,11 @@ func BenchmarkDecimalIndexLookup(b *testing.B) { prepareBenchData(se, "decimal(32,6)", "%v.1234", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where col = 64.1234") + rs, err := se.Execute(ctx, "select * from t where col = 64.1234") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 1) + readResult(ctx, rs[0], 1) } b.StopTimer() } @@ -353,7 +353,7 @@ func BenchmarkInsertWithIndex(b *testing.B) { mustExecute(se, "create table t (pk int primary key, col int, index idx (col))") b.ResetTimer() for i := 0; i < b.N; i++ { - mustExecute(se, "insert t values (%d, %d)", i, i) + mustExecute(se, fmt.Sprintf("insert t values (%d, %d)", i, i)) } b.StopTimer() } @@ -369,7 +369,7 @@ func BenchmarkInsertNoIndex(b *testing.B) { mustExecute(se, "create table t (pk int primary key, col int)") b.ResetTimer() for i := 0; i < b.N; i++ { - mustExecute(se, "insert t values (%d, %d)", i, i) + mustExecute(se, fmt.Sprintf("insert t values (%d, %d)", i, i)) } b.StopTimer() } @@ -382,14 +382,14 @@ func BenchmarkSort(b *testing.B) { st.Close() do.Close() }() - prepareSortBenchData(se, "int", bigCount) + prepareSortBenchData(se, "int", "%v", bigCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t order by col limit 50") + rs, err := se.Execute(ctx, "select * from t order by col limit 50") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 50) + readResult(ctx, rs[0], 50) } b.StopTimer() } @@ -405,11 +405,11 @@ func BenchmarkJoin(b *testing.B) { prepareJoinBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t a join t b on a.col = b.col") + rs, err := se.Execute(ctx, "select * from t a join t b on a.col = b.col") if err != nil { b.Fatal(err) } - readResult(ctx, rs, smallCount) + readResult(ctx, rs[0], smallCount) } b.StopTimer() } @@ -425,11 +425,11 @@ func BenchmarkJoinLimit(b *testing.B) { prepareJoinBenchData(se, "int", "%v", smallCount) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t a join t b on a.col = b.col limit 1") + rs, err := se.Execute(ctx, "select * from t a join t b on a.col = b.col limit 1") if err != nil { b.Fatal(err) } - readResult(ctx, rs, 1) + readResult(ctx, rs[0], 1) } b.StopTimer() } @@ -1472,11 +1472,11 @@ partition p1023 values less than (738538) )`) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where dt > to_days('2019-04-01 21:00:00') and dt < to_days('2019-04-07 23:59:59')") + rs, err := se.Execute(ctx, "select * from t where dt > to_days('2019-04-01 21:00:00') and dt < to_days('2019-04-07 23:59:59')") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs) + _, err = drainRecordSet(ctx, se.(*session), rs[0]) if err != nil { b.Fatal(err) } @@ -1504,11 +1504,11 @@ func BenchmarkRangeColumnPartitionPruning(b *testing.B) { mustExecute(se, build.String()) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where dt > '2020-05-01' and dt < '2020-06-07'") + rs, err := se.Execute(ctx, "select * from t where dt > '2020-05-01' and dt < '2020-06-07'") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs) + _, err = drainRecordSet(ctx, se.(*session), rs[0]) if err != nil { b.Fatal(err) } @@ -1528,11 +1528,11 @@ func BenchmarkHashPartitionPruningPointSelect(b *testing.B) { mustExecute(se, `create table t (id int, dt datetime) partition by hash(id) partitions 1024;`) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where id = 2330") + rs, err := se.Execute(ctx, "select * from t where id = 2330") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs) + _, err = drainRecordSet(ctx, se.(*session), rs[0]) if err != nil { b.Fatal(err) } @@ -1552,27 +1552,27 @@ func BenchmarkHashPartitionPruningMultiSelect(b *testing.B) { mustExecute(se, `create table t (id int, dt datetime) partition by hash(id) partitions 1024;`) b.ResetTimer() for i := 0; i < b.N; i++ { - rs, err := se.ExecuteInternal(ctx, "select * from t where id = 2330") + rs, err := se.Execute(ctx, "select * from t where id = 2330") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs) + _, err = drainRecordSet(ctx, se.(*session), rs[0]) if err != nil { b.Fatal(err) } - rs, err = se.ExecuteInternal(ctx, "select * from t where id = 1233 or id = 1512") + rs, err = se.Execute(ctx, "select * from t where id = 1233 or id = 1512") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs) + _, err = drainRecordSet(ctx, se.(*session), rs[0]) if err != nil { b.Fatal(err) } - rs, err = se.ExecuteInternal(ctx, "select * from t where id in (117, 1233, 15678)") + rs, err = se.Execute(ctx, "select * from t where id in (117, 1233, 15678)") if err != nil { b.Fatal(err) } - _, err = drainRecordSet(ctx, se.(*session), rs) + _, err = drainRecordSet(ctx, se.(*session), rs[0]) if err != nil { b.Fatal(err) } diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index bec967310e048..5470082eff101 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -81,7 +81,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { se, err = CreateSession4Test(store) c.Assert(err, IsNil) mustExecSQL(c, se, "USE test;") - mustExecSQL(c, se, "insert t values (%?)", 3) + mustExecSQL(c, se, "insert t values (?)", 3) se, err = CreateSession4Test(store) c.Assert(err, IsNil) mustExecSQL(c, se, "USE test;") @@ -227,7 +227,8 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { err = txn.Commit(context.Background()) c.Assert(err, IsNil) mustExecSQL(c, se1, `delete from mysql.TiDB where VARIABLE_NAME="tidb_server_version";`) - mustExecSQL(c, se1, `delete from mysql.global_variables where VARIABLE_NAME= %?`, variable.TiDBDistSQLScanConcurrency) + mustExecSQL(c, se1, fmt.Sprintf(`delete from mysql.global_variables where VARIABLE_NAME="%s";`, + variable.TiDBDistSQLScanConcurrency)) mustExecSQL(c, se1, `commit;`) unsetStoreBootstrapped(store.UUID()) // Make sure the version is downgraded. @@ -262,7 +263,7 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { c.Assert(ver, Equals, int64(currentBootstrapVersion)) // Verify that 'new_collation_enabled' is false. - r = mustExecSQL(c, se2, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME= %?;`, tidbNewCollationEnabled) + r = mustExecSQL(c, se2, fmt.Sprintf(`SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME='%s';`, tidbNewCollationEnabled)) req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) @@ -540,11 +541,13 @@ func (s *testBootstrapSuite) TestUpdateBindInfo(c *C) { defer dom.Close() se := newSession(c, store, s.dbName) for _, bindCase := range bindCases { - mustExecSQL(c, se, "insert into mysql.bind_info values(%?, %?, %?, 'using', '2021-01-04 14:50:58.257', '2021-01-04 14:50:58.257', 'utf8', 'utf8_general_ci', 'manual')", + sql := fmt.Sprintf("insert into mysql.bind_info values('%s', '%s', '%s', 'using', '2021-01-04 14:50:58.257', '2021-01-04 14:50:58.257', 'utf8', 'utf8_general_ci', 'manual')", bindCase.originText, bindCase.bindText, bindCase.db, ) + mustExecSQL(c, se, sql) + upgradeToVer61(se, version60) r := mustExecSQL(c, se, `select original_sql, bind_sql, default_db, status from mysql.bind_info where source != 'builtin'`) req := r.NewChunk() @@ -555,7 +558,7 @@ func (s *testBootstrapSuite) TestUpdateBindInfo(c *C) { c.Assert(row.GetString(2), Equals, "") c.Assert(row.GetString(3), Equals, "using") c.Assert(r.Close(), IsNil) - sql := fmt.Sprintf("drop global binding for %s", bindCase.deleteText) + sql = fmt.Sprintf("drop global binding for %s", bindCase.deleteText) mustExecSQL(c, se, sql) r = mustExecSQL(c, se, `select original_sql, bind_sql, status from mysql.bind_info where source != 'builtin'`) c.Assert(r.Next(ctx, req), IsNil) diff --git a/session/session_fail_test.go b/session/session_fail_test.go index 2d53625a31b29..661dabae718a4 100644 --- a/session/session_fail_test.go +++ b/session/session_fail_test.go @@ -49,7 +49,7 @@ func (s *testSessionSerialSuite) TestGetTSFailDirtyState(c *C) { ctx := failpoint.WithHook(context.Background(), func(ctx context.Context, fpname string) bool { return fpname == "github.com/pingcap/tidb/session/mockGetTSFail" }) - _, err := tk.Se.ExecuteInternal(ctx, "select * from t") + _, err := tk.Se.Execute(ctx, "select * from t") c.Assert(err, NotNil) // Fix a bug that active txn fail set TxnState.fail to error, and then the following write diff --git a/session/session_test.go b/session/session_test.go index ee1f24561f2d2..332f64465a1a1 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -45,6 +45,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/binloginfo" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/store" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/mockstore/cluster" "github.com/pingcap/tidb/store/mockstore/mocktikv" @@ -122,7 +123,7 @@ func clearStorage(store kv.Storage) error { return txn.Commit(context.Background()) } -func clearETCD(ebd tikv.EtcdBackend) error { +func clearETCD(ebd kv.EtcdBackend) error { endpoints, err := ebd.EtcdAddrs() if err != nil { return err @@ -178,7 +179,7 @@ func (s *testSessionSuiteBase) SetUpSuite(c *C) { if *withTiKV { initPdAddrs() s.pdAddr = <-pdAddrChan - var d tikv.Driver + var d store.TiKVDriver config.UpdateGlobal(func(conf *config.Config) { conf.TxnLocalLatches.Enabled = false }) @@ -186,7 +187,7 @@ func (s *testSessionSuiteBase) SetUpSuite(c *C) { c.Assert(err, IsNil) err = clearStorage(store) c.Assert(err, IsNil) - err = clearETCD(store.(tikv.EtcdBackend)) + err = clearETCD(store.(kv.EtcdBackend)) c.Assert(err, IsNil) session.ResetStoreForWithTiKVTest(store) s.store = store @@ -319,11 +320,11 @@ func (s *testSessionSuite) TestQueryString(c *C) { c.Assert(queryStr, Equals, "create table multi2 (a int)") // Test execution of DDL through the "ExecutePreparedStmt" interface. - _, err := tk.Se.ExecuteInternal(context.Background(), "use test;") + _, err := tk.Se.Execute(context.Background(), "use test;") c.Assert(err, IsNil) - _, err = tk.Se.ExecuteInternal(context.Background(), "CREATE TABLE t (id bigint PRIMARY KEY, age int)") + _, err = tk.Se.Execute(context.Background(), "CREATE TABLE t (id bigint PRIMARY KEY, age int)") c.Assert(err, IsNil) - _, err = tk.Se.ExecuteInternal(context.Background(), "show create table t") + _, err = tk.Se.Execute(context.Background(), "show create table t") c.Assert(err, IsNil) id, _, _, err := tk.Se.PrepareStmt("CREATE TABLE t2(id bigint PRIMARY KEY, age int)") c.Assert(err, IsNil) @@ -334,13 +335,13 @@ func (s *testSessionSuite) TestQueryString(c *C) { c.Assert(qs.(string), Equals, "CREATE TABLE t2(id bigint PRIMARY KEY, age int)") // Test execution of DDL through the "Execute" interface. - _, err = tk.Se.ExecuteInternal(context.Background(), "use test;") + _, err = tk.Se.Execute(context.Background(), "use test;") c.Assert(err, IsNil) - _, err = tk.Se.ExecuteInternal(context.Background(), "drop table t2") + _, err = tk.Se.Execute(context.Background(), "drop table t2") c.Assert(err, IsNil) - _, err = tk.Se.ExecuteInternal(context.Background(), "prepare stmt from 'CREATE TABLE t2(id bigint PRIMARY KEY, age int)'") + _, err = tk.Se.Execute(context.Background(), "prepare stmt from 'CREATE TABLE t2(id bigint PRIMARY KEY, age int)'") c.Assert(err, IsNil) - _, err = tk.Se.ExecuteInternal(context.Background(), "execute stmt") + _, err = tk.Se.Execute(context.Background(), "execute stmt") c.Assert(err, IsNil) qs = tk.Se.Value(sessionctx.QueryString) c.Assert(qs.(string), Equals, "CREATE TABLE t2(id bigint PRIMARY KEY, age int)") @@ -2624,17 +2625,17 @@ func (s *testSessionSuite3) TestSetTransactionIsolationOneShot(c *C) { ctx := context.WithValue(context.Background(), "CheckSelectRequestHook", func(req *kv.Request) { c.Assert(req.IsolationLevel, Equals, kv.SI) }) - tk.Se.ExecuteInternal(ctx, "select * from t where k = 1") + tk.Se.Execute(ctx, "select * from t where k = 1") // Check it just take effect for one time. ctx = context.WithValue(context.Background(), "CheckSelectRequestHook", func(req *kv.Request) { c.Assert(req.IsolationLevel, Equals, kv.SI) }) - tk.Se.ExecuteInternal(ctx, "select * from t where k = 1") + tk.Se.Execute(ctx, "select * from t where k = 1") // Can't change isolation level when it's inside a transaction. tk.MustExec("begin") - _, err := tk.Se.ExecuteInternal(ctx, "set transaction isolation level read committed") + _, err := tk.Se.Execute(ctx, "set transaction isolation level read committed") c.Assert(err, NotNil) } @@ -2692,7 +2693,7 @@ func (s *testSessionSuite2) TestCommitRetryCount(c *C) { tk2.MustExec("commit") // No auto retry because retry limit is set to 0. - _, err := tk1.Se.ExecuteInternal(context.Background(), "commit") + _, err := tk1.Se.Execute(context.Background(), "commit") c.Assert(err, NotNil) } @@ -2723,7 +2724,7 @@ func (s *testSessionSerialSuite) TestTxnRetryErrMsg(c *C) { tk2.MustExec("update no_retry set id = id + 1") tk1.MustExec("update no_retry set id = id + 1") c.Assert(tikv.MockRetryableErrorResp.Enable(`return(true)`), IsNil) - _, err := tk1.Se.ExecuteInternal(context.Background(), "commit") + _, err := tk1.Se.Execute(context.Background(), "commit") tikv.MockRetryableErrorResp.Disable() c.Assert(err, NotNil) c.Assert(kv.ErrTxnRetryable.Equal(err), IsTrue, Commentf("error: %s", err)) @@ -2746,7 +2747,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk2.MustExec("commit") // No auto retry because tidb_disable_txn_auto_retry is set to 1. - _, err := tk1.Se.ExecuteInternal(context.Background(), "commit") + _, err := tk1.Se.Execute(context.Background(), "commit") c.Assert(err, NotNil) // session 1 starts a transaction early. @@ -2778,7 +2779,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk2.MustExec("update no_retry set id = 8") - _, err = tk1.Se.ExecuteInternal(context.Background(), "commit") + _, err = tk1.Se.Execute(context.Background(), "commit") c.Assert(err, NotNil) c.Assert(kv.ErrWriteConflict.Equal(err), IsTrue, Commentf("error: %s", err)) c.Assert(strings.Contains(err.Error(), kv.TxnRetryableMark), IsTrue, Commentf("error: %s", err)) @@ -2791,7 +2792,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk2.MustExec("alter table no_retry add index idx(id)") tk2.MustQuery("select * from no_retry").Check(testkit.Rows("8")) tk1.MustExec("update no_retry set id = 10") - _, err = tk1.Se.ExecuteInternal(context.Background(), "commit") + _, err = tk1.Se.Execute(context.Background(), "commit") c.Assert(err, NotNil) // set autocommit to begin and commit @@ -2799,7 +2800,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk1.MustQuery("select * from no_retry").Check(testkit.Rows("8")) tk2.MustExec("update no_retry set id = 11") tk1.MustExec("update no_retry set id = 12") - _, err = tk1.Se.ExecuteInternal(context.Background(), "set autocommit = 1") + _, err = tk1.Se.Execute(context.Background(), "set autocommit = 1") c.Assert(err, NotNil) c.Assert(kv.ErrWriteConflict.Equal(err), IsTrue, Commentf("error: %s", err)) c.Assert(strings.Contains(err.Error(), kv.TxnRetryableMark), IsTrue, Commentf("error: %s", err)) @@ -2810,7 +2811,7 @@ func (s *testSchemaSuite) TestDisableTxnAutoRetry(c *C) { tk1.MustQuery("select * from no_retry").Check(testkit.Rows("11")) tk2.MustExec("update no_retry set id = 13") tk1.MustExec("update no_retry set id = 14") - _, err = tk1.Se.ExecuteInternal(context.Background(), "commit") + _, err = tk1.Se.Execute(context.Background(), "commit") c.Assert(err, NotNil) c.Assert(kv.ErrWriteConflict.Equal(err), IsTrue, Commentf("error: %s", err)) c.Assert(strings.Contains(err.Error(), kv.TxnRetryableMark), IsTrue, Commentf("error: %s", err)) @@ -4102,7 +4103,7 @@ func (s *testSessionSerialSuite) TestTiKVSystemVars(c *C) { result.Check(testkit.Rows("true")) tk.MustExec("UPDATE mysql.tidb SET variable_value = 'false' WHERE variable_name='tikv_gc_enable'") - result = tk.MustQuery("SELECT @@tidb_gc_enable") + result = tk.MustQuery("SELECT @@tidb_gc_enable;") result.Check(testkit.Rows("0")) // reads from mysql.tidb value and changes to false tk.MustExec("SET GLOBAL tidb_gc_concurrency = -1") // sets auto concurrency and concurrency diff --git a/session/tidb_test.go b/session/tidb_test.go index f057e88984bba..eff7406332925 100644 --- a/session/tidb_test.go +++ b/session/tidb_test.go @@ -130,8 +130,8 @@ func newSession(c *C, store kv.Storage, dbName string) Session { se.SetConnectionID(id) c.Assert(err, IsNil) se.Auth(&auth.UserIdentity{Username: "root", Hostname: `%`}, nil, []byte("012345678901234567890")) - mustExecSQL(c, se, "create database if not exists %n", dbName) - mustExecSQL(c, se, "use %n", dbName) + mustExecSQL(c, se, "create database if not exists "+dbName) + mustExecSQL(c, se, "use "+dbName) return se } @@ -139,9 +139,32 @@ func removeStore(c *C, dbPath string) { os.RemoveAll(dbPath) } -func mustExecSQL(c *C, se Session, sql string, args ...interface{}) sqlexec.RecordSet { +func exec(se Session, sql string, args ...interface{}) (sqlexec.RecordSet, error) { ctx := context.Background() - rs, err := se.ExecuteInternal(ctx, sql, args...) + if len(args) == 0 { + rs, err := se.Execute(ctx, sql) + if err == nil && len(rs) > 0 { + return rs[0], nil + } + return nil, err + } + stmtID, _, _, err := se.PrepareStmt(sql) + if err != nil { + return nil, err + } + params := make([]types.Datum, len(args)) + for i := 0; i < len(params); i++ { + params[i] = types.NewDatum(args[i]) + } + rs, err := se.ExecutePreparedStmt(ctx, stmtID, params) + if err != nil { + return nil, err + } + return rs, nil +} + +func mustExecSQL(c *C, se Session, sql string, args ...interface{}) sqlexec.RecordSet { + rs, err := exec(se, sql, args...) c.Assert(err, IsNil) return rs } From e3a754012f2b5bd9cd9460293ae195369d49f0af Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Fri, 5 Feb 2021 10:48:07 +0800 Subject: [PATCH 5/6] try to fix the data race --- statistics/handle/handle.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 88da2b80a1b53..cf3aaf5ecfecd 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -267,12 +267,12 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { // UpdateSessionVar updates the necessary session variables for the stats reader. func (h *Handle) UpdateSessionVar() error { + h.mu.Lock() + defer h.mu.Unlock() verInString, err := h.mu.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBAnalyzeVersion) if err != nil { return err } - h.mu.Lock() - defer h.mu.Unlock() ver, err := strconv.ParseInt(verInString, 10, 64) if err != nil { return err From a96db15c1a53cf16ed478d1931b570cfe23f16c1 Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Fri, 5 Feb 2021 17:09:24 +0800 Subject: [PATCH 6/6] fix a DATA RACE --- session/session.go | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/session/session.go b/session/session.go index e50bd55e6a2b2..944772e79e293 100644 --- a/session/session.go +++ b/session/session.go @@ -2264,8 +2264,6 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { return nil, err } - se.initLoadCommonGlobalVarsSQL() - // get system tz from mysql.tidb tz, err := se.getTableValue(context.TODO(), mysql.TiDBTable, "system_tz") if err != nil { @@ -2631,28 +2629,8 @@ var builtinGlobalVariable = []string{ variable.TiDBEnableExchangePartition, } -var ( - loadCommonGlobalVarsSQLOnce sync.Once - loadCommonGlobalVarsStmt ast.StmtNode -) - -func (s *session) initLoadCommonGlobalVarsSQL() { - loadCommonGlobalVarsSQLOnce.Do(func() { - vars := append(make([]string, 0, len(builtinGlobalVariable)+len(variable.PluginVarNames)), builtinGlobalVariable...) - if len(variable.PluginVarNames) > 0 { - vars = append(vars, variable.PluginVarNames...) - } - var err error - loadCommonGlobalVarsStmt, err = s.ParseWithParams(context.TODO(), "SELECT HIGH_PRIORITY * from mysql.global_variables where variable_name in (%?)", vars) - if err != nil { - loadCommonGlobalVarsStmt = nil - } - }) -} - // loadCommonGlobalVariablesIfNeeded loads and applies commonly used global variables for the session. func (s *session) loadCommonGlobalVariablesIfNeeded() error { - s.initLoadCommonGlobalVarsSQL() vars := s.sessionVars if vars.CommonGlobalLoaded { return nil @@ -2667,7 +2645,17 @@ func (s *session) loadCommonGlobalVariablesIfNeeded() error { // When a lot of connections connect to TiDB simultaneously, it can protect TiKV meta region from overload. gvc := domain.GetDomain(s).GetGlobalVarsCache() loadFunc := func() ([]chunk.Row, []*ast.ResultField, error) { - return s.ExecRestrictedStmt(context.TODO(), loadCommonGlobalVarsStmt) + vars := append(make([]string, 0, len(builtinGlobalVariable)+len(variable.PluginVarNames)), builtinGlobalVariable...) + if len(variable.PluginVarNames) > 0 { + vars = append(vars, variable.PluginVarNames...) + } + + stmt, err := s.ParseWithParams(context.TODO(), "select HIGH_PRIORITY * from mysql.global_variables where variable_name in (%?)", vars) + if err != nil { + return nil, nil, errors.Trace(err) + } + + return s.ExecRestrictedStmt(context.TODO(), stmt) } rows, fields, err := gvc.LoadGlobalVariables(loadFunc) if err != nil {