Skip to content

Commit

Permalink
pdutil: fix retry reusing body reader (#48312) (#48319)
Browse files Browse the repository at this point in the history
close #48307
  • Loading branch information
ti-chi-bot authored Nov 7, 2023
1 parent 1bdee71 commit 5a6c8c8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
25 changes: 15 additions & 10 deletions br/pkg/pdutil/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,21 @@ var (
)

// pdHTTPRequest defines the interface to send a request to pd and return the result in bytes.
type pdHTTPRequest func(ctx context.Context, addr string, prefix string, cli *http.Client, method string, body io.Reader) ([]byte, error)
type pdHTTPRequest func(ctx context.Context, addr string, prefix string, cli *http.Client, method string, body []byte) ([]byte, error)

// pdRequest is a func to send an HTTP to pd and return the result bytes.
func pdRequest(
ctx context.Context,
addr string, prefix string,
cli *http.Client, method string, body io.Reader) ([]byte, error) {
cli *http.Client, method string, body []byte) ([]byte, error) {
_, respBody, err := pdRequestWithCode(ctx, addr, prefix, cli, method, body)
return respBody, err
}

func pdRequestWithCode(
ctx context.Context,
addr string, prefix string,
cli *http.Client, method string, body io.Reader) (int, []byte, error) {
cli *http.Client, method string, body []byte) (int, []byte, error) {
u, err := url.Parse(addr)
if err != nil {
return 0, nil, errors.Trace(err)
Expand All @@ -167,9 +167,12 @@ func pdRequestWithCode(
req *http.Request
resp *http.Response
)
if body == nil {
body = []byte("")
}
count := 0
for {
req, err = http.NewRequestWithContext(ctx, method, reqURL, body)
req, err = http.NewRequestWithContext(ctx, method, reqURL, bytes.NewBuffer(body))
if err != nil {
return 0, nil, errors.Trace(err)
}
Expand All @@ -196,6 +199,8 @@ func pdRequestWithCode(
(err != nil && !common.IsRetryableError(err)) {
break
}
log.Warn("request failed, will retry later",
zap.String("url", reqURL), zap.Int("retry-count", count), zap.Error(err))
if resp != nil {
_ = resp.Body.Close()
}
Expand Down Expand Up @@ -451,7 +456,7 @@ func (p *PdController) doPauseSchedulers(ctx context.Context, schedulers []strin
for _, scheduler := range schedulers {
prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler)
for _, addr := range p.getAllPDAddrs() {
_, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body))
_, err = post(ctx, addr, prefix, p.cli, http.MethodPost, body)
if err == nil {
removedSchedulers = append(removedSchedulers, scheduler)
break
Expand Down Expand Up @@ -534,7 +539,7 @@ func (p *PdController) resumeSchedulerWith(ctx context.Context, schedulers []str
for _, scheduler := range schedulers {
prefix := fmt.Sprintf("%s/%s", schedulerPrefix, scheduler)
for _, addr := range p.getAllPDAddrs() {
_, err = post(ctx, addr, prefix, p.cli, http.MethodPost, bytes.NewBuffer(body))
_, err = post(ctx, addr, prefix, p.cli, http.MethodPost, body)
if err == nil {
break
}
Expand Down Expand Up @@ -623,7 +628,7 @@ func (p *PdController) doUpdatePDScheduleConfig(
return errors.Trace(err)
}
_, e := post(ctx, addr, prefix,
p.cli, http.MethodPost, bytes.NewBuffer(reqData))
p.cli, http.MethodPost, reqData)
if e == nil {
return nil
}
Expand Down Expand Up @@ -859,7 +864,7 @@ func (p *PdController) RecoverBaseAllocID(ctx context.Context, id uint64) error
})
var err error
for _, addr := range p.getAllPDAddrs() {
_, e := pdRequest(ctx, addr, baseAllocIDPrefix, p.cli, http.MethodPost, bytes.NewBuffer(reqData))
_, e := pdRequest(ctx, addr, baseAllocIDPrefix, p.cli, http.MethodPost, reqData)
if e != nil {
log.Warn("failed to recover base alloc id", zap.String("addr", addr), zap.Error(e))
err = e
Expand All @@ -883,7 +888,7 @@ func (p *PdController) ResetTS(ctx context.Context, ts uint64) error {
})
var err error
for _, addr := range p.getAllPDAddrs() {
code, _, e := pdRequestWithCode(ctx, addr, resetTSPrefix, p.cli, http.MethodPost, bytes.NewBuffer(reqData))
code, _, e := pdRequestWithCode(ctx, addr, resetTSPrefix, p.cli, http.MethodPost, reqData)
if e != nil {
// for pd version <= 6.2, if the given ts < current ts of pd, pd returns StatusForbidden.
// it's not an error for br
Expand Down Expand Up @@ -960,7 +965,7 @@ func (p *PdController) CreateOrUpdateRegionLabelRule(ctx context.Context, rule L
addrs := p.getAllPDAddrs()
for i, addr := range addrs {
_, lastErr = pdRequest(ctx, addr, regionLabelPrefix,
p.cli, http.MethodPost, bytes.NewBuffer(reqData))
p.cli, http.MethodPost, reqData)
if lastErr == nil {
return nil
}
Expand Down
19 changes: 10 additions & 9 deletions br/pkg/pdutil/pd_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package pdutil

import (
"bytes"
"context"
"encoding/hex"
"encoding/json"
Expand Down Expand Up @@ -31,7 +30,7 @@ func TestScheduler(t *testing.T) {
defer cancel()

scheduler := "balance-leader-scheduler"
mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
return nil, errors.New("failed")
}
schedulerPauseCh := make(chan struct{})
Expand Down Expand Up @@ -66,7 +65,7 @@ func TestScheduler(t *testing.T) {
_, err = pdController.listSchedulersWith(ctx, mock)
require.EqualError(t, err, "failed")

mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
return []byte(`["` + scheduler + `"]`), nil
}

Expand All @@ -86,7 +85,7 @@ func TestScheduler(t *testing.T) {
func TestGetClusterVersion(t *testing.T) {
pdController := &PdController{addrs: []string{"", ""}} // two endpoints
counter := 0
mock := func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock := func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
counter++
if counter <= 1 {
return nil, errors.New("mock error")
Expand All @@ -99,7 +98,7 @@ func TestGetClusterVersion(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "test", respString)

mock = func(context.Context, string, string, *http.Client, string, io.Reader) ([]byte, error) {
mock = func(context.Context, string, string, *http.Client, string, []byte) ([]byte, error) {
return nil, errors.New("mock error")
}
_, err = pdController.getClusterVersionWith(ctx, mock)
Expand Down Expand Up @@ -129,7 +128,7 @@ func TestRegionCount(t *testing.T) {
require.Equal(t, 3, len(regions.Regions))

mock := func(
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader,
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte,
) ([]byte, error) {
query := fmt.Sprintf("%s/%s", addr, prefix)
u, e := url.Parse(query)
Expand Down Expand Up @@ -180,6 +179,9 @@ func TestPDRequestRetry(t *testing.T) {
count := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count++
bytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
require.Equal(t, "test", string(bytes))
if count <= pdRequestRetryTime-1 {
w.WriteHeader(http.StatusGatewayTimeout)
return
Expand All @@ -195,8 +197,7 @@ func TestPDRequestRetry(t *testing.T) {
cli.Transport.(*http.Transport).DisableKeepAlives = true

taddr := ts.URL
body := bytes.NewBuffer([]byte("test"))
_, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, body)
_, reqErr := pdRequest(ctx, taddr, "", cli, http.MethodPost, []byte("test"))
require.NoError(t, reqErr)
ts.Close()
count = 0
Expand Down Expand Up @@ -268,7 +269,7 @@ func TestStoreInfo(t *testing.T) {
},
}
mock := func(
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ io.Reader,
_ context.Context, addr string, prefix string, _ *http.Client, _ string, _ []byte,
) ([]byte, error) {
query := fmt.Sprintf("%s/%s", addr, prefix)
require.Equal(t, "http://mock/pd/api/v1/store/1", query)
Expand Down

0 comments on commit 5a6c8c8

Please sign in to comment.