diff --git a/sdk/azcore/core.go b/sdk/azcore/core.go index 4e5ed8405f25..e02bd3e1a81b 100644 --- a/sdk/azcore/core.go +++ b/sdk/azcore/core.go @@ -27,15 +27,15 @@ type Policy interface { // Do applies the policy to the specified Request. When implementing a Policy, mutate the // request before calling req.Next() to move on to the next policy, and respond to the result // before returning to the caller. - Do(req *Request) (*Response, error) + Do(req *Request) (*http.Response, error) } // policyFunc is a type that implements the Policy interface. // Use this type when implementing a stateless policy as a first-class function. -type policyFunc func(*Request) (*Response, error) +type policyFunc func(*Request) (*http.Response, error) // Do implements the Policy interface on PolicyFunc. -func (pf policyFunc) Do(req *Request) (*Response, error) { +func (pf policyFunc) Do(req *Request) (*http.Response, error) { return pf(req) } @@ -50,7 +50,7 @@ type transportPolicy struct { trans Transporter } -func (tp transportPolicy) Do(req *Request) (*Response, error) { +func (tp transportPolicy) Do(req *Request) (*http.Response, error) { resp, err := tp.trans.Do(req.Request) if err != nil { return nil, err @@ -59,7 +59,7 @@ func (tp transportPolicy) Do(req *Request) (*Response, error) { // this ensures the retry policy will retry the request return nil, errors.New("received nil response") } - return &Response{Response: resp}, nil + return resp, nil } // Pipeline represents a primitive for sending HTTP requests and receiving responses. @@ -84,7 +84,7 @@ func NewPipeline(transport Transporter, policies ...Policy) Pipeline { // Do is called for each and every HTTP request. It passes the request through all // the Policy objects (which can transform the Request's URL/query parameters/headers) // and ultimately sends the transformed HTTP request over the network. -func (p Pipeline) Do(req *Request) (*Response, error) { +func (p Pipeline) Do(req *Request) (*http.Response, error) { if err := req.valid(); err != nil { return nil, err } diff --git a/sdk/azcore/policy_anonymous_credential.go b/sdk/azcore/policy_anonymous_credential.go index 3f5bc5ceecd4..c767a6370af9 100644 --- a/sdk/azcore/policy_anonymous_credential.go +++ b/sdk/azcore/policy_anonymous_credential.go @@ -5,11 +5,13 @@ package azcore +import "net/http" + func anonCredAuthPolicyFunc(AuthenticationOptions) Policy { return policyFunc(anonCredPolicyFunc) } -func anonCredPolicyFunc(req *Request) (*Response, error) { +func anonCredPolicyFunc(req *Request) (*http.Response, error) { return req.Next() } diff --git a/sdk/azcore/policy_body_download.go b/sdk/azcore/policy_body_download.go index 7087242dc4b4..2c9527235bc3 100644 --- a/sdk/azcore/policy_body_download.go +++ b/sdk/azcore/policy_body_download.go @@ -15,7 +15,7 @@ import ( ) // bodyDownloadPolicy creates a policy object that downloads the response's body to a []byte. -func bodyDownloadPolicy(req *Request) (*Response, error) { +func bodyDownloadPolicy(req *Request) (*http.Response, error) { resp, err := req.Next() if err != nil { return resp, err diff --git a/sdk/azcore/policy_body_download_test.go b/sdk/azcore/policy_body_download_test.go index 5d2a00876931..dff5e5a6379a 100644 --- a/sdk/azcore/policy_body_download_test.go +++ b/sdk/azcore/policy_body_download_test.go @@ -29,7 +29,7 @@ func TestDownloadBody(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - payload, err := resp.Payload() + payload, err := Payload(resp) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -57,7 +57,7 @@ func TestSkipBodyDownload(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - payload, err := resp.Payload() + payload, err := Payload(resp) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -80,7 +80,7 @@ func TestDownloadBodyFail(t *testing.T) { if err == nil { t.Fatal("unexpected nil error") } - payload, err := resp.Payload() + payload, err := Payload(resp) if err == nil { t.Fatalf("expected an error") } @@ -106,7 +106,7 @@ func TestDownloadBodyWithRetryGet(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - payload, err := resp.Payload() + payload, err := Payload(resp) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -138,7 +138,7 @@ func TestDownloadBodyWithRetryDelete(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - payload, err := resp.Payload() + payload, err := Payload(resp) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -170,7 +170,7 @@ func TestDownloadBodyWithRetryPut(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - payload, err := resp.Payload() + payload, err := Payload(resp) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -205,7 +205,7 @@ func TestDownloadBodyWithRetryPatch(t *testing.T) { if _, ok := err.(*bodyDownloadError); !ok { t.Fatal("expected *bodyDownloadError type") } - payload, err := resp.Payload() + payload, err := Payload(resp) if err == nil { t.Fatalf("expected an error") } @@ -235,7 +235,7 @@ func TestDownloadBodyWithRetryPost(t *testing.T) { if err == nil { t.Fatal("unexpected nil error") } - payload, err := resp.Payload() + payload, err := Payload(resp) if err == nil { t.Fatalf("expected an error") } @@ -264,7 +264,7 @@ func TestSkipBodyDownloadWith400(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - payload, err := resp.Payload() + payload, err := Payload(resp) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -292,7 +292,7 @@ func TestReadBodyAfterSeek(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - payload, err := resp.Payload() + payload, err := Payload(resp) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/sdk/azcore/policy_http_header.go b/sdk/azcore/policy_http_header.go index 5f5838af5b66..dbdde42bd950 100644 --- a/sdk/azcore/policy_http_header.go +++ b/sdk/azcore/policy_http_header.go @@ -14,7 +14,7 @@ import ( type ctxWithHTTPHeader struct{} // newHTTPHeaderPolicy creates a policy object that adds custom HTTP headers to a request -func httpHeaderPolicy(req *Request) (*Response, error) { +func httpHeaderPolicy(req *Request) (*http.Response, error) { // check if any custom HTTP headers have been specified if header := req.Context().Value(ctxWithHTTPHeader{}); header != nil { for k, v := range header.(http.Header) { diff --git a/sdk/azcore/policy_logging.go b/sdk/azcore/policy_logging.go index 66e2fb8313b3..35291b63c906 100644 --- a/sdk/azcore/policy_logging.go +++ b/sdk/azcore/policy_logging.go @@ -8,6 +8,7 @@ package azcore import ( "bytes" "fmt" + "net/http" "strings" "time" @@ -42,7 +43,7 @@ type logPolicyOpValues struct { start time.Time } -func (p *logPolicy) Do(req *Request) (*Response, error) { +func (p *logPolicy) Do(req *Request) (*http.Response, error) { // Get the per-operation values. These are saved in the Message's map so that they persist across each retry calling into this policy object. var opValues logPolicyOpValues if req.OperationValue(&opValues); opValues.start.IsZero() { @@ -88,7 +89,7 @@ func (p *logPolicy) Do(req *Request) (*Response, error) { // skip frames runtime.Callers() and runtime.StackTrace() b.WriteString(diag.StackTrace(2, StackFrameCount)) } else if p.options.IncludeBody { - err = response.writeBody(b) + err = writeBody(response, b) } log.Write(log.Response, b.String()) } diff --git a/sdk/azcore/policy_retry.go b/sdk/azcore/policy_retry.go index 79a371b8c61b..f7ed396c7bc0 100644 --- a/sdk/azcore/policy_retry.go +++ b/sdk/azcore/policy_retry.go @@ -124,7 +124,7 @@ type retryPolicy struct { options RetryOptions } -func (p *retryPolicy) Do(req *Request) (resp *Response, err error) { +func (p *retryPolicy) Do(req *Request) (resp *http.Response, err error) { options := p.options // check if the retry options have been overridden for this call if override := req.Context().Value(ctxWithRetryOptionsKey{}); override != nil { @@ -170,7 +170,7 @@ func (p *retryPolicy) Do(req *Request) (resp *Response, err error) { log.Writef(log.RetryPolicy, "error %v", err) } - if err == nil && !resp.HasStatusCode(options.StatusCodes...) { + if err == nil && !HasStatusCode(resp, options.StatusCodes...) { // if there is no error and the response code isn't in the list of retry codes then we're done. return } else if ctxErr := req.Context().Err(); ctxErr != nil { @@ -195,10 +195,10 @@ func (p *retryPolicy) Do(req *Request) (resp *Response, err error) { } // drain before retrying so nothing is leaked - resp.Drain() + Drain(resp) // use the delay from retry-after if available - delay := resp.retryAfter() + delay := RetryAfter(resp) if delay <= 0 { delay = options.calcDelay(try) } diff --git a/sdk/azcore/policy_retry_test.go b/sdk/azcore/policy_retry_test.go index d8de29aa37a5..871bc1cdc5a0 100644 --- a/sdk/azcore/policy_retry_test.go +++ b/sdk/azcore/policy_retry_test.go @@ -91,7 +91,7 @@ func TestRetryPolicyFailOnStatusCodeRespBodyPreserved(t *testing.T) { srv.SetResponse(mock.WithStatusCode(http.StatusInternalServerError), mock.WithBody([]byte(respBody))) // add a per-request policy that reads and restores the request body. // this is to simulate how something like httputil.DumpRequest works. - pl := NewPipeline(srv, policyFunc(func(r *Request) (*Response, error) { + pl := NewPipeline(srv, policyFunc(func(r *Request) (*http.Response, error) { b, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatal(err) diff --git a/sdk/azcore/policy_telemetry.go b/sdk/azcore/policy_telemetry.go index 188ba5c3b8ea..b013785d3c2b 100644 --- a/sdk/azcore/policy_telemetry.go +++ b/sdk/azcore/policy_telemetry.go @@ -8,6 +8,7 @@ package azcore import ( "bytes" "fmt" + "net/http" "os" "runtime" "strings" @@ -64,7 +65,7 @@ func NewTelemetryPolicy(o *TelemetryOptions) Policy { return &tp } -func (p telemetryPolicy) Do(req *Request) (*Response, error) { +func (p telemetryPolicy) Do(req *Request) (*http.Response, error) { if p.telemetryValue == "" { return req.Next() } diff --git a/sdk/azcore/poller.go b/sdk/azcore/poller.go index 185488e03a36..db8d18219c82 100644 --- a/sdk/azcore/poller.go +++ b/sdk/azcore/poller.go @@ -23,7 +23,7 @@ import ( // NewPoller creates a Poller based on the provided initial response. // pollerID - a unique identifier for an LRO, it's usually the client.Method string. // NOTE: this is only meant for internal use in generated code. -func NewPoller(pollerID string, resp *Response, pl Pipeline, eu func(*Response) error) (*Poller, error) { +func NewPoller(pollerID string, resp *http.Response, pl Pipeline, eu func(*http.Response) error) (*Poller, error) { // this is a back-stop in case the swagger is incorrect (i.e. missing one or more status codes for success). // ideally the codegen should return an error if the initial response failed and not even create a poller. if !lroStatusCodeValid(resp) { @@ -54,7 +54,7 @@ func NewPoller(pollerID string, resp *Response, pl Pipeline, eu func(*Response) // NewPollerFromResumeToken creates a Poller from a resume token string. // pollerID - a unique identifier for an LRO, it's usually the client.Method string. // NOTE: this is only meant for internal use in generated code. -func NewPollerFromResumeToken(pollerID string, token string, pl Pipeline, eu func(*Response) error) (*Poller, error) { +func NewPollerFromResumeToken(pollerID string, token string, pl Pipeline, eu func(*http.Response) error) (*Poller, error) { // unmarshal into JSON object to determine the poller type obj := map[string]interface{}{} err := json.Unmarshal([]byte(token), &obj) @@ -99,8 +99,8 @@ func NewPollerFromResumeToken(pollerID string, token string, pl Pipeline, eu fun type Poller struct { lro lroPoller pl Pipeline - eu func(*Response) error - resp *Response + eu func(*http.Response) error + resp *http.Response err error } @@ -117,7 +117,7 @@ func (l *Poller) Poll(ctx context.Context) (*http.Response, error) { if l.Done() { // the LRO has reached a terminal state, don't poll again if l.resp != nil { - return l.resp.Response, nil + return l.resp, nil } return nil, l.err } @@ -140,7 +140,7 @@ func (l *Poller) Poll(ctx context.Context) (*http.Response, error) { return nil, err } l.resp = resp - return l.resp.Response, nil + return l.resp, nil } // ResumeToken returns a token string that can be used to resume a poller that has not yet reached a terminal state. @@ -163,7 +163,7 @@ func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http } // if there's nothing to unmarshall into just return the final response if respType == nil { - return l.resp.Response, nil + return l.resp, nil } u, err := l.lro.FinalGetURL(l.resp) if err != nil { @@ -191,7 +191,7 @@ func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http if err = json.Unmarshal(body, respType); err != nil { return nil, err } - return l.resp.Response, nil + return l.resp, nil } // PollUntilDone will handle the entire span of the polling operation until a terminal state is reached, @@ -204,7 +204,7 @@ func (l *Poller) PollUntilDone(ctx context.Context, freq time.Duration, respType log.Writef(log.LongRunningOperation, "BEGIN PollUntilDone() for %T", l.lro) if l.resp != nil { // initial check for a retry-after header existing on the initial response - if retryAfter := RetryAfter(l.resp.Response); retryAfter > 0 { + if retryAfter := RetryAfter(l.resp); retryAfter > 0 { log.Writef(log.LongRunningOperation, "initial Retry-After delay for %s", retryAfter.String()) if err := delay(ctx, retryAfter); err != nil { logPollUntilDoneExit(err) @@ -222,7 +222,7 @@ func (l *Poller) PollUntilDone(ctx context.Context, freq time.Duration, respType if l.Done() { logPollUntilDoneExit(l.lro.Status()) if !l.lro.Succeeded() { - return nil, l.eu(&Response{resp}) + return nil, l.eu(resp) } return l.FinalResponse(ctx, respType) } @@ -243,8 +243,8 @@ func (l *Poller) PollUntilDone(ctx context.Context, freq time.Duration, respType // abstracts the differences between concrete poller types type lroPoller interface { Done() bool - Update(resp *Response) error - FinalGetURL(resp *Response) (string, error) + Update(resp *http.Response) error + FinalGetURL(resp *http.Response) (string, error) URL() string Status() string Succeeded() bool @@ -262,7 +262,7 @@ type opPoller struct { status string } -func newOpPoller(pollerType, pollingURL, locationURL string, initialResponse *Response) *opPoller { +func newOpPoller(pollerType, pollingURL, locationURL string, initialResponse *http.Response) *opPoller { return &opPoller{ Type: fmt.Sprintf("%s;opPoller", pollerType), ReqMethod: initialResponse.Request.Method, @@ -286,7 +286,7 @@ func (p *opPoller) Succeeded() bool { return strings.EqualFold(p.status, "succeeded") } -func (p *opPoller) Update(resp *Response) error { +func (p *opPoller) Update(resp *http.Response) error { status, err := extractJSONValue(resp, "status") if err != nil { return err @@ -302,7 +302,7 @@ func (p *opPoller) Update(resp *Response) error { return nil } -func (p *opPoller) FinalGetURL(resp *Response) (string, error) { +func (p *opPoller) FinalGetURL(resp *http.Response) (string, error) { if !p.Done() { return "", errors.New("cannot return a final response from a poller in a non-terminal state") } @@ -358,7 +358,7 @@ func (p *locPoller) Succeeded() bool { return p.status >= 200 && p.status < 300 } -func (p *locPoller) Update(resp *Response) error { +func (p *locPoller) Update(resp *http.Response) error { // if the endpoint returned a location header, update cached value if loc := resp.Header.Get(headerLocation); loc != "" { p.PollURL = loc @@ -367,7 +367,7 @@ func (p *locPoller) Update(resp *Response) error { return nil } -func (*locPoller) FinalGetURL(*Response) (string, error) { +func (*locPoller) FinalGetURL(*http.Response) (string, error) { return "", nil } @@ -392,11 +392,11 @@ func (*nopPoller) Succeeded() bool { return true } -func (*nopPoller) Update(*Response) error { +func (*nopPoller) Update(*http.Response) error { return nil } -func (*nopPoller) FinalGetURL(*Response) (string, error) { +func (*nopPoller) FinalGetURL(*http.Response) (string, error) { return "", nil } @@ -405,12 +405,12 @@ func (*nopPoller) Status() string { } // returns true if the LRO response contains a valid HTTP status code -func lroStatusCodeValid(resp *Response) bool { - return resp.HasStatusCode(http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) +func lroStatusCodeValid(resp *http.Response) bool { + return HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) } // extracs a JSON value from the provided reader -func extractJSONValue(resp *Response, val string) (string, error) { +func extractJSONValue(resp *http.Response, val string) (string, error) { if resp.ContentLength == 0 { return "", errors.New("the response does not contain a body") } diff --git a/sdk/azcore/poller_test.go b/sdk/azcore/poller_test.go index cf35675bd3e2..c2e46b9e3877 100644 --- a/sdk/azcore/poller_test.go +++ b/sdk/azcore/poller_test.go @@ -25,9 +25,9 @@ func (p pollerError) Error() string { return p.Err } -func errUnmarshall(resp *Response) error { +func errUnmarshall(resp *http.Response) error { var pe pollerError - if err := resp.UnmarshalAsJSON(&pe); err != nil { + if err := UnmarshalAsJSON(resp, &pe); err != nil { panic(err) } return pe @@ -38,10 +38,8 @@ type widget struct { } func TestNewPollerFail(t *testing.T) { - p, err := NewPoller("fake.poller", &Response{ - &http.Response{ - StatusCode: http.StatusBadRequest, - }, + p, err := NewPoller("fake.poller", &http.Response{ + StatusCode: http.StatusBadRequest, }, NewPipeline(nil), errUnmarshall) if err == nil { t.Fatal("unexpected nil error") @@ -87,17 +85,15 @@ func TestOpPollerSimple(t *testing.T) { if err != nil { t.Fatal(err) } - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPut, - URL: reqURL, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPut, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -127,17 +123,15 @@ func TestOpPollerWithWidgetPUT(t *testing.T) { if err != nil { t.Fatal(err) } - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPut, - URL: reqURL, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPut, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -171,18 +165,16 @@ func TestOpPollerWithWidgetPOSTLocation(t *testing.T) { if err != nil { t.Fatal(err) } - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPost, - URL: reqURL, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Operation-Location": []string{srv.URL()}, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPost, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -215,17 +207,15 @@ func TestOpPollerWithWidgetPOST(t *testing.T) { if err != nil { t.Fatal(err) } - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPost, - URL: reqURL, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPost, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -260,18 +250,16 @@ func TestOpPollerWithWidgetResourceLocation(t *testing.T) { if err != nil { t.Fatal(err) } - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPatch, - URL: reqURL, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Operation-Location": []string{srv.URL()}, + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPatch, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -303,17 +291,15 @@ func TestOpPollerWithResumeToken(t *testing.T) { if err != nil { t.Fatal(err) } - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Operation-Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, - Request: &http.Request{ - Method: http.MethodPut, - URL: reqURL, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Operation-Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, + }, + Request: &http.Request{ + Method: http.MethodPut, + URL: reqURL, }, } pl := NewPipeline(srv) @@ -362,13 +348,11 @@ func TestLocPollerSimple(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -392,13 +376,11 @@ func TestLocPollerWithWidget(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"size": 3}`))) - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -426,13 +408,11 @@ func TestLocPollerCancelled(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(`{"error": "cancelled"}`))) - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -463,13 +443,11 @@ func TestLocPollerWithError(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) srv.AppendError(errors.New("oops")) - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -500,13 +478,11 @@ func TestLocPollerWithResumeToken(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) defer close() - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Location": []string{srv.URL()}, - "Retry-After": []string{"1"}, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Location": []string{srv.URL()}, + "Retry-After": []string{"1"}, }, } pl := NewPipeline(srv) @@ -555,12 +531,10 @@ func TestLocPollerWithTimeout(t *testing.T) { srv.AppendResponse(mock.WithSlowResponse(2 * time.Second)) defer close() - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusAccepted, - Header: http.Header{ - "Location": []string{srv.URL()}, - }, + firstResp := &http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Location": []string{srv.URL()}, }, } pl := NewPipeline(srv) @@ -580,10 +554,8 @@ func TestLocPollerWithTimeout(t *testing.T) { } func TestNopPoller(t *testing.T) { - firstResp := &Response{ - &http.Response{ - StatusCode: http.StatusOK, - }, + firstResp := &http.Response{ + StatusCode: http.StatusOK, } pl := NewPipeline(nil) lro, err := NewPoller("fake.poller", firstResp, pl, errUnmarshall) @@ -600,21 +572,21 @@ func TestNopPoller(t *testing.T) { if err != nil { t.Fatal(err) } - if resp != firstResp.Response { + if resp != firstResp { t.Fatal("unexpected response") } resp, err = lro.Poll(context.Background()) if err != nil { t.Fatal(err) } - if resp != firstResp.Response { + if resp != firstResp { t.Fatal("unexpected response") } resp, err = lro.PollUntilDone(context.Background(), 5*time.Millisecond, nil) if err != nil { t.Fatal(err) } - if resp != firstResp.Response { + if resp != firstResp { t.Fatal("unexpected response") } tk, err := lro.ResumeToken() diff --git a/sdk/azcore/request.go b/sdk/azcore/request.go index ed50a6b98a98..9f28b6a69d93 100644 --- a/sdk/azcore/request.go +++ b/sdk/azcore/request.go @@ -114,7 +114,7 @@ func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Reque // If there are no more policies, nil and ErrNoMorePolicies are returned. // This method is intended to be called from pipeline policies. // To send a request through a pipeline call Pipeline.Do(). -func (req *Request) Next() (*Response, error) { +func (req *Request) Next() (*http.Response, error) { if len(req.policies) == 0 { return nil, errors.New("no more policies") } diff --git a/sdk/azcore/response.go b/sdk/azcore/response.go index 28348d9bfeeb..c765a61afe3b 100644 --- a/sdk/azcore/response.go +++ b/sdk/azcore/response.go @@ -20,34 +20,29 @@ import ( "time" ) -// Response represents the response from an HTTP request. -type Response struct { - *http.Response -} - // Payload reads and returns the response body or an error. // On a successful read, the response body is cached. -func (r *Response) Payload() ([]byte, error) { +func Payload(resp *http.Response) ([]byte, error) { // r.Body won't be a nopClosingBytesReader if downloading was skipped - if buf, ok := r.Body.(*nopClosingBytesReader); ok { + if buf, ok := resp.Body.(*nopClosingBytesReader); ok { return buf.Bytes(), nil } - bytesBody, err := ioutil.ReadAll(r.Body) - r.Body.Close() + bytesBody, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() if err != nil { return nil, err } - r.Body = &nopClosingBytesReader{s: bytesBody, i: 0} + resp.Body = &nopClosingBytesReader{s: bytesBody, i: 0} return bytesBody, nil } // HasStatusCode returns true if the Response's status code is one of the specified values. -func (r *Response) HasStatusCode(statusCodes ...int) bool { - if r == nil { +func HasStatusCode(resp *http.Response, statusCodes ...int) bool { + if resp == nil { return false } for _, sc := range statusCodes { - if r.StatusCode == sc { + if resp.StatusCode == sc { return true } } @@ -55,8 +50,8 @@ func (r *Response) HasStatusCode(statusCodes ...int) bool { } // UnmarshalAsByteArray will base-64 decode the received payload and place the result into the value pointed to by v. -func (r *Response) UnmarshalAsByteArray(v *[]byte, format Base64Encoding) error { - p, err := r.Payload() +func UnmarshalAsByteArray(resp *http.Response, v *[]byte, format Base64Encoding) error { + p, err := Payload(resp) if err != nil { return err } @@ -64,8 +59,8 @@ func (r *Response) UnmarshalAsByteArray(v *[]byte, format Base64Encoding) error } // UnmarshalAsJSON calls json.Unmarshal() to unmarshal the received payload into the value pointed to by v. -func (r *Response) UnmarshalAsJSON(v interface{}) error { - payload, err := r.Payload() +func UnmarshalAsJSON(resp *http.Response, v interface{}) error { + payload, err := Payload(resp) if err != nil { return err } @@ -73,7 +68,7 @@ func (r *Response) UnmarshalAsJSON(v interface{}) error { if len(payload) == 0 { return nil } - err = r.removeBOM() + err = removeBOM(resp) if err != nil { return err } @@ -85,8 +80,8 @@ func (r *Response) UnmarshalAsJSON(v interface{}) error { } // UnmarshalAsXML calls xml.Unmarshal() to unmarshal the received payload into the value pointed to by v. -func (r *Response) UnmarshalAsXML(v interface{}) error { - payload, err := r.Payload() +func UnmarshalAsXML(resp *http.Response, v interface{}) error { + payload, err := Payload(resp) if err != nil { return err } @@ -94,7 +89,7 @@ func (r *Response) UnmarshalAsXML(v interface{}) error { if len(payload) == 0 { return nil } - err = r.removeBOM() + err = removeBOM(resp) if err != nil { return err } @@ -106,45 +101,37 @@ func (r *Response) UnmarshalAsXML(v interface{}) error { } // Drain reads the response body to completion then closes it. The bytes read are discarded. -func (r *Response) Drain() { - if r != nil && r.Body != nil { - _, _ = io.Copy(ioutil.Discard, r.Body) - r.Body.Close() +func Drain(resp *http.Response) { + if resp != nil && resp.Body != nil { + _, _ = io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() } } // removeBOM removes any byte-order mark prefix from the payload if present. -func (r *Response) removeBOM() error { - payload, err := r.Payload() +func removeBOM(resp *http.Response) error { + payload, err := Payload(resp) if err != nil { return err } // UTF8 trimmed := bytes.TrimPrefix(payload, []byte("\xef\xbb\xbf")) if len(trimmed) < len(payload) { - r.Body.(*nopClosingBytesReader).Set(trimmed) + resp.Body.(*nopClosingBytesReader).Set(trimmed) } return nil } -// helper to reduce nil Response checks -func (r *Response) retryAfter() time.Duration { - if r == nil { - return 0 - } - return RetryAfter(r.Response) -} - // writes to a buffer, used for logging purposes -func (r *Response) writeBody(b *bytes.Buffer) error { - ct := r.Header.Get(headerContentType) +func writeBody(resp *http.Response, b *bytes.Buffer) error { + ct := resp.Header.Get(headerContentType) if ct == "" { fmt.Fprint(b, " Response contained no body\n") return nil } else if !shouldLogBody(b, ct) { return nil } - body, err := r.Payload() + body, err := Payload(resp) if err != nil { fmt.Fprintf(b, " Failed to read response body: %s\n", err.Error()) return err @@ -209,14 +196,14 @@ func DecodeByteArray(s string, v *[]byte, format Base64Encoding) error { // writeRequestWithResponse appends a formatted HTTP request into a Buffer. If request and/or err are // not nil, then these are also written into the Buffer. -func writeRequestWithResponse(b *bytes.Buffer, request *Request, response *Response, err error) { +func writeRequestWithResponse(b *bytes.Buffer, request *Request, resp *http.Response, err error) { // Write the request into the buffer. fmt.Fprint(b, " "+request.Method+" "+request.URL.String()+"\n") writeHeader(b, request.Header) - if response != nil { + if resp != nil { fmt.Fprintln(b, " --------------------------------------------------------------------------------") - fmt.Fprint(b, " RESPONSE Status: "+response.Status+"\n") - writeHeader(b, response.Header) + fmt.Fprint(b, " RESPONSE Status: "+resp.Status+"\n") + writeHeader(b, resp.Header) } if err != nil { fmt.Fprintln(b, " --------------------------------------------------------------------------------") diff --git a/sdk/azcore/response_test.go b/sdk/azcore/response_test.go index a3b5bf419e0a..4f93f16b8063 100644 --- a/sdk/azcore/response_test.go +++ b/sdk/azcore/response_test.go @@ -28,11 +28,11 @@ func TestResponseUnmarshalXML(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !resp.HasStatusCode(http.StatusOK) { + if !HasStatusCode(resp, http.StatusOK) { t.Fatalf("unexpected status code: %d", resp.StatusCode) } var tx testXML - if err := resp.UnmarshalAsXML(&tx); err != nil { + if err := UnmarshalAsXML(resp, &tx); err != nil { t.Fatalf("unexpected error unmarshalling: %v", err) } if tx.SomeInt != 1 || tx.SomeString != "s" { @@ -53,7 +53,7 @@ func TestResponseFailureStatusCode(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if resp.HasStatusCode(http.StatusOK) { + if HasStatusCode(resp, http.StatusOK) { t.Fatalf("unexpected status code: %d", resp.StatusCode) } } @@ -71,11 +71,11 @@ func TestResponseUnmarshalJSON(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !resp.HasStatusCode(http.StatusOK) { + if !HasStatusCode(resp, http.StatusOK) { t.Fatalf("unexpected status code: %d", resp.StatusCode) } var tx testJSON - if err := resp.UnmarshalAsJSON(&tx); err != nil { + if err := UnmarshalAsJSON(resp, &tx); err != nil { t.Fatalf("unexpected error unmarshalling: %v", err) } if tx.SomeInt != 1 || tx.SomeString != "s" { @@ -97,11 +97,11 @@ func TestResponseUnmarshalJSONskipDownload(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !resp.HasStatusCode(http.StatusOK) { + if !HasStatusCode(resp, http.StatusOK) { t.Fatalf("unexpected status code: %d", resp.StatusCode) } var tx testJSON - if err := resp.UnmarshalAsJSON(&tx); err != nil { + if err := UnmarshalAsJSON(resp, &tx); err != nil { t.Fatalf("unexpected error unmarshalling: %v", err) } if tx.SomeInt != 1 || tx.SomeString != "s" { @@ -122,10 +122,10 @@ func TestResponseUnmarshalJSONNoBody(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !resp.HasStatusCode(http.StatusOK) { + if !HasStatusCode(resp, http.StatusOK) { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if err := resp.UnmarshalAsJSON(nil); err != nil { + if err := UnmarshalAsJSON(resp, nil); err != nil { t.Fatalf("unexpected error unmarshalling: %v", err) } } @@ -143,24 +143,23 @@ func TestResponseUnmarshalXMLNoBody(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !resp.HasStatusCode(http.StatusOK) { + if !HasStatusCode(resp, http.StatusOK) { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if err := resp.UnmarshalAsXML(nil); err != nil { + if err := UnmarshalAsXML(resp, nil); err != nil { t.Fatalf("unexpected error unmarshalling: %v", err) } } func TestRetryAfter(t *testing.T) { - raw := &http.Response{ + resp := &http.Response{ Header: http.Header{}, } - resp := Response{raw} - if d := resp.retryAfter(); d > 0 { + if d := RetryAfter(resp); d > 0 { t.Fatalf("unexpected retry-after value %d", d) } - raw.Header.Set(headerRetryAfter, "300") - d := resp.retryAfter() + resp.Header.Set(headerRetryAfter, "300") + d := RetryAfter(resp) if d <= 0 { t.Fatal("expected retry-after value from seconds") } @@ -168,8 +167,8 @@ func TestRetryAfter(t *testing.T) { t.Fatalf("expected 300 seconds, got %d", d/time.Second) } atDate := time.Now().Add(600 * time.Second) - raw.Header.Set(headerRetryAfter, atDate.Format(time.RFC1123)) - d = resp.retryAfter() + resp.Header.Set(headerRetryAfter, atDate.Format(time.RFC1123)) + d = RetryAfter(resp) if d <= 0 { t.Fatal("expected retry-after value from date") } @@ -192,11 +191,11 @@ func TestResponseUnmarshalAsByteArrayURLFormat(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !resp.HasStatusCode(http.StatusOK) { + if !HasStatusCode(resp, http.StatusOK) { t.Fatalf("unexpected status code: %d", resp.StatusCode) } var ba []byte - if err := resp.UnmarshalAsByteArray(&ba, Base64URLFormat); err != nil { + if err := UnmarshalAsByteArray(resp, &ba, Base64URLFormat); err != nil { t.Fatalf("unexpected error unmarshalling: %v", err) } if string(ba) != "a string that gets encoded with base64url" { @@ -217,11 +216,11 @@ func TestResponseUnmarshalAsByteArrayStdFormat(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } - if !resp.HasStatusCode(http.StatusOK) { + if !HasStatusCode(resp, http.StatusOK) { t.Fatalf("unexpected status code: %d", resp.StatusCode) } var ba []byte - if err := resp.UnmarshalAsByteArray(&ba, Base64StdFormat); err != nil { + if err := UnmarshalAsByteArray(resp, &ba, Base64StdFormat); err != nil { t.Fatalf("unexpected error unmarshalling: %v", err) } if string(ba) != "a string that gets encoded with base64url" {