diff --git a/sdk/azcore/headers.go b/sdk/azcore/headers.go index 0daa2908ee04..2b8ea562e2d5 100644 --- a/sdk/azcore/headers.go +++ b/sdk/azcore/headers.go @@ -22,6 +22,7 @@ const ( HeaderIfUnmodifiedSince = "If-Unmodified-Since" HeaderMetadata = "Metadata" HeaderRange = "Range" + HeaderRetryAfter = "Retry-After" HeaderURLEncoded = "application/x-www-form-urlencoded" HeaderUserAgent = "User-Agent" HeaderXmsDate = "x-ms-date" diff --git a/sdk/azcore/policy_logging_test.go b/sdk/azcore/policy_logging_test.go index 046a6a448262..7afc73d3f59b 100644 --- a/sdk/azcore/policy_logging_test.go +++ b/sdk/azcore/policy_logging_test.go @@ -25,8 +25,10 @@ func TestPolicyLoggingSuccess(t *testing.T) { srv.SetResponse() pl := NewPipeline(srv, NewRequestLogPolicy(RequestLogOptions{})) req := NewRequest(http.MethodGet, srv.URL()) - req.SetQueryParam("one", "fish") - req.SetQueryParam("sig", "redact") + qp := req.URL.Query() + qp.Set("one", "fish") + qp.Set("sig", "redact") + req.URL.RawQuery = qp.Encode() resp, err := pl.Do(context.Background(), req) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/sdk/azcore/policy_retry.go b/sdk/azcore/policy_retry.go index 7d48f98a42b4..eb33517563d3 100644 --- a/sdk/azcore/policy_retry.go +++ b/sdk/azcore/policy_retry.go @@ -15,30 +15,26 @@ import ( ) const ( - defaultMaxTries = 4 + defaultMaxRetries = 3 ) // RetryOptions configures the retry policy's behavior. type RetryOptions struct { - // MaxTries specifies the maximum number of attempts an operation will be tried before producing an error (0=default). - // A value of zero means that you accept our default policy. A value of 1 means 1 try and no retries. - MaxTries int32 + // MaxRetries specifies the maximum number of attempts a failed operation will be retried + // before producing an error. A value of zero means one try and no retries. + MaxRetries int32 // TryTimeout indicates the maximum time allowed for any single try of an HTTP request. - // A value of zero means that you accept our default timeout. NOTE: When transferring large amounts - // of data, the default TryTimeout will probably not be sufficient. You should override this value - // based on the bandwidth available to the host machine and proximity to the service. A good - // starting point may be something like (60 seconds per MB of anticipated-payload-size). TryTimeout time.Duration - // RetryDelay specifies the amount of delay to use before retrying an operation (0=default). + // RetryDelay specifies the amount of delay to use before retrying an operation. // The delay increases exponentially with each retry up to a maximum specified by MaxRetryDelay. // If you specify 0, then you must also specify 0 for MaxRetryDelay. // If you specify RetryDelay, then you must also specify MaxRetryDelay, and MaxRetryDelay should be // equal to or greater than RetryDelay. RetryDelay time.Duration - // MaxRetryDelay specifies the maximum delay allowed before retrying an operation (0=default). + // MaxRetryDelay specifies the maximum delay allowed before retrying an operation. // If you specify 0, then you must also specify 0 for RetryDelay. MaxRetryDelay time.Duration @@ -49,9 +45,9 @@ type RetryOptions struct { var ( // StatusCodesForRetry is the default set of HTTP status code for which the policy will retry. - StatusCodesForRetry = [6]int{ + // Changing its value will affect future created clients that use the default values. + StatusCodesForRetry = []int{ http.StatusRequestTimeout, // 408 - http.StatusTooManyRequests, // 429 http.StatusInternalServerError, // 500 http.StatusBadGateway, // 502 http.StatusServiceUnavailable, // 503 @@ -62,14 +58,23 @@ var ( // DefaultRetryOptions returns an instance of RetryOptions initialized with default values. func DefaultRetryOptions() RetryOptions { return RetryOptions{ - StatusCodes: StatusCodesForRetry[:], - MaxTries: defaultMaxTries, + StatusCodes: StatusCodesForRetry, + MaxRetries: defaultMaxRetries, TryTimeout: 1 * time.Minute, RetryDelay: 4 * time.Second, MaxRetryDelay: 120 * time.Second, } } +// used as a context key for adding/retrieving RetryOptions +type ctxWithRetryOptionsKey struct{} + +// WithRetryOptions adds the specified RetryOptions to the parent context. +// Use this to specify custom RetryOptions at the API-call level. +func WithRetryOptions(parent context.Context, options RetryOptions) context.Context { + return context.WithValue(parent, ctxWithRetryOptionsKey{}, options) +} + func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never 0 pow := func(number int64, exponent int32) int64 { // pow is nested helper function var result int64 = 1 @@ -105,6 +110,11 @@ type retryPolicy struct { } func (p *retryPolicy) Do(ctx context.Context, req *Request) (resp *Response, err error) { + options := p.options + // check if the retry options have been overridden for this call + if override := ctx.Value(ctxWithRetryOptionsKey{}); override != nil { + options = override.(RetryOptions) + } // Exponential retry algorithm: ((2 ^ attempt) - 1) * delay * random(0.8, 1.2) // When to retry: connection failure or temporary/timeout. if req.Body != nil { @@ -134,14 +144,14 @@ func (p *retryPolicy) Do(ctx context.Context, req *Request) (resp *Response, err } // Set the time for this particular retry operation and then Do the operation. - tryCtx, tryCancel := context.WithTimeout(ctx, p.options.TryTimeout) + tryCtx, tryCancel := context.WithTimeout(ctx, options.TryTimeout) resp, err = req.Next(tryCtx) // Make the request tryCancel() if shouldLog { Log().Write(LogRetryPolicy, fmt.Sprintf("Err=%v, response=%v\n", err, resp)) } - if err == nil && !resp.HasStatusCode(p.options.StatusCodes...) { + if err == nil && !resp.HasStatusCode(options.StatusCodes...) { // if there is no error and the response code isn't in the list of retry codes then we're done. return } else if ctx.Err() != nil { @@ -155,7 +165,7 @@ func (p *retryPolicy) Do(ctx context.Context, req *Request) (resp *Response, err // drain before retrying so nothing is leaked resp.Drain() - if try == p.options.MaxTries { + if try == options.MaxRetries+1 { // max number of tries has been reached, don't sleep again return } @@ -163,7 +173,7 @@ func (p *retryPolicy) Do(ctx context.Context, req *Request) (resp *Response, err // use the delay from retry-after if available delay, ok := resp.RetryAfter() if !ok { - delay = p.options.calcDelay(try) + delay = options.calcDelay(try) } if shouldLog { Log().Write(LogRetryPolicy, fmt.Sprintf("Try=%d, Delay=%v\n", try, delay)) diff --git a/sdk/azcore/policy_retry_test.go b/sdk/azcore/policy_retry_test.go index 6a3dc79197cf..d290e48586fe 100644 --- a/sdk/azcore/policy_retry_test.go +++ b/sdk/azcore/policy_retry_test.go @@ -61,10 +61,10 @@ func TestRetryPolicyFailOnStatusCode(t *testing.T) { if resp.StatusCode != http.StatusInternalServerError { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != defaultMaxTries { - t.Fatalf("wrong retry count, got %d expected %d", r, defaultMaxTries) + if r := srv.Requests(); r != defaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) } - if body.rcount != defaultMaxTries-1 { + if body.rcount != defaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -116,10 +116,10 @@ func TestRetryPolicyFailOnError(t *testing.T) { if resp != nil { t.Fatal("unexpected response") } - if r := srv.Requests(); r != defaultMaxTries { - t.Fatalf("wrong retry count, got %d expected %d", r, defaultMaxTries) + if r := srv.Requests(); r != defaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) } - if body.rcount != defaultMaxTries-1 { + if body.rcount != defaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -145,10 +145,10 @@ func TestRetryPolicySuccessWithRetryComplex(t *testing.T) { if resp.StatusCode != http.StatusAccepted { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != defaultMaxTries { - t.Fatalf("wrong retry count, got %d expected %d", r, 3) + if r := srv.Requests(); r != defaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) } - if body.rcount != defaultMaxTries-1 { + if body.rcount != defaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -212,6 +212,35 @@ func TestRetryPolicyIsNotRetriable(t *testing.T) { } } +func TestWithRetryOptions(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.RepeatResponse(9, mock.WithStatusCode(http.StatusRequestTimeout)) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) + defaultOptions := testRetryOptions() + pl := NewPipeline(srv, NewRetryPolicy(defaultOptions)) + customOptions := *defaultOptions + customOptions.MaxRetries = 10 + customOptions.MaxRetryDelay = 200 * time.Millisecond + retryCtx := WithRetryOptions(context.Background(), customOptions) + req := NewRequest(http.MethodGet, srv.URL()) + body := newRewindTrackingBody("stuff") + req.SetBody(body) + resp, err := pl.Do(retryCtx, req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + if body.rcount != int(customOptions.MaxRetries-1) { + t.Fatalf("unexpected rewind count: %d", body.rcount) + } + if !body.closed { + t.Fatal("request body wasn't closed") + } +} + // TODO: add test for retry failing to read response body // TODO: add test for per-retry timeout failed but e2e succeeded diff --git a/sdk/azcore/request.go b/sdk/azcore/request.go index 7d87b20baa07..1ce0ff875c42 100644 --- a/sdk/azcore/request.go +++ b/sdk/azcore/request.go @@ -27,7 +27,6 @@ const ( type Request struct { *http.Request policies []Policy - qp url.Values values opValues } @@ -79,11 +78,6 @@ func (req *Request) Next(ctx context.Context) (*Response, error) { nextPolicy := req.policies[0] nextReq := *req nextReq.policies = nextReq.policies[1:] - // encode any pending query params - if nextReq.qp != nil { - nextReq.Request.URL.RawQuery = nextReq.qp.Encode() - nextReq.qp = nil - } return nextPolicy.Do(ctx, &nextReq) } @@ -125,14 +119,6 @@ func (req *Request) OperationValue(value interface{}) bool { return req.values.get(value) } -// SetQueryParam sets the key to value. -func (req *Request) SetQueryParam(key, value string) { - if req.qp == nil { - req.qp = req.Request.URL.Query() - } - req.qp.Set(key, value) -} - // SetBody sets the specified ReadSeekCloser as the HTTP request body. func (req *Request) SetBody(body ReadSeekCloser) error { // Set the body and content length. diff --git a/sdk/azcore/response.go b/sdk/azcore/response.go index edb19ad244e3..75b649c5b364 100644 --- a/sdk/azcore/response.go +++ b/sdk/azcore/response.go @@ -96,13 +96,21 @@ func (r *Response) removeBOM() { } } -// RetryAfter returns (non-zero, true) if the response contains a Retry-After header value +// RetryAfter returns (non-zero, true) if the response contains a Retry-After header value. func (r *Response) RetryAfter() (time.Duration, bool) { if r == nil { return 0, false } - if retryAfter, _ := strconv.Atoi(r.Header.Get("Retry-After")); retryAfter > 0 { + ra := r.Header.Get(HeaderRetryAfter) + if ra == "" { + return 0, false + } + // retry-after values are expressed in either number of + // seconds or an HTTP-date indicating when to try again + if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 { return time.Duration(retryAfter) * time.Second, true + } else if t, err := time.Parse(time.RFC1123, ra); err == nil { + return t.Sub(time.Now()), true } return 0, false } diff --git a/sdk/azcore/response_test.go b/sdk/azcore/response_test.go index 263f0c3db674..660a8bc549ee 100644 --- a/sdk/azcore/response_test.go +++ b/sdk/azcore/response_test.go @@ -9,6 +9,7 @@ import ( "context" "net/http" "testing" + "time" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -103,3 +104,31 @@ func TestResponseUnmarshalXMLNoBody(t *testing.T) { t.Fatalf("unexpected error unmarshalling: %v", err) } } + +func TestRetryAfter(t *testing.T) { + raw := &http.Response{ + Header: http.Header{}, + } + resp := Response{raw} + if d, ok := resp.RetryAfter(); ok { + t.Fatalf("unexpected retry-after value %d", d) + } + raw.Header.Set(HeaderRetryAfter, "300") + d, ok := resp.RetryAfter() + if !ok { + t.Fatal("expected retry-after value from seconds") + } + if d != 300*time.Second { + t.Fatalf("expected 300 seconds, got %d", d/time.Second) + } + atDate := time.Now().Add(600 * time.Second) + raw.Header.Set(HeaderRetryAfter, atDate.Format(time.RFC1123)) + d, ok = resp.RetryAfter() + if !ok { + t.Fatal("expected retry-after value from date") + } + // d will not be exactly 600 seconds but it will be close + if d/time.Second != 599 { + t.Fatalf("expected ~600 seconds, got %d", d/time.Second) + } +}