From e4bcebcafa8bdb432d5b9138150b97a15c65aadc Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Thu, 4 Feb 2021 16:00:50 +0800 Subject: [PATCH 1/3] migrate code in ddl package from using Execute/ExecRestrictedSQL to safe API(2) Signed-off-by: AilinKid <314806019@qq.com> --- ddl/delete_range.go | 18 ++++---- ddl/partition.go | 109 +++++++++++++++++++++++++++++++++----------- ddl/reorg.go | 8 +++- 3 files changed, 99 insertions(+), 36 deletions(-) diff --git a/ddl/delete_range.go b/ddl/delete_range.go index 805f600df9cae..e4a9bc1eace2d 100644 --- a/ddl/delete_range.go +++ b/ddl/delete_range.go @@ -16,7 +16,6 @@ package ddl import ( "context" "encoding/hex" - "fmt" "math" "sync" "sync/atomic" @@ -36,7 +35,7 @@ import ( const ( insertDeleteRangeSQLPrefix = `INSERT IGNORE INTO mysql.gc_delete_range VALUES ` - insertDeleteRangeSQLValue = `("%d", "%d", "%s", "%s", "%d")` + insertDeleteRangeSQLValue = `(%?, %?, %?, %?, %?)` insertDeleteRangeSQL = insertDeleteRangeSQLPrefix + insertDeleteRangeSQLValue delBatchSize = 65536 @@ -404,17 +403,19 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error func doBatchDeleteIndiceRange(s sqlexec.SQLExecutor, jobID, tableID int64, indexIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range indices", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", indexIDs)) sql := insertDeleteRangeSQLPrefix + paramsList := make([]interface{}, 0, len(indexIDs)*5) for i, indexID := range indexIDs { startKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID) endKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID+1) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql += fmt.Sprintf(insertDeleteRangeSQLValue, jobID, indexID, startKeyEncoded, endKeyEncoded, ts) + sql += insertDeleteRangeSQLValue if i != len(indexIDs)-1 { sql += "," } + paramsList = append(paramsList, jobID, indexID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), sql, paramsList...) return errors.Trace(err) } @@ -422,25 +423,26 @@ func doInsert(s sqlexec.SQLExecutor, jobID int64, elementID int64, startKey, end logutil.BgLogger().Info("[ddl] insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64("elementID", elementID)) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql := fmt.Sprintf(insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) return errors.Trace(err) } func doBatchInsert(s sqlexec.SQLExecutor, jobID int64, tableIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", tableIDs)) sql := insertDeleteRangeSQLPrefix + paramsList := make([]interface{}, 0, len(tableIDs)*5) for i, tableID := range tableIDs { startKey := tablecodec.EncodeTablePrefix(tableID) endKey := tablecodec.EncodeTablePrefix(tableID + 1) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql += fmt.Sprintf(insertDeleteRangeSQLValue, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) + sql += insertDeleteRangeSQLValue if i != len(tableIDs)-1 { sql += "," } + paramsList = append(paramsList, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), sql, paramsList...) return errors.Trace(err) } diff --git a/ddl/partition.go b/ddl/partition.go index 0862e8e732dfc..aeb8df360db4b 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -1337,6 +1337,7 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, index int, schemaName, tableName model.CIStr) error { var sql string + var paramList []interface{} pi := pt.Partition @@ -1345,7 +1346,12 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde if pi.Num == 1 { return nil } - sql = fmt.Sprintf("select 1 from `%s`.`%s` where mod(%s, %d) != %d limit 1", schemaName.L, tableName.L, pi.Expr, pi.Num, index) + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where mod(") + buf.WriteString(pi.Expr) + buf.WriteString(", %?) != %? limit 1") + sql = buf.String() + paramList = append(paramList, schemaName.L, tableName.L, pi.Num, index) case model.PartitionTypeRange: // Table has only one partition and has the maximum value if len(pi.Definitions) == 1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { @@ -1353,15 +1359,15 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde } // For range expression and range columns if len(pi.Columns) == 0 { - sql = buildCheckSQLForRangeExprPartition(pi, index, schemaName, tableName) + sql, paramList = buildCheckSQLForRangeExprPartition(pi, index, schemaName, tableName) } else if len(pi.Columns) == 1 { - sql = buildCheckSQLForRangeColumnsPartition(pi, index, schemaName, tableName) + sql, paramList = buildCheckSQLForRangeColumnsPartition(pi, index, schemaName, tableName) } case model.PartitionTypeList: if len(pi.Columns) == 0 { - sql = buildCheckSQLForListPartition(pi, index, schemaName, tableName) + sql, paramList = buildCheckSQLForListPartition(pi, index, schemaName, tableName) } else if len(pi.Columns) == 1 { - sql = buildCheckSQLForListColumnsPartition(pi, index, schemaName, tableName) + sql, paramList = buildCheckSQLForListColumnsPartition(pi, index, schemaName, tableName) } default: return errUnsupportedPartitionType.GenWithStackByArgs(pt.Name.O) @@ -1374,7 +1380,11 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde } defer w.sessPool.put(ctx) - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), sql, paramList...) + if err != nil { + return errors.Trace(err) + } + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) if err != nil { return errors.Trace(err) } @@ -1385,46 +1395,93 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde return nil } -func buildCheckSQLForRangeExprPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) string { +func buildCheckSQLForRangeExprPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { + var buf strings.Builder + paramList := make([]interface{}, 0, 4) + // Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParams(...) + // So we write it to the origin sql string here. if index == 0 { - return fmt.Sprintf("select 1 from `%s`.`%s` where %s >= %s limit 1", schemaName.L, tableName.L, pi.Expr, pi.Definitions[index].LessThan[0]) + buf.WriteString("select 1 from %n.%n where ") + buf.WriteString(pi.Expr) + buf.WriteString(" >= %? limit 1") + paramList = append(paramList, schemaName.L, tableName.L, trimQuotation(pi.Definitions[index].LessThan[0])) + return buf.String(), paramList } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - return fmt.Sprintf("select 1 from `%s`.`%s` where %s < %s limit 1", schemaName.L, tableName.L, pi.Expr, pi.Definitions[index-1].LessThan[0]) + buf.WriteString("select 1 from %n.%n where ") + buf.WriteString(pi.Expr) + buf.WriteString(" < %? limit 1") + paramList = append(paramList, schemaName.L, tableName.L, trimQuotation(pi.Definitions[index-1].LessThan[0])) + return buf.String(), paramList } else { - return fmt.Sprintf("select 1 from `%s`.`%s` where %s < %s or %s >= %s limit 1", schemaName.L, tableName.L, pi.Expr, pi.Definitions[index-1].LessThan[0], pi.Expr, pi.Definitions[index].LessThan[0]) + buf.WriteString("select 1 from %n.%n where ") + buf.WriteString(pi.Expr) + buf.WriteString(" < %? or ") + buf.WriteString(pi.Expr) + buf.WriteString(" >= %? limit 1") + paramList = append(paramList, schemaName.L, tableName.L, trimQuotation(pi.Definitions[index-1].LessThan[0]), trimQuotation(pi.Definitions[index].LessThan[0])) + return buf.String(), paramList } } -func buildCheckSQLForRangeColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) string { +func trimQuotation(str string) string { + return strings.Trim(str, "\"") +} + +func buildCheckSQLForRangeColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { + paramList := make([]interface{}, 0, 6) colName := pi.Columns[0].L if index == 0 { - return fmt.Sprintf("select 1 from `%s`.`%s` where `%s` >= %s limit 1", schemaName.L, tableName.L, colName, pi.Definitions[index].LessThan[0]) + paramList = append(paramList, schemaName.L, tableName.L, colName, trimQuotation(pi.Definitions[index].LessThan[0])) + return "select 1 from %n.%n where %n >= %? limit 1", paramList } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - return fmt.Sprintf("select 1 from `%s`.`%s` where `%s` < %s limit 1", schemaName.L, tableName.L, colName, pi.Definitions[index-1].LessThan[0]) + paramList = append(paramList, schemaName.L, tableName.L, colName, trimQuotation(pi.Definitions[index-1].LessThan[0])) + return "select 1 from %n.%n where %n < %? limit 1", paramList } else { - return fmt.Sprintf("select 1 from `%s`.`%s` where `%s` < %s or `%s` >= %s limit 1", schemaName.L, tableName.L, colName, pi.Definitions[index-1].LessThan[0], colName, pi.Definitions[index].LessThan[0]) + paramList = append(paramList, schemaName.L, tableName.L, colName, trimQuotation(pi.Definitions[index-1].LessThan[0]), colName, trimQuotation(pi.Definitions[index].LessThan[0])) + return "select 1 from %n.%n where %n < %? or %n >= %? limit 1", paramList } } -func buildCheckSQLForListPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) string { - inValues := getInValues(pi, index) - sql := fmt.Sprintf("select 1 from `%s`.`%s` where %s not in (%s) limit 1", schemaName.L, tableName.L, pi.Expr, inValues) - return sql +func buildCheckSQLForListPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where ") + buf.WriteString(pi.Expr) + buf.WriteString(" not in (") + inValues := getInValues(&buf, pi, index) + buf.WriteString(") limit 1") + + paramList := make([]interface{}, 0, 2+len(inValues)) + paramList = append(paramList, schemaName.L, tableName.L) + paramList = append(paramList, inValues...) + return buf.String(), paramList } -func buildCheckSQLForListColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) string { +func buildCheckSQLForListColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { colName := pi.Columns[0].L - inValues := getInValues(pi, index) - sql := fmt.Sprintf("select 1 from `%s`.`%s` where %s not in (%s) limit 1", schemaName.L, tableName.L, colName, inValues) - return sql + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where %n not in (") + inValues := getInValues(&buf, pi, index) + buf.WriteString(") limit 1") + + paramList := make([]interface{}, 0, 3+len(inValues)) + paramList = append(paramList, schemaName.L, tableName.L, colName) + paramList = append(paramList, inValues...) + return buf.String(), paramList } -func getInValues(pi *model.PartitionInfo, index int) string { - inValues := make([]string, 0, len(pi.Definitions[index].InValues)) +func getInValues(buf *strings.Builder, pi *model.PartitionInfo, index int) []interface{} { + inValues := make([]interface{}, 0, len(pi.Definitions[index].InValues)) for _, inValue := range pi.Definitions[index].InValues { - inValues = append(inValues, inValue...) + for _, one := range inValue { + if len(inValues) == 0 { + buf.WriteString("%?") + } else { + buf.WriteString(", %?") + } + inValues = append(inValues, one) + } } - return strings.Join(inValues, ",") + return inValues } func checkAddPartitionTooManyPartitions(piDefs uint64) error { diff --git a/ddl/reorg.go b/ddl/reorg.go index 9c72040a0221c..6528a484bb834 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -319,8 +319,12 @@ func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { if !ok { return statistics.PseudoRowCount } - sql := fmt.Sprintf("select table_rows from information_schema.tables where tidb_table_id=%v;", tblInfo.ID) - rows, _, err := executor.ExecRestrictedSQL(sql) + sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" + stmt, err := executor.ParseWithParams(context.Background(), sql, tblInfo.ID) + if err != nil { + return statistics.PseudoRowCount + } + rows, _, err := executor.ExecRestrictedStmt(context.Background(), stmt) if err != nil { return statistics.PseudoRowCount } From 105ceb80bf071025c32e492c47023b641b1594e8 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Thu, 4 Feb 2021 16:10:35 +0800 Subject: [PATCH 2/3] use string builder rather string concat Signed-off-by: AilinKid <314806019@qq.com> --- ddl/delete_range.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/ddl/delete_range.go b/ddl/delete_range.go index e4a9bc1eace2d..e64c122d41e4d 100644 --- a/ddl/delete_range.go +++ b/ddl/delete_range.go @@ -17,6 +17,7 @@ import ( "context" "encoding/hex" "math" + "strings" "sync" "sync/atomic" @@ -402,20 +403,21 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error func doBatchDeleteIndiceRange(s sqlexec.SQLExecutor, jobID, tableID int64, indexIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range indices", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", indexIDs)) - sql := insertDeleteRangeSQLPrefix paramsList := make([]interface{}, 0, len(indexIDs)*5) + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) for i, indexID := range indexIDs { startKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID) endKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID+1) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql += insertDeleteRangeSQLValue + buf.WriteString(insertDeleteRangeSQLValue) if i != len(indexIDs)-1 { - sql += "," + buf.WriteString(",") } paramsList = append(paramsList, jobID, indexID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.ExecuteInternal(context.Background(), sql, paramsList...) + _, err := s.ExecuteInternal(context.Background(), buf.String(), paramsList...) return errors.Trace(err) } @@ -429,20 +431,21 @@ func doInsert(s sqlexec.SQLExecutor, jobID int64, elementID int64, startKey, end func doBatchInsert(s sqlexec.SQLExecutor, jobID int64, tableIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", tableIDs)) - sql := insertDeleteRangeSQLPrefix + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) paramsList := make([]interface{}, 0, len(tableIDs)*5) for i, tableID := range tableIDs { startKey := tablecodec.EncodeTablePrefix(tableID) endKey := tablecodec.EncodeTablePrefix(tableID + 1) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql += insertDeleteRangeSQLValue + buf.WriteString(insertDeleteRangeSQLValue) if i != len(tableIDs)-1 { - sql += "," + buf.WriteString(",") } paramsList = append(paramsList, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.ExecuteInternal(context.Background(), sql, paramsList...) + _, err := s.ExecuteInternal(context.Background(), buf.String(), paramsList...) return errors.Trace(err) } From a0578bcc3c06c58567e3b38dacb8f72270fe0a1a Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Fri, 5 Feb 2021 13:38:03 +0800 Subject: [PATCH 3/3] address comment Signed-off-by: AilinKid <314806019@qq.com> --- ddl/partition.go | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/ddl/partition.go b/ddl/partition.go index aeb8df360db4b..58832c4007455 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -1446,38 +1446,29 @@ func buildCheckSQLForListPartition(pi *model.PartitionInfo, index int, schemaNam var buf strings.Builder buf.WriteString("select 1 from %n.%n where ") buf.WriteString(pi.Expr) - buf.WriteString(" not in (") - inValues := getInValues(&buf, pi, index) - buf.WriteString(") limit 1") + buf.WriteString(" not in (%?) limit 1") + inValues := getInValues(pi, index) - paramList := make([]interface{}, 0, 2+len(inValues)) - paramList = append(paramList, schemaName.L, tableName.L) - paramList = append(paramList, inValues...) + paramList := make([]interface{}, 0, 3) + paramList = append(paramList, schemaName.L, tableName.L, inValues) return buf.String(), paramList } func buildCheckSQLForListColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { colName := pi.Columns[0].L var buf strings.Builder - buf.WriteString("select 1 from %n.%n where %n not in (") - inValues := getInValues(&buf, pi, index) - buf.WriteString(") limit 1") + buf.WriteString("select 1 from %n.%n where %n not in (%?) limit 1") + inValues := getInValues(pi, index) - paramList := make([]interface{}, 0, 3+len(inValues)) - paramList = append(paramList, schemaName.L, tableName.L, colName) - paramList = append(paramList, inValues...) + paramList := make([]interface{}, 0, 4) + paramList = append(paramList, schemaName.L, tableName.L, colName, inValues) return buf.String(), paramList } -func getInValues(buf *strings.Builder, pi *model.PartitionInfo, index int) []interface{} { - inValues := make([]interface{}, 0, len(pi.Definitions[index].InValues)) +func getInValues(pi *model.PartitionInfo, index int) []string { + inValues := make([]string, 0, len(pi.Definitions[index].InValues)) for _, inValue := range pi.Definitions[index].InValues { for _, one := range inValue { - if len(inValues) == 0 { - buf.WriteString("%?") - } else { - buf.WriteString(", %?") - } inValues = append(inValues, one) } }