diff --git a/ddl/delete_range.go b/ddl/delete_range.go index 805f600df9cae..e64c122d41e4d 100644 --- a/ddl/delete_range.go +++ b/ddl/delete_range.go @@ -16,8 +16,8 @@ package ddl import ( "context" "encoding/hex" - "fmt" "math" + "strings" "sync" "sync/atomic" @@ -36,7 +36,7 @@ import ( const ( insertDeleteRangeSQLPrefix = `INSERT IGNORE INTO mysql.gc_delete_range VALUES ` - insertDeleteRangeSQLValue = `("%d", "%d", "%s", "%s", "%d")` + insertDeleteRangeSQLValue = `(%?, %?, %?, %?, %?)` insertDeleteRangeSQL = insertDeleteRangeSQLPrefix + insertDeleteRangeSQLValue delBatchSize = 65536 @@ -403,18 +403,21 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error func doBatchDeleteIndiceRange(s sqlexec.SQLExecutor, jobID, tableID int64, indexIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range indices", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", indexIDs)) - sql := insertDeleteRangeSQLPrefix + paramsList := make([]interface{}, 0, len(indexIDs)*5) + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) for i, indexID := range indexIDs { startKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID) endKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID+1) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql += fmt.Sprintf(insertDeleteRangeSQLValue, jobID, indexID, startKeyEncoded, endKeyEncoded, ts) + buf.WriteString(insertDeleteRangeSQLValue) if i != len(indexIDs)-1 { - sql += "," + buf.WriteString(",") } + paramsList = append(paramsList, jobID, indexID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), buf.String(), paramsList...) return errors.Trace(err) } @@ -422,25 +425,27 @@ func doInsert(s sqlexec.SQLExecutor, jobID int64, elementID int64, startKey, end logutil.BgLogger().Info("[ddl] insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64("elementID", elementID)) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql := fmt.Sprintf(insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) return errors.Trace(err) } func doBatchInsert(s sqlexec.SQLExecutor, jobID int64, tableIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", tableIDs)) - sql := insertDeleteRangeSQLPrefix + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) + paramsList := make([]interface{}, 0, len(tableIDs)*5) for i, tableID := range tableIDs { startKey := tablecodec.EncodeTablePrefix(tableID) endKey := tablecodec.EncodeTablePrefix(tableID + 1) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - sql += fmt.Sprintf(insertDeleteRangeSQLValue, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) + buf.WriteString(insertDeleteRangeSQLValue) if i != len(tableIDs)-1 { - sql += "," + buf.WriteString(",") } + paramsList = append(paramsList, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.Execute(context.Background(), sql) + _, err := s.ExecuteInternal(context.Background(), buf.String(), paramsList...) return errors.Trace(err) } diff --git a/ddl/partition.go b/ddl/partition.go index 0862e8e732dfc..58832c4007455 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -1337,6 +1337,7 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, index int, schemaName, tableName model.CIStr) error { var sql string + var paramList []interface{} pi := pt.Partition @@ -1345,7 +1346,12 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde if pi.Num == 1 { return nil } - sql = fmt.Sprintf("select 1 from `%s`.`%s` where mod(%s, %d) != %d limit 1", schemaName.L, tableName.L, pi.Expr, pi.Num, index) + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where mod(") + buf.WriteString(pi.Expr) + buf.WriteString(", %?) != %? limit 1") + sql = buf.String() + paramList = append(paramList, schemaName.L, tableName.L, pi.Num, index) case model.PartitionTypeRange: // Table has only one partition and has the maximum value if len(pi.Definitions) == 1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { @@ -1353,15 +1359,15 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde } // For range expression and range columns if len(pi.Columns) == 0 { - sql = buildCheckSQLForRangeExprPartition(pi, index, schemaName, tableName) + sql, paramList = buildCheckSQLForRangeExprPartition(pi, index, schemaName, tableName) } else if len(pi.Columns) == 1 { - sql = buildCheckSQLForRangeColumnsPartition(pi, index, schemaName, tableName) + sql, paramList = buildCheckSQLForRangeColumnsPartition(pi, index, schemaName, tableName) } case model.PartitionTypeList: if len(pi.Columns) == 0 { - sql = buildCheckSQLForListPartition(pi, index, schemaName, tableName) + sql, paramList = buildCheckSQLForListPartition(pi, index, schemaName, tableName) } else if len(pi.Columns) == 1 { - sql = buildCheckSQLForListColumnsPartition(pi, index, schemaName, tableName) + sql, paramList = buildCheckSQLForListColumnsPartition(pi, index, schemaName, tableName) } default: return errUnsupportedPartitionType.GenWithStackByArgs(pt.Name.O) @@ -1374,7 +1380,11 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde } defer w.sessPool.put(ctx) - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), sql, paramList...) + if err != nil { + return errors.Trace(err) + } + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) if err != nil { return errors.Trace(err) } @@ -1385,46 +1395,84 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde return nil } -func buildCheckSQLForRangeExprPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) string { +func buildCheckSQLForRangeExprPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { + var buf strings.Builder + paramList := make([]interface{}, 0, 4) + // Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParams(...) + // So we write it to the origin sql string here. if index == 0 { - return fmt.Sprintf("select 1 from `%s`.`%s` where %s >= %s limit 1", schemaName.L, tableName.L, pi.Expr, pi.Definitions[index].LessThan[0]) + buf.WriteString("select 1 from %n.%n where ") + buf.WriteString(pi.Expr) + buf.WriteString(" >= %? limit 1") + paramList = append(paramList, schemaName.L, tableName.L, trimQuotation(pi.Definitions[index].LessThan[0])) + return buf.String(), paramList } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - return fmt.Sprintf("select 1 from `%s`.`%s` where %s < %s limit 1", schemaName.L, tableName.L, pi.Expr, pi.Definitions[index-1].LessThan[0]) + buf.WriteString("select 1 from %n.%n where ") + buf.WriteString(pi.Expr) + buf.WriteString(" < %? limit 1") + paramList = append(paramList, schemaName.L, tableName.L, trimQuotation(pi.Definitions[index-1].LessThan[0])) + return buf.String(), paramList } else { - return fmt.Sprintf("select 1 from `%s`.`%s` where %s < %s or %s >= %s limit 1", schemaName.L, tableName.L, pi.Expr, pi.Definitions[index-1].LessThan[0], pi.Expr, pi.Definitions[index].LessThan[0]) + buf.WriteString("select 1 from %n.%n where ") + buf.WriteString(pi.Expr) + buf.WriteString(" < %? or ") + buf.WriteString(pi.Expr) + buf.WriteString(" >= %? limit 1") + paramList = append(paramList, schemaName.L, tableName.L, trimQuotation(pi.Definitions[index-1].LessThan[0]), trimQuotation(pi.Definitions[index].LessThan[0])) + return buf.String(), paramList } } -func buildCheckSQLForRangeColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) string { +func trimQuotation(str string) string { + return strings.Trim(str, "\"") +} + +func buildCheckSQLForRangeColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { + paramList := make([]interface{}, 0, 6) colName := pi.Columns[0].L if index == 0 { - return fmt.Sprintf("select 1 from `%s`.`%s` where `%s` >= %s limit 1", schemaName.L, tableName.L, colName, pi.Definitions[index].LessThan[0]) + paramList = append(paramList, schemaName.L, tableName.L, colName, trimQuotation(pi.Definitions[index].LessThan[0])) + return "select 1 from %n.%n where %n >= %? limit 1", paramList } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - return fmt.Sprintf("select 1 from `%s`.`%s` where `%s` < %s limit 1", schemaName.L, tableName.L, colName, pi.Definitions[index-1].LessThan[0]) + paramList = append(paramList, schemaName.L, tableName.L, colName, trimQuotation(pi.Definitions[index-1].LessThan[0])) + return "select 1 from %n.%n where %n < %? limit 1", paramList } else { - return fmt.Sprintf("select 1 from `%s`.`%s` where `%s` < %s or `%s` >= %s limit 1", schemaName.L, tableName.L, colName, pi.Definitions[index-1].LessThan[0], colName, pi.Definitions[index].LessThan[0]) + paramList = append(paramList, schemaName.L, tableName.L, colName, trimQuotation(pi.Definitions[index-1].LessThan[0]), colName, trimQuotation(pi.Definitions[index].LessThan[0])) + return "select 1 from %n.%n where %n < %? or %n >= %? limit 1", paramList } } -func buildCheckSQLForListPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) string { +func buildCheckSQLForListPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where ") + buf.WriteString(pi.Expr) + buf.WriteString(" not in (%?) limit 1") inValues := getInValues(pi, index) - sql := fmt.Sprintf("select 1 from `%s`.`%s` where %s not in (%s) limit 1", schemaName.L, tableName.L, pi.Expr, inValues) - return sql + + paramList := make([]interface{}, 0, 3) + paramList = append(paramList, schemaName.L, tableName.L, inValues) + return buf.String(), paramList } -func buildCheckSQLForListColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) string { +func buildCheckSQLForListColumnsPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { colName := pi.Columns[0].L + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where %n not in (%?) limit 1") inValues := getInValues(pi, index) - sql := fmt.Sprintf("select 1 from `%s`.`%s` where %s not in (%s) limit 1", schemaName.L, tableName.L, colName, inValues) - return sql + + paramList := make([]interface{}, 0, 4) + paramList = append(paramList, schemaName.L, tableName.L, colName, inValues) + return buf.String(), paramList } -func getInValues(pi *model.PartitionInfo, index int) string { +func getInValues(pi *model.PartitionInfo, index int) []string { inValues := make([]string, 0, len(pi.Definitions[index].InValues)) for _, inValue := range pi.Definitions[index].InValues { - inValues = append(inValues, inValue...) + for _, one := range inValue { + inValues = append(inValues, one) + } } - return strings.Join(inValues, ",") + return inValues } func checkAddPartitionTooManyPartitions(piDefs uint64) error { diff --git a/ddl/reorg.go b/ddl/reorg.go index 9c72040a0221c..6528a484bb834 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -319,8 +319,12 @@ func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { if !ok { return statistics.PseudoRowCount } - sql := fmt.Sprintf("select table_rows from information_schema.tables where tidb_table_id=%v;", tblInfo.ID) - rows, _, err := executor.ExecRestrictedSQL(sql) + sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" + stmt, err := executor.ParseWithParams(context.Background(), sql, tblInfo.ID) + if err != nil { + return statistics.PseudoRowCount + } + rows, _, err := executor.ExecRestrictedStmt(context.Background(), stmt) if err != nil { return statistics.PseudoRowCount }