diff --git a/br/pkg/restore/BUILD.bazel b/br/pkg/restore/BUILD.bazel index 7d0071047d8e8..d4f70e278fd2c 100644 --- a/br/pkg/restore/BUILD.bazel +++ b/br/pkg/restore/BUILD.bazel @@ -61,6 +61,7 @@ go_library( "//util/collate", "//util/hack", "//util/mathutil", + "//util/sqlexec", "//util/table-filter", "@com_github_emirpasic_gods//maps/treemap", "@com_github_go_sql_driver_mysql//:mysql", diff --git a/br/pkg/restore/client.go b/br/pkg/restore/client.go index e218ca7b014bc..c0f58817ec0af 100644 --- a/br/pkg/restore/client.go +++ b/br/pkg/restore/client.go @@ -53,6 +53,7 @@ import ( "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/mathutil" + "github.com/pingcap/tidb/util/sqlexec" filter "github.com/pingcap/tidb/util/table-filter" "github.com/tikv/client-go/v2/oracle" pd "github.com/tikv/pd/client" @@ -1126,6 +1127,18 @@ func (rc *Client) SplitRanges(ctx context.Context, return SplitRanges(ctx, rc, ranges, rewriteRules, updateCh, isRawKv) } +func (rc *Client) WrapLogFilesIterWithSplitHelper(iter LogIter, rules map[int64]*RewriteRules, g glue.Glue, store kv.Storage) (LogIter, error) { + se, err := g.CreateSession(store) + if err != nil { + return nil, errors.Trace(err) + } + execCtx := se.GetSessionCtx().(sqlexec.RestrictedSQLExecutor) + splitSize, splitKeys := utils.GetRegionSplitInfo(execCtx) + log.Info("get split threshold from tikv config", zap.Uint64("split-size", splitSize), zap.Int64("split-keys", splitKeys)) + client := split.NewSplitClient(rc.GetPDClient(), rc.GetTLSConfig(), false) + return NewLogFilesIterWithSplitHelper(iter, rules, client, splitSize, splitKeys), nil +} + // RestoreSSTFiles tries to restore the files. func (rc *Client) RestoreSSTFiles( ctx context.Context, diff --git a/br/pkg/restore/split.go b/br/pkg/restore/split.go index a707d0f086ce9..17e04486587b9 100644 --- a/br/pkg/restore/split.go +++ b/br/pkg/restore/split.go @@ -5,12 +5,15 @@ package restore import ( "bytes" "context" + "sort" "strconv" "strings" + "sync" "time" "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" sst "github.com/pingcap/kvproto/pkg/import_sstpb" "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" @@ -19,9 +22,12 @@ import ( "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/br/pkg/rtree" "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/utils/iter" + "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/codec" "go.uber.org/multierr" "go.uber.org/zap" + "golang.org/x/sync/errgroup" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -428,3 +434,426 @@ func replacePrefix(s []byte, rewriteRules *RewriteRules) ([]byte, *sst.RewriteRu return s, nil } + +type rewriteSplitter struct { + rewriteKey []byte + tableID int64 + rule *RewriteRules + splitter *split.SplitHelper +} + +type splitHelperIterator struct { + tableSplitters []*rewriteSplitter +} + +func (iter *splitHelperIterator) Traverse(fn func(v split.Valued, endKey []byte, rule *RewriteRules) bool) { + for _, entry := range iter.tableSplitters { + endKey := codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(entry.tableID+1)) + rule := entry.rule + entry.splitter.Traverse(func(v split.Valued) bool { + return fn(v, endKey, rule) + }) + } +} + +func NewSplitHelperIteratorForTest(helper *split.SplitHelper, tableID int64, rule *RewriteRules) *splitHelperIterator { + return &splitHelperIterator{ + tableSplitters: []*rewriteSplitter{ + { + tableID: tableID, + rule: rule, + splitter: helper, + }, + }, + } +} + +type LogSplitHelper struct { + tableSplitter map[int64]*split.SplitHelper + rules map[int64]*RewriteRules + client split.SplitClient + pool *utils.WorkerPool + eg *errgroup.Group + regionsCh chan []*split.RegionInfo + + splitThreSholdSize uint64 + splitThreSholdKeys int64 +} + +func NewLogSplitHelper(rules map[int64]*RewriteRules, client split.SplitClient, splitSize uint64, splitKeys int64) *LogSplitHelper { + return &LogSplitHelper{ + tableSplitter: make(map[int64]*split.SplitHelper), + rules: rules, + client: client, + pool: utils.NewWorkerPool(128, "split region"), + eg: nil, + + splitThreSholdSize: splitSize, + splitThreSholdKeys: splitKeys, + } +} + +func (helper *LogSplitHelper) iterator() *splitHelperIterator { + tableSplitters := make([]*rewriteSplitter, 0, len(helper.tableSplitter)) + for tableID, splitter := range helper.tableSplitter { + delete(helper.tableSplitter, tableID) + rewriteRule, exists := helper.rules[tableID] + if !exists { + log.Info("skip splitting due to no table id matched", zap.Int64("tableID", tableID)) + continue + } + newTableID := GetRewriteTableID(tableID, rewriteRule) + if newTableID == 0 { + log.Warn("failed to get the rewrite table id", zap.Int64("tableID", tableID)) + continue + } + tableSplitters = append(tableSplitters, &rewriteSplitter{ + rewriteKey: codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(newTableID)), + tableID: newTableID, + rule: rewriteRule, + splitter: splitter, + }) + } + sort.Slice(tableSplitters, func(i, j int) bool { + return bytes.Compare(tableSplitters[i].rewriteKey, tableSplitters[j].rewriteKey) < 0 + }) + return &splitHelperIterator{ + tableSplitters: tableSplitters, + } +} + +const splitFileThreshold = 1024 * 1024 // 1 MB + +func (helper *LogSplitHelper) skipFile(file *backuppb.DataFileInfo) bool { + _, exist := helper.rules[file.TableId] + return file.Length < splitFileThreshold || file.IsMeta || !exist +} + +func (helper *LogSplitHelper) Merge(file *backuppb.DataFileInfo) { + if helper.skipFile(file) { + return + } + splitHelper, exist := helper.tableSplitter[file.TableId] + if !exist { + splitHelper = split.NewSplitHelper() + helper.tableSplitter[file.TableId] = splitHelper + } + + splitHelper.Merge(split.Valued{ + Key: split.Span{ + StartKey: file.StartKey, + EndKey: file.EndKey, + }, + Value: split.Value{ + Size: file.Length, + Number: file.NumberOfEntries, + }, + }) +} + +type splitFunc = func(context.Context, *RegionSplitter, uint64, int64, *split.RegionInfo, []split.Valued) error + +func (helper *LogSplitHelper) splitRegionByPoints( + ctx context.Context, + regionSplitter *RegionSplitter, + initialLength uint64, + initialNumber int64, + region *split.RegionInfo, + valueds []split.Valued, +) error { + var ( + splitPoints [][]byte = make([][]byte, 0) + lastKey []byte = region.Region.StartKey + length uint64 = initialLength + number int64 = initialNumber + ) + for _, v := range valueds { + // decode will discard ts behind the key, which results in the same key for consecutive ranges + if !bytes.Equal(lastKey, v.GetStartKey()) && (v.Value.Size+length > helper.splitThreSholdSize || v.Value.Number+number > helper.splitThreSholdKeys) { + _, rawKey, _ := codec.DecodeBytes(v.GetStartKey(), nil) + splitPoints = append(splitPoints, rawKey) + length = 0 + number = 0 + } + lastKey = v.GetStartKey() + length += v.Value.Size + number += v.Value.Number + } + + if len(splitPoints) == 0 { + return nil + } + + helper.pool.ApplyOnErrorGroup(helper.eg, func() error { + newRegions, errSplit := regionSplitter.splitAndScatterRegions(ctx, region, splitPoints) + if errSplit != nil { + log.Warn("failed to split the scaned region", zap.Error(errSplit)) + _, startKey, _ := codec.DecodeBytes(region.Region.StartKey, nil) + ranges := make([]rtree.Range, 0, len(splitPoints)) + for _, point := range splitPoints { + ranges = append(ranges, rtree.Range{StartKey: startKey, EndKey: point}) + startKey = point + } + + return regionSplitter.Split(ctx, ranges, nil, false, func([][]byte) {}) + } + select { + case <-ctx.Done(): + return nil + case helper.regionsCh <- newRegions: + } + log.Info("split the region", zap.Uint64("region-id", region.Region.Id), zap.Int("split-point-number", len(splitPoints))) + return nil + }) + return nil +} + +// GetRewriteTableID gets rewrite table id by the rewrite rule and original table id +func GetRewriteTableID(tableID int64, rewriteRules *RewriteRules) int64 { + tableKey := tablecodec.GenTableRecordPrefix(tableID) + rule := matchOldPrefix(tableKey, rewriteRules) + if rule == nil { + return 0 + } + + return tablecodec.DecodeTableID(rule.GetNewKeyPrefix()) +} + +// SplitPoint selects ranges overlapped with each region, and calls `splitF` to split the region +func SplitPoint( + ctx context.Context, + iter *splitHelperIterator, + client split.SplitClient, + splitF splitFunc, +) (err error) { + // common status + var ( + regionSplitter *RegionSplitter = NewRegionSplitter(client) + ) + // region traverse status + var ( + // the region buffer of each scan + regions []*split.RegionInfo = nil + regionIndex int = 0 + ) + // region split status + var ( + // range span +----------------+------+---+-------------+ + // region span +------------------------------------+ + // +initial length+ +end valued+ + // regionValueds is the ranges array overlapped with `regionInfo` + regionValueds []split.Valued = nil + // regionInfo is the region to be split + regionInfo *split.RegionInfo = nil + // intialLength is the length of the part of the first range overlapped with the region + initialLength uint64 = 0 + initialNumber int64 = 0 + ) + // range status + var ( + // regionOverCount is the number of regions overlapped with the range + regionOverCount uint64 = 0 + ) + + iter.Traverse(func(v split.Valued, endKey []byte, rule *RewriteRules) bool { + if v.Value.Number == 0 || v.Value.Size == 0 { + return true + } + var ( + vStartKey []byte + vEndKey []byte + ) + // use `vStartKey` and `vEndKey` to compare with region's key + vStartKey, vEndKey, err = GetRewriteEncodedKeys(v, rule) + if err != nil { + return false + } + // traverse to the first region overlapped with the range + for ; regionIndex < len(regions); regionIndex++ { + if bytes.Compare(vStartKey, regions[regionIndex].Region.EndKey) < 0 { + break + } + } + // cannot find any regions overlapped with the range + // need to scan regions again + if regionIndex == len(regions) { + regions = nil + } + regionOverCount = 0 + for { + if regionIndex >= len(regions) { + var startKey []byte + if len(regions) > 0 { + // has traversed over the region buffer, should scan from the last region's end-key of the region buffer + startKey = regions[len(regions)-1].Region.EndKey + } else { + // scan from the range's start-key + startKey = vStartKey + } + // scan at most 64 regions into the region buffer + regions, err = split.ScanRegionsWithRetry(ctx, client, startKey, endKey, 64) + if err != nil { + return false + } + regionIndex = 0 + } + + region := regions[regionIndex] + // this region must be overlapped with the range + regionOverCount++ + // the region is the last one overlapped with the range, + // should split the last recorded region, + // and then record this region as the region to be split + if bytes.Compare(vEndKey, region.Region.EndKey) < 0 { + endLength := v.Value.Size / regionOverCount + endNumber := v.Value.Number / int64(regionOverCount) + if len(regionValueds) > 0 && regionInfo != region { + // add a part of the range as the end part + if bytes.Compare(vStartKey, regionInfo.Region.EndKey) < 0 { + regionValueds = append(regionValueds, split.NewValued(vStartKey, regionInfo.Region.EndKey, split.Value{Size: endLength, Number: endNumber})) + } + // try to split the region + err = splitF(ctx, regionSplitter, initialLength, initialNumber, regionInfo, regionValueds) + if err != nil { + return false + } + regionValueds = make([]split.Valued, 0) + } + if regionOverCount == 1 { + // the region completely contains the range + regionValueds = append(regionValueds, split.Valued{ + Key: split.Span{ + StartKey: vStartKey, + EndKey: vEndKey, + }, + Value: v.Value, + }) + } else { + // the region is overlapped with the last part of the range + initialLength = endLength + initialNumber = endNumber + } + regionInfo = region + // try the next range + return true + } + + // try the next region + regionIndex++ + } + }) + + if err != nil { + return errors.Trace(err) + } + if len(regionValueds) > 0 { + // try to split the region + err = splitF(ctx, regionSplitter, initialLength, initialNumber, regionInfo, regionValueds) + if err != nil { + return errors.Trace(err) + } + } + + return nil +} + +func (helper *LogSplitHelper) Split(ctx context.Context) error { + var ectx context.Context + var wg sync.WaitGroup + helper.eg, ectx = errgroup.WithContext(ctx) + helper.regionsCh = make(chan []*split.RegionInfo, 1024) + wg.Add(1) + go func() { + defer wg.Done() + scatterRegions := make([]*split.RegionInfo, 0) + receiveNewRegions: + for { + select { + case <-ectx.Done(): + return + case newRegions, ok := <-helper.regionsCh: + if !ok { + break receiveNewRegions + } + + scatterRegions = append(scatterRegions, newRegions...) + } + } + + startTime := time.Now() + regionSplitter := NewRegionSplitter(helper.client) + for _, region := range scatterRegions { + regionSplitter.waitForScatterRegion(ctx, region) + // It is too expensive to stop recovery and wait for a small number of regions + // to complete scatter, so the maximum waiting time is reduced to 1 minute. + if time.Since(startTime) > time.Minute { + break + } + } + }() + + iter := helper.iterator() + if err := SplitPoint(ectx, iter, helper.client, helper.splitRegionByPoints); err != nil { + return errors.Trace(err) + } + + // wait for completion of splitting regions + if err := helper.eg.Wait(); err != nil { + return errors.Trace(err) + } + + // wait for completion of scattering regions + close(helper.regionsCh) + wg.Wait() + + return nil +} + +type LogFilesIterWithSplitHelper struct { + iter LogIter + helper *LogSplitHelper + buffer []*backuppb.DataFileInfo + next int +} + +const SplitFilesBufferSize = 4096 + +func NewLogFilesIterWithSplitHelper(iter LogIter, rules map[int64]*RewriteRules, client split.SplitClient, splitSize uint64, splitKeys int64) LogIter { + return &LogFilesIterWithSplitHelper{ + iter: iter, + helper: NewLogSplitHelper(rules, client, splitSize, splitKeys), + buffer: nil, + next: 0, + } +} + +func (splitIter *LogFilesIterWithSplitHelper) TryNext(ctx context.Context) iter.IterResult[*backuppb.DataFileInfo] { + if splitIter.next >= len(splitIter.buffer) { + splitIter.buffer = make([]*backuppb.DataFileInfo, 0, SplitFilesBufferSize) + for r := splitIter.iter.TryNext(ctx); !r.Finished; r = splitIter.iter.TryNext(ctx) { + if r.Err != nil { + return r + } + f := r.Item + splitIter.helper.Merge(f) + splitIter.buffer = append(splitIter.buffer, f) + if len(splitIter.buffer) >= SplitFilesBufferSize { + break + } + } + splitIter.next = 0 + if len(splitIter.buffer) == 0 { + return iter.Done[*backuppb.DataFileInfo]() + } + log.Info("start to split the regions") + startTime := time.Now() + if err := splitIter.helper.Split(ctx); err != nil { + return iter.Throw[*backuppb.DataFileInfo](errors.Trace(err)) + } + log.Info("end to split the regions", zap.Duration("takes", time.Since(startTime))) + } + + res := iter.Emit(splitIter.buffer[splitIter.next]) + splitIter.next += 1 + return res +} diff --git a/br/pkg/restore/split/BUILD.bazel b/br/pkg/restore/split/BUILD.bazel index 49fbec82c543c..ac9eb50eb4d20 100644 --- a/br/pkg/restore/split/BUILD.bazel +++ b/br/pkg/restore/split/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "split", @@ -6,6 +6,7 @@ go_library( "client.go", "region.go", "split.go", + "sum_sorted.go", ], importpath = "github.com/pingcap/tidb/br/pkg/restore/split", visibility = ["//visibility:public"], @@ -16,7 +17,9 @@ go_library( "//br/pkg/logutil", "//br/pkg/redact", "//br/pkg/utils", + "//kv", "//store/pdtypes", + "@com_github_google_btree//:btree", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/errorpb", @@ -34,3 +37,12 @@ go_library( "@org_uber_go_zap//:zap", ], ) + +go_test( + name = "split_test", + srcs = ["sum_sorted_test.go"], + deps = [ + ":split", + "@com_github_stretchr_testify//require", + ], +) diff --git a/br/pkg/restore/split/split.go b/br/pkg/restore/split/split.go index bd00c445e1184..e06c8ab1c93d5 100644 --- a/br/pkg/restore/split/split.go +++ b/br/pkg/restore/split/split.go @@ -121,6 +121,65 @@ func PaginateScanRegion( return regions, err } +// CheckPartRegionConsistency only checks the continuity of regions and the first region consistency. +func CheckPartRegionConsistency(startKey, endKey []byte, regions []*RegionInfo) error { + // current pd can't guarantee the consistency of returned regions + if len(regions) == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan region return empty result, startKey: %s, endKey: %s", + redact.Key(startKey), redact.Key(endKey)) + } + + if bytes.Compare(regions[0].Region.StartKey, startKey) > 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, "first region's startKey > startKey, startKey: %s, regionStartKey: %s", + redact.Key(startKey), redact.Key(regions[0].Region.StartKey)) + } + + cur := regions[0] + for _, r := range regions[1:] { + if !bytes.Equal(cur.Region.EndKey, r.Region.StartKey) { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, "region endKey not equal to next region startKey, endKey: %s, startKey: %s", + redact.Key(cur.Region.EndKey), redact.Key(r.Region.StartKey)) + } + cur = r + } + + return nil +} + +func ScanRegionsWithRetry( + ctx context.Context, client SplitClient, startKey, endKey []byte, limit int, +) ([]*RegionInfo, error) { + if len(endKey) != 0 && bytes.Compare(startKey, endKey) > 0 { + return nil, errors.Annotatef(berrors.ErrRestoreInvalidRange, "startKey > endKey, startKey: %s, endkey: %s", + hex.EncodeToString(startKey), hex.EncodeToString(endKey)) + } + + var regions []*RegionInfo + var err error + // we don't need to return multierr. since there only 3 times retry. + // in most case 3 times retry have the same error. so we just return the last error. + // actually we'd better remove all multierr in br/lightning. + // because it's not easy to check multierr equals normal error. + // see https://github.com/pingcap/tidb/issues/33419. + _ = utils.WithRetry(ctx, func() error { + regions, err = client.ScanRegions(ctx, startKey, endKey, limit) + if err != nil { + err = errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan regions from start-key:%s, err: %s", + redact.Key(startKey), err.Error()) + return err + } + + if err = CheckPartRegionConsistency(startKey, endKey, regions); err != nil { + log.Warn("failed to scan region, retrying", logutil.ShortError(err)) + return err + } + + return nil + }, newScanRegionBackoffer()) + + return regions, err +} + type scanRegionBackoffer struct { attempt int } diff --git a/br/pkg/restore/split/sum_sorted.go b/br/pkg/restore/split/sum_sorted.go new file mode 100644 index 0000000000000..c4e9657900e35 --- /dev/null +++ b/br/pkg/restore/split/sum_sorted.go @@ -0,0 +1,204 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. +package split + +import ( + "bytes" + "fmt" + + "github.com/google/btree" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/kv" +) + +// Value is the value type of stored in the span tree. +type Value struct { + Size uint64 + Number int64 +} + +// join finds the upper bound of two values. +func join(a, b Value) Value { + return Value{ + Size: a.Size + b.Size, + Number: a.Number + b.Number, + } +} + +// Span is the type of an adjacent sub key space. +type Span = kv.KeyRange + +// Valued is span binding to a value, which is the entry type of span tree. +type Valued struct { + Key Span + Value Value +} + +func NewValued(startKey, endKey []byte, value Value) Valued { + return Valued{ + Key: Span{ + StartKey: startKey, + EndKey: endKey, + }, + Value: value, + } +} + +func (v Valued) String() string { + return fmt.Sprintf("(%s, %.2f MB, %d)", logutil.StringifyRange(v.Key), float64(v.Value.Size)/1024/1024, v.Value.Number) +} + +func (v Valued) Less(other btree.Item) bool { + return bytes.Compare(v.Key.StartKey, other.(Valued).Key.StartKey) < 0 +} + +// implement for `AppliedFile` +func (v Valued) GetStartKey() []byte { + return v.Key.StartKey +} + +// implement for `AppliedFile` +func (v Valued) GetEndKey() []byte { + return v.Key.EndKey +} + +// SplitHelper represents a set of valued ranges, which doesn't overlap and union of them all is the full key space. +type SplitHelper struct { + inner *btree.BTree +} + +// NewSplitHelper creates a set of a subset of spans, with the full key space as initial status +func NewSplitHelper() *SplitHelper { + t := btree.New(16) + t.ReplaceOrInsert(Valued{Value: Value{Size: 0, Number: 0}, Key: Span{StartKey: []byte(""), EndKey: []byte("")}}) + return &SplitHelper{inner: t} +} + +func (f *SplitHelper) Merge(val Valued) { + if len(val.Key.StartKey) == 0 || len(val.Key.EndKey) == 0 { + return + } + overlaps := make([]Valued, 0, 8) + f.overlapped(val.Key, &overlaps) + f.mergeWithOverlap(val, overlaps) +} + +// traverse the items in ascend order +func (f *SplitHelper) Traverse(m func(Valued) bool) { + f.inner.Ascend(func(item btree.Item) bool { + return m(item.(Valued)) + }) +} + +func (f *SplitHelper) mergeWithOverlap(val Valued, overlapped []Valued) { + // There isn't any range overlaps with the input range, perhaps the input range is empty. + // do nothing for this case. + if len(overlapped) == 0 { + return + } + + for _, r := range overlapped { + f.inner.Delete(r) + } + // Assert All overlapped ranges are deleted. + + // the new valued item's Value is equally dividedd into `len(overlapped)` shares + appendValue := Value{ + Size: val.Value.Size / uint64(len(overlapped)), + Number: val.Value.Number / int64(len(overlapped)), + } + var ( + rightTrail *Valued + leftTrail *Valued + // overlapped ranges +-------------+----------+ + // new valued item +-------------+ + // a b c d e + // the part [a,b] is `standalone` because it is not overlapped with the new valued item + // the part [a,b] and [b,c] are `split` because they are from range [a,c] + emitToCollected = func(rng Valued, standalone bool, split bool) { + merged := rng.Value + if split { + merged.Size /= 2 + merged.Number /= 2 + } + if !standalone { + merged = join(appendValue, merged) + } + rng.Value = merged + f.inner.ReplaceOrInsert(rng) + } + ) + + leftmost := overlapped[0] + if bytes.Compare(leftmost.Key.StartKey, val.Key.StartKey) < 0 { + leftTrail = &Valued{ + Key: Span{StartKey: leftmost.Key.StartKey, EndKey: val.Key.StartKey}, + Value: leftmost.Value, + } + overlapped[0].Key.StartKey = val.Key.StartKey + } + + rightmost := overlapped[len(overlapped)-1] + if utils.CompareBytesExt(rightmost.Key.EndKey, true, val.Key.EndKey, true) > 0 { + rightTrail = &Valued{ + Key: Span{StartKey: val.Key.EndKey, EndKey: rightmost.Key.EndKey}, + Value: rightmost.Value, + } + overlapped[len(overlapped)-1].Key.EndKey = val.Key.EndKey + if len(overlapped) == 1 && leftTrail != nil { + // (split) (split) (split) + // overlapped ranges +-----------------------------+ + // new valued item +-------------+ + // a b c d + // now the overlapped range should be divided into 3 equal parts + // so modify the value to the 2/3x to be compatible with function `emitToCollected` + val := Value{ + Size: rightTrail.Value.Size * 2 / 3, + Number: rightTrail.Value.Number * 2 / 3, + } + leftTrail.Value = val + overlapped[0].Value = val + rightTrail.Value = val + } + } + + if leftTrail != nil { + emitToCollected(*leftTrail, true, true) + } + + for i, rng := range overlapped { + split := (i == 0 && leftTrail != nil) || (i == len(overlapped)-1 && rightTrail != nil) + emitToCollected(rng, false, split) + } + + if rightTrail != nil { + emitToCollected(*rightTrail, true, true) + } +} + +// overlapped inserts the overlapped ranges of the span into the `result` slice. +func (f *SplitHelper) overlapped(k Span, result *[]Valued) { + var first Span + f.inner.DescendLessOrEqual(Valued{Key: k}, func(item btree.Item) bool { + first = item.(Valued).Key + return false + }) + + f.inner.AscendGreaterOrEqual(Valued{Key: first}, func(item btree.Item) bool { + r := item.(Valued) + if !checkOverlaps(r.Key, k) { + return false + } + *result = append(*result, r) + return true + }) +} + +// checkOverlaps checks whether two spans have overlapped part. +// `ap` should be a finite range +func checkOverlaps(a, ap Span) bool { + if len(a.EndKey) == 0 { + return bytes.Compare(ap.EndKey, a.StartKey) > 0 + } + return bytes.Compare(a.StartKey, ap.EndKey) < 0 && bytes.Compare(ap.StartKey, a.EndKey) < 0 +} diff --git a/br/pkg/restore/split/sum_sorted_test.go b/br/pkg/restore/split/sum_sorted_test.go new file mode 100644 index 0000000000000..3a3b3db6d90eb --- /dev/null +++ b/br/pkg/restore/split/sum_sorted_test.go @@ -0,0 +1,141 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. +package split_test + +import ( + "testing" + + "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/stretchr/testify/require" +) + +func v(s, e string, val split.Value) split.Valued { + return split.Valued{ + Key: split.Span{ + StartKey: []byte(s), + EndKey: []byte(e), + }, + Value: val, + } +} + +func mb(b uint64) split.Value { + return split.Value{ + Size: b * 1024 * 1024, + Number: int64(b), + } +} + +func TestSumSorted(t *testing.T) { + cases := []struct { + values []split.Valued + result []uint64 + }{ + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("d", "g", mb(100)), + }, + result: []uint64{0, 250, 25, 75, 50, 0}, + }, + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("d", "f", mb(100)), + }, + result: []uint64{0, 250, 25, 125, 0}, + }, + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("c", "f", mb(100)), + }, + result: []uint64{0, 250, 150, 0}, + }, + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("c", "f", mb(100)), + v("da", "db", mb(100)), + }, + result: []uint64{0, 250, 50, 150, 50, 0}, + }, + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("c", "f", mb(100)), + v("da", "db", mb(100)), + v("cb", "db", mb(100)), + }, + result: []uint64{0, 250, 25, 75, 200, 50, 0}, + }, + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("c", "f", mb(100)), + v("da", "db", mb(100)), + v("cb", "f", mb(150)), + }, + result: []uint64{0, 250, 25, 75, 200, 100, 0}, + }, + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("c", "f", mb(100)), + v("da", "db", mb(100)), + v("cb", "df", mb(150)), + }, + result: []uint64{0, 250, 25, 75, 200, 75, 25, 0}, + }, + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("c", "f", mb(100)), + v("da", "db", mb(100)), + v("cb", "df", mb(150)), + }, + result: []uint64{0, 250, 25, 75, 200, 75, 25, 0}, + }, + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("c", "f", mb(100)), + v("da", "db", mb(100)), + v("c", "df", mb(150)), + }, + result: []uint64{0, 250, 100, 200, 75, 25, 0}, + }, + { + values: []split.Valued{ + v("a", "f", mb(100)), + v("a", "c", mb(200)), + v("c", "f", mb(100)), + v("da", "db", mb(100)), + v("c", "f", mb(150)), + }, + result: []uint64{0, 250, 100, 200, 100, 0}, + }, + } + + for _, ca := range cases { + full := split.NewSplitHelper() + for _, v := range ca.values { + full.Merge(v) + } + + i := 0 + full.Traverse(func(v split.Valued) bool { + require.Equal(t, mb(ca.result[i]), v.Value) + i++ + return true + }) + } +} diff --git a/br/pkg/restore/split_test.go b/br/pkg/restore/split_test.go index b726a5ec78729..1b560a4e1474d 100644 --- a/br/pkg/restore/split_test.go +++ b/br/pkg/restore/split_test.go @@ -5,6 +5,7 @@ package restore_test import ( "bytes" "context" + "fmt" "sync" "testing" "time" @@ -22,7 +23,9 @@ import ( "github.com/pingcap/tidb/br/pkg/restore/split" "github.com/pingcap/tidb/br/pkg/rtree" "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/utils/iter" "github.com/pingcap/tidb/store/pdtypes" + "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/codec" "github.com/stretchr/testify/require" "go.uber.org/multierr" @@ -729,3 +732,316 @@ func TestSplitFailed(t *testing.T) { require.GreaterOrEqual(t, len(r.splitRanges), 2) require.Len(t, r.restoredFiles, 0) } + +func keyWithTablePrefix(tableID int64, key string) []byte { + rawKey := append(tablecodec.GenTableRecordPrefix(tableID), []byte(key)...) + return codec.EncodeBytes([]byte{}, rawKey) +} + +func TestSplitPoint(t *testing.T) { + ctx := context.Background() + var oldTableID int64 = 50 + var tableID int64 = 100 + rewriteRules := &restore.RewriteRules{ + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.EncodeTablePrefix(oldTableID), + NewKeyPrefix: tablecodec.EncodeTablePrefix(tableID), + }, + }, + } + + // range: b c d e g i + // +---+ +---+ +---------+ + // +-------------+----------+---------+ + // region: a f h j + splitHelper := split.NewSplitHelper() + splitHelper.Merge(split.Valued{Key: split.Span{StartKey: keyWithTablePrefix(oldTableID, "b"), EndKey: keyWithTablePrefix(oldTableID, "c")}, Value: split.Value{Size: 100, Number: 100}}) + splitHelper.Merge(split.Valued{Key: split.Span{StartKey: keyWithTablePrefix(oldTableID, "d"), EndKey: keyWithTablePrefix(oldTableID, "e")}, Value: split.Value{Size: 200, Number: 200}}) + splitHelper.Merge(split.Valued{Key: split.Span{StartKey: keyWithTablePrefix(oldTableID, "g"), EndKey: keyWithTablePrefix(oldTableID, "i")}, Value: split.Value{Size: 300, Number: 300}}) + client := NewFakeSplitClient() + client.AppendRegion(keyWithTablePrefix(tableID, "a"), keyWithTablePrefix(tableID, "f")) + client.AppendRegion(keyWithTablePrefix(tableID, "f"), keyWithTablePrefix(tableID, "h")) + client.AppendRegion(keyWithTablePrefix(tableID, "h"), keyWithTablePrefix(tableID, "j")) + client.AppendRegion(keyWithTablePrefix(tableID, "j"), keyWithTablePrefix(tableID+1, "a")) + + iter := restore.NewSplitHelperIteratorForTest(splitHelper, tableID, rewriteRules) + err := restore.SplitPoint(ctx, iter, client, func(ctx context.Context, rs *restore.RegionSplitter, u uint64, o int64, ri *split.RegionInfo, v []split.Valued) error { + require.Equal(t, u, uint64(0)) + require.Equal(t, o, int64(0)) + require.Equal(t, ri.Region.StartKey, keyWithTablePrefix(tableID, "a")) + require.Equal(t, ri.Region.EndKey, keyWithTablePrefix(tableID, "f")) + require.EqualValues(t, v[0].Key.StartKey, keyWithTablePrefix(tableID, "b")) + require.EqualValues(t, v[0].Key.EndKey, keyWithTablePrefix(tableID, "c")) + require.EqualValues(t, v[1].Key.StartKey, keyWithTablePrefix(tableID, "d")) + require.EqualValues(t, v[1].Key.EndKey, keyWithTablePrefix(tableID, "e")) + require.Equal(t, len(v), 2) + return nil + }) + require.NoError(t, err) +} + +func getCharFromNumber(prefix string, i int) string { + c := '1' + (i % 10) + b := '1' + (i%100)/10 + a := '1' + i/100 + return fmt.Sprintf("%s%c%c%c", prefix, a, b, c) +} + +func TestSplitPoint2(t *testing.T) { + ctx := context.Background() + var oldTableID int64 = 50 + var tableID int64 = 100 + rewriteRules := &restore.RewriteRules{ + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.EncodeTablePrefix(oldTableID), + NewKeyPrefix: tablecodec.EncodeTablePrefix(tableID), + }, + }, + } + + // range: b c d e f i j k l n + // +---+ +---+ +-----------------+ +----+ +--------+ + // +---------------+--+.....+----+------------+---------+ + // region: a g >128 h m o + splitHelper := split.NewSplitHelper() + splitHelper.Merge(split.Valued{Key: split.Span{StartKey: keyWithTablePrefix(oldTableID, "b"), EndKey: keyWithTablePrefix(oldTableID, "c")}, Value: split.Value{Size: 100, Number: 100}}) + splitHelper.Merge(split.Valued{Key: split.Span{StartKey: keyWithTablePrefix(oldTableID, "d"), EndKey: keyWithTablePrefix(oldTableID, "e")}, Value: split.Value{Size: 200, Number: 200}}) + splitHelper.Merge(split.Valued{Key: split.Span{StartKey: keyWithTablePrefix(oldTableID, "f"), EndKey: keyWithTablePrefix(oldTableID, "i")}, Value: split.Value{Size: 300, Number: 300}}) + splitHelper.Merge(split.Valued{Key: split.Span{StartKey: keyWithTablePrefix(oldTableID, "j"), EndKey: keyWithTablePrefix(oldTableID, "k")}, Value: split.Value{Size: 200, Number: 200}}) + splitHelper.Merge(split.Valued{Key: split.Span{StartKey: keyWithTablePrefix(oldTableID, "l"), EndKey: keyWithTablePrefix(oldTableID, "n")}, Value: split.Value{Size: 200, Number: 200}}) + client := NewFakeSplitClient() + client.AppendRegion(keyWithTablePrefix(tableID, "a"), keyWithTablePrefix(tableID, "g")) + client.AppendRegion(keyWithTablePrefix(tableID, "g"), keyWithTablePrefix(tableID, getCharFromNumber("g", 0))) + for i := 0; i < 256; i++ { + client.AppendRegion(keyWithTablePrefix(tableID, getCharFromNumber("g", i)), keyWithTablePrefix(tableID, getCharFromNumber("g", i+1))) + } + client.AppendRegion(keyWithTablePrefix(tableID, getCharFromNumber("g", 256)), keyWithTablePrefix(tableID, "h")) + client.AppendRegion(keyWithTablePrefix(tableID, "h"), keyWithTablePrefix(tableID, "m")) + client.AppendRegion(keyWithTablePrefix(tableID, "m"), keyWithTablePrefix(tableID, "o")) + client.AppendRegion(keyWithTablePrefix(tableID, "o"), keyWithTablePrefix(tableID+1, "a")) + + firstSplit := true + iter := restore.NewSplitHelperIteratorForTest(splitHelper, tableID, rewriteRules) + err := restore.SplitPoint(ctx, iter, client, func(ctx context.Context, rs *restore.RegionSplitter, u uint64, o int64, ri *split.RegionInfo, v []split.Valued) error { + if firstSplit { + require.Equal(t, u, uint64(0)) + require.Equal(t, o, int64(0)) + require.Equal(t, ri.Region.StartKey, keyWithTablePrefix(tableID, "a")) + require.Equal(t, ri.Region.EndKey, keyWithTablePrefix(tableID, "g")) + require.EqualValues(t, v[0].Key.StartKey, keyWithTablePrefix(tableID, "b")) + require.EqualValues(t, v[0].Key.EndKey, keyWithTablePrefix(tableID, "c")) + require.EqualValues(t, v[1].Key.StartKey, keyWithTablePrefix(tableID, "d")) + require.EqualValues(t, v[1].Key.EndKey, keyWithTablePrefix(tableID, "e")) + require.EqualValues(t, v[2].Key.StartKey, keyWithTablePrefix(tableID, "f")) + require.EqualValues(t, v[2].Key.EndKey, keyWithTablePrefix(tableID, "g")) + require.Equal(t, v[2].Value.Size, uint64(1)) + require.Equal(t, v[2].Value.Number, int64(1)) + require.Equal(t, len(v), 3) + firstSplit = false + } else { + require.Equal(t, u, uint64(1)) + require.Equal(t, o, int64(1)) + require.Equal(t, ri.Region.StartKey, keyWithTablePrefix(tableID, "h")) + require.Equal(t, ri.Region.EndKey, keyWithTablePrefix(tableID, "m")) + require.EqualValues(t, v[0].Key.StartKey, keyWithTablePrefix(tableID, "j")) + require.EqualValues(t, v[0].Key.EndKey, keyWithTablePrefix(tableID, "k")) + require.EqualValues(t, v[1].Key.StartKey, keyWithTablePrefix(tableID, "l")) + require.EqualValues(t, v[1].Key.EndKey, keyWithTablePrefix(tableID, "m")) + require.Equal(t, v[1].Value.Size, uint64(100)) + require.Equal(t, v[1].Value.Number, int64(100)) + require.Equal(t, len(v), 2) + } + return nil + }) + require.NoError(t, err) +} + +type fakeSplitClient struct { + regions []*split.RegionInfo +} + +func NewFakeSplitClient() *fakeSplitClient { + return &fakeSplitClient{ + regions: make([]*split.RegionInfo, 0), + } +} + +func (f *fakeSplitClient) AppendRegion(startKey, endKey []byte) { + f.regions = append(f.regions, &split.RegionInfo{ + Region: &metapb.Region{ + StartKey: startKey, + EndKey: endKey, + }, + }) +} + +func (*fakeSplitClient) GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) { + return nil, nil +} +func (*fakeSplitClient) GetRegion(ctx context.Context, key []byte) (*split.RegionInfo, error) { + return nil, nil +} +func (*fakeSplitClient) GetRegionByID(ctx context.Context, regionID uint64) (*split.RegionInfo, error) { + return nil, nil +} +func (*fakeSplitClient) SplitRegion(ctx context.Context, regionInfo *split.RegionInfo, key []byte) (*split.RegionInfo, error) { + return nil, nil +} +func (*fakeSplitClient) BatchSplitRegions(ctx context.Context, regionInfo *split.RegionInfo, keys [][]byte) ([]*split.RegionInfo, error) { + return nil, nil +} +func (*fakeSplitClient) BatchSplitRegionsWithOrigin(ctx context.Context, regionInfo *split.RegionInfo, keys [][]byte) (*split.RegionInfo, []*split.RegionInfo, error) { + return nil, nil, nil +} +func (*fakeSplitClient) ScatterRegion(ctx context.Context, regionInfo *split.RegionInfo) error { + return nil +} +func (*fakeSplitClient) ScatterRegions(ctx context.Context, regionInfo []*split.RegionInfo) error { + return nil +} +func (*fakeSplitClient) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) { + return nil, nil +} +func (f *fakeSplitClient) ScanRegions(ctx context.Context, startKey, endKey []byte, limit int) ([]*split.RegionInfo, error) { + result := make([]*split.RegionInfo, 0) + count := 0 + for _, rng := range f.regions { + if bytes.Compare(rng.Region.StartKey, endKey) <= 0 && bytes.Compare(rng.Region.EndKey, startKey) > 0 { + result = append(result, rng) + count++ + } + if count >= limit { + break + } + } + return result, nil +} +func (*fakeSplitClient) GetPlacementRule(ctx context.Context, groupID, ruleID string) (pdtypes.Rule, error) { + return pdtypes.Rule{}, nil +} +func (*fakeSplitClient) SetPlacementRule(ctx context.Context, rule pdtypes.Rule) error { return nil } +func (*fakeSplitClient) DeletePlacementRule(ctx context.Context, groupID, ruleID string) error { + return nil +} +func (*fakeSplitClient) SetStoresLabel(ctx context.Context, stores []uint64, labelKey, labelValue string) error { + return nil +} + +func TestGetRewriteTableID(t *testing.T) { + var tableID int64 = 76 + var oldTableID int64 = 80 + { + rewriteRules := &restore.RewriteRules{ + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.EncodeTablePrefix(oldTableID), + NewKeyPrefix: tablecodec.EncodeTablePrefix(tableID), + }, + }, + } + + newTableID := restore.GetRewriteTableID(oldTableID, rewriteRules) + require.Equal(t, tableID, newTableID) + } + + { + rewriteRules := &restore.RewriteRules{ + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.GenTableRecordPrefix(oldTableID), + NewKeyPrefix: tablecodec.GenTableRecordPrefix(tableID), + }, + }, + } + + newTableID := restore.GetRewriteTableID(oldTableID, rewriteRules) + require.Equal(t, tableID, newTableID) + } +} + +type mockLogIter struct { + next int +} + +func (m *mockLogIter) TryNext(ctx context.Context) iter.IterResult[*backuppb.DataFileInfo] { + if m.next > 10000 { + return iter.Done[*backuppb.DataFileInfo]() + } + m.next += 1 + return iter.Emit(&backuppb.DataFileInfo{ + StartKey: []byte(fmt.Sprintf("a%d", m.next)), + EndKey: []byte("b"), + Length: 1024, // 1 KB + }) +} + +func TestLogFilesIterWithSplitHelper(t *testing.T) { + var tableID int64 = 76 + var oldTableID int64 = 80 + rewriteRules := &restore.RewriteRules{ + Data: []*import_sstpb.RewriteRule{ + { + OldKeyPrefix: tablecodec.EncodeTablePrefix(oldTableID), + NewKeyPrefix: tablecodec.EncodeTablePrefix(tableID), + }, + }, + } + rewriteRulesMap := map[int64]*restore.RewriteRules{ + oldTableID: rewriteRules, + } + mockIter := &mockLogIter{} + ctx := context.Background() + logIter := restore.NewLogFilesIterWithSplitHelper(mockIter, rewriteRulesMap, NewFakeSplitClient(), 144*1024*1024, 1440000) + next := 0 + for r := logIter.TryNext(ctx); !r.Finished; r = logIter.TryNext(ctx) { + require.NoError(t, r.Err) + next += 1 + require.Equal(t, []byte(fmt.Sprintf("a%d", next)), r.Item.StartKey) + } +} + +func regionInfo(startKey, endKey string) *split.RegionInfo { + return &split.RegionInfo{ + Region: &metapb.Region{ + StartKey: []byte(startKey), + EndKey: []byte(endKey), + }, + } +} + +func TestSplitCheckPartRegionConsistency(t *testing.T) { + var ( + startKey []byte = []byte("a") + endKey []byte = []byte("f") + err error + ) + err = split.CheckPartRegionConsistency(startKey, endKey, nil) + require.Error(t, err) + err = split.CheckPartRegionConsistency(startKey, endKey, []*split.RegionInfo{ + regionInfo("b", "c"), + }) + require.Error(t, err) + err = split.CheckPartRegionConsistency(startKey, endKey, []*split.RegionInfo{ + regionInfo("a", "c"), + regionInfo("d", "e"), + }) + require.Error(t, err) + err = split.CheckPartRegionConsistency(startKey, endKey, []*split.RegionInfo{ + regionInfo("a", "c"), + regionInfo("c", "d"), + }) + require.NoError(t, err) + err = split.CheckPartRegionConsistency(startKey, endKey, []*split.RegionInfo{ + regionInfo("a", "c"), + regionInfo("c", "d"), + regionInfo("d", "f"), + }) + require.NoError(t, err) + err = split.CheckPartRegionConsistency(startKey, endKey, []*split.RegionInfo{ + regionInfo("a", "c"), + regionInfo("c", "z"), + }) + require.NoError(t, err) +} diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index ced88b27e3f3d..2ffa7bc7dd9af 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -1243,9 +1243,16 @@ func restoreStream( updateRewriteRules(rewriteRules, schemasReplace) logFilesIter, err := client.LoadDMLFiles(ctx) + if err != nil { + return errors.Trace(err) + } + logFilesIterWithSplit, err := client.WrapLogFilesIterWithSplitHelper(logFilesIter, rewriteRules, g, mgr.GetStorage()) + if err != nil { + return errors.Trace(err) + } pd := g.StartProgress(ctx, "Restore KV Files", int64(dataFileCount), !cfg.LogProgress) err = withProgress(pd, func(p glue.Progress) error { - return client.RestoreKVFiles(ctx, rewriteRules, logFilesIter, cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy) + return client.RestoreKVFiles(ctx, rewriteRules, logFilesIterWithSplit, cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy) }) if err != nil { return errors.Annotate(err, "failed to restore kv files") diff --git a/br/pkg/utils/BUILD.bazel b/br/pkg/utils/BUILD.bazel index c3bcc629183d5..1cad8d5628dee 100644 --- a/br/pkg/utils/BUILD.bazel +++ b/br/pkg/utils/BUILD.bazel @@ -38,6 +38,7 @@ go_library( "//util", "//util/sqlexec", "@com_github_cheggaaa_pb_v3//:pb", + "@com_github_docker_go_units//:go-units", "@com_github_google_uuid//:uuid", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", diff --git a/br/pkg/utils/db.go b/br/pkg/utils/db.go index 9574c06670573..060df603d16cb 100644 --- a/br/pkg/utils/db.go +++ b/br/pkg/utils/db.go @@ -5,11 +5,14 @@ package utils import ( "context" "database/sql" + "strconv" "strings" "sync" + "github.com/docker/go-units" "github.com/pingcap/errors" "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/sqlexec" @@ -99,6 +102,72 @@ func IsLogBackupEnabled(ctx sqlexec.RestrictedSQLExecutor) (bool, error) { return true, nil } +func GetRegionSplitInfo(ctx sqlexec.RestrictedSQLExecutor) (uint64, int64) { + return GetSplitSize(ctx), GetSplitKeys(ctx) +} + +func GetSplitSize(ctx sqlexec.RestrictedSQLExecutor) uint64 { + const defaultSplitSize = 96 * 1024 * 1024 + varStr := "show config where name = 'coprocessor.region-split-size' and type = 'tikv'" + rows, fields, err := ctx.ExecRestrictedSQL( + kv.WithInternalSourceType(context.Background(), kv.InternalTxnBR), + nil, + varStr, + ) + if err != nil { + log.Warn("failed to get split size, use default value", logutil.ShortError(err)) + return defaultSplitSize + } + if len(rows) == 0 { + // use the default value + return defaultSplitSize + } + + d := rows[0].GetDatum(3, &fields[3].Column.FieldType) + splitSizeStr, err := d.ToString() + if err != nil { + log.Warn("failed to get split size, use default value", logutil.ShortError(err)) + return defaultSplitSize + } + splitSize, err := units.FromHumanSize(splitSizeStr) + if err != nil { + log.Warn("failed to get split size, use default value", logutil.ShortError(err)) + return defaultSplitSize + } + return uint64(splitSize) +} + +func GetSplitKeys(ctx sqlexec.RestrictedSQLExecutor) int64 { + const defaultSplitKeys = 960000 + varStr := "show config where name = 'coprocessor.region-split-keys' and type = 'tikv'" + rows, fields, err := ctx.ExecRestrictedSQL( + kv.WithInternalSourceType(context.Background(), kv.InternalTxnBR), + nil, + varStr, + ) + if err != nil { + log.Warn("failed to get split keys, use default value", logutil.ShortError(err)) + return defaultSplitKeys + } + if len(rows) == 0 { + // use the default value + return defaultSplitKeys + } + + d := rows[0].GetDatum(3, &fields[3].Column.FieldType) + splitKeysStr, err := d.ToString() + if err != nil { + log.Warn("failed to get split keys, use default value", logutil.ShortError(err)) + return defaultSplitKeys + } + splitKeys, err := strconv.ParseInt(splitKeysStr, 10, 64) + if err != nil { + log.Warn("failed to get split keys, use default value", logutil.ShortError(err)) + return defaultSplitKeys + } + return splitKeys +} + func GetGcRatio(ctx sqlexec.RestrictedSQLExecutor) (string, error) { valStr := "show config where name = 'gc.ratio-threshold' and type = 'tikv'" rows, fields, errSQL := ctx.ExecRestrictedSQL( diff --git a/br/pkg/utils/db_test.go b/br/pkg/utils/db_test.go index 1334d868641f0..1004764b0d206 100644 --- a/br/pkg/utils/db_test.go +++ b/br/pkg/utils/db_test.go @@ -168,3 +168,44 @@ func TestGc(t *testing.T) { require.Nil(t, err) require.Equal(t, ratio, "-1.0") } + +func TestRegionSplitInfo(t *testing.T) { + // config format: + // MySQL [(none)]> show config where name = 'coprocessor.region-split-size'; + // +------+-------------------+-------------------------------+-------+ + // | Type | Instance | Name | Value | + // +------+-------------------+-------------------------------+-------+ + // | tikv | 127.0.0.1:20161 | coprocessor.region-split-size | 10MB | + // +------+-------------------+-------------------------------+-------+ + // MySQL [(none)]> show config where name = 'coprocessor.region-split-keys'; + // +------+-------------------+-------------------------------+--------+ + // | Type | Instance | Name | Value | + // +------+-------------------+-------------------------------+--------+ + // | tikv | 127.0.0.1:20161 | coprocessor.region-split-keys | 100000 | + // +------+-------------------+-------------------------------+--------+ + + fields := make([]*ast.ResultField, 4) + tps := []*types.FieldType{ + types.NewFieldType(mysql.TypeString), + types.NewFieldType(mysql.TypeString), + types.NewFieldType(mysql.TypeString), + types.NewFieldType(mysql.TypeString), + } + for i := 0; i < len(tps); i++ { + rf := new(ast.ResultField) + rf.Column = new(model.ColumnInfo) + rf.Column.FieldType = *tps[i] + fields[i] = rf + } + rows := make([]chunk.Row, 0, 1) + row := chunk.MutRowFromValues("tikv", "127.0.0.1:20161", "coprocessor.region-split-size", "10MB").ToRow() + rows = append(rows, row) + s := &mockRestrictedSQLExecutor{rows: rows, fields: fields} + require.Equal(t, utils.GetSplitSize(s), uint64(10000000)) + + rows = make([]chunk.Row, 0, 1) + row = chunk.MutRowFromValues("tikv", "127.0.0.1:20161", "coprocessor.region-split-keys", "100000").ToRow() + rows = append(rows, row) + s = &mockRestrictedSQLExecutor{rows: rows, fields: fields} + require.Equal(t, utils.GetSplitKeys(s), int64(100000)) +}