diff --git a/util/client/httpclient/client.go b/util/client/httpclient/client.go index 520aa427ba..7869bd679b 100644 --- a/util/client/httpclient/client.go +++ b/util/client/httpclient/client.go @@ -8,13 +8,13 @@ package httpclient import ( "context" - "errors" - "fmt" + "crypto/tls" + "net" + "net/http" + "runtime" + "time" - "github.com/go-resty/resty/v2" "github.com/joomcode/errorx" - "github.com/pingcap/log" - "go.uber.org/zap" "github.com/pingcap/tidb-dashboard/util/nocopy" ) @@ -22,78 +22,54 @@ import ( var ( ErrNS = errorx.NewNamespace("http_client") ErrInvalidEndpoint = ErrNS.NewType("invalid_endpoint") - ErrServerError = ErrNS.NewType("server_error") + ErrRequestFailed = ErrNS.NewType("request_failed") ) -// Client is a lightweight wrapper over resty.Client, providing default error handling and timeout settings. -// WARN: This structure is not thread-safe. +// Client caches connections for future re-use and should be reused instead of +// created as needed. type Client struct { nocopy.NoCopy - inner *resty.Client - kindTag string - ctx context.Context + kindTag string + transport *http.Transport + defaultCtx context.Context + defaultBaseURL string } -func (c *Client) SetHeader(header, value string) *Client { - c.inner.Header.Set(header, value) - return c -} - -// LifecycleR builds a new Request with the default lifecycle context and the default timeout. -// This function is intentionally not named as `R()` to avoid being confused with `resty.Client.R()`. -func (c *Client) LifecycleR() *Request { - return newRequestFromClient(c) -} - -// ======== Below are helper functions to build the Client ======== - -var defaultRedirectPolicy = resty.FlexibleRedirectPolicy(5) - -func New(config Config) *Client { - c := &Client{ - inner: resty.New(), - kindTag: config.KindTag, - ctx: config.Context, +func newTransport(tlsConfig *tls.Config) *http.Transport { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, + TLSClientConfig: tlsConfig, } - c.inner.SetRedirectPolicy(defaultRedirectPolicy) - c.inner.OnAfterResponse(c.handleAfterResponseHook) - c.inner.OnError(c.handleErrorHook) - c.inner.SetHostURL(config.BaseURL) - c.inner.SetTLSClientConfig(config.TLS) - return c } -func (c *Client) handleAfterResponseHook(_ *resty.Client, r *resty.Response) error { - // Note: IsError != !IsSuccess - if !r.IsSuccess() { - // Turn all non success responses to an error. - return ErrServerError.New("%s %s (%s): Response status %d", - r.Request.Method, - r.Request.URL, - c.kindTag, - r.StatusCode()) +func New(config Config) *Client { + return &Client{ + kindTag: config.KindTag, + transport: newTransport(config.TLSConfig), + defaultCtx: config.DefaultCtx, + defaultBaseURL: config.DefaultBaseURL, } - return nil } -func (c *Client) handleErrorHook(req *resty.Request, err error) { - // Log all kind of errors - fields := []zap.Field{ - zap.String("kindTag", c.kindTag), - zap.String("url", req.URL), - } - var respErr *resty.ResponseError - if errors.As(err, &respErr) && respErr.Response != nil && respErr.Response.RawResponse != nil { - fields = append(fields, - zap.String("responseStatus", respErr.Response.Status()), - zap.String("responseBody", respErr.Response.String()), - ) - err = respErr.Unwrap() +func (c *Client) LR() *LazyRequest { + lReq := newRequest(c.kindTag, c.transport) + if c.defaultCtx != nil { + lReq.SetContext(c.defaultCtx) } - fields = append(fields, zap.Error(err)) - if _, hasVerboseError := err.(fmt.Formatter); !hasVerboseError { //nolint:errorlint - fields = append(fields, zap.Stack("stack")) + if c.defaultBaseURL != "" { + lReq.SetBaseURL(c.defaultBaseURL) } - log.Warn("Request failed", fields...) + return lReq } diff --git a/util/client/httpclient/client_test.go b/util/client/httpclient/client_test.go deleted file mode 100644 index da07029966..0000000000 --- a/util/client/httpclient/client_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. - -package httpclient - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSetHeader(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprintln(w, r.Header.Get("X-Test")) - })) - defer ts.Close() - - client := New(Config{}) - client.SetHeader("X-Test", "foobar") - cancel, resp, err := client.LifecycleR().Get(ts.URL) - defer cancel() - require.Nil(t, err) - require.NotNil(t, resp) - require.Equal(t, http.StatusOK, resp.StatusCode()) - require.Equal(t, "foobar", resp.String()) -} - -func TestSetBaseURL(t *testing.T) { - ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprintln(w, "ts1"+r.URL.Path) - })) - defer ts1.Close() - - ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprintln(w, "ts2"+r.URL.Path) - })) - defer ts2.Close() - - client := New(Config{ - BaseURL: ts1.URL, - }) - cancel, resp, err := client.LifecycleR().Get("/foo") - defer cancel() - require.Nil(t, err) - require.NotNil(t, resp) - require.Equal(t, http.StatusOK, resp.StatusCode()) - require.Equal(t, "ts1/foo", resp.String()) - - cancel, resp, err = client.LifecycleR().Get(ts2.URL) // BaseURL can be overwritten - defer cancel() - require.Nil(t, err) - require.NotNil(t, resp) - require.Equal(t, http.StatusOK, resp.StatusCode()) - require.Equal(t, "ts2/", resp.String()) -} diff --git a/util/client/httpclient/config.go b/util/client/httpclient/config.go index 288a7a2df0..77911e59dd 100644 --- a/util/client/httpclient/config.go +++ b/util/client/httpclient/config.go @@ -9,10 +9,10 @@ import ( ) type Config struct { - BaseURL string - Context context.Context - TLS *tls.Config - KindTag string // Used to mark what kind of HttpClient it is in error messages and logs. + KindTag string + TLSConfig *tls.Config + DefaultCtx context.Context + DefaultBaseURL string } type APIClientConfig struct { @@ -38,9 +38,9 @@ func (dc APIClientConfig) IntoConfig(kindTag string) (Config, error) { schema = "http" } return Config{ - BaseURL: schema + "://" + u.Host, - Context: dc.Context, - TLS: dc.TLS, - KindTag: kindTag, + TLSConfig: dc.TLS, + KindTag: kindTag, + DefaultCtx: dc.Context, + DefaultBaseURL: schema + "://" + u.Host, }, nil } diff --git a/util/client/httpclient/info.go b/util/client/httpclient/info.go new file mode 100644 index 0000000000..e66d8befe2 --- /dev/null +++ b/util/client/httpclient/info.go @@ -0,0 +1,34 @@ +// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. + +package httpclient + +import ( + "github.com/pingcap/log" + "go.uber.org/zap" +) + +// execInfo is a copy of necessary information during the execution. +// It can be used to print logs when something happens. +type execInfo struct { + kindTag string + reqURL string + reqMethod string + respStatus string + respBody string +} + +func (e *execInfo) Warn(msg string, err error) { + fields := []zap.Field{ + zap.String("kindTag", e.kindTag), + zap.String("url", e.reqURL), + zap.String("method", e.reqMethod), + } + if e.respStatus != "" { + fields = append(fields, zap.String("responseStatus", e.respStatus)) + } + if e.respBody != "" { + fields = append(fields, zap.String("responseBody", e.respBody)) + } + fields = append(fields, zap.Error(err)) + log.Warn(msg, fields...) +} diff --git a/util/client/httpclient/request.go b/util/client/httpclient/request.go index 45a890e392..71822889b8 100644 --- a/util/client/httpclient/request.go +++ b/util/client/httpclient/request.go @@ -8,6 +8,7 @@ package httpclient import ( "context" + "net/http" "time" "github.com/go-resty/resty/v2" @@ -15,79 +16,122 @@ import ( "github.com/pingcap/tidb-dashboard/util/nocopy" ) -const ( - defaultTimeout = time.Minute * 2 // Just a default long enough timeout. -) - -// Request is a lightweight wrapper over resty.Request. -// Different to resty.Request, it enforces a timeout. -// WARN: This structure is not thread-safe. -type Request struct { +// LazyRequest can be used to compose and fire individual request from the client. +// The request will not be actually sent until reading from LazyResponse. +type LazyRequest struct { + // Note: this is a lazy struct. nocopy.NoCopy - inner *resty.Request + kindTag string + transport *http.Transport + opsR []requestUpdateFn + opsC []clientUpdateFn +} - ctx context.Context - timeout time.Duration +func newRequest(kindTag string, transport *http.Transport) *LazyRequest { + return &LazyRequest{ + kindTag: kindTag, + transport: transport, + } } -func newRequestFromClient(c *Client) *Request { - return &Request{ - inner: c.inner.R(), - ctx: c.ctx, - timeout: defaultTimeout, +func (lReq *LazyRequest) Clone() *LazyRequest { + lReqCloned := &LazyRequest{ + kindTag: lReq.kindTag, + transport: lReq.transport, // transport will never change after creation, so this is concurrent-safe + opsR: make([]requestUpdateFn, len(lReq.opsR)), + opsC: make([]clientUpdateFn, len(lReq.opsC)), } + copy(lReqCloned.opsR, lReq.opsR) + copy(lReqCloned.opsC, lReq.opsC) + return lReqCloned } -func (r *Request) SetContext(ctx context.Context) *Request { - if ctx != nil { - r.ctx = ctx +// SetContext method sets the context.Context for current Request. It allows +// to interrupt the request execution if ctx.Done() channel is closed. +// See https://blog.golang.org/context article and the "context" package +// documentation. +func (lReq *LazyRequest) SetContext(ctx context.Context) *LazyRequest { + if ctx == nil { + return lReq } - return r + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetContext(ctx) + }) + return lReq } -func (r *Request) SetTimeout(timeout time.Duration) *Request { - r.timeout = timeout - return r +// SetTimeout sets the total timeout for sending the request and reading the response. +func (lReq *LazyRequest) SetTimeout(timeout time.Duration) *LazyRequest { + lReq.opsC = append(lReq.opsC, func(c *resty.Client) { + c.SetTimeout(timeout) + }) + return lReq } -// SetJSONResult expects a JSON response from the remote endpoint and specify how response is deserialized. -func (r *Request) SetJSONResult(res interface{}) *Request { - // If we don't force a content type, when this content type is missing in the response, - // the `Response.Result()` will silently produce an empty and valid structure without any errors. - r.inner.ForceContentType("application/json") - r.inner.SetResult(res) - return r +func (lReq *LazyRequest) SetURL(url string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.URL = url + }) + return lReq } -// WARN: The returned cancelFunc must be called to avoid context leak. -func (r *Request) Get(url string) (context.CancelFunc, *resty.Response, error) { - return r.Execute(resty.MethodGet, url) +func (lReq *LazyRequest) SetMethod(method string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.Method = method + }) + return lReq } -// WARN: The returned cancelFunc must be called to avoid context leak. -func (r *Request) Post(url string) (context.CancelFunc, *resty.Response, error) { - return r.Execute(resty.MethodPost, url) +// SetBaseURL method is to set Host URL in the client instance. It will be used with request +// raised from this client with relative URL +// // Setting HTTP address +// client.SetBaseURL("http://myjeeva.com") +// +// // Setting HTTPS address +// client.SetBaseURL("https://myjeeva.com") +func (lReq *LazyRequest) SetBaseURL(baseURL string) *LazyRequest { + lReq.opsC = append(lReq.opsC, func(c *resty.Client) { + c.SetHostURL(baseURL) + }) + return lReq } -// WARN: The returned cancelFunc must be called to avoid context leak. -func (r *Request) Put(url string) (context.CancelFunc, *resty.Response, error) { - return r.Execute(resty.MethodPut, url) +// Send method lazily send the HTTP request using the method and URL already defined +// for current LazyRequest. +// resp := client.LR(). +// SetMethod("GET"). +// SetURL("http://httpbin.org/get"). +// Send() +func (lReq *LazyRequest) Send() *LazyResponse { + return newResponse(lReq.Clone()) } -// WARN: The returned cancelFunc must be called to avoid context leak. -func (r *Request) Delete(url string) (context.CancelFunc, *resty.Response, error) { - return r.Execute(resty.MethodDelete, url) +// Execute method lazily send the HTTP request with given HTTP method and URL +// for current LazyRequest. +// resp := client.LR(). +// Execute("GET", "http://httpbin.org/get") +func (lReq *LazyRequest) Execute(method, url string) *LazyResponse { + cloned := lReq.Clone() + cloned.opsR = append(cloned.opsR, func(r *resty.Request) { + r.Method = method + r.URL = url + }) + return newResponse(cloned) } -// WARN: The returned cancelFunc must be called to avoid context leak. -func (r *Request) Execute(method, url string) (context.CancelFunc, *resty.Response, error) { - baseCtx := r.ctx - if baseCtx == nil { - baseCtx = context.Background() - } - ctx, cancelFn := context.WithTimeout(baseCtx, r.timeout) - r.inner.SetContext(ctx) - resp, err := r.inner.Execute(method, url) - return cancelFn, resp, err +func (lReq *LazyRequest) Get(url string) *LazyResponse { + return lReq.Execute(resty.MethodGet, url) +} + +func (lReq *LazyRequest) Post(url string) *LazyResponse { + return lReq.Execute(resty.MethodPost, url) +} + +func (lReq *LazyRequest) Put(url string) *LazyResponse { + return lReq.Execute(resty.MethodPut, url) +} + +func (lReq *LazyRequest) Delete(url string) *LazyResponse { + return lReq.Execute(resty.MethodDelete, url) } diff --git a/util/client/httpclient/request_resty.go b/util/client/httpclient/request_resty.go index dc43f05539..0f136981b3 100644 --- a/util/client/httpclient/request_resty.go +++ b/util/client/httpclient/request_resty.go @@ -16,172 +16,450 @@ import ( "github.com/go-resty/resty/v2" ) -func (r *Request) SetHeader(header, value string) *Request { - r.inner.SetHeader(header, value) - return r -} - -func (r *Request) SetHeaders(headers map[string]string) *Request { - r.inner.SetHeaders(headers) - return r -} - -func (r *Request) SetHeaderVerbatim(header, value string) *Request { - r.inner.SetHeaderVerbatim(header, value) - return r -} - -func (r *Request) SetQueryParam(param, value string) *Request { - r.inner.SetQueryParam(param, value) - return r -} - -func (r *Request) SetQueryParams(params map[string]string) *Request { - r.inner.SetQueryParams(params) - return r -} - -func (r *Request) SetQueryParamsFromValues(params url.Values) *Request { - r.inner.SetQueryParamsFromValues(params) - return r -} - -func (r *Request) SetQueryString(query string) *Request { - r.inner.SetQueryString(query) - return r -} - -func (r *Request) SetFormData(data map[string]string) *Request { - r.inner.SetFormData(data) - return r -} - -func (r *Request) SetFormDataFromValues(data url.Values) *Request { - r.inner.SetFormDataFromValues(data) - return r -} - -func (r *Request) SetBody(body interface{}) *Request { - r.inner.SetBody(body) - return r -} - -// Note: This function is not safe to use and is deprecated. Use `Request.SetJSONResult()`. -// func (r *Request) SetResult(res interface{}) *Request { -// r.inner.SetResult(res) -// return r -// } - -func (r *Request) SetError(err interface{}) *Request { - r.inner.SetError(err) - return r -} - -func (r *Request) SetFile(param, filePath string) *Request { - r.inner.SetFile(param, filePath) - return r -} - -func (r *Request) SetFiles(files map[string]string) *Request { - r.inner.SetFiles(files) - return r -} - -func (r *Request) SetFileReader(param, fileName string, reader io.Reader) *Request { - r.inner.SetFileReader(param, fileName, reader) - return r -} - -func (r *Request) SetMultipartFormData(data map[string]string) *Request { - r.inner.SetMultipartFormData(data) - return r -} - -func (r *Request) SetMultipartField(param, fileName, contentType string, reader io.Reader) *Request { - r.inner.SetMultipartField(param, fileName, contentType, reader) - return r -} - -func (r *Request) SetMultipartFields(fields ...*resty.MultipartField) *Request { - r.inner.SetMultipartFields(fields...) - return r -} - -func (r *Request) SetContentLength(l bool) *Request { - r.inner.SetContentLength(l) - return r -} - -func (r *Request) SetBasicAuth(username, password string) *Request { - r.inner.SetBasicAuth(username, password) - return r -} - -func (r *Request) SetAuthToken(token string) *Request { - r.inner.SetAuthToken(token) - return r -} - -func (r *Request) SetAuthScheme(scheme string) *Request { - r.inner.SetAuthScheme(scheme) - return r -} - -func (r *Request) SetOutput(file string) *Request { - r.inner.SetOutput(file) - return r -} - -func (r *Request) SetSRV(srv *resty.SRVRecord) *Request { - r.inner.SetSRV(srv) - return r -} - -func (r *Request) SetDoNotParseResponse(parse bool) *Request { - r.inner.SetDoNotParseResponse(parse) - return r -} - -func (r *Request) SetPathParam(param, value string) *Request { - r.inner.SetPathParam(param, value) - return r -} - -func (r *Request) SetPathParams(params map[string]string) *Request { - r.inner.SetPathParams(params) - return r -} - -func (r *Request) ExpectContentType(contentType string) *Request { - r.inner.ExpectContentType(contentType) - return r -} - -func (r *Request) ForceContentType(contentType string) *Request { - r.inner.ForceContentType(contentType) - return r -} - -func (r *Request) SetJSONEscapeHTML(b bool) *Request { - r.inner.SetJSONEscapeHTML(b) - return r -} - -func (r *Request) SetCookie(hc *http.Cookie) *Request { - r.inner.SetCookie(hc) - return r -} - -func (r *Request) SetCookies(rs []*http.Cookie) *Request { - r.inner.SetCookies(rs) - return r -} - -func (r *Request) EnableTrace() *Request { - r.inner.EnableTrace() - return r -} - -func (r *Request) TraceInfo() resty.TraceInfo { - return r.inner.TraceInfo() +// SetHeader method is to set a single header field and its value in the current request. +// +// For Example: To set `Content-Type` and `Accept` as `application/json`. +// client.LR(). +// SetHeader("Content-Type", "application/json"). +// SetHeader("Accept", "application/json") +// +// Also you can override header value, which was set at client instance level. +func (lReq *LazyRequest) SetHeader(header, value string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetHeader(header, value) + }) + return lReq +} + +// SetHeaders method sets multiple headers field and its values at one go in the current request. +// +// For Example: To set `Content-Type` and `Accept` as `application/json` +// +// client.LR(). +// SetHeaders(map[string]string{ +// "Content-Type": "application/json", +// "Accept": "application/json", +// }) +// Also you can override header value, which was set at client instance level. +func (lReq *LazyRequest) SetHeaders(headers map[string]string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetHeaders(headers) + }) + return lReq +} + +// SetHeaderVerbatim method is to set a single header field and its value verbatim in the current request. +// +// For Example: To set `all_lowercase` and `UPPERCASE` as `available`. +// client.LR(). +// SetHeaderVerbatim("all_lowercase", "available"). +// SetHeaderVerbatim("UPPERCASE", "available") +// +// Also you can override header value, which was set at client instance level. +func (lReq *LazyRequest) SetHeaderVerbatim(header, value string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetHeaderVerbatim(header, value) + }) + return lReq +} + +// SetQueryParam method sets single parameter and its value in the current request. +// It will be formed as query string for the request. +// +// For Example: `search=kitchen%20papers&size=large` in the URL after `?` mark. +// client.LR(). +// SetQueryParam("search", "kitchen papers"). +// SetQueryParam("size", "large") +// Also you can override query params value, which was set at client instance level. +func (lReq *LazyRequest) SetQueryParam(param, value string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetQueryParam(param, value) + }) + return lReq +} + +// SetQueryParams method sets multiple parameters and its values at one go in the current request. +// It will be formed as query string for the request. +// +// For Example: `search=kitchen%20papers&size=large` in the URL after `?` mark. +// client.LR(). +// SetQueryParams(map[string]string{ +// "search": "kitchen papers", +// "size": "large", +// }) +// Also you can override query params value, which was set at client instance level. +func (lReq *LazyRequest) SetQueryParams(params map[string]string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetQueryParams(params) + }) + return lReq +} + +// SetQueryParamsFromValues method appends multiple parameters with multi-value +// (`url.Values`) at one go in the current request. It will be formed as +// query string for the request. +// +// For Example: `status=pending&status=approved&status=open` in the URL after `?` mark. +// client.LR(). +// SetQueryParamsFromValues(url.Values{ +// "status": []string{"pending", "approved", "open"}, +// }) +// Also you can override query params value, which was set at client instance level. +func (lReq *LazyRequest) SetQueryParamsFromValues(params url.Values) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetQueryParamsFromValues(params) + }) + return lReq +} + +// SetQueryString method provides ability to use string as an input to set URL query string for the request. +// +// Using String as an input +// client.LR(). +// SetQueryString("productId=232&template=fresh-sample&cat=resty&source=google&kw=buy a lot more") +func (lReq *LazyRequest) SetQueryString(query string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetQueryString(query) + }) + return lReq +} + +// SetFormData method sets Form parameters and their values in the current request. +// It's applicable only HTTP method `POST` and `PUT` and requests content type would be set as +// `application/x-www-form-urlencoded`. +// client.LR(). +// SetFormData(map[string]string{ +// "access_token": "BC594900-518B-4F7E-AC75-BD37F019E08F", +// "user_id": "3455454545", +// }) +// Also you can override form data value, which was set at client instance level. +func (lReq *LazyRequest) SetFormData(data map[string]string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetFormData(data) + }) + return lReq +} + +// SetFormDataFromValues method appends multiple form parameters with multi-value +// (`url.Values`) at one go in the current request. +// client.LR(). +// SetFormDataFromValues(url.Values{ +// "search_criteria": []string{"book", "glass", "pencil"}, +// }) +// Also you can override form data value, which was set at client instance level. +func (lReq *LazyRequest) SetFormDataFromValues(data url.Values) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetFormDataFromValues(data) + }) + return lReq +} + +// SetBody method sets the request body for the request. It supports various realtime needs as easy. +// We can say its quite handy or powerful. Supported request body data types is `string`, +// `[]byte`, `struct`, `map`, `slice` and `io.Reader`. Body value can be pointer or non-pointer. +// Automatic marshalling for JSON and XML content type, if it is `struct`, `map`, or `slice`. +// +// Note: `io.Reader` is processed as bufferless mode while sending request. +// +// For Example: Struct as a body input, based on content type, it will be marshalled. +// client.LR(). +// SetBody(User{ +// Username: "jeeva@myjeeva.com", +// Password: "welcome2resty", +// }) +// +// Map as a body input, based on content type, it will be marshalled. +// client.LR(). +// SetBody(map[string]interface{}{ +// "username": "jeeva@myjeeva.com", +// "password": "welcome2resty", +// "address": &Address{ +// Address1: "1111 This is my street", +// Address2: "Apt 201", +// City: "My City", +// State: "My State", +// ZipCode: 00000, +// }, +// }) +// +// String as a body input. Suitable for any need as a string input. +// client.LR(). +// SetBody(`{ +// "username": "jeeva@getrightcare.com", +// "password": "admin" +// }`) +// +// []byte as a body input. Suitable for raw request such as file upload, serialize & deserialize, etc. +// client.LR(). +// SetBody([]byte("This is my raw request, sent as-is")) +func (lReq *LazyRequest) SetBody(body interface{}) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetBody(body) + }) + return lReq +} + +// Deprecated: This usage is intentionally not supported. +func (lReq *LazyRequest) SetResult() { + panic("do not use this in LazyRequest") +} + +// Deprecated: This usage is intentionally not supported. +func (lReq *LazyRequest) SetError() { + panic("do not use this in LazyRequest") +} + +// SetFile method is to set single file field name and its path for multipart upload. +// client.LR(). +// SetFile("my_file", "/Users/jeeva/Gas Bill - Sep.pdf") +func (lReq *LazyRequest) SetFile(param, filePath string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetFile(param, filePath) + }) + return lReq +} + +// SetFiles method is to set multiple file field name and its path for multipart upload. +// client.LR(). +// SetFiles(map[string]string{ +// "my_file1": "/Users/jeeva/Gas Bill - Sep.pdf", +// "my_file2": "/Users/jeeva/Electricity Bill - Sep.pdf", +// "my_file3": "/Users/jeeva/Water Bill - Sep.pdf", +// }) +func (lReq *LazyRequest) SetFiles(files map[string]string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetFiles(files) + }) + return lReq +} + +// SetFileReader method is to set single file using io.Reader for multipart upload. +// client.LR(). +// SetFileReader("profile_img", "my-profile-img.png", bytes.NewReader(profileImgBytes)). +// SetFileReader("notes", "user-notes.txt", bytes.NewReader(notesBytes)) +func (lReq *LazyRequest) SetFileReader(param, fileName string, reader io.Reader) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetFileReader(param, fileName, reader) + }) + return lReq +} + +// SetMultipartFormData method allows simple form data to be attached to the request as `multipart:form-data`. +func (lReq *LazyRequest) SetMultipartFormData(data map[string]string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetMultipartFormData(data) + }) + return lReq +} + +// SetMultipartField method is to set custom data using io.Reader for multipart upload. +func (lReq *LazyRequest) SetMultipartField(param, fileName, contentType string, reader io.Reader) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetMultipartField(param, fileName, contentType, reader) + }) + return lReq +} + +// SetMultipartFields method is to set multiple data fields using io.Reader for multipart upload. +// +// For Example: +// client.LR().SetMultipartFields( +// &resty.MultipartField{ +// Param: "uploadManifest1", +// FileName: "upload-file-1.json", +// ContentType: "application/json", +// Reader: strings.NewReader(`{"input": {"name": "Uploaded document 1", "_filename" : ["file1.txt"]}}`), +// }, +// &resty.MultipartField{ +// Param: "uploadManifest2", +// FileName: "upload-file-2.json", +// ContentType: "application/json", +// Reader: strings.NewReader(`{"input": {"name": "Uploaded document 2", "_filename" : ["file2.txt"]}}`), +// }) +// +// If you have slice already, then simply call- +// client.LR().SetMultipartFields(fields...) +func (lReq *LazyRequest) SetMultipartFields(fields ...*resty.MultipartField) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetMultipartFields(fields...) + }) + return lReq +} + +// SetContentLength method sets the HTTP header `Content-Length` value for current request. +// By default Resty won't set `Content-Length`. Also you have an option to enable for every +// request. +func (lReq *LazyRequest) SetContentLength(l bool) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetContentLength(l) + }) + return lReq +} + +// SetBasicAuth method sets the basic authentication header in the current HTTP request. +// +// For Example: +// Authorization: Basic +// +// To set the header for username "go-resty" and password "welcome" +// client.LR().SetBasicAuth("go-resty", "welcome") +// +// This method overrides the credentials set by method `Client.SetBasicAuth`. +func (lReq *LazyRequest) SetBasicAuth(username, password string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetBasicAuth(username, password) + }) + return lReq +} + +// SetAuthToken method sets the auth token header(Default Scheme: Bearer) in the current HTTP request. Header example: +// Authorization: Bearer +// +// For Example: To set auth token BC594900518B4F7EAC75BD37F019E08FBC594900518B4F7EAC75BD37F019E08F +// +// client.LR().SetAuthToken("BC594900518B4F7EAC75BD37F019E08FBC594900518B4F7EAC75BD37F019E08F") +// +// This method overrides the Auth token set by method `Client.SetAuthToken`. +func (lReq *LazyRequest) SetAuthToken(token string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetAuthToken(token) + }) + return lReq +} + +// SetAuthScheme method sets the auth token scheme type in the HTTP request. For Example: +// Authorization: +// +// For Example: To set the scheme to use OAuth +// +// client.LR().SetAuthScheme("OAuth") +// +// This auth header scheme gets added to all the request rasied from this client instance. +// Also it can be overridden or set one at the request level is supported. +// +// Information about Auth schemes can be found in RFC7235 which is linked to below along with the page containing +// the currently defined official authentication schemes: +// https://tools.ietf.org/html/rfc7235 +// https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml#authschemes +// +// This method overrides the Authorization scheme set by method `Client.SetAuthScheme`. +func (lReq *LazyRequest) SetAuthScheme(scheme string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetAuthScheme(scheme) + }) + return lReq +} + +// Deprecated: This usage is intentionally not supported. +func (lReq *LazyRequest) SetOutput() { + panic("do not use this in LazyRequest") +} + +// SetSRV method sets the details to query the service SRV record and execute the +// request. +// client.LR(). +// SetSRV(SRVRecord{"web", "testservice.com"}). +// Get("/get") +func (lReq *LazyRequest) SetSRV(srv *resty.SRVRecord) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetSRV(srv) + }) + return lReq +} + +// Deprecated: This usage is intentionally not supported. +func (lReq *LazyRequest) SetDoNotParseResponse() { + panic("do not use this in LazyRequest") +} + +// SetPathParam method sets single URL path key-value pair in the +// Resty current request instance. +// client.LR().SetPathParam("userId", "sample@sample.com") +// +// Result: +// URL - /v1/users/{userId}/details +// Composed URL - /v1/users/sample@sample.com/details +// It replaces the value of the key while composing the request URL. Also you can +// override Path Params value, which was set at client instance level. +func (lReq *LazyRequest) SetPathParam(param, value string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetPathParam(param, value) + }) + return lReq +} + +// SetPathParams method sets multiple URL path key-value pairs at one go in the +// Resty current request instance. +// client.LR().SetPathParams(map[string]string{ +// "userId": "sample@sample.com", +// "subAccountId": "100002", +// }) +// +// Result: +// URL - /v1/users/{userId}/{subAccountId}/details +// Composed URL - /v1/users/sample@sample.com/100002/details +// It replaces the value of the key while composing request URL. Also you can +// override Path Params value, which was set at client instance level. +func (lReq *LazyRequest) SetPathParams(params map[string]string) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetPathParams(params) + }) + return lReq +} + +// Deprecated: This usage is intentionally not supported. +func (lReq *LazyRequest) ExpectContentType() { + panic("do not use this in LazyRequest") +} + +// Deprecated: This usage is intentionally not supported. +func (lReq *LazyRequest) ForceContentType() { + panic("do not use this in LazyRequest") +} + +// SetJSONEscapeHTML method is to enable/disable the HTML escape on JSON marshal. +// +// Note: This option only applicable to standard JSON Marshaller. +func (lReq *LazyRequest) SetJSONEscapeHTML(b bool) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetJSONEscapeHTML(b) + }) + return lReq +} + +// SetCookie method appends a single cookie in the current request instance. +// client.LR().SetCookie(&http.Cookie{ +// Name:"go-resty", +// Value:"This is cookie value", +// }) +// +// Note: Method appends the Cookie value into existing Cookie if already existing. +func (lReq *LazyRequest) SetCookie(hc *http.Cookie) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetCookie(hc) + }) + return lReq +} + +// SetCookies method sets an array of cookies in the current request instance. +// cookies := []*http.Cookie{ +// &http.Cookie{ +// Name:"go-resty-1", +// Value:"This is cookie 1 value", +// }, +// &http.Cookie{ +// Name:"go-resty-2", +// Value:"This is cookie 2 value", +// }, +// } +// +// // Setting a cookies into resty's current request +// client.LR().SetCookies(cookies) +// +// Note: Method appends the Cookie value into existing Cookie if already existing. +func (lReq *LazyRequest) SetCookies(rs []*http.Cookie) *LazyRequest { + lReq.opsR = append(lReq.opsR, func(r *resty.Request) { + r.SetCookies(rs) + }) + return lReq } diff --git a/util/client/httpclient/request_test.go b/util/client/httpclient/request_test.go deleted file mode 100644 index 129e12f6e6..0000000000 --- a/util/client/httpclient/request_test.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. - -package httpclient - -import ( - "fmt" - "net" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestRemoteEndpointError(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Some internal error", http.StatusInternalServerError) - })) - defer ts.Close() - - client := New(Config{}) - cancel, resp, err := client.LifecycleR().Get(ts.URL) - defer cancel() - require.NotNil(t, err) - require.NotNil(t, resp) - require.False(t, resp.IsSuccess()) - require.Equal(t, http.StatusInternalServerError, resp.StatusCode()) -} - -func TestRemoteEndpointBadServer(t *testing.T) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - require.Nil(t, err) - go func() { - for { - conn, err := listener.Accept() - if err != nil { - return - } - _ = conn.Close() - } - }() - defer func() { - _ = listener.Close() - }() - - client := New(Config{}) - cancel, resp, err := client.LifecycleR().Get(fmt.Sprintf("http://%s/foo", listener.Addr().String())) - defer cancel() - require.NotNil(t, err) - require.NotNil(t, resp) - require.False(t, resp.IsSuccess()) -} - -func TestBadScheme(t *testing.T) { - client := New(Config{}) - cancel, resp, err := client.LifecycleR().Get("foo://abc.com") - defer cancel() - require.NotNil(t, err) - require.NotNil(t, resp) - require.False(t, resp.IsSuccess()) -} - -func TestTimeoutHeader(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(time.Second * 1) - _, _ = fmt.Fprintln(w, "OK") - })) - defer ts.Close() - - now := time.Now() - - client := New(Config{}) - cancel, resp, err := client.LifecycleR().SetTimeout(100 * time.Millisecond).Get(ts.URL) - defer cancel() - require.NotNil(t, err) - require.NotNil(t, resp) - require.False(t, resp.IsSuccess()) - require.LessOrEqual(t, time.Since(now), 500*time.Millisecond) -} - -func TestTimeoutBody(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - w.(http.Flusher).Flush() - time.Sleep(time.Second * 1) - _, _ = fmt.Fprintln(w, "OK") - })) - defer ts.Close() - - now := time.Now() - - client := New(Config{}) - cancel, resp, err := client.LifecycleR().SetTimeout(100 * time.Millisecond).Get(ts.URL) - defer cancel() - require.NotNil(t, err) - require.NotNil(t, resp) - // Note: in this case, since a response code is returned, we have IsSuccess = true. An error is also returned. - require.True(t, resp.IsSuccess()) - require.Equal(t, http.StatusOK, resp.StatusCode()) - require.LessOrEqual(t, time.Since(now), 500*time.Millisecond) -} - -func TestUnmarshalFailure1(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("content-type", "application/json") - _, _ = fmt.Fprintln(w, "InvalidJSON") - })) - defer ts.Close() - - type respType struct { - Foo int - } - - client := New(Config{}) - cancel, resp, err := client.LifecycleR().SetJSONResult(&respType{}).Get(ts.URL) - defer cancel() - require.NotNil(t, err) - require.NotNil(t, resp) - require.True(t, resp.IsSuccess()) -} - -func TestUnmarshalFailure2(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprintln(w, "InvalidJSON") - })) - defer ts.Close() - - type respType struct { - Foo int - } - - client := New(Config{}) - cancel, resp, err := client.LifecycleR().SetJSONResult(&respType{}).Get(ts.URL) - defer cancel() - require.NotNil(t, err) - require.NotNil(t, resp) - require.True(t, resp.IsSuccess()) -} diff --git a/util/client/httpclient/response.go b/util/client/httpclient/response.go new file mode 100644 index 0000000000..d0d57ceffb --- /dev/null +++ b/util/client/httpclient/response.go @@ -0,0 +1,222 @@ +// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. + +package httpclient + +import ( + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "runtime" + "runtime/debug" + "strings" + "time" + + "github.com/go-resty/resty/v2" + + "github.com/pingcap/tidb-dashboard/util/israce" + "github.com/pingcap/tidb-dashboard/util/nocopy" +) + +const ( + defaultTimeout = time.Minute * 5 // Just a default long enough timeout. +) + +type requestUpdateFn func(r *resty.Request) + +type clientUpdateFn func(c *resty.Client) + +// LazyResponse provides access to the response body and response headers in convenient ways. +// No request is actually sent until LazyResponse is read. +type LazyResponse struct { + nocopy.NoCopy + + // The source request object to execute. It is a clone of the original request object + // to allow concurrent executions. + requestSnapshot *LazyRequest + + // stackAtNew is the stack when Response is created. It is only available when Golang race mode is enabled. + // This is used to report the missing `Close()` calls. + stackAtNew []byte + + // Fields below are set only after the request is actually sent. + isExecuted bool + executedResponseWithoutBody *http.Response + executedResponseBody io.ReadCloser + executedError error + executeInfo *execInfo // Contains some execution information. Will be logged when error happens. +} + +func newResponse(sourceSnapshot *LazyRequest) *LazyResponse { + er := &LazyResponse{ + requestSnapshot: sourceSnapshot, + } + runtime.SetFinalizer(er, (*LazyResponse).finalize) + if israce.Enabled { + er.stackAtNew = debug.Stack() + } + return er +} + +func (lResp *LazyResponse) doExecutionOnce() { + if lResp.isExecuted { + return + } + + client := resty.NewWithClient(&http.Client{ + Transport: lResp.requestSnapshot.transport, + }) + client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(10)) + client.SetTimeout(defaultTimeout) + for _, op := range lResp.requestSnapshot.opsC { + op(client) + } + restyReq := client.R() + restyReq.SetDoNotParseResponse(true) + for _, op := range lResp.requestSnapshot.opsR { + op(restyReq) + } + + info := &execInfo{kindTag: lResp.requestSnapshot.kindTag} + info.reqURL = restyReq.URL + info.reqMethod = restyReq.Method + + restyResp, err := restyReq.Send() + if err != nil { + // Turn all errors into ErrRequestFailed. + err = ErrRequestFailed.WrapWithNoMessage(err) + } + + if (restyResp == nil || restyResp.RawResponse == nil) && err == nil { + // Response and error come from 3rd-party libraries, we are not sure about this. + // Let's try out best to catch it. + err = ErrRequestFailed.New("%s %s (%s): internal error, no response", + restyResp.Request.Method, + restyResp.Request.URL, + lResp.requestSnapshot.kindTag) + restyResp = nil + } + + if !restyResp.IsSuccess() && err == nil { + // Turn all non success responses to an error, like 301 Moved Permanently and 404 Not Found. + // Note: IsError != !IsSuccess. + err = ErrRequestFailed.New("%s %s (%s): Response status %d", + restyResp.Request.Method, + restyResp.Request.URL, + lResp.requestSnapshot.kindTag, + restyResp.StatusCode()) + } + + if err != nil { + // Turn response into nil when there is an error. + if restyResp != nil && restyResp.RawResponse != nil { + data, _ := ioutil.ReadAll(restyResp.RawResponse.Body) + _ = restyResp.RawResponse.Body.Close() + info.respStatus = restyResp.Status() + info.respBody = string(data) + restyResp = nil + } + info.Warn("Request failed", err) + } + + if restyResp != nil && restyResp.RawResponse != nil { + lResp.executedResponseBody = restyResp.RawResponse.Body + lResp.executedResponseWithoutBody = restyResp.RawResponse + restyResp.RawResponse.Body = nil + } else { + lResp.executedResponseBody = nil + lResp.executedResponseWithoutBody = nil + } + lResp.executedError = err + lResp.executeInfo = info + lResp.isExecuted = true + + // The request is executed, no need to schedule a check for the execution any more. + runtime.SetFinalizer(lResp, nil) +} + +func (lResp *LazyResponse) close() { + _ = lResp.executedResponseBody.Close() +} + +// Finish closes the response. Read is not possible any more. +// The returned raw response does not have a body so that you don't need to close it manually. +func (lResp *LazyResponse) Finish() (respNoBody *http.Response, err error) { + lResp.doExecutionOnce() + if lResp.executedError != nil { + return nil, lResp.executedError + } + respNoBody = lResp.executedResponseWithoutBody + lResp.close() + return +} + +func (lResp *LazyResponse) PipeBody(w io.Writer) (written int64, respNoBody *http.Response, err error) { + lResp.doExecutionOnce() + if lResp.executedError != nil { + return 0, nil, lResp.executedError + } + respNoBody = lResp.executedResponseWithoutBody + written, err = io.Copy(w, lResp.executedResponseBody) + if err != nil { + respNoBody = nil + err = ErrRequestFailed.WrapWithNoMessage(err) + lResp.executeInfo.Warn("Request failed", err) + } + lResp.close() + return +} + +func (lResp *LazyResponse) ReadBodyAsBytes() (bytes []byte, respNoBody *http.Response, err error) { + lResp.doExecutionOnce() + if lResp.executedError != nil { + return nil, nil, lResp.executedError + } + respNoBody = lResp.executedResponseWithoutBody + bytes, err = ioutil.ReadAll(lResp.executedResponseBody) + if err != nil { + bytes = nil + respNoBody = nil + err = ErrRequestFailed.WrapWithNoMessage(err) + lResp.executeInfo.Warn("Request failed", err) + } + lResp.close() + return +} + +func (lResp *LazyResponse) ReadBodyAsString() (data string, respNoBody *http.Response, err error) { + bytes, resp, err := lResp.ReadBodyAsBytes() + if err != nil { + return "", nil, err + } + return strings.TrimSpace(string(bytes)), resp, nil +} + +func (lResp *LazyResponse) ReadBodyAsJSON(destination interface{}) (respNoBody *http.Response, err error) { + bytes, resp, err := lResp.ReadBodyAsBytes() + if err != nil { + return nil, err + } + err = json.Unmarshal(bytes, destination) + if err != nil { + err = ErrRequestFailed.WrapWithNoMessage(err) + ei := *lResp.executeInfo + ei.respStatus = lResp.executedResponseWithoutBody.Status + ei.respBody = string(bytes) + ei.Warn("Request failed", err) + return nil, err + } + return resp, nil +} + +func (lResp *LazyResponse) finalize() { + if israce.Enabled { + // try to catch incorrect usages + _, _ = os.Stderr.Write(lResp.stackAtNew) + panic(fmt.Sprintf("%T is not used correctly, one of PipeBody(), ReadBodyAsBytes(), ReadBodyAsString(), ReadBodyAsJSON() or Finish() must be called", lResp)) + } + // If a LazyResponse is GCed without actually sending the request, then we can just do nothing. + // There is even no need to close the response body, since the request is not sent. +} diff --git a/util/client/httpclient/response_test.go b/util/client/httpclient/response_test.go new file mode 100644 index 0000000000..bcc5fb5035 --- /dev/null +++ b/util/client/httpclient/response_test.go @@ -0,0 +1,986 @@ +// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. + +package httpclient + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/joomcode/errorx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" +) + +func TestReadBodyAsString(t *testing.T) { + requestTimes := atomic.Int32{} + responseStatus := atomic.Int32{} + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestTimes.Inc() + w.WriteHeader(int(responseStatus.Load())) + _, _ = fmt.Fprintf(w, "Basically OK, Req #%d", requestTimes.Load()) + })) + defer ts.Close() + + client := New(Config{}) + + responseStatus.Store(200) + req := client.LR() + resp := req.Get(ts.URL) + responseStatus.Store(202) // Lazy request + require.Equal(t, int32(0), requestTimes.Load()) // Lazy request + dataStr, rawResp, err := resp.ReadBodyAsString() + require.Equal(t, int32(1), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, "Basically OK, Req #1", dataStr) + require.Nil(t, rawResp.Body) + require.Equal(t, 202, rawResp.StatusCode) // Due to lazy request, we should get 202 + + // Read again should result in error + dataStrE, rawRespE, err := resp.ReadBodyAsString() + require.Equal(t, int32(1), requestTimes.Load()) + require.Contains(t, err.Error(), "read on closed response body") + require.Empty(t, dataStrE) + require.Nil(t, rawRespE) + + // Other kind of read operations should also result in error + bytesE, rawRespE, err := resp.ReadBodyAsBytes() + require.Equal(t, int32(1), requestTimes.Load()) + require.Contains(t, err.Error(), "read on closed response body") + require.Nil(t, bytesE) + require.Nil(t, rawRespE) + + // Test sending a new request via Get() over the same request again + responseStatus.Store(201) + resp = req.Get(ts.URL) + require.Equal(t, int32(1), requestTimes.Load()) + dataStr2, rawResp2, err := resp.ReadBodyAsString() + require.Equal(t, int32(2), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, "Basically OK, Req #2", dataStr2) + require.Nil(t, rawResp2.Body) + require.Equal(t, 202, rawResp.StatusCode) // The previous response should not be changed by a new request + require.Equal(t, 201, rawResp2.StatusCode) + + // Sending a new request via LR() over the same client + responseStatus.Store(200) + resp = client.LR().Get(ts.URL) + require.Equal(t, int32(2), requestTimes.Load()) + dataStr3, rawResp3, err := resp.ReadBodyAsString() + require.Equal(t, int32(3), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, "Basically OK, Req #3", dataStr3) + require.Nil(t, rawResp3.Body) + require.Equal(t, 202, rawResp.StatusCode) + require.Equal(t, 201, rawResp2.StatusCode) + require.Equal(t, 200, rawResp3.StatusCode) + + // Sending a new request via creating a new client + client2 := New(Config{}) + responseStatus.Store(202) + resp = client2.LR().Get(ts.URL) + require.Equal(t, int32(3), requestTimes.Load()) + dataStr4, rawResp4, err := resp.ReadBodyAsString() + require.Equal(t, int32(4), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, "Basically OK, Req #4", dataStr4) + require.Nil(t, rawResp4.Body) + require.Equal(t, 202, rawResp.StatusCode) + require.Equal(t, 201, rawResp2.StatusCode) + require.Equal(t, 200, rawResp3.StatusCode) + require.Equal(t, 202, rawResp4.StatusCode) +} + +func TestFinish(t *testing.T) { + requestTimes := atomic.Int32{} + responseStatus := atomic.Int32{} + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestTimes.Inc() + w.WriteHeader(int(responseStatus.Load())) + _, _ = fmt.Fprintf(w, "Basically OK, Req #%d", requestTimes.Load()) + })) + defer ts.Close() + + client := New(Config{}) + responseStatus.Store(200) + + resp := client.LR().Get(ts.URL) + responseStatus.Store(202) // Lazy request + require.Equal(t, int32(0), requestTimes.Load()) // Lazy request + rawResp, err := resp.Finish() + require.Equal(t, int32(1), requestTimes.Load()) + require.Nil(t, err) + require.Nil(t, rawResp.Body) + require.Equal(t, 202, rawResp.StatusCode) + + // Call Finish() again should not send a new request + rawResp, err = resp.Finish() + require.Equal(t, int32(1), requestTimes.Load()) + require.Nil(t, err) + require.Nil(t, rawResp.Body) + require.Equal(t, 202, rawResp.StatusCode) + + // Read after Finish() should become errors + dataStrE, rawRespE, err := resp.ReadBodyAsString() + require.Equal(t, int32(1), requestTimes.Load()) + require.Contains(t, err.Error(), "read on closed response body") + require.Empty(t, dataStrE) + require.Nil(t, rawRespE) + bytesE, rawRespE, err := resp.ReadBodyAsBytes() + require.Equal(t, int32(1), requestTimes.Load()) + require.Contains(t, err.Error(), "read on closed response body") + require.Nil(t, bytesE) + require.Nil(t, rawRespE) + + // Finish() after read is fine. + resp = client.LR().Get(ts.URL) + responseStatus.Store(200) + require.Equal(t, int32(1), requestTimes.Load()) + dataStr, rawResp, err := resp.ReadBodyAsString() + require.Equal(t, int32(2), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, "Basically OK, Req #2", dataStr) + require.Nil(t, rawResp.Body) + require.Equal(t, 200, rawResp.StatusCode) + rawResp2, err := resp.Finish() + require.Equal(t, int32(2), requestTimes.Load()) + require.Nil(t, err) + require.Same(t, rawResp, rawResp2) +} + +func TestReadBodyAsJSON(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, `{"foo":"bar"}`) + })) + defer ts.Close() + + // Unmarshal into map + client := New(Config{}) + var respMap map[string]interface{} + rawResp, err := client.LR().Get(ts.URL).ReadBodyAsJSON(&respMap) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) + expectedMap := map[string]interface{}{ + "foo": "bar", + } + require.Equal(t, expectedMap, respMap) + + // Unmarshal into struct + type Response struct { + Foo string `json:"foo"` + } + var respStruct Response + req := client.LR().Get(ts.URL) + rawResp, err = req.ReadBodyAsJSON(&respStruct) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) + require.Equal(t, Response{Foo: "bar"}, respStruct) +} + +func TestReadBodyAsJSON_UnmarshalFailure(t *testing.T) { + requestTimes := atomic.Int32{} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestTimes.Inc() + _, _ = fmt.Fprintln(w, `bad_json`) + })) + defer ts.Close() + + client := New(Config{}) + + var respMap map[string]interface{} + assert.Equal(t, int32(0), requestTimes.Load()) + req := client.LR().Get(ts.URL) + rawResp, err := req.ReadBodyAsJSON(&respMap) + assert.Equal(t, int32(1), requestTimes.Load()) + require.Contains(t, err.Error(), "invalid character") + require.Nil(t, rawResp) + require.Nil(t, respMap) + + // Read JSON again should not send new request + rawResp, err = req.ReadBodyAsJSON(&respMap) + assert.Equal(t, int32(1), requestTimes.Load()) + require.Contains(t, err.Error(), "read on closed response body") + require.Nil(t, rawResp) + require.Nil(t, respMap) + + // Finish should success without sending new requests + // Unlike other Read errors, for unmarshal errors, Finish() will succeed since an OK response is read successfully + rawResp, err = req.Finish() + assert.Equal(t, int32(1), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) +} + +type myWriter struct { + writedBytes int + writeCalled int + errorRaised int +} + +func (w *myWriter) Write(p []byte) (int, error) { + w.writeCalled++ + if w.writedBytes > 5 { + w.errorRaised++ + return 0, fmt.Errorf("write too many bytes") + } + w.writedBytes += len(p) + return len(p), nil +} + +func TestPipeBody(t *testing.T) { + requestTimes := atomic.Int32{} + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestTimes.Inc() + _, _ = fmt.Fprintln(w, "Hello world") + })) + defer ts.Close() + + client := New(Config{}) + + buf := bytes.Buffer{} + assert.Equal(t, int32(0), requestTimes.Load()) + req := client.LR().Get(ts.URL) + wBytes, rawResp, err := req.PipeBody(&buf) + assert.Equal(t, int32(1), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) + require.Equal(t, "Hello world\n", buf.String()) + require.Equal(t, int64(12), wBytes) + + // The copy chunk size is large, so that there will be only one write call to the writer + w := myWriter{} + assert.Equal(t, int32(1), requestTimes.Load()) + wBytes, rawResp, err = client.LR().Get(ts.URL).PipeBody(&w) + assert.Equal(t, int32(2), requestTimes.Load()) + require.Equal(t, int64(12), wBytes) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) + require.Equal(t, 1, w.writeCalled) + require.Equal(t, 12, w.writedBytes) + require.Equal(t, 0, w.errorRaised) + + // Now the server write data chunk by chunk... + ctx, cancel := context.WithCancel(context.Background()) + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestTimes.Inc() + _, _ = w.Write([]byte("Partial...")) + w.(http.Flusher).Flush() + select { + case <-ctx.Done(): + return + case <-time.After(1 * time.Second): + _, _ = fmt.Fprintln(w, "Done") + } + })) + defer ts.Close() + defer cancel() + + // PipeData should produce data chunk by chunk + w = myWriter{} + assert.Equal(t, int32(2), requestTimes.Load()) + resp := client.LR().Get(ts.URL) + wBytes, rawResp, err = resp.PipeBody(&w) + assert.Equal(t, int32(3), requestTimes.Load()) + require.Equal(t, int64(10), wBytes) + require.NotNil(t, err) + require.Contains(t, err.Error(), "write too many bytes") + require.Nil(t, rawResp) + require.Equal(t, 2, w.writeCalled) + require.Equal(t, 10, w.writedBytes) // The size of the first chunk + require.Equal(t, 1, w.errorRaised) + // Call PipeBody again should fail due to response is closed + wBytes, rawResp, err = resp.PipeBody(&w) + assert.Equal(t, int32(3), requestTimes.Load()) + require.Equal(t, int64(0), wBytes) + require.NotNil(t, err) + require.Contains(t, err.Error(), "read on closed response body") + require.Nil(t, rawResp) + require.Equal(t, 2, w.writeCalled) // Unchanged + require.Equal(t, 10, w.writedBytes) + require.Equal(t, 1, w.errorRaised) + + // PipeBody should copy all data when there are multiple chunks from the server + buf = bytes.Buffer{} + assert.Equal(t, int32(3), requestTimes.Load()) + req = client.LR().Get(ts.URL) + wBytes, rawResp, err = req.PipeBody(&buf) + assert.Equal(t, int32(4), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) + require.Equal(t, "Partial...Done\n", buf.String()) + require.Equal(t, int64(15), wBytes) +} + +func TestResponseHeader(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("foo", "bar") + w.WriteHeader(http.StatusAlreadyReported) + _, _ = fmt.Fprintln(w, "Fine!") + })) + defer ts.Close() + + client := New(Config{}) + resp := client.LR().Get(ts.URL) + rawResp, err := resp.Finish() + require.Nil(t, err) + require.Equal(t, http.StatusAlreadyReported, rawResp.StatusCode) + require.Equal(t, "bar", rawResp.Header.Get("foo")) +} + +func TestSetURL(t *testing.T) { + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "Result from server 1") + })) + defer ts1.Close() + + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "Result from server 2") + })) + defer ts2.Close() + + client := New(Config{}) + req := client.LR() + + // SetXxx should make changes in place + r1 := req.SetURL(ts1.URL) + r2 := req.SetURL(ts2.URL) + require.Same(t, r1, r2) + dataStr, _, err := r1.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 2", dataStr) + dataStr, _, err = r2.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 2", dataStr) + + r1.SetURL(ts1.URL) + dataStr, _, err = r2.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) + + // SetURL should not affect another request in the same client + req2 := client.LR() + req2.SetURL(ts2.URL) + dataStr, _, err = r1.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) + dataStr, _, err = r2.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) + dataStr, _, err = req2.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 2", dataStr) + dataStr, _, err = req.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) + dataStr, _, err = r1.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) +} + +func TestLR(t *testing.T) { + client := New(Config{}) + req1 := client.LR() + req2 := client.LR() + require.NotSame(t, req1, req2) +} + +func TestGet(t *testing.T) { + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "Result from server 1") + })) + defer ts1.Close() + + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "Result from server 2") + })) + defer ts2.Close() + + client := New(Config{}) + + // "Get" from different requests should not affect each other + resp1 := client.LR().Get(ts1.URL) + resp2 := client.LR().Get(ts2.URL) + dataStr, _, err := resp1.ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) + dataStr, _, err = resp2.ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 2", dataStr) + + // "Get" should not affect each other + req := client.LR() + resp1 = req.Get(ts1.URL) + resp2 = req.Get(ts2.URL) + dataStr, _, err = resp1.ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) + dataStr, _, err = resp2.ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 2", dataStr) + resp3 := req.Get(ts1.URL) + dataStr, _, err = resp3.ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) + + // "Get()" should not affect the previous "SetURL()" call + req = client.LR() + req.SetURL(ts1.URL) + dataStr, _, err = req.Get(ts2.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 2", dataStr) + dataStr, _, err = req.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) + + // "SetURL()" should not affect the previous "Get()" call + req = client.LR() + resp1 = req.Get(ts1.URL) + req.SetURL(ts2.URL) + dataStr, _, err = resp1.ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 1", dataStr) + dataStr, _, err = req.Send().ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Result from server 2", dataStr) +} + +func TestPost(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := ioutil.ReadAll(r.Body) + _, _ = fmt.Fprintf(w, "Body is %s", string(body)) + })) + defer ts.Close() + + client := New(Config{}) + + // SetBody from different requests should not affect each other + r1 := client.LR().SetBody("foo") + dataStr, _, err := r1.Post(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Body is foo", dataStr) + + r2 := client.LR().SetBody("bar") + dataStr, _, err = r2.Post(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Body is bar", dataStr) + + dataStr, _, err = r1.Post(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Body is foo", dataStr) +} + +func TestSetHeader(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, r.Header.Get("X-Test")) + })) + defer ts.Close() + + client := New(Config{}) + req := client.LR().SetHeader("X-Test", "foobar") + + // SetHeader from different requests should not affect each other + dataStr, _, err := req.Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "foobar", dataStr) + + dataStr, _, err = client.LR().Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "", dataStr) + + dataStr, _, err = req.Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "foobar", dataStr) + + // SetHeader after Get should not taking effect + req = client.LR() + resp := req.Get(ts.URL) + req.SetHeader("X-Test", "hello") + dataStr, _, err = resp.ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "", dataStr) + + resp = req.Get(ts.URL) + dataStr, _, err = resp.ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "hello", dataStr) +} + +func TestSetBaseURL(t *testing.T) { + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "ts1"+r.URL.Path) + })) + defer ts1.Close() + + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "ts2"+r.URL.Path) + })) + defer ts2.Close() + + client := New(Config{}) + dataStr, _, err := client.LR().SetBaseURL(ts1.URL).Get("/foo").ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "ts1/foo", dataStr) + + // BaseURL can be overwritten + dataStr, _, err = client.LR().SetBaseURL(ts1.URL).Get(ts2.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "ts2/", dataStr) +} + +func TestFailureStatusCode(t *testing.T) { + requestTimes := atomic.Int32{} + responseStatus := atomic.Int32{} + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestTimes.Inc() + w.WriteHeader(int(responseStatus.Load())) + _, _ = fmt.Fprintf(w, "Fail from req #%d", requestTimes.Load()) + })) + defer ts.Close() + + // Although request succeeded, failure status code will turn into errors by design. + + client := New(Config{}) + + // ReadBodyAsBytes should fail + responseStatus.Store(500) + require.Equal(t, int32(0), requestTimes.Load()) + bytes, rawResp, err := client.LR().Get(ts.URL).ReadBodyAsBytes() + require.Equal(t, int32(1), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), "Response status 500") + require.Nil(t, bytes) + require.Nil(t, rawResp) + + // ReadBodyAsString should return empty string + responseStatus.Store(400) + require.Equal(t, int32(1), requestTimes.Load()) + resp := client.LR().Get(ts.URL) + dataStr, rawResp, err := resp.ReadBodyAsString() + require.Equal(t, int32(2), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), "Response status 400") + require.Empty(t, dataStr) + require.Nil(t, rawResp) + // Read again after failure should not send request again + responseStatus.Store(500) + dataStr, rawResp, err = resp.ReadBodyAsString() + require.Equal(t, int32(2), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), "Response status 400") + require.Empty(t, dataStr) + require.Nil(t, rawResp) + + // ReadBodyAsJSON should fail + var respMap map[string]interface{} + resp = client.LR().Get(ts.URL) + rawResp, err = resp.ReadBodyAsJSON(respMap) + require.Equal(t, int32(3), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), "Response status 500") + require.Empty(t, dataStr) + require.Nil(t, rawResp) + require.Nil(t, respMap) + rawResp, err = resp.ReadBodyAsJSON(respMap) + require.Equal(t, int32(3), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), "Response status 500") + require.Empty(t, dataStr) + require.Nil(t, rawResp) + require.Nil(t, respMap) + + // Finish should fail + responseStatus.Store(404) + require.Equal(t, int32(3), requestTimes.Load()) + resp = client.LR().Get(ts.URL) + rawResp, err = resp.Finish() + require.Equal(t, int32(4), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), "Response status 404") + require.Nil(t, rawResp) + // Finish again after failure should not send request again + responseStatus.Store(200) + rawResp, err = resp.Finish() + require.Equal(t, int32(4), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), "Response status 404") + require.Nil(t, rawResp) + // Mix Finish() and ReadBodyAsString() + responseStatus.Store(403) + dataStr, rawResp, err = resp.ReadBodyAsString() + require.Equal(t, int32(4), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), "Response status 404") + require.Empty(t, dataStr) + require.Nil(t, rawResp) + responseStatus.Store(200) + rawResp, err = resp.Finish() + require.Equal(t, int32(4), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), "Response status 404") + require.Nil(t, rawResp) +} + +func TestBadServer(t *testing.T) { + requestTimes := atomic.Int32{} + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.Nil(t, err) + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + requestTimes.Inc() + _, _ = conn.Write([]byte("Hello")) + _ = conn.Close() + } + }() + defer func() { + _ = listener.Close() + }() + url := fmt.Sprintf("http://%s/foo", listener.Addr().String()) + + client := New(Config{}) + + // ReadBodyAsString should return empty string + require.Equal(t, int32(0), requestTimes.Load()) + resp := client.LR().Get(url) + dataStr, rawResp, err := resp.ReadBodyAsString() + require.Equal(t, int32(1), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Empty(t, dataStr) + require.Nil(t, rawResp) + // Call multiple times + dataStr, rawResp, err = resp.ReadBodyAsString() + require.Equal(t, int32(1), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Empty(t, dataStr) + require.Nil(t, rawResp) + + // Response should fail + require.Equal(t, int32(1), requestTimes.Load()) + resp = client.LR().Get(url) + rawResp, err = resp.Finish() + require.Equal(t, int32(2), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Nil(t, rawResp) + // Call multiple times + rawResp, err = resp.Finish() + require.Equal(t, int32(2), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Nil(t, rawResp) + // Mix Finish() and ReadBodyAsString() + dataStr, rawResp, err = resp.ReadBodyAsString() + require.Equal(t, int32(2), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Empty(t, dataStr) + require.Nil(t, rawResp) + rawResp, err = resp.Finish() + require.Equal(t, int32(2), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Nil(t, rawResp) + + // ReadBodyASJSON should fail + require.Equal(t, int32(2), requestTimes.Load()) + resp = client.LR().Get(url) + var respMap map[string]interface{} + rawResp, err = resp.ReadBodyAsJSON(respMap) + require.Equal(t, int32(3), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Nil(t, rawResp) + require.Nil(t, respMap) + // Call multiple times + rawResp, err = resp.ReadBodyAsJSON(respMap) + require.Equal(t, int32(3), requestTimes.Load()) + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Nil(t, rawResp) + require.Nil(t, respMap) +} + +func TestBadScheme(t *testing.T) { + client := New(Config{}) + bytes, rawResp, err := client.LR().Get("foo://abc.com").ReadBodyAsBytes() + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), `unsupported protocol scheme "foo"`) + require.Nil(t, bytes) + require.Nil(t, rawResp) + + rawResp, err = client.LR().Get("bar://abc.com").Finish() + require.True(t, errorx.IsOfType(err, ErrRequestFailed)) + require.Contains(t, err.Error(), `unsupported protocol scheme "bar"`) + require.Nil(t, rawResp) +} + +func TestConnectionReuse(t *testing.T) { + newConn := atomic.Int32{} + closedConn := atomic.Int32{} + + requestTimes := atomic.Int32{} + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestTimes.Inc() + _, _ = fmt.Fprintf(w, "Req #%d", requestTimes.Load()) + })) + ts.Config.ConnState = func(c net.Conn, cs http.ConnState) { + switch cs { + case http.StateNew: + newConn.Inc() + case http.StateHijacked, http.StateClosed: + closedConn.Inc() + default: + // we do not care other states + } + } + ts.Start() + defer ts.Close() + + require.Equal(t, int32(0), newConn.Load()) + require.Equal(t, int32(0), closedConn.Load()) + + client := New(Config{}) + dataStr, _, err := client.LR().Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Req #1", dataStr) + require.Equal(t, int32(1), newConn.Load()) + require.Equal(t, int32(0), closedConn.Load()) + + // Use the same client to send request, the connection is expected to be reused + dataStr, _, err = client.LR().Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Req #2", dataStr) + require.Equal(t, int32(1), newConn.Load()) + require.Equal(t, int32(0), closedConn.Load()) + + // A new client should create a new connection + client2 := New(Config{}) + dataStr, _, err = client2.LR().Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Req #3", dataStr) + require.Equal(t, int32(2), newConn.Load()) + require.Equal(t, int32(0), closedConn.Load()) + + // Connections are reused + dataStr, _, err = client.LR().Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Req #4", dataStr) + require.Equal(t, int32(2), newConn.Load()) + require.Equal(t, int32(0), closedConn.Load()) + dataStr, _, err = client2.LR().Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.Equal(t, "Req #5", dataStr) + require.Equal(t, int32(2), newConn.Load()) + require.Equal(t, int32(0), closedConn.Load()) +} + +func TestClone(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m := make(map[string]string) + for header, value := range r.Header { + if strings.HasPrefix(header, "X-") { + m[header] = value[0] + } + } + j, _ := json.Marshal(m) + _, _ = w.Write(j) + })) + defer ts.Close() + + client := New(Config{}) + + req1 := client.LR() + req1.SetHeader("x-req1header1", "value1") + + req2 := req1.Clone() + // After clone, they will not affect each other + req1.SetHeader("x-req1header2", "value2") + req2.SetHeader("x-req2header1", "value1") + + dataStr, _, err := req1.Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.JSONEq(t, `{"X-Req1header1":"value1","X-Req1header2":"value2"}`, dataStr) + + dataStr, _, err = req2.Get(ts.URL).ReadBodyAsString() + require.Nil(t, err) + require.JSONEq(t, `{"X-Req1header1":"value1","X-Req2header1":"value1"}`, dataStr) +} + +func TestTimeoutHeader(t *testing.T) { + requestTimes := atomic.Int32{} + ctx, cancel := context.WithCancel(context.Background()) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestTimes.Inc() + select { + case <-ctx.Done(): + w.WriteHeader(http.StatusGatewayTimeout) + case <-time.After(1 * time.Second): + _, _ = fmt.Fprintln(w, "OK") + } + })) + defer ts.Close() + defer cancel() + + client := New(Config{}) + tBegin := time.Now() + require.Equal(t, int32(0), requestTimes.Load()) + resp := client.LR().SetTimeout(100 * time.Millisecond).Get(ts.URL) + _, rawResp, err := resp.ReadBodyAsString() + require.Equal(t, int32(1), requestTimes.Load()) + require.Less(t, time.Since(tBegin), 300*time.Millisecond) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + // Read again + _, rawResp, err = resp.ReadBodyAsString() + require.Equal(t, int32(1), requestTimes.Load()) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + rawResp, err = resp.Finish() + require.Equal(t, int32(1), requestTimes.Load()) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + // Even if the request is finished then, we should still get timeout error. + time.Sleep(1 * time.Second) + _, rawResp, err = resp.ReadBodyAsString() + require.Equal(t, int32(1), requestTimes.Load()) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + + // Call Finish() directly should fail + tBegin = time.Now() + resp = client.LR().SetTimeout(100 * time.Millisecond).Get(ts.URL) + rawResp, err = resp.Finish() + require.Equal(t, int32(2), requestTimes.Load()) + require.Less(t, time.Since(tBegin), 300*time.Millisecond) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + + // Read using long enough timeout should succeed + resp = client.LR().SetTimeout(1200 * time.Millisecond).Get(ts.URL) + dataStr, rawResp, err := resp.ReadBodyAsString() + require.Equal(t, int32(3), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) + require.Equal(t, "OK", dataStr) +} + +func TestTimeoutBody(t *testing.T) { + requestTimes := atomic.Int32{} + ctx, cancel := context.WithCancel(context.Background()) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestTimes.Inc() + _, _ = w.Write([]byte("Partial...")) + w.(http.Flusher).Flush() + select { + case <-ctx.Done(): + return + case <-time.After(1 * time.Second): + _, _ = fmt.Fprintln(w, "Done") + } + })) + defer ts.Close() + defer cancel() + + client := New(Config{}) + + // Finish() should succeed, since a header is successfully returned + tBegin := time.Now() + resp := client.LR().SetTimeout(100 * time.Millisecond).Get(ts.URL) + rawResp, err := resp.Finish() + require.Equal(t, int32(1), requestTimes.Load()) + require.Less(t, time.Since(tBegin), 50*time.Millisecond) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) + + // ReadBodyAsString() should fail + tBegin = time.Now() + resp = client.LR().SetTimeout(100 * time.Millisecond).Get(ts.URL) + _, rawResp, err = resp.ReadBodyAsString() + require.Equal(t, int32(2), requestTimes.Load()) + require.Less(t, time.Since(tBegin), 300*time.Millisecond) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + // Read again + _, rawResp, err = resp.ReadBodyAsString() + require.Equal(t, int32(2), requestTimes.Load()) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + // Wait enough time and read again + time.Sleep(1 * time.Second) + _, rawResp, err = resp.ReadBodyAsString() + require.Equal(t, int32(2), requestTimes.Load()) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + // Finish() should succeed + tBegin = time.Now() + rawResp, err = resp.Finish() + require.Equal(t, int32(2), requestTimes.Load()) + require.Less(t, time.Since(tBegin), 50*time.Millisecond) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) + + // PipeBody() should fail + buf := bytes.Buffer{} + tBegin = time.Now() + resp = client.LR().SetTimeout(100 * time.Millisecond).Get(ts.URL) + wBytes, rawResp, err := resp.PipeBody(&buf) + require.Equal(t, int32(3), requestTimes.Load()) + require.Less(t, time.Since(tBegin), 300*time.Millisecond) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + require.Equal(t, int64(10), wBytes) // The first chunk is written + require.Equal(t, "Partial...", buf.String()) + // PipeBody again should fail + wBytes, rawResp, err = resp.PipeBody(&buf) + require.Equal(t, int32(3), requestTimes.Load()) + require.Less(t, time.Since(tBegin), 300*time.Millisecond) + require.NotNil(t, err) + require.Contains(t, err.Error(), "Client.Timeout") + require.Nil(t, rawResp) + require.Equal(t, int64(0), wBytes) // No more chunk is written + require.Equal(t, "Partial...", buf.String()) + + // Read using long enough timeout should succeed + resp = client.LR().SetTimeout(1200 * time.Millisecond).Get(ts.URL) + dataStr, rawResp, err := resp.ReadBodyAsString() + require.Equal(t, int32(4), requestTimes.Load()) + require.Nil(t, err) + require.Equal(t, http.StatusOK, rawResp.StatusCode) + require.Equal(t, "Partial...Done", dataStr) +} + +// FIXME: Seems that there is no way to test the panic happens inside runtime finalizers. +//func TestUsageCheck(t *testing.T) { +// if !israce.Enabled { +// t.Skipf("LazyResponse usage check will be tested only when race detector is enabled") +// return +// } +// client := New(Config{}) +// client.LR().Get("foo://example.com") +// assert.Panics(t, func() { runtime.GC() }) +//} + +// TODO: TestCtxRequest + +// TODO: TestCtxResponse +// This test shows that ctx doesn't really restrict the response's lifetime. + +// TODO: Test log output diff --git a/util/client/pdclient/pd_api.go b/util/client/pdclient/pd_api.go index 948fa43736..85b435e190 100644 --- a/util/client/pdclient/pd_api.go +++ b/util/client/pdclient/pd_api.go @@ -8,13 +8,9 @@ type GetStatusResponse struct { StartTimestamp int64 `json:"start_timestamp"` } -func (api *APIClient) GetStatus() (*GetStatusResponse, error) { - cancel, resp, err := api.LifecycleR().SetJSONResult(&GetStatusResponse{}).Get("/status") - defer cancel() - if err != nil { - return nil, err - } - return resp.Result().(*GetStatusResponse), nil +func (api *APIClient) GetStatus() (resp *GetStatusResponse, err error) { + _, err = api.LR().Get("/status").ReadBodyAsJSON(resp) + return } type GetHealthResponse []struct { @@ -22,13 +18,9 @@ type GetHealthResponse []struct { Health bool `json:"health"` } -func (api *APIClient) GetHealth() (*GetHealthResponse, error) { - cancel, resp, err := api.LifecycleR().SetJSONResult(&GetHealthResponse{}).Get("/health") - defer cancel() - if err != nil { - return nil, err - } - return resp.Result().(*GetHealthResponse), nil +func (api *APIClient) GetHealth() (resp *GetHealthResponse, err error) { + _, err = api.LR().Get("/health").ReadBodyAsJSON(resp) + return } type GetMembersResponse struct { @@ -41,26 +33,18 @@ type GetMembersResponse struct { } `json:"members"` } -func (api *APIClient) GetMembers() (*GetMembersResponse, error) { - cancel, resp, err := api.LifecycleR().SetJSONResult(&GetMembersResponse{}).Get("/members") - defer cancel() - if err != nil { - return nil, err - } - return resp.Result().(*GetMembersResponse), nil +func (api *APIClient) GetMembers() (resp *GetMembersResponse, err error) { + _, err = api.LR().Get("/members").ReadBodyAsJSON(resp) + return } type GetConfigReplicateResponse struct { LocationLabels string `json:"location-labels"` } -func (api *APIClient) GetConfigReplicate() (*GetConfigReplicateResponse, error) { - cancel, resp, err := api.LifecycleR().SetJSONResult(&GetConfigReplicateResponse{}).Get("/config/replicate") - defer cancel() - if err != nil { - return nil, err - } - return resp.Result().(*GetConfigReplicateResponse), nil +func (api *APIClient) GetConfigReplicate() (resp *GetConfigReplicateResponse, err error) { + _, err = api.LR().Get("/config/replicate").ReadBodyAsJSON(resp) + return } type GetStoresResponseStore struct { @@ -84,11 +68,7 @@ type GetStoresResponse struct { } `json:"stores"` } -func (api *APIClient) GetStores() (*GetStoresResponse, error) { - cancel, resp, err := api.LifecycleR().SetJSONResult(&GetStoresResponse{}).Get("/stores") - defer cancel() - if err != nil { - return nil, err - } - return resp.Result().(*GetStoresResponse), nil +func (api *APIClient) GetStores() (resp *GetStoresResponse, err error) { + _, err = api.LR().Get("/stores").ReadBodyAsJSON(resp) + return } diff --git a/util/client/pdclient/pd_api_client.go b/util/client/pdclient/pd_api_client.go index b711b31cfa..bf5c5aa55d 100644 --- a/util/client/pdclient/pd_api_client.go +++ b/util/client/pdclient/pd_api_client.go @@ -12,12 +12,12 @@ type APIClient struct { *httpclient.Client } -// Returns error when config is invalid. +// NewAPIClient returns error when config is invalid. func NewAPIClient(config httpclient.APIClientConfig) (*APIClient, error) { c2, err := config.IntoConfig(distro.R().PD) if err != nil { return nil, err } - c2.BaseURL += "/pd/api/v1" + c2.DefaultBaseURL += "/pd/api/v1" return &APIClient{httpclient.New(c2)}, nil } diff --git a/util/israce/no_race.go b/util/israce/no_race.go new file mode 100644 index 0000000000..911121a5e0 --- /dev/null +++ b/util/israce/no_race.go @@ -0,0 +1,9 @@ +// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. + +// +build !race + +// Package israce reports if the Go race detector is enabled. +package israce + +// Enabled reports if the race detector is enabled. +const Enabled = false diff --git a/util/israce/race.go b/util/israce/race.go new file mode 100644 index 0000000000..09878684af --- /dev/null +++ b/util/israce/race.go @@ -0,0 +1,9 @@ +// Copyright 2021 PingCAP, Inc. Licensed under Apache-2.0. + +// +build race + +// Package israce reports if the Go race detector is enabled. +package israce + +// Enabled reports if the race detector is enabled. +const Enabled = true diff --git a/util/testutil/testutil.go b/util/testutil/testutil.go index 876f7a9769..11e5ec51b3 100644 --- a/util/testutil/testutil.go +++ b/util/testutil/testutil.go @@ -3,6 +3,7 @@ package testutil import ( + "runtime" "testing" "github.com/gin-gonic/gin" @@ -13,4 +14,5 @@ func TestMain(m *testing.M) { EnableDebugLog() gin.SetMode(gin.TestMode) goleak.VerifyTestMain(m) + runtime.GC() }