diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index 6c614ca32ddc7..40bde4936ba4e 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -143,13 +143,13 @@ var ( ) // pdHTTPRequest defines the interface to send a request to pd and return the result in bytes. -type pdHTTPRequest func(ctx context.Context, addr string, prefix string, cli *http.Client, method string, body io.Reader) ([]byte, error) +type pdHTTPRequest func(ctx context.Context, addr string, prefix string, cli *http.Client, method string, body []byte) ([]byte, error) // pdRequest is a func to send an HTTP to pd and return the result bytes. func pdRequest( ctx context.Context, addr string, prefix string, - cli *http.Client, method string, body io.Reader) ([]byte, error) { + cli *http.Client, method string, body []byte) ([]byte, error) { _, respBody, err := pdRequestWithCode(ctx, addr, prefix, cli, method, body) return respBody, err } @@ -157,7 +157,7 @@ func pdRequest( func pdRequestWithCode( ctx context.Context, addr string, prefix string, - cli *http.Client, method string, body io.Reader) (int, []byte, error) { + cli *http.Client, method string, body []byte) (int, []byte, error) { u, err := url.Parse(addr) if err != nil { return 0, nil, errors.Trace(err) @@ -167,9 +167,12 @@ func pdRequestWithCode( req *http.Request resp *http.Response ) + if body == nil { + body = []byte("") + } count := 0 for { - req, err = http.NewRequestWithContext(ctx, method, reqURL, body) + req, err = http.NewRequestWithContext(ctx, method, reqURL, bytes.NewBuffer(body)) if err != nil { return 0, nil, errors.Trace(err) } @@ -196,6 +199,8 @@ func pdRequestWithCode( (err != nil && !common.IsRetryableError(err)) { break } + log.Warn("request failed, will retry later", + zap.String("url", reqURL), zap.Int("retry-count", count), zap.Error(err)) if resp != nil { _ = resp.Body.Close() } @@ -451,7 +456,7 @@ func (p *PdController) doPauseSchedulers(ctx context.Context, schedulers []strin for _, scheduler := range schedulers { prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler) for _, addr := range p.getAllPDAddrs() { - _, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body)) + _, err = post(ctx, addr, prefix, p.cli, http.MethodPost, body) if err == nil { removedSchedulers = append(removedSchedulers, scheduler) break @@ -534,7 +539,7 @@ func (p *PdController) resumeSchedulerWith(ctx context.Context, schedulers []str for _, scheduler := range schedulers { prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler) for _, addr := range p.getAllPDAddrs() { - _, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body)) + _, err = post(ctx, addr, prefix, p.cli, http.MethodPost, body) if err == nil { break } @@ -623,7 +628,7 @@ func (p *PdController) doUpdatePDScheduleConfig( return errors.Trace(err) } _, e := post(ctx, addr, prefix, - p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + p.cli, http.MethodPost, reqData) if e == nil { return nil } @@ -859,7 +864,7 @@ func (p *PdController) RecoverBaseAllocID(ctx context.Context, id uint64) error }) var err error for _, addr := range p.getAllPDAddrs() { - _, e := pdRequest(ctx, addr, baseAllocIDPrefix, p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + _, e := pdRequest(ctx, addr, baseAllocIDPrefix, p.cli, http.MethodPost, reqData) if e != nil { log.Warn("failed to recover base alloc id", zap.String("addr", addr), zap.Error(e)) err = e @@ -883,7 +888,7 @@ func (p *PdController) ResetTS(ctx context.Context, ts uint64) error { }) var err error for _, addr := range p.getAllPDAddrs() { - code, _, e := pdRequestWithCode(ctx, addr, resetTSPrefix, p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + code, _, e := pdRequestWithCode(ctx, addr, resetTSPrefix, p.cli, http.MethodPost, reqData) if e != nil { // for pd version <= 6.2, if the given ts < current ts of pd, pd returns StatusForbidden. // it's not an error for br @@ -960,7 +965,7 @@ func (p *PdController) CreateOrUpdateRegionLabelRule(ctx context.Context, rule L addrs := p.getAllPDAddrs() for i, addr := range addrs { _, lastErr = pdRequest(ctx, addr, regionLabelPrefix, - p.cli, http.MethodPost, bytes.NewBuffer(reqData)) + p.cli, http.MethodPost, reqData) if lastErr == nil { return nil } diff --git a/br/pkg/pdutil/pd_serial_test.go b/br/pkg/pdutil/pd_serial_test.go index 39c2fae8dd014..271ca8ee2ebae 100644 --- a/br/pkg/pdutil/pd_serial_test.go +++ b/br/pkg/pdutil/pd_serial_test.go @@ -3,7 +3,6 @@ package pdutil import ( - "bytes" "context" "encoding/hex" "encoding/json" @@ -31,7 +30,7 @@ func TestScheduler(t *testing.T) { defer cancel() scheduler := "balance-leader-scheduler" - mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return nil, errors.New("failed") } schedulerPauseCh := make(chan struct{}) @@ -66,7 +65,7 @@ func TestScheduler(t *testing.T) { _, err = pdController.listSchedulersWith(ctx, mock) require.EqualError(t, err, "failed") - mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return []byte(`["` + scheduler + `"]`), nil } @@ -86,7 +85,7 @@ func TestScheduler(t *testing.T) { func TestGetClusterVersion(t *testing.T) { pdController := &PdController{addrs: []string{"", ""}} // two endpoints counter := 0 - mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { counter++ if counter <= 1 { return nil, errors.New("mock error") @@ -99,7 +98,7 @@ func TestGetClusterVersion(t *testing.T) { require.NoError(t, err) require.Equal(t, "test", respString) - mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) { + mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) { return nil, errors.New("mock error") } _, err = pdController.getClusterVersionWith(ctx, mock) @@ -129,7 +128,7 @@ func TestRegionCount(t *testing.T) { require.Equal(t, 3, len(regions.Regions)) mock := func( - _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, + _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) u, e := url.Parse(query) @@ -180,6 +179,9 @@ func TestPDRequestRetry(t *testing.T) { count := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { count++ + bytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.Equal(t, "test", string(bytes)) if count <= pdRequestRetryTime-1 { w.WriteHeader(http.StatusGatewayTimeout) return @@ -195,8 +197,7 @@ func TestPDRequestRetry(t *testing.T) { cli.Transport.(*http.Transport).DisableKeepAlives = true taddr := ts.URL - body := bytes.NewBuffer([]byte("test")) - _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, body) + _, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, []byte("test")) require.NoError(t, reqErr) ts.Close() count = 0 @@ -268,7 +269,7 @@ func TestStoreInfo(t *testing.T) { }, } mock := func( - _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader, + _ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte, ) ([]byte, error) { query := fmt.Sprintf("%s/%s", addr, prefix) require.Equal(t, "http://mock/pd/api/v1/store/1", query)