diff --git a/pkg/backup/client.go b/pkg/backup/client.go index e694f092e..566bc4c69 100644 --- a/pkg/backup/client.go +++ b/pkg/backup/client.go @@ -400,8 +400,6 @@ func (bc *Client) BackupRanges( updateCh glue.Progress, ) ([]*kvproto.File, error) { errCh := make(chan error) - ctx, cancel := context.WithCancel(ctx) - defer cancel() // we collect all files in a single goroutine to avoid thread safety issues. filesCh := make(chan []*kvproto.File, concurrency) @@ -477,8 +475,6 @@ func (bc *Client) BackupRange( zap.Stringer("EndKey", utils.WrapKey(endKey)), zap.Uint64("RateLimit", req.RateLimit), zap.Uint32("Concurrency", req.Concurrency)) - ctx, cancel := context.WithCancel(ctx) - defer cancel() var allStores []*metapb.Store allStores, err = conn.GetAllTiKVStores(ctx, bc.mgr.GetPDClient(), conn.SkipTiFlash) @@ -491,10 +487,10 @@ func (bc *Client) BackupRange( req.EndKey = endKey req.StorageBackend = bc.backend - push := newPushDown(ctx, bc.mgr, len(allStores)) + push := newPushDown(bc.mgr, len(allStores)) var results rtree.RangeTree - results, err = push.pushBackup(req, allStores, updateCh) + results, err = push.pushBackup(ctx, req, allStores, updateCh) if err != nil { return nil, err } @@ -801,8 +797,6 @@ func SendBackup( zap.Stringer("EndKey", utils.WrapKey(req.EndKey)), zap.Uint64("storeID", storeID), ) - ctx, cancel := context.WithCancel(ctx) - defer cancel() bcli, err := client.Backup(ctx, &req) if err != nil { log.Error("fail to backup", zap.Uint64("StoreID", storeID)) diff --git a/pkg/backup/push.go b/pkg/backup/push.go index b94464c62..fd1903ab9 100644 --- a/pkg/backup/push.go +++ b/pkg/backup/push.go @@ -18,17 +18,14 @@ import ( // pushDown warps a backup task. type pushDown struct { - ctx context.Context mgr ClientMgr respCh chan *backup.BackupResponse errCh chan error } // newPushDown creates a push down backup. -func newPushDown(ctx context.Context, mgr ClientMgr, cap int) *pushDown { - log.Info("new backup client") +func newPushDown(mgr ClientMgr, cap int) *pushDown { return &pushDown{ - ctx: ctx, mgr: mgr, respCh: make(chan *backup.BackupResponse, cap), errCh: make(chan error, cap), @@ -37,6 +34,7 @@ func newPushDown(ctx context.Context, mgr ClientMgr, cap int) *pushDown { // FullBackup make a full backup of a tikv cluster. func (push *pushDown) pushBackup( + ctx context.Context, req backup.BackupRequest, stores []*metapb.Store, updateCh glue.Progress, @@ -50,7 +48,7 @@ func (push *pushDown) pushBackup( log.Warn("skip store", zap.Uint64("StoreID", storeID), zap.Stringer("State", s.GetState())) continue } - client, err := push.mgr.GetBackupClient(push.ctx, storeID) + client, err := push.mgr.GetBackupClient(ctx, storeID) if err != nil { log.Error("fail to connect store", zap.Uint64("StoreID", storeID)) return res, errors.Trace(err) @@ -59,7 +57,7 @@ func (push *pushDown) pushBackup( go func() { defer wg.Done() err := SendBackup( - push.ctx, storeID, client, req, + ctx, storeID, client, req, func(resp *backup.BackupResponse) error { // Forward all responses (including error). push.respCh <- resp diff --git a/pkg/restore/client.go b/pkg/restore/client.go index 49e941224..fceba96a0 100644 --- a/pkg/restore/client.go +++ b/pkg/restore/client.go @@ -47,9 +47,6 @@ const defaultChecksumConcurrency = 64 // Client sends requests to restore files. type Client struct { - ctx context.Context - cancel context.CancelFunc - pdClient pd.Client toolClient SplitClient fileImporter FileImporter @@ -84,22 +81,17 @@ type Client struct { // NewRestoreClient returns a new RestoreClient. func NewRestoreClient( - ctx context.Context, g glue.Glue, pdClient pd.Client, store kv.Storage, tlsConf *tls.Config, ) (*Client, error) { - ctx, cancel := context.WithCancel(ctx) db, err := NewDB(g, store) if err != nil { - cancel() return nil, errors.Trace(err) } return &Client{ - ctx: ctx, - cancel: cancel, pdClient: pdClient, toolClient: NewSplitClient(pdClient, tlsConf), db: db, @@ -145,7 +137,6 @@ func (rc *Client) Close() { if rc.db != nil { rc.db.Close() } - rc.cancel() log.Info("Restore client closed") } @@ -258,11 +249,11 @@ func (rc *Client) GetTS(ctx context.Context) (uint64, error) { } // ResetTS resets the timestamp of PD to a bigger value. -func (rc *Client) ResetTS(pdAddrs []string) error { +func (rc *Client) ResetTS(ctx context.Context, pdAddrs []string) error { restoreTS := rc.backupMeta.GetEndVersion() log.Info("reset pd timestamp", zap.Uint64("ts", restoreTS)) i := 0 - return utils.WithRetry(rc.ctx, func() error { + return utils.WithRetry(ctx, func() error { idx := i % len(pdAddrs) i++ return utils.ResetTS(pdAddrs[idx], restoreTS, rc.tlsConf) @@ -270,10 +261,10 @@ func (rc *Client) ResetTS(pdAddrs []string) error { } // GetPlacementRules return the current placement rules. -func (rc *Client) GetPlacementRules(pdAddrs []string) ([]placement.Rule, error) { +func (rc *Client) GetPlacementRules(ctx context.Context, pdAddrs []string) ([]placement.Rule, error) { var placementRules []placement.Rule i := 0 - errRetry := utils.WithRetry(rc.ctx, func() error { + errRetry := utils.WithRetry(ctx, func() error { var err error idx := i % len(pdAddrs) i++ @@ -317,12 +308,12 @@ func (rc *Client) GetTableSchema( } // CreateDatabase creates a database. -func (rc *Client) CreateDatabase(db *model.DBInfo) error { +func (rc *Client) CreateDatabase(ctx context.Context, db *model.DBInfo) error { if rc.IsSkipCreateSQL() { log.Info("skip create database", zap.Stringer("database", db.Name)) return nil } - return rc.db.CreateDatabase(rc.ctx, db) + return rc.db.CreateDatabase(ctx, db) } // CreateTables creates multiple tables, and returns their rewrite rules. @@ -472,14 +463,14 @@ func (rc *Client) createTablesWithDBPool(ctx context.Context, } // ExecDDLs executes the queries of the ddl jobs. -func (rc *Client) ExecDDLs(ddlJobs []*model.Job) error { +func (rc *Client) ExecDDLs(ctx context.Context, ddlJobs []*model.Job) error { // Sort the ddl jobs by schema version in ascending order. sort.Slice(ddlJobs, func(i, j int) bool { return ddlJobs[i].BinlogInfo.SchemaVersion < ddlJobs[j].BinlogInfo.SchemaVersion }) for _, job := range ddlJobs { - err := rc.db.ExecDDL(rc.ctx, job) + err := rc.db.ExecDDL(ctx, job) if err != nil { return errors.Trace(err) } @@ -491,14 +482,14 @@ func (rc *Client) ExecDDLs(ddlJobs []*model.Job) error { return nil } -func (rc *Client) setSpeedLimit() error { +func (rc *Client) setSpeedLimit(ctx context.Context) error { if !rc.hasSpeedLimited && rc.rateLimit != 0 { - stores, err := conn.GetAllTiKVStores(rc.ctx, rc.pdClient, conn.SkipTiFlash) + stores, err := conn.GetAllTiKVStores(ctx, rc.pdClient, conn.SkipTiFlash) if err != nil { return err } for _, store := range stores { - err = rc.fileImporter.setDownloadSpeedLimit(rc.ctx, store.GetId()) + err = rc.fileImporter.setDownloadSpeedLimit(ctx, store.GetId()) if err != nil { return err } @@ -510,6 +501,7 @@ func (rc *Client) setSpeedLimit() error { // RestoreFiles tries to restore the files. func (rc *Client) RestoreFiles( + ctx context.Context, files []*backup.File, rewriteRules *RewriteRules, updateCh glue.Progress, @@ -527,8 +519,8 @@ func (rc *Client) RestoreFiles( log.Debug("start to restore files", zap.Int("files", len(files)), ) - eg, ectx := errgroup.WithContext(rc.ctx) - err = rc.setSpeedLimit() + eg, ectx := errgroup.WithContext(ctx) + err = rc.setSpeedLimit(ctx) if err != nil { return err } @@ -553,7 +545,9 @@ func (rc *Client) RestoreFiles( } // RestoreRaw tries to restore raw keys in the specified range. -func (rc *Client) RestoreRaw(startKey []byte, endKey []byte, files []*backup.File, updateCh glue.Progress) error { +func (rc *Client) RestoreRaw( + ctx context.Context, startKey []byte, endKey []byte, files []*backup.File, updateCh glue.Progress, +) error { start := time.Now() defer func() { elapsed := time.Since(start) @@ -563,7 +557,7 @@ func (rc *Client) RestoreRaw(startKey []byte, endKey []byte, files []*backup.Fil zap.Duration("take", elapsed)) }() errCh := make(chan error, len(files)) - eg, ectx := errgroup.WithContext(rc.ctx) + eg, ectx := errgroup.WithContext(ctx) defer close(errCh) err := rc.fileImporter.SetRawRange(startKey, endKey) diff --git a/pkg/restore/client_test.go b/pkg/restore/client_test.go index d61a74b46..1bac5843d 100644 --- a/pkg/restore/client_test.go +++ b/pkg/restore/client_test.go @@ -3,7 +3,6 @@ package restore_test import ( - "context" "math" "strconv" @@ -40,7 +39,7 @@ func (s *testRestoreClientSuite) TestCreateTables(c *C) { c.Assert(s.mock.Start(), IsNil) defer s.mock.Stop() - client, err := restore.NewRestoreClient(context.Background(), gluetidb.New(), s.mock.PDClient, s.mock.Storage, nil) + client, err := restore.NewRestoreClient(gluetidb.New(), s.mock.PDClient, s.mock.Storage, nil) c.Assert(err, IsNil) info, err := s.mock.Domain.GetSnapshotInfoSchema(math.MaxInt64) @@ -98,7 +97,7 @@ func (s *testRestoreClientSuite) TestIsOnline(c *C) { c.Assert(s.mock.Start(), IsNil) defer s.mock.Stop() - client, err := restore.NewRestoreClient(context.Background(), gluetidb.New(), s.mock.PDClient, s.mock.Storage, nil) + client, err := restore.NewRestoreClient(gluetidb.New(), s.mock.PDClient, s.mock.Storage, nil) c.Assert(err, IsNil) c.Assert(client.IsOnline(), IsFalse) diff --git a/pkg/restore/pipeline_items.go b/pkg/restore/pipeline_items.go index 83242cb2b..aca8c1efd 100644 --- a/pkg/restore/pipeline_items.go +++ b/pkg/restore/pipeline_items.go @@ -164,7 +164,7 @@ func (b *tikvSender) RestoreBatch(ctx context.Context, ranges []rtree.Range, rew files = append(files, fs.Files...) } - if err := b.client.RestoreFiles(files, rewriteRules, b.updateCh); err != nil { + if err := b.client.RestoreFiles(ctx, files, rewriteRules, b.updateCh); err != nil { return err } diff --git a/pkg/task/restore.go b/pkg/task/restore.go index f3e85f355..c91e56abf 100644 --- a/pkg/task/restore.go +++ b/pkg/task/restore.go @@ -99,7 +99,7 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf } defer mgr.Close() - client, err := restore.NewRestoreClient(ctx, g, mgr.GetPDClient(), mgr.GetTiKV(), mgr.GetTLSConfig()) + client, err := restore.NewRestoreClient(g, mgr.GetPDClient(), mgr.GetTiKV(), mgr.GetTLSConfig()) if err != nil { return err } @@ -158,7 +158,7 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf defer restoreDBConfig() // execute DDL first - err = client.ExecDDLs(ddlJobs) + err = client.ExecDDLs(ctx, ddlJobs) if err != nil { return errors.Trace(err) } @@ -172,7 +172,7 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf } for _, db := range dbs { - err = client.CreateDatabase(db.Info) + err = client.CreateDatabase(ctx, db.Info) if err != nil { return err } @@ -226,7 +226,7 @@ func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf // Do not reset timestamp if we are doing incremental restore, because // we are not allowed to decrease timestamp. if !client.IsIncremental() { - if err = client.ResetTS(cfg.PD); err != nil { + if err = client.ResetTS(ctx, cfg.PD); err != nil { log.Error("reset pd TS failed", zap.Error(err)) return err } diff --git a/pkg/task/restore_raw.go b/pkg/task/restore_raw.go index a1a6f0513..a4ca0a446 100644 --- a/pkg/task/restore_raw.go +++ b/pkg/task/restore_raw.go @@ -56,7 +56,7 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR } defer mgr.Close() - client, err := restore.NewRestoreClient(ctx, g, mgr.GetPDClient(), mgr.GetTiKV(), mgr.GetTLSConfig()) + client, err := restore.NewRestoreClient(g, mgr.GetPDClient(), mgr.GetTiKV(), mgr.GetTLSConfig()) if err != nil { return err } @@ -116,7 +116,7 @@ func RunRestoreRaw(c context.Context, g glue.Glue, cmdName string, cfg *RestoreR } defer restorePostWork(ctx, client, restoreSchedulers) - err = client.RestoreRaw(cfg.StartKey, cfg.EndKey, files, updateCh) + err = client.RestoreRaw(ctx, cfg.StartKey, cfg.EndKey, files, updateCh) if err != nil { return errors.Trace(err) }