diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index 16b8ce3dbac58..8d1003e46de4a 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -130,13 +130,30 @@ var ( ) // pdHTTPRequest defines the interface to send a request to pd and return the result in bytes. +<<<<<<< HEAD type pdHTTPRequest func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) +======= +type pdHTTPRequest func(ctx context.Context, addr string, prefix string, + cli *http.Client, method string, body []byte) ([]byte, error) +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) // pdRequest is a func to send a HTTP to pd and return the result bytes. func pdRequest( ctx context.Context, addr string, prefix string, +<<<<<<< HEAD cli *http.Client, method string, body io.Reader) ([]byte, error) { +======= + cli *http.Client, method string, body []byte) ([]byte, error) { + _, respBody, err := pdRequestWithCode(ctx, addr, prefix, cli, method, body) + return respBody, err +} + +func pdRequestWithCode( + ctx context.Context, + addr string, prefix string, + cli *http.Client, method string, body []byte) (int, []byte, error) { +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) u, err := url.Parse(addr) if err != nil { return nil, errors.Trace(err) @@ -150,13 +167,40 @@ func pdRequest( if err != nil { return nil, errors.Trace(err) } +<<<<<<< HEAD +======= + reqURL := fmt.Sprintf("%s%s", u, prefix) + var ( + req *http.Request + resp *http.Response + ) + if body == nil { + body = []byte("") + } +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) count := 0 for { +<<<<<<< HEAD +======= + req, err = http.NewRequestWithContext(ctx, method, reqURL, bytes.NewBuffer(body)) + if err != nil { + return 0, nil, errors.Trace(err) + } + resp, err = cli.Do(req) //nolint:bodyclose +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) count++ if count > pdRequestRetryTime || resp.StatusCode < 500 { break } +<<<<<<< HEAD _ = resp.Body.Close() +======= + log.Warn("request failed, will retry later", + zap.String("url", reqURL), zap.Int("retry-count", count), zap.Error(err)) + if resp != nil { + _ = resp.Body.Close() + } +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) time.Sleep(pdRequestRetryInterval()) resp, err = cli.Do(req) if err != nil { @@ -388,9 +432,14 @@ func (p *PdController) doPauseSchedulers(ctx context.Context, schedulers []strin // PauseSchedulers remove pd scheduler temporarily. removedSchedulers := make([]string, 0, len(schedulers)) for _, scheduler := range schedulers { +<<<<<<< HEAD prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler) for _, addr := range p.addrs { _, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body)) +======= + for _, addr := range p.getAllPDAddrs() { + _, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, body) +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) if err == nil { removedSchedulers = append(removedSchedulers, scheduler) break @@ -471,9 +520,14 @@ func (p *PdController) resumeSchedulerWith(ctx context.Context, schedulers []str return errors.Trace(err) } for _, scheduler := range schedulers { +<<<<<<< HEAD prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler) for _, addr := range p.addrs { _, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body)) +======= + for _, addr := range p.getAllPDAddrs() { + _, err = post(ctx, addr, pdapi.SchedulerByName(scheduler), p.cli, http.MethodPost, body) +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) if err == nil { break } @@ -562,7 +616,7 @@ func (p *PdController) doUpdatePDScheduleConfig( return errors.Trace(err) } _, e := post(ctx, addr, prefix, - p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + p.cli, http.MethodPost, reqData) if e == nil { return nil } @@ -711,6 +765,255 @@ func (p *PdController) doRemoveSchedulersWith( return removedSchedulers, err } +<<<<<<< HEAD +======= +// GetMinResolvedTS get min-resolved-ts from pd +func (p *PdController) GetMinResolvedTS(ctx context.Context) (uint64, error) { + var err error + for _, addr := range p.getAllPDAddrs() { + v, e := pdRequest(ctx, addr, pdapi.MinResolvedTS, p.cli, http.MethodGet, nil) + if e != nil { + log.Warn("failed to get min resolved ts", zap.String("addr", addr), zap.Error(e)) + err = e + continue + } + log.Info("min resolved ts", zap.String("resp", string(v))) + d := struct { + IsRealTime bool `json:"is_real_time,omitempty"` + MinResolvedTS uint64 `json:"min_resolved_ts"` + }{} + err = json.Unmarshal(v, &d) + if err != nil { + return 0, errors.Trace(err) + } + if !d.IsRealTime { + message := "min resolved ts not enabled" + log.Error(message, zap.String("addr", addr)) + return 0, errors.Trace(errors.New(message)) + } + return d.MinResolvedTS, nil + } + return 0, errors.Trace(err) +} + +// RecoverBaseAllocID recover base alloc id +func (p *PdController) RecoverBaseAllocID(ctx context.Context, id uint64) error { + reqData, _ := json.Marshal(&struct { + ID string `json:"id"` + }{ + ID: fmt.Sprintf("%d", id), + }) + var err error + for _, addr := range p.getAllPDAddrs() { + _, e := pdRequest(ctx, addr, pdapi.BaseAllocID, p.cli, http.MethodPost, reqData) + if e != nil { + log.Warn("failed to recover base alloc id", zap.String("addr", addr), zap.Error(e)) + err = e + continue + } + return nil + } + return errors.Trace(err) +} + +// ResetTS reset current ts of pd +func (p *PdController) ResetTS(ctx context.Context, ts uint64) error { + // reset-ts of PD will never set ts < current pd ts + // we set force-use-larger=true to allow ts > current pd ts + 24h(on default) + reqData, _ := json.Marshal(&struct { + Tso string `json:"tso"` + ForceUseLarger bool `json:"force-use-larger"` + }{ + Tso: fmt.Sprintf("%d", ts), + ForceUseLarger: true, + }) + var err error + for _, addr := range p.getAllPDAddrs() { + code, _, e := pdRequestWithCode(ctx, addr, pdapi.ResetTS, p.cli, http.MethodPost, reqData) + if e != nil { + // for pd version <= 6.2, if the given ts < current ts of pd, pd returns StatusForbidden. + // it's not an error for br + if code == http.StatusForbidden { + log.Info("reset-ts returns with status forbidden, ignore") + return nil + } + log.Warn("failed to reset ts", zap.Uint64("ts", ts), zap.String("addr", addr), zap.Error(e)) + err = e + continue + } + return nil + } + return errors.Trace(err) +} + +// MarkRecovering mark pd into recovering +func (p *PdController) MarkRecovering(ctx context.Context) error { + return p.operateRecoveringMark(ctx, http.MethodPost) +} + +// UnmarkRecovering unmark pd recovering +func (p *PdController) UnmarkRecovering(ctx context.Context) error { + return p.operateRecoveringMark(ctx, http.MethodDelete) +} + +func (p *PdController) operateRecoveringMark(ctx context.Context, method string) error { + var err error + for _, addr := range p.getAllPDAddrs() { + _, e := pdRequest(ctx, addr, pdapi.SnapshotRecoveringMark, p.cli, method, nil) + if e != nil { + log.Warn("failed to operate recovering mark", zap.String("method", method), + zap.String("addr", addr), zap.Error(e)) + err = e + continue + } + return nil + } + return errors.Trace(err) +} + +// RegionLabel is the label of a region. This struct is partially copied from +// https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L31. +type RegionLabel struct { + Key string `json:"key"` + Value string `json:"value"` + TTL string `json:"ttl,omitempty"` + StartAt string `json:"start_at,omitempty"` +} + +// LabelRule is the rule to assign labels to a region. This struct is partially copied from +// https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L41. +type LabelRule struct { + ID string `json:"id"` + Labels []RegionLabel `json:"labels"` + RuleType string `json:"rule_type"` + Data interface{} `json:"data"` +} + +// KeyRangeRule contains the start key and end key of the LabelRule. This struct is partially copied from +// https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L62. +type KeyRangeRule struct { + StartKeyHex string `json:"start_key"` // hex format start key, for marshal/unmarshal + EndKeyHex string `json:"end_key"` // hex format end key, for marshal/unmarshal +} + +// CreateOrUpdateRegionLabelRule creates or updates a region label rule. +func (p *PdController) CreateOrUpdateRegionLabelRule(ctx context.Context, rule LabelRule) error { + reqData, err := json.Marshal(&rule) + if err != nil { + panic(err) + } + var lastErr error + addrs := p.getAllPDAddrs() + for i, addr := range addrs { + _, lastErr = pdRequest(ctx, addr, pdapi.RegionLabelRule, + p.cli, http.MethodPost, reqData) + if lastErr == nil { + return nil + } + if berrors.IsContextCanceled(lastErr) { + return errors.Trace(lastErr) + } + + if i < len(addrs) { + log.Warn("failed to create or update region label rule, will try next pd address", + zap.Error(lastErr), zap.String("pdAddr", addr)) + } + } + return errors.Trace(lastErr) +} + +// DeleteRegionLabelRule deletes a region label rule. +func (p *PdController) DeleteRegionLabelRule(ctx context.Context, ruleID string) error { + var lastErr error + addrs := p.getAllPDAddrs() + for i, addr := range addrs { + _, lastErr = pdRequest(ctx, addr, fmt.Sprintf("%s/%s", pdapi.RegionLabelRule, ruleID), + p.cli, http.MethodDelete, nil) + if lastErr == nil { + return nil + } + if berrors.IsContextCanceled(lastErr) { + return errors.Trace(lastErr) + } + + if i < len(addrs) { + log.Warn("failed to delete region label rule, will try next pd address", + zap.Error(lastErr), zap.String("pdAddr", addr)) + } + } + return errors.Trace(lastErr) +} + +// PauseSchedulersByKeyRange will pause schedulers for regions in the specific key range. +// This function will spawn a goroutine to keep pausing schedulers periodically until the context is done. +// The return done channel is used to notify the caller that the background goroutine is exited. +func (p *PdController) PauseSchedulersByKeyRange(ctx context.Context, + startKey, endKey []byte) (done <-chan struct{}, err error) { + return p.pauseSchedulerByKeyRangeWithTTL(ctx, startKey, endKey, pauseTimeout) +} + +func (p *PdController) pauseSchedulerByKeyRangeWithTTL(ctx context.Context, + startKey, endKey []byte, ttl time.Duration) (_done <-chan struct{}, err error) { + rule := LabelRule{ + ID: uuid.New().String(), + Labels: []RegionLabel{{ + Key: "schedule", + Value: "deny", + TTL: ttl.String(), + }}, + RuleType: "key-range", + // Data should be a list of KeyRangeRule when rule type is key-range. + // See https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L169. + Data: []KeyRangeRule{{ + StartKeyHex: hex.EncodeToString(startKey), + EndKeyHex: hex.EncodeToString(endKey), + }}, + } + done := make(chan struct{}) + if err := p.CreateOrUpdateRegionLabelRule(ctx, rule); err != nil { + close(done) + return nil, errors.Trace(err) + } + + go func() { + defer close(done) + ticker := time.NewTicker(ttl / 3) + defer ticker.Stop() + loop: + for { + select { + case <-ticker.C: + if err := p.CreateOrUpdateRegionLabelRule(ctx, rule); err != nil { + if berrors.IsContextCanceled(err) { + break loop + } + log.Warn("pause scheduler by key range failed, ignore it and wait next time pause", + zap.Error(err)) + } + case <-ctx.Done(): + break loop + } + } + // Use a new context to avoid the context is canceled by the caller. + recoverCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + // Set ttl to 0 to remove the rule. + rule.Labels[0].TTL = time.Duration(0).String() + if err := p.DeleteRegionLabelRule(recoverCtx, rule.ID); err != nil { + log.Warn("failed to delete region label rule, the rule will be removed after ttl expires", + zap.String("rule-id", rule.ID), zap.Duration("ttl", ttl), zap.Error(err)) + } + }() + return done, nil +} + +// CanPauseSchedulerByKeyRange returns whether the scheduler can be paused by key range. +func (p *PdController) CanPauseSchedulerByKeyRange() bool { + // We need ttl feature to ensure scheduler can recover from pause automatically. + return p.version.Compare(minVersionForRegionLabelTTL) >= 0 +} + +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) // Close close the connection to pd. func (p *PdController) Close() { p.pdClient.Close() diff --git a/br/pkg/pdutil/pd_serial_test.go b/br/pkg/pdutil/pd_serial_test.go index 05f0d34aa2ef2..7899436a20652 100644 --- a/br/pkg/pdutil/pd_serial_test.go +++ b/br/pkg/pdutil/pd_serial_test.go @@ -27,7 +27,7 @@ func TestScheduler(t *testing.T) { defer cancel() scheduler := "balance-leader-scheduler" - mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return nil, errors.New("failed") } schedulerPauseCh := make(chan struct{}) @@ -62,7 +62,7 @@ func TestScheduler(t *testing.T) { _, err = pdController.listSchedulersWith(ctx, mock) require.EqualError(t, err, "failed") - mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return []byte(`["` + scheduler + `"]`), nil } @@ -82,7 +82,7 @@ func TestScheduler(t *testing.T) { func TestGetClusterVersion(t *testing.T) { pdController := &PdController{addrs: []string{"", ""}} // two endpoints counter := 0 - mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { counter++ if counter <= 1 { return nil, errors.New("mock error") @@ -95,7 +95,7 @@ func TestGetClusterVersion(t *testing.T) { require.NoError(t, err) require.Equal(t, "test", respString) - mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return nil, errors.New("mock error") } _, err = pdController.getClusterVersionWith(ctx, mock) @@ -125,7 +125,7 @@ func TestRegionCount(t *testing.T) { require.Equal(t, 3, len(regions.Regions)) mock := func( - _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, + _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) u, e := url.Parse(query) @@ -176,6 +176,9 @@ func TestPDRequestRetry(t *testing.T) { count := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { count++ + bytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.Equal(t, "test", string(bytes)) if count <= pdRequestRetryTime-1 { w.WriteHeader(http.StatusGatewayTimeout) return @@ -184,7 +187,11 @@ func TestPDRequestRetry(t *testing.T) { })) cli := http.DefaultClient taddr := ts.URL +<<<<<<< HEAD _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodGet, nil) +======= + _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, []byte("test")) +>>>>>>> e2d3047ca9c (pdutil: fix retry reusing body reader (#48312)) require.NoError(t, reqErr) ts.Close() count = 0 @@ -213,7 +220,7 @@ func TestStoreInfo(t *testing.T) { }, } mock := func( - _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, + _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) require.Equal(t, "http://mock/pd/api/v1/store/1", query)