Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.
23 changes: 10 additions & 13 deletions pkg/backup/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"time"

"github.com/pingcap/errors"
Expand All @@ -16,6 +15,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tipb/go-tipb"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"

"github.com/pingcap/br/pkg/checksum"
"github.com/pingcap/br/pkg/glue"
Expand All @@ -35,15 +35,13 @@ type Schemas struct {
schemas map[string]backup.Schema
backupSchemaCh chan backup.Schema
errCh chan error
wg *sync.WaitGroup
}

func newBackupSchemas() *Schemas {
return &Schemas{
schemas: make(map[string]backup.Schema),
backupSchemaCh: make(chan backup.Schema),
errCh: make(chan error),
wg: new(sync.WaitGroup),
}
}

Expand All @@ -65,28 +63,24 @@ func (pending *Schemas) Start(
updateCh glue.Progress,
) {
workerPool := utils.NewWorkerPool(concurrency, "Schemas")
errg, ectx := errgroup.WithContext(ctx)
go func() {
startAll := time.Now()
for n, s := range pending.schemas {
log.Info("table checksum start", zap.String("table", n))
name := n
schema := s
pending.wg.Add(1)
workerPool.Apply(func() {
defer pending.wg.Done()

workerPool.ApplyOnErrorGroup(errg, func() error {
start := time.Now()
table := model.TableInfo{}
err := json.Unmarshal(schema.Table, &table)
if err != nil {
pending.errCh <- err
return
return err
}
checksumResp, err := calculateChecksum(
ctx, &table, store.GetClient(), backupTS)
ectx, &table, store.GetClient(), backupTS)
if err != nil {
pending.errCh <- err
return
return err
}
schema.Crc64Xor = checksumResp.Checksum
schema.TotalKvs = checksumResp.TotalKvs
Expand All @@ -100,9 +94,12 @@ func (pending *Schemas) Start(
pending.backupSchemaCh <- schema

updateCh.Inc()
return nil
})
}
pending.wg.Wait()
if err := errg.Wait(); err != nil {
pending.errCh <- err
}
close(pending.backupSchemaCh)
log.Info("backup checksum",
zap.Duration("take", time.Since(startAll)))
Expand Down
182 changes: 77 additions & 105 deletions pkg/restore/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"fmt"
"sort"
"strconv"
"sync"
"time"

"github.com/gogo/protobuf/proto"
Expand All @@ -30,6 +29,7 @@ import (
"github.com/pingcap/tidb/util/codec"
"go.uber.org/multierr"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -179,7 +179,7 @@ func (rc *Client) InitBackupMeta(backupMeta *backup.BackupMeta, backend *backup.

metaClient := NewSplitClient(rc.pdClient, rc.tlsConf)
importClient := NewImportClient(metaClient, rc.tlsConf)
rc.fileImporter = NewFileImporter(rc.ctx, metaClient, importClient, backend, backupMeta.IsRawKv, rc.rateLimit)
rc.fileImporter = NewFileImporter(metaClient, importClient, backend, backupMeta.IsRawKv, rc.rateLimit)
return nil
}

Expand Down Expand Up @@ -379,10 +379,6 @@ func (rc *Client) createTable(
table *utils.Table,
newTS uint64,
) (CreatedTable, error) {
if db == nil {
db = rc.db
}

if rc.IsSkipCreateSQL() {
log.Info("skip create table and alter autoIncID", zap.Stringer("table", table.Info.Name))
} else {
Expand Down Expand Up @@ -419,14 +415,15 @@ func (rc *Client) GoCreateTables(
errCh chan<- error,
) <-chan CreatedTable {
// Could we have a smaller size of tables?
log.Info("start create tables")
outCh := make(chan CreatedTable, len(tables))
createOneTable := func(db *DB, t *utils.Table) error {
createOneTable := func(c context.Context, db *DB, t *utils.Table) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.Done():
return c.Err()
default:
}
rt, err := rc.createTable(ctx, db, dom, t, newTS)
rt, err := rc.createTable(c, db, dom, t, newTS)
if err != nil {
log.Error("create table failed",
zap.Error(err),
Expand All @@ -441,50 +438,48 @@ func (rc *Client) GoCreateTables(
outCh <- rt
return nil
}
startWork := func(t *utils.Table, done func()) {
defer done()
if err := createOneTable(nil, t); err != nil {
errCh <- err
return
}
}
if len(dbPool) > 0 {
workers := utils.NewWorkerPool(uint(len(dbPool)), "DDL workers")
startWork = func(t *utils.Table, done func()) {
workers.ApplyWithID(func(id uint64) {
defer done()
selectedDB := int(id) % len(dbPool)
if err := createOneTable(dbPool[selectedDB], t); err != nil {
errCh <- err
return
}
})
}
}

go func() {
// TODO replace it with an errgroup
wg := new(sync.WaitGroup)
defer close(outCh)
defer log.Info("all tables created")
defer func() {
if len(dbPool) > 0 {
for _, db := range dbPool {
db.Close()
}
}
}()

for _, table := range tables {
tbl := table
wg.Add(1)
startWork(tbl, wg.Done)
defer log.Debug("all tables are created")
var err error
if len(dbPool) > 0 {
err = rc.createTablesWithDBPool(ctx, createOneTable, tables, dbPool)
} else {
err = rc.createTablesWithSoleDB(ctx, createOneTable, tables)
}
if err != nil {
errCh <- err
}
wg.Wait()
}()
return outCh
}

func (rc *Client) createTablesWithSoleDB(ctx context.Context,
createOneTable func(ctx context.Context, db *DB, t *utils.Table) error,
tables []*utils.Table) error {
for _, t := range tables {
if err := createOneTable(ctx, rc.db, t); err != nil {
return err
}
}
return nil
}

func (rc *Client) createTablesWithDBPool(ctx context.Context,
createOneTable func(ctx context.Context, db *DB, t *utils.Table) error,
tables []*utils.Table, dbPool []*DB) error {
eg, ectx := errgroup.WithContext(ctx)
workers := utils.NewWorkerPool(uint(len(dbPool)), "DDL workers")
for _, t := range tables {
table := t
workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error {
db := dbPool[id%uint64(len(dbPool))]
return createOneTable(ectx, db, table)
})
}
return eg.Wait()
}

// makeTiFlashOfTableRecord make a 'record' repsenting TiFlash of a table that has been removed.
// We doesn't record table ID here because when restore TiFlash replicas,
// we use `ALTER TABLE db.tbl SET TIFLASH_REPLICA = xxx` DDL, instead of use some internal TiDB API.
Expand Down Expand Up @@ -613,7 +608,7 @@ func (rc *Client) setSpeedLimit() error {
return err
}
for _, store := range stores {
err = rc.fileImporter.setDownloadSpeedLimit(store.GetId())
err = rc.fileImporter.setDownloadSpeedLimit(rc.ctx, store.GetId())
if err != nil {
return err
}
Expand Down Expand Up @@ -643,42 +638,28 @@ func (rc *Client) RestoreFiles(
log.Debug("start to restore files",
zap.Int("files", len(files)),
)
errCh := make(chan error, len(files))
wg := new(sync.WaitGroup)
defer close(errCh)
eg, ectx := errgroup.WithContext(rc.ctx)
err = rc.setSpeedLimit()
if err != nil {
return err
}

for _, file := range files {
wg.Add(1)
fileReplica := file
rc.workerPool.Apply(
func() {
defer wg.Done()
select {
case <-rc.ctx.Done():
errCh <- rc.ctx.Err()
case errCh <- rc.fileImporter.Import(fileReplica, rejectStoreMap, rewriteRules):
updateCh.Inc()
}
rc.workerPool.ApplyOnErrorGroup(eg,
func() error {
defer updateCh.Inc()
return rc.fileImporter.Import(ectx, fileReplica, rejectStoreMap, rewriteRules)
})
}
for i := range files {
err := <-errCh
if err != nil {
summary.CollectFailureUnit(fmt.Sprintf("file:%d", i), err)
rc.cancel()
wg.Wait()
log.Error(
"restore files failed",
zap.Error(err),
)
return err
}
if err := eg.Wait(); err != nil {
summary.CollectFailureUnit("file", err)
log.Error(
"restore files failed",
zap.Error(err),
)
return err
}
wg.Wait()
return nil
}

Expand All @@ -693,7 +674,7 @@ func (rc *Client) RestoreRaw(startKey []byte, endKey []byte, files []*backup.Fil
zap.Duration("take", elapsed))
}()
errCh := make(chan error, len(files))
wg := new(sync.WaitGroup)
eg, ectx := errgroup.WithContext(rc.ctx)
defer close(errCh)

err := rc.fileImporter.SetRawRange(startKey, endKey)
Expand All @@ -703,32 +684,21 @@ func (rc *Client) RestoreRaw(startKey []byte, endKey []byte, files []*backup.Fil

emptyRules := &RewriteRules{}
for _, file := range files {
wg.Add(1)
fileReplica := file
rc.workerPool.Apply(
func() {
defer wg.Done()
select {
case <-rc.ctx.Done():
errCh <- rc.ctx.Err()
case errCh <- rc.fileImporter.Import(fileReplica, nil, emptyRules):
updateCh.Inc()
}
rc.workerPool.ApplyOnErrorGroup(eg,
func() error {
defer updateCh.Inc()
return rc.fileImporter.Import(ectx, fileReplica, nil, emptyRules)
})
}
for range files {
err := <-errCh
if err != nil {
rc.cancel()
wg.Wait()
log.Error(
"restore raw range failed",
zap.String("startKey", hex.EncodeToString(startKey)),
zap.String("endKey", hex.EncodeToString(endKey)),
zap.Error(err),
)
return err
}
if err := eg.Wait(); err != nil {
log.Error(
"restore raw range failed",
zap.String("startKey", hex.EncodeToString(startKey)),
zap.String("endKey", hex.EncodeToString(endKey)),
zap.Error(err),
)
return err
}
log.Info(
"finish to restore raw range",
Expand Down Expand Up @@ -828,31 +798,33 @@ func (rc *Client) GoValidateChecksum(
workers := utils.NewWorkerPool(defaultChecksumConcurrency, "RestoreChecksum")
go func() {
start := time.Now()
wg := new(sync.WaitGroup)
wg, ectx := errgroup.WithContext(ctx)
defer func() {
log.Info("all checksum ended")
wg.Wait()
if err := wg.Wait(); err != nil {
errCh <- err
}
elapsed := time.Since(start)
summary.CollectDuration("restore checksum", elapsed)
outCh <- struct{}{}
close(outCh)
}()
for {
select {
// if we use ectx here, maybe canceled will mask real error.
case <-ctx.Done():
errCh <- ctx.Err()
case tbl, ok := <-tableStream:
if !ok {
return
}
wg.Add(1)
workers.Apply(func() {
err := rc.execChecksum(ctx, tbl, kvClient)
workers.ApplyOnErrorGroup(wg, func() error {
err := rc.execChecksum(ectx, tbl, kvClient)
if err != nil {
errCh <- err
return err
}
updateCh.Inc()
wg.Done()
return nil
})
}
}
Expand Down
Loading