Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pdutil: fix retry reusing body reader (#48312) #48320

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions br/pkg/pdutil/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,21 @@ var (

// pdHTTPRequest defines the interface to send a request to pd and return the result in bytes.
type pdHTTPRequest func(ctx context.Context, addr string, prefix string,
cli *http.Client, method string, body io.Reader) ([]byte, error)
cli *http.Client, method string, body []byte) ([]byte, error)

// pdRequest is a func to send an HTTP to pd and return the result bytes.
func pdRequest(
ctx context.Context,
addr string, prefix string,
cli *http.Client, method string, body io.Reader) ([]byte, error) {
cli *http.Client, method string, body []byte) ([]byte, error) {
_, respBody, err := pdRequestWithCode(ctx, addr, prefix, cli, method, body)
return respBody, err
}

func pdRequestWithCode(
ctx context.Context,
addr string, prefix string,
cli *http.Client, method string, body io.Reader) (int, []byte, error) {
cli *http.Client, method string, body []byte) (int, []byte, error) {
u, err := url.Parse(addr)
if err != nil {
return 0, nil, errors.Trace(err)
Expand All @@ -167,10 +167,13 @@ func pdRequestWithCode(
req *http.Request
resp *http.Response
)
if body == nil {
body = []byte("")
}
count := 0
// the total retry duration: 120*1 = 2min
for {
req, err = http.NewRequestWithContext(ctx, method, reqURL, body)
req, err = http.NewRequestWithContext(ctx, method, reqURL, bytes.NewBuffer(body))
if err != nil {
return 0, nil, errors.Trace(err)
}
Expand All @@ -197,6 +200,8 @@ func pdRequestWithCode(
(err != nil && !common.IsRetryableError(err)) {
break
}
log.Warn("request failed, will retry later",
zap.String("url", reqURL), zap.Int("retry-count", count), zap.Error(err))
if resp != nil {
_ = resp.Body.Close()
}
Expand Down Expand Up @@ -454,7 +459,7 @@ func (p *PdController) doPauseSchedulers(ctx context.Context,
for _, scheduler := range schedulers {
prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler)
for _, addr := range p.getAllPDAddrs() {
_, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body))
_, err = post(ctx, addr, prefix, p.cli, http.MethodPost, body)
if err == nil {
removedSchedulers = append(removedSchedulers, scheduler)
break
Expand Down Expand Up @@ -537,7 +542,7 @@ func (p *PdController) resumeSchedulerWith(ctx context.Context, schedulers []str
for _, scheduler := range schedulers {
prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler)
for _, addr := range p.getAllPDAddrs() {
_, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body))
_, err = post(ctx, addr, prefix, p.cli, http.MethodPost, body)
if err == nil {
break
}
Expand Down Expand Up @@ -626,7 +631,7 @@ func (p *PdController) doUpdatePDScheduleConfig(
return errors.Trace(err)
}
_, e := post(ctx, addr, prefix,
p.cli, http.MethodPost, bytes.NewBuffer(reqData))
p.cli, http.MethodPost, reqData)
if e == nil {
return nil
}
Expand Down Expand Up @@ -883,7 +888,7 @@ func (p *PdController) RecoverBaseAllocID(ctx context.Context, id uint64) error
})
var err error
for _, addr := range p.getAllPDAddrs() {
_, e := pdRequest(ctx, addr, baseAllocIDPrefix, p.cli, http.MethodPost, bytes.NewBuffer(reqData))
_, e := pdRequest(ctx, addr, baseAllocIDPrefix, p.cli, http.MethodPost, reqData)
if e != nil {
log.Warn("failed to recover base alloc id", zap.String("addr", addr), zap.Error(e))
err = e
Expand All @@ -907,7 +912,7 @@ func (p *PdController) ResetTS(ctx context.Context, ts uint64) error {
})
var err error
for _, addr := range p.getAllPDAddrs() {
code, _, e := pdRequestWithCode(ctx, addr, resetTSPrefix, p.cli, http.MethodPost, bytes.NewBuffer(reqData))
code, _, e := pdRequestWithCode(ctx, addr, resetTSPrefix, p.cli, http.MethodPost, reqData)
if e != nil {
// for pd version <= 6.2, if the given ts < current ts of pd, pd returns StatusForbidden.
// it's not an error for br
Expand Down Expand Up @@ -984,7 +989,7 @@ func (p *PdController) CreateOrUpdateRegionLabelRule(ctx context.Context, rule L
addrs := p.getAllPDAddrs()
for i, addr := range addrs {
_, lastErr = pdRequest(ctx, addr, regionLabelPrefix,
p.cli, http.MethodPost, bytes.NewBuffer(reqData))
p.cli, http.MethodPost, reqData)
if lastErr == nil {
return nil
}
Expand Down
19 changes: 10 additions & 9 deletions br/pkg/pdutil/pd_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package pdutil

import (
"bytes"
"context"
"encoding/hex"
"encoding/json"
Expand Down Expand Up @@ -31,7 +30,7 @@ func TestScheduler(t *testing.T) {
defer cancel()

scheduler := "balance-leader-scheduler"
mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
return nil, errors.New("failed")
}
schedulerPauseCh := make(chan struct{})
Expand Down Expand Up @@ -66,7 +65,7 @@ func TestScheduler(t *testing.T) {
_, err = pdController.listSchedulersWith(ctx, mock)
require.EqualError(t, err, "failed")

mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
return []byte(`["` + scheduler + `"]`), nil
}

Expand All @@ -86,7 +85,7 @@ func TestScheduler(t *testing.T) {
func TestGetClusterVersion(t *testing.T) {
pdController := &PdController{addrs: []string{"", ""}} // two endpoints
counter := 0
mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
counter++
if counter <= 1 {
return nil, errors.New("mock error")
Expand All @@ -99,7 +98,7 @@ func TestGetClusterVersion(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "test", respString)

mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
return nil, errors.New("mock error")
}
_, err = pdController.getClusterVersionWith(ctx, mock)
Expand Down Expand Up @@ -129,7 +128,7 @@ func TestRegionCount(t *testing.T) {
require.Equal(t, 3, len(regions.Regions))

mock := func(
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader,
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte,
) ([]byte, error) {
query := fmt.Sprintf("%s/%s", addr, prefix)
u, e := url.Parse(query)
Expand Down Expand Up @@ -180,6 +179,9 @@ func TestPDRequestRetry(t *testing.T) {
count := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
bytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.Equal(t, "test", string(bytes))
if count <= pdRequestRetryTime-1 {
w.WriteHeader(http.StatusGatewayTimeout)
return
Expand All @@ -195,8 +197,7 @@ func TestPDRequestRetry(t *testing.T) {
cli.Transport.(*http.Transport).DisableKeepAlives = true

taddr := ts.URL
body := bytes.NewBuffer([]byte("test"))
_, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, body)
_, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, []byte("test"))
require.NoError(t, reqErr)
ts.Close()
count = 0
Expand Down Expand Up @@ -268,7 +269,7 @@ func TestStoreInfo(t *testing.T) {
},
}
mock := func(
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader,
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte,
) ([]byte, error) {
query := fmt.Sprintf("%s/%s", addr, prefix)
require.Equal(t, "http://mock/pd/api/v1/store/1", query)
Expand Down