Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions pkg/backup/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"time"

"github.com/pingcap/errors"
Expand All @@ -16,6 +15,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tipb/go-tipb"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"

"github.com/pingcap/br/pkg/checksum"
"github.com/pingcap/br/pkg/glue"
Expand All @@ -35,15 +35,13 @@ type Schemas struct {
schemas map[string]backup.Schema
backupSchemaCh chan backup.Schema
errCh chan error
wg *sync.WaitGroup
}

func newBackupSchemas() *Schemas {
return &Schemas{
schemas: make(map[string]backup.Schema),
backupSchemaCh: make(chan backup.Schema),
errCh: make(chan error),
wg: new(sync.WaitGroup),
}
}

Expand All @@ -65,28 +63,24 @@ func (pending *Schemas) Start(
updateCh glue.Progress,
) {
workerPool := utils.NewWorkerPool(concurrency, "Schemas")
errg, ectx := errgroup.WithContext(ctx)
go func() {
startAll := time.Now()
for n, s := range pending.schemas {
log.Info("table checksum start", zap.String("table", n))
name := n
schema := s
pending.wg.Add(1)
workerPool.Apply(func() {
defer pending.wg.Done()

workerPool.ApplyOnErrorGroup(errg, func() error {
start := time.Now()
table := model.TableInfo{}
err := json.Unmarshal(schema.Table, &table)
if err != nil {
pending.errCh <- err
return
return err
}
checksumResp, err := calculateChecksum(
ctx, &table, store.GetClient(), backupTS)
ectx, &table, store.GetClient(), backupTS)
if err != nil {
pending.errCh <- err
return
return err
}
schema.Crc64Xor = checksumResp.Checksum
schema.TotalKvs = checksumResp.TotalKvs
Expand All @@ -100,9 +94,12 @@ func (pending *Schemas) Start(
pending.backupSchemaCh <- schema

updateCh.Inc()
return nil
})
}
pending.wg.Wait()
if err := errg.Wait(); err != nil {
pending.errCh <- err
}
close(pending.backupSchemaCh)
log.Info("backup checksum",
zap.Duration("take", time.Since(startAll)))
Expand Down
182 changes: 77 additions & 105 deletions pkg/restore/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"math"
"sort"
"strconv"
"sync"
"time"

"github.com/gogo/protobuf/proto"
Expand All @@ -31,6 +30,7 @@ import (
"github.com/pingcap/tidb/util/codec"
"go.uber.org/multierr"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -180,7 +180,7 @@ func (rc *Client) InitBackupMeta(backupMeta *backup.BackupMeta, backend *backup.

metaClient := NewSplitClient(rc.pdClient, rc.tlsConf)
importClient := NewImportClient(metaClient, rc.tlsConf)
rc.fileImporter = NewFileImporter(rc.ctx, metaClient, importClient, backend, backupMeta.IsRawKv, rc.rateLimit)
rc.fileImporter = NewFileImporter(metaClient, importClient, backend, backupMeta.IsRawKv, rc.rateLimit)
return nil
}

Expand Down Expand Up @@ -383,10 +383,6 @@ func (rc *Client) createTable(
table *utils.Table,
newTS uint64,
) (CreatedTable, error) {
if db == nil {
db = rc.db
}

if rc.IsSkipCreateSQL() {
log.Info("skip create table and alter autoIncID", zap.Stringer("table", table.Info.Name))
} else {
Expand Down Expand Up @@ -423,14 +419,15 @@ func (rc *Client) GoCreateTables(
errCh chan<- error,
) <-chan CreatedTable {
// Could we have a smaller size of tables?
log.Info("start create tables")
outCh := make(chan CreatedTable, len(tables))
createOneTable := func(db *DB, t *utils.Table) error {
createOneTable := func(c context.Context, db *DB, t *utils.Table) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.Done():
return c.Err()
default:
}
rt, err := rc.createTable(ctx, db, dom, t, newTS)
rt, err := rc.createTable(c, db, dom, t, newTS)
if err != nil {
log.Error("create table failed",
zap.Error(err),
Expand All @@ -445,50 +442,48 @@ func (rc *Client) GoCreateTables(
outCh <- rt
return nil
}
startWork := func(t *utils.Table, done func()) {
defer done()
if err := createOneTable(nil, t); err != nil {
errCh <- err
return
}
}
if len(dbPool) > 0 {
workers := utils.NewWorkerPool(uint(len(dbPool)), "DDL workers")
startWork = func(t *utils.Table, done func()) {
workers.ApplyWithID(func(id uint64) {
defer done()
selectedDB := int(id) % len(dbPool)
if err := createOneTable(dbPool[selectedDB], t); err != nil {
errCh <- err
return
}
})
}
}

go func() {
// TODO replace it with an errgroup
wg := new(sync.WaitGroup)
defer close(outCh)
defer log.Info("all tables created")
defer func() {
if len(dbPool) > 0 {
for _, db := range dbPool {
db.Close()
}
}
}()

for _, table := range tables {
tbl := table
wg.Add(1)
startWork(tbl, wg.Done)
defer log.Debug("all tables are created")
var err error
if len(dbPool) > 0 {
err = rc.createTablesWithDBPool(ctx, createOneTable, tables, dbPool)
} else {
err = rc.createTablesWithSoleDB(ctx, createOneTable, tables)
}
if err != nil {
errCh <- err
}
wg.Wait()
}()
return outCh
}

func (rc *Client) createTablesWithSoleDB(ctx context.Context,
createOneTable func(ctx context.Context, db *DB, t *utils.Table) error,
tables []*utils.Table) error {
for _, t := range tables {
if err := createOneTable(ctx, rc.db, t); err != nil {
return err
}
}
return nil
}

func (rc *Client) createTablesWithDBPool(ctx context.Context,
createOneTable func(ctx context.Context, db *DB, t *utils.Table) error,
tables []*utils.Table, dbPool []*DB) error {
eg, ectx := errgroup.WithContext(ctx)
workers := utils.NewWorkerPool(uint(len(dbPool)), "DDL workers")
for _, t := range tables {
table := t
workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error {
db := dbPool[id%uint64(len(dbPool))]
return createOneTable(ectx, db, table)
})
}
return eg.Wait()
}

// makeTiFlashOfTableRecord make a 'record' repsenting TiFlash of a table that has been removed.
// We doesn't record table ID here because when restore TiFlash replicas,
// we use `ALTER TABLE db.tbl SET TIFLASH_REPLICA = xxx` DDL, instead of use some internal TiDB API.
Expand Down Expand Up @@ -617,7 +612,7 @@ func (rc *Client) setSpeedLimit() error {
return err
}
for _, store := range stores {
err = rc.fileImporter.setDownloadSpeedLimit(store.GetId())
err = rc.fileImporter.setDownloadSpeedLimit(rc.ctx, store.GetId())
if err != nil {
return err
}
Expand Down Expand Up @@ -647,42 +642,28 @@ func (rc *Client) RestoreFiles(
log.Debug("start to restore files",
zap.Int("files", len(files)),
)
errCh := make(chan error, len(files))
wg := new(sync.WaitGroup)
defer close(errCh)
eg, ectx := errgroup.WithContext(rc.ctx)
err = rc.setSpeedLimit()
if err != nil {
return err
}

for _, file := range files {
wg.Add(1)
fileReplica := file
rc.workerPool.Apply(
func() {
defer wg.Done()
select {
case <-rc.ctx.Done():
errCh <- rc.ctx.Err()
case errCh <- rc.fileImporter.Import(fileReplica, rejectStoreMap, rewriteRules):
updateCh.Inc()
}
rc.workerPool.ApplyOnErrorGroup(eg,
func() error {
defer updateCh.Inc()
return rc.fileImporter.Import(ectx, fileReplica, rejectStoreMap, rewriteRules)
})
}
for i := range files {
err := <-errCh
if err != nil {
summary.CollectFailureUnit(fmt.Sprintf("file:%d", i), err)
rc.cancel()
wg.Wait()
log.Error(
"restore files failed",
zap.Error(err),
)
return err
}
if err := eg.Wait(); err != nil {
summary.CollectFailureUnit("file", err)
log.Error(
"restore files failed",
zap.Error(err),
)
return err
}
wg.Wait()
return nil
}

Expand All @@ -697,7 +678,7 @@ func (rc *Client) RestoreRaw(startKey []byte, endKey []byte, files []*backup.Fil
zap.Duration("take", elapsed))
}()
errCh := make(chan error, len(files))
wg := new(sync.WaitGroup)
eg, ectx := errgroup.WithContext(rc.ctx)
defer close(errCh)

err := rc.fileImporter.SetRawRange(startKey, endKey)
Expand All @@ -707,32 +688,21 @@ func (rc *Client) RestoreRaw(startKey []byte, endKey []byte, files []*backup.Fil

emptyRules := &RewriteRules{}
for _, file := range files {
wg.Add(1)
fileReplica := file
rc.workerPool.Apply(
func() {
defer wg.Done()
select {
case <-rc.ctx.Done():
errCh <- rc.ctx.Err()
case errCh <- rc.fileImporter.Import(fileReplica, nil, emptyRules):
updateCh.Inc()
}
rc.workerPool.ApplyOnErrorGroup(eg,
func() error {
defer updateCh.Inc()
return rc.fileImporter.Import(ectx, fileReplica, nil, emptyRules)
})
}
for range files {
err := <-errCh
if err != nil {
rc.cancel()
wg.Wait()
log.Error(
"restore raw range failed",
zap.String("startKey", hex.EncodeToString(startKey)),
zap.String("endKey", hex.EncodeToString(endKey)),
zap.Error(err),
)
return err
}
if err := eg.Wait(); err != nil {
log.Error(
"restore raw range failed",
zap.String("startKey", hex.EncodeToString(startKey)),
zap.String("endKey", hex.EncodeToString(endKey)),
zap.Error(err),
)
return err
}
log.Info(
"finish to restore raw range",
Expand Down Expand Up @@ -832,31 +802,33 @@ func (rc *Client) GoValidateChecksum(
workers := utils.NewWorkerPool(defaultChecksumConcurrency, "RestoreChecksum")
go func() {
start := time.Now()
wg := new(sync.WaitGroup)
wg, ectx := errgroup.WithContext(ctx)
defer func() {
log.Info("all checksum ended")
wg.Wait()
if err := wg.Wait(); err != nil {
errCh <- err
}
elapsed := time.Since(start)
summary.CollectDuration("restore checksum", elapsed)
outCh <- struct{}{}
close(outCh)
}()
for {
select {
// if we use ectx here, maybe canceled will mask real error.
case <-ctx.Done():
errCh <- ctx.Err()
case tbl, ok := <-tableStream:
if !ok {
return
}
wg.Add(1)
workers.Apply(func() {
err := rc.execChecksum(ctx, tbl, kvClient)
workers.ApplyOnErrorGroup(wg, func() error {
err := rc.execChecksum(ectx, tbl, kvClient)
if err != nil {
errCh <- err
return err
}
updateCh.Inc()
wg.Done()
return nil
})
}
}
Expand Down
Loading