diff --git a/sdk/azcore/arm/internal/pollers/async/async.go b/sdk/azcore/arm/internal/pollers/async/async.go index b4f3f0f23d21..f6708c479a86 100644 --- a/sdk/azcore/arm/internal/pollers/async/async.go +++ b/sdk/azcore/arm/internal/pollers/async/async.go @@ -69,7 +69,7 @@ func New(resp *http.Response, finalState pollers.FinalStateVia, pollerID string) } // check for provisioning state state, err := armpollers.GetProvisioningState(resp) - if errors.Is(err, shared.ErrNoBody) || state == "" { + if errors.Is(err, pollers.ErrNoBody) || state == "" { // NOTE: the ARM RPC spec explicitly states that for async PUT the initial response MUST // contain a provisioning state. to maintain compat with track 1 and other implementations // we are explicitly relaxing this requirement. diff --git a/sdk/azcore/arm/internal/pollers/body/body.go b/sdk/azcore/arm/internal/pollers/body/body.go index 0abe8969b4d1..baf1fb90c327 100644 --- a/sdk/azcore/arm/internal/pollers/body/body.go +++ b/sdk/azcore/arm/internal/pollers/body/body.go @@ -12,7 +12,6 @@ import ( armpollers "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) @@ -50,7 +49,7 @@ func New(resp *http.Response, pollerID string) (*Poller, error) { // status code and provisioning state, we might change the value. curState := pollers.StatusInProgress provState, err := armpollers.GetProvisioningState(resp) - if err != nil && !errors.Is(err, shared.ErrNoBody) { + if err != nil && !errors.Is(err, pollers.ErrNoBody) { return nil, err } if resp.StatusCode == http.StatusCreated && provState != "" { @@ -87,7 +86,7 @@ func (p *Poller) Update(resp *http.Response) error { return nil } state, err := armpollers.GetProvisioningState(resp) - if errors.Is(err, shared.ErrNoBody) { + if errors.Is(err, pollers.ErrNoBody) { // a missing response body in non-204 case is an error return err } else if state == "" { diff --git a/sdk/azcore/arm/internal/pollers/body/body_test.go b/sdk/azcore/arm/internal/pollers/body/body_test.go index 76bf8dd63947..3b7706eda03e 100644 --- a/sdk/azcore/arm/internal/pollers/body/body_test.go +++ b/sdk/azcore/arm/internal/pollers/body/body_test.go @@ -14,7 +14,7 @@ import ( "strings" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" ) const ( @@ -113,7 +113,7 @@ func TestUpdateNoProvStateFail(t *testing.T) { if err == nil { t.Fatal("unexpected nil error") } - if !errors.Is(err, shared.ErrNoBody) { + if !errors.Is(err, pollers.ErrNoBody) { t.Fatalf("unexpected error type %T", err) } } diff --git a/sdk/azcore/arm/internal/pollers/loc/loc.go b/sdk/azcore/arm/internal/pollers/loc/loc.go index 548250ff3adb..ca5a95386461 100644 --- a/sdk/azcore/arm/internal/pollers/loc/loc.go +++ b/sdk/azcore/arm/internal/pollers/loc/loc.go @@ -75,7 +75,7 @@ func (p *Poller) Update(resp *http.Response) error { if runtime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) { // if a 200/201 returns a provisioning state, use that instead state, err := armpollers.GetProvisioningState(resp) - if err != nil && !errors.Is(err, shared.ErrNoBody) { + if err != nil && !errors.Is(err, pollers.ErrNoBody) { return err } if state != "" { diff --git a/sdk/azcore/arm/internal/pollers/pollers.go b/sdk/azcore/arm/internal/pollers/pollers.go index f23159fd10b9..3bad751076d5 100644 --- a/sdk/azcore/arm/internal/pollers/pollers.go +++ b/sdk/azcore/arm/internal/pollers/pollers.go @@ -9,7 +9,7 @@ package pollers import ( "net/http" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" ) // provisioningState returns the provisioning state from the response or the empty string. @@ -50,7 +50,7 @@ func status(jsonBody map[string]interface{}) string { // Typically used for Azure-AsyncOperation flows. // If there is no status in the response body the empty string is returned. func GetStatus(resp *http.Response) (string, error) { - jsonBody, err := shared.GetJSON(resp) + jsonBody, err := pollers.GetJSON(resp) if err != nil { return "", err } @@ -60,7 +60,7 @@ func GetStatus(resp *http.Response) (string, error) { // GetProvisioningState returns the LRO's state from the response body. // If there is no state in the response body the empty string is returned. func GetProvisioningState(resp *http.Response) (string, error) { - jsonBody, err := shared.GetJSON(resp) + jsonBody, err := pollers.GetJSON(resp) if err != nil { return "", err } diff --git a/sdk/azcore/arm/internal/pollers/pollers_test.go b/sdk/azcore/arm/internal/pollers/pollers_test.go index f9af2d899608..307ce15a4600 100644 --- a/sdk/azcore/arm/internal/pollers/pollers_test.go +++ b/sdk/azcore/arm/internal/pollers/pollers_test.go @@ -13,7 +13,7 @@ import ( "strings" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" ) func TestGetStatusSuccess(t *testing.T) { @@ -35,14 +35,14 @@ func TestGetNoBody(t *testing.T) { Body: http.NoBody, } status, err := GetStatus(resp) - if !errors.Is(err, shared.ErrNoBody) { + if !errors.Is(err, pollers.ErrNoBody) { t.Fatalf("unexpected error %T", err) } if status != "" { t.Fatal("expected empty status") } status, err = GetProvisioningState(resp) - if !errors.Is(err, shared.ErrNoBody) { + if !errors.Is(err, pollers.ErrNoBody) { t.Fatalf("unexpected error %T", err) } if status != "" { diff --git a/sdk/azcore/arm/runtime/pipeline.go b/sdk/azcore/arm/runtime/pipeline.go index 1628d5111237..d0408be83c67 100644 --- a/sdk/azcore/arm/runtime/pipeline.go +++ b/sdk/azcore/arm/runtime/pipeline.go @@ -10,39 +10,38 @@ import ( "errors" "reflect" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" ) // NewPipeline creates a pipeline from connection options. // The telemetry policy, when enabled, will use the specified module and version info. -func NewPipeline(module, version string, cred shared.TokenCredential, plOpts azruntime.PipelineOptions, options *arm.ClientOptions) (pipeline.Pipeline, error) { +func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azruntime.PipelineOptions, options *arm.ClientOptions) (azruntime.Pipeline, error) { if options == nil { options = &arm.ClientOptions{} } conf, err := getConfiguration(&options.ClientOptions) if err != nil { - return pipeline.Pipeline{}, err + return azruntime.Pipeline{}, err } authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{ Scopes: []string{conf.Audience + "/.default"}, AuxiliaryTenants: options.AuxiliaryTenants, }) - perRetry := make([]pipeline.Policy, 0, len(plOpts.PerRetry)+1) + perRetry := make([]azpolicy.Policy, 0, len(plOpts.PerRetry)+1) copy(perRetry, plOpts.PerRetry) plOpts.PerRetry = append(perRetry, authPolicy) if !options.DisableRPRegistration { regRPOpts := armpolicy.RegistrationOptions{ClientOptions: options.ClientOptions} regPolicy, err := NewRPRegistrationPolicy(cred, ®RPOpts) if err != nil { - return pipeline.Pipeline{}, err + return azruntime.Pipeline{}, err } - perCall := make([]pipeline.Policy, 0, len(plOpts.PerCall)+1) + perCall := make([]azpolicy.Policy, 0, len(plOpts.PerCall)+1) copy(perCall, plOpts.PerCall) plOpts.PerCall = append(perCall, regPolicy) } diff --git a/sdk/azcore/arm/runtime/pipeline_test.go b/sdk/azcore/arm/runtime/pipeline_test.go index f4a9d112075f..239bce745a73 100644 --- a/sdk/azcore/arm/runtime/pipeline_test.go +++ b/sdk/azcore/arm/runtime/pipeline_test.go @@ -13,9 +13,9 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" @@ -179,7 +179,7 @@ func TestPipelineAudience(t *testing.T) { t.Fatal("unexpected audience " + audience) } getTokenCalled := false - cred := mockCredential{getTokenImpl: func(ctx context.Context, options shared.TokenRequestOptions) (*shared.AccessToken, error) { + cred := mockCredential{getTokenImpl: func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { getTokenCalled = true if n := len(options.Scopes); n != 1 { t.Fatalf("expected 1 scope, got %d", n) @@ -187,7 +187,7 @@ func TestPipelineAudience(t *testing.T) { if options.Scopes[0] != audience+"/.default" { t.Fatalf(`unexpected scope "%s"`, options.Scopes[0]) } - return &shared.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil + return &azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil }} req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { diff --git a/sdk/azcore/arm/runtime/policy_bearer_token.go b/sdk/azcore/arm/runtime/policy_bearer_token.go index 88e2945e5226..4dc32b3a8df6 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -23,8 +24,8 @@ type acquiringResourceState struct { // acquire acquires or updates the resource; only one // thread/goroutine at a time ever calls this function -func acquire(state acquiringResourceState) (newResource *shared.AccessToken, newExpiration time.Time, err error) { - tk, err := state.p.cred.GetToken(state.ctx, shared.TokenRequestOptions{ +func acquire(state acquiringResourceState) (newResource *azcore.AccessToken, newExpiration time.Time, err error) { + tk, err := state.p.cred.GetToken(state.ctx, azpolicy.TokenRequestOptions{ Scopes: state.p.options.Scopes, TenantID: state.tenant, }) @@ -37,18 +38,18 @@ func acquire(state acquiringResourceState) (newResource *shared.AccessToken, new // BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. type BearerTokenPolicy struct { // mainResource is the resource to be retreived using the tenant specified in the credential - mainResource *shared.ExpiringResource[*shared.AccessToken, acquiringResourceState] + mainResource *shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState] // auxResources are additional resources that are required for cross-tenant applications - auxResources map[string]*shared.ExpiringResource[*shared.AccessToken, acquiringResourceState] + auxResources map[string]*shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState] // the following fields are read-only - cred shared.TokenCredential + cred azcore.TokenCredential options armpolicy.BearerTokenOptions } // NewBearerTokenPolicy creates a policy object that authorizes requests with bearer tokens. // cred: an azcore.TokenCredential implementation such as a credential object from azidentity // opts: optional settings. Pass nil to accept default values; this is the same as passing a zero-value options. -func NewBearerTokenPolicy(cred shared.TokenCredential, opts *armpolicy.BearerTokenOptions) *BearerTokenPolicy { +func NewBearerTokenPolicy(cred azcore.TokenCredential, opts *armpolicy.BearerTokenOptions) *BearerTokenPolicy { if opts == nil { opts = &armpolicy.BearerTokenOptions{} } @@ -58,7 +59,7 @@ func NewBearerTokenPolicy(cred shared.TokenCredential, opts *armpolicy.BearerTok mainResource: shared.NewExpiringResource(acquire), } if len(opts.AuxiliaryTenants) > 0 { - p.auxResources = map[string]*shared.ExpiringResource[*shared.AccessToken, acquiringResourceState]{} + p.auxResources = map[string]*shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState]{} } for _, t := range opts.AuxiliaryTenants { p.auxResources[t] = shared.NewExpiringResource(acquire) diff --git a/sdk/azcore/arm/runtime/policy_bearer_token_test.go b/sdk/azcore/arm/runtime/policy_bearer_token_test.go index e381393da8a1..54849b31d7d8 100644 --- a/sdk/azcore/arm/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/arm/runtime/policy_bearer_token_test.go @@ -12,10 +12,11 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" @@ -29,14 +30,14 @@ const ( ) type mockCredential struct { - getTokenImpl func(ctx context.Context, options shared.TokenRequestOptions) (*shared.AccessToken, error) + getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) } -func (mc mockCredential) GetToken(ctx context.Context, options shared.TokenRequestOptions) (*shared.AccessToken, error) { +func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { if mc.getTokenImpl != nil { return mc.getTokenImpl(ctx, options) } - return &shared.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil + return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil } func (mc mockCredential) NewAuthenticationPolicy() azpolicy.Policy { @@ -47,11 +48,11 @@ func (mc mockCredential) Do(req *azpolicy.Request) (*http.Response, error) { return nil, nil } -func newTestPipeline(opts *azpolicy.ClientOptions) pipeline.Pipeline { +func newTestPipeline(opts *azpolicy.ClientOptions) runtime.Pipeline { return runtime.NewPipeline("testmodule", "v0.1.0", runtime.PipelineOptions{}, opts) } -func defaultTestPipeline(srv azpolicy.Transporter, scope string) (pipeline.Pipeline, error) { +func defaultTestPipeline(srv azpolicy.Transporter, scope string) (runtime.Pipeline, error) { retryOpts := azpolicy.RetryOptions{ MaxRetryDelay: 500 * time.Millisecond, RetryDelay: time.Millisecond, @@ -97,7 +98,7 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { defer close() expectedErr := errors.New("oops") failCredential := mockCredential{} - failCredential.getTokenImpl = func(ctx context.Context, options shared.TokenRequestOptions) (*shared.AccessToken, error) { + failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { return nil, expectedErr } b := NewBearerTokenPolicy(failCredential, nil) @@ -156,7 +157,7 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { MaxRetries: 3, } b := NewBearerTokenPolicy(mockCredential{}, nil) - pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) + pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []azpolicy.Policy{b}}) req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatal(err) @@ -189,7 +190,7 @@ func TestBearerTokenWithAuxiliaryTenants(t *testing.T) { AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"}, }, ) - pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) + pipeline := newTestPipeline(&azpolicy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []azpolicy.Policy{b}}) req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("Unexpected error: %v", err) diff --git a/sdk/azcore/arm/runtime/policy_register_rp.go b/sdk/azcore/arm/runtime/policy_register_rp.go index 7d798169fc13..7bbc1836e480 100644 --- a/sdk/azcore/arm/runtime/policy_register_rp.go +++ b/sdk/azcore/arm/runtime/policy_register_rp.go @@ -15,8 +15,9 @@ import ( "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" @@ -49,7 +50,7 @@ func setDefaults(r *armpolicy.RegistrationOptions) { // NewRPRegistrationPolicy creates a policy object configured using the specified options. // The policy controls whether an unregistered resource provider should automatically be // registered. See https://aka.ms/rps-not-found for more information. -func NewRPRegistrationPolicy(cred shared.TokenCredential, o *armpolicy.RegistrationOptions) (azpolicy.Policy, error) { +func NewRPRegistrationPolicy(cred azcore.TokenCredential, o *armpolicy.RegistrationOptions) (azpolicy.Policy, error) { if o == nil { o = &armpolicy.RegistrationOptions{} } @@ -60,7 +61,7 @@ func NewRPRegistrationPolicy(cred shared.TokenCredential, o *armpolicy.Registrat authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{Scopes: []string{conf.Audience + "/.default"}}) p := &rpRegistrationPolicy{ endpoint: conf.Endpoint, - pipeline: runtime.NewPipeline(shared.Module, shared.Version, runtime.PipelineOptions{PerRetry: []pipeline.Policy{authPolicy}}, &o.ClientOptions), + pipeline: runtime.NewPipeline(shared.Module, shared.Version, runtime.PipelineOptions{PerRetry: []azpolicy.Policy{authPolicy}}, &o.ClientOptions), options: *o, } // init the copy @@ -70,7 +71,7 @@ func NewRPRegistrationPolicy(cred shared.TokenCredential, o *armpolicy.Registrat type rpRegistrationPolicy struct { endpoint string - pipeline pipeline.Pipeline + pipeline runtime.Pipeline options armpolicy.RegistrationOptions } @@ -207,7 +208,7 @@ type serviceErrorDetails struct { /////////////////////////////////////////////////////////////////////////////////////////////// type providersOperations struct { - p pipeline.Pipeline + p runtime.Pipeline u string subID string } @@ -247,7 +248,7 @@ func (client *providersOperations) getCreateRequest(ctx context.Context, resourc // getHandleResponse handles the Get response. func (client *providersOperations) getHandleResponse(resp *http.Response) (*ProviderResponse, error) { if !runtime.HasStatusCode(resp, http.StatusOK) { - return nil, shared.NewResponseError(resp) + return nil, exported.NewResponseError(resp) } result := ProviderResponse{RawResponse: resp} err := runtime.UnmarshalAsJSON(resp, &result.Provider) @@ -292,7 +293,7 @@ func (client *providersOperations) registerCreateRequest(ctx context.Context, re // registerHandleResponse handles the Register response. func (client *providersOperations) registerHandleResponse(resp *http.Response) (*ProviderResponse, error) { if !runtime.HasStatusCode(resp, http.StatusOK) { - return nil, shared.NewResponseError(resp) + return nil, exported.NewResponseError(resp) } result := ProviderResponse{RawResponse: resp} err := runtime.UnmarshalAsJSON(resp, &result.Provider) diff --git a/sdk/azcore/arm/runtime/policy_register_rp_test.go b/sdk/azcore/arm/runtime/policy_register_rp_test.go index 0fcdfd262958..1ec452626190 100644 --- a/sdk/azcore/arm/runtime/policy_register_rp_test.go +++ b/sdk/azcore/arm/runtime/policy_register_rp_test.go @@ -15,11 +15,11 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" @@ -56,7 +56,7 @@ const rpRegisteredResp = `{ const requestEndpoint = "/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/fakeResourceGroupo/providers/Microsoft.Storage/storageAccounts/fakeAccountName" -func newTestRPRegistrationPipeline(t *testing.T, srv *mock.Server) pipeline.Pipeline { +func newTestRPRegistrationPipeline(t *testing.T, srv *mock.Server) runtime.Pipeline { opts := testRPRegistrationOptions(srv) rp, err := NewRPRegistrationPolicy(mockCredential{}, testRPRegistrationOptions(srv)) if err != nil { @@ -347,7 +347,7 @@ func TestRPRegistrationPolicyDisabled(t *testing.T) { if err != nil { t.Fatal(err) } - pl := runtime.NewPipeline("test", "v0.1.0", runtime.PipelineOptions{PerCall: []pipeline.Policy{rp}}, nil) + pl := runtime.NewPipeline("test", "v0.1.0", runtime.PipelineOptions{PerCall: []azpolicy.Policy{rp}}, nil) req, err := runtime.NewRequest(context.Background(), http.MethodGet, runtime.JoinPaths(srv.URL(), requestEndpoint)) if err != nil { t.Fatal(err) @@ -395,7 +395,7 @@ func TestRPRegistrationPolicyAudience(t *testing.T) { }, } getTokenCalled := false - cred := mockCredential{getTokenImpl: func(ctx context.Context, options shared.TokenRequestOptions) (*shared.AccessToken, error) { + cred := mockCredential{getTokenImpl: func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { getTokenCalled = true if n := len(options.Scopes); n != 1 { t.Fatalf("expected 1 scope, got %d", n) @@ -403,7 +403,7 @@ func TestRPRegistrationPolicyAudience(t *testing.T) { if options.Scopes[0] != audience+"/.default" { t.Fatalf(`unexpected scope "%s"`, options.Scopes[0]) } - return &shared.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil + return &azcore.AccessToken{Token: "...", ExpiresOn: time.Now().Add(time.Hour)}, nil }} opts := azpolicy.ClientOptions{Cloud: conf, Transport: srv} rp, err := NewRPRegistrationPolicy(cred, &armpolicy.RegistrationOptions{ClientOptions: opts}) diff --git a/sdk/azcore/arm/runtime/poller.go b/sdk/azcore/arm/runtime/poller.go index cc03ad097f28..6e902ffc27ff 100644 --- a/sdk/azcore/arm/runtime/poller.go +++ b/sdk/azcore/arm/runtime/poller.go @@ -17,8 +17,8 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/async" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/body" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/loc" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) @@ -50,7 +50,7 @@ type NewPollerOptions[T any] struct { } // NewPoller creates a Poller based on the provided initial response. -func NewPoller[T any](resp *http.Response, pl pipeline.Pipeline, options *NewPollerOptions[T]) (*Poller[T], error) { +func NewPoller[T any](resp *http.Response, pl runtime.Pipeline, options *NewPollerOptions[T]) (*Poller[T], error) { if options == nil { options = &NewPollerOptions[T]{} } @@ -98,7 +98,7 @@ type NewPollerFromResumeTokenOptions[T any] struct { } // NewPollerFromResumeToken creates a Poller from a resume token string. -func NewPollerFromResumeToken[T any](token string, pl pipeline.Pipeline, options *NewPollerFromResumeTokenOptions[T]) (*Poller[T], error) { +func NewPollerFromResumeToken[T any](token string, pl runtime.Pipeline, options *NewPollerFromResumeTokenOptions[T]) (*Poller[T], error) { if options == nil { options = &NewPollerFromResumeTokenOptions[T]{} } diff --git a/sdk/azcore/arm/runtime/poller_test.go b/sdk/azcore/arm/runtime/poller_test.go index 5a16a41170fc..bd3b827189e0 100644 --- a/sdk/azcore/arm/runtime/poller_test.go +++ b/sdk/azcore/arm/runtime/poller_test.go @@ -18,7 +18,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/async" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/body" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/internal/pollers/loc" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -42,11 +42,11 @@ type mockType struct { Field *string `json:"field,omitempty"` } -func getPipeline(srv *mock.Server) pipeline.Pipeline { +func getPipeline(srv *mock.Server) runtime.Pipeline { return runtime.NewPipeline( "test", "v0.1.0", - runtime.PipelineOptions{PerRetry: []pipeline.Policy{runtime.NewLogPolicy(nil)}}, + runtime.PipelineOptions{PerRetry: []policy.Policy{runtime.NewLogPolicy(nil)}}, &policy.ClientOptions{Transport: srv}, ) } @@ -270,7 +270,7 @@ func TestNewPollerFailedWithError(t *testing.T) { if err == nil { t.Fatal(err) } - if _, ok := err.(*shared.ResponseError); !ok { + if _, ok := err.(*exported.ResponseError); !ok { t.Fatalf("unexpected error type %T", err) } } diff --git a/sdk/azcore/core.go b/sdk/azcore/core.go index 2f1f5a52b07b..b188b7f0e954 100644 --- a/sdk/azcore/core.go +++ b/sdk/azcore/core.go @@ -7,17 +7,25 @@ package azcore import ( + "context" "reflect" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) // AccessToken represents an Azure service bearer access token with expiry information. -type AccessToken = shared.AccessToken +type AccessToken struct { + Token string + ExpiresOn time.Time +} // TokenCredential represents a credential capable of providing an OAuth token. -type TokenCredential = shared.TokenCredential +type TokenCredential interface { + // GetToken requests an access token for the specified set of scopes. + GetToken(ctx context.Context, options policy.TokenRequestOptions) (*AccessToken, error) +} // holds sentinel values used to send nulls var nullables map[reflect.Type]interface{} = map[reflect.Type]interface{}{} diff --git a/sdk/azcore/errors.go b/sdk/azcore/errors.go index dda41c3331c4..17bd50c67320 100644 --- a/sdk/azcore/errors.go +++ b/sdk/azcore/errors.go @@ -6,11 +6,9 @@ package azcore -import ( - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" -) +import "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" // ResponseError is returned when a request is made to a service and // the service returns a non-success HTTP status code. // Use errors.As() to access this type in the error chain. -type ResponseError = shared.ResponseError +type ResponseError = exported.ResponseError diff --git a/sdk/azcore/internal/exported/exported.go b/sdk/azcore/internal/exported/exported.go new file mode 100644 index 000000000000..7efde892a185 --- /dev/null +++ b/sdk/azcore/internal/exported/exported.go @@ -0,0 +1,117 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "errors" + "io" + "io/ioutil" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" +) + +type nopCloser struct { + io.ReadSeeker +} + +func (n nopCloser) Close() error { + return nil +} + +// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. +// Exported as streaming.NopCloser(). +func NopCloser(rs io.ReadSeeker) io.ReadSeekCloser { + return nopCloser{rs} +} + +// HasStatusCode returns true if the Response's status code is one of the specified values. +// Exported as runtime.HasStatusCode(). +func HasStatusCode(resp *http.Response, statusCodes ...int) bool { + if resp == nil { + return false + } + for _, sc := range statusCodes { + if resp.StatusCode == sc { + return true + } + } + return false +} + +// Payload reads and returns the response body or an error. +// On a successful read, the response body is cached. +// Subsequent reads will access the cached value. +// Exported as runtime.Payload(). +func Payload(resp *http.Response) ([]byte, error) { + // r.Body won't be a nopClosingBytesReader if downloading was skipped + if buf, ok := resp.Body.(*nopClosingBytesReader); ok { + return buf.Bytes(), nil + } + bytesBody, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + resp.Body = &nopClosingBytesReader{s: bytesBody, i: 0} + return bytesBody, nil +} + +// NopClosingBytesReader is an io.ReadSeekCloser around a byte slice. +// It also provides direct access to the byte slice to avoid rereading. +type nopClosingBytesReader struct { + s []byte + i int64 +} + +// Bytes returns the underlying byte slice. +func (r *nopClosingBytesReader) Bytes() []byte { + return r.s +} + +// Close implements the io.Closer interface. +func (*nopClosingBytesReader) Close() error { + return nil +} + +// Read implements the io.Reader interface. +func (r *nopClosingBytesReader) Read(b []byte) (n int, err error) { + if r.i >= int64(len(r.s)) { + return 0, io.EOF + } + n = copy(b, r.s[r.i:]) + r.i += int64(n) + return +} + +// Set replaces the existing byte slice with the specified byte slice and resets the reader. +func (r *nopClosingBytesReader) Set(b []byte) { + r.s = b + r.i = 0 +} + +// Seek implements the io.Seeker interface. +func (r *nopClosingBytesReader) Seek(offset int64, whence int) (int64, error) { + var i int64 + switch whence { + case io.SeekStart: + i = offset + case io.SeekCurrent: + i = r.i + offset + case io.SeekEnd: + i = int64(len(r.s)) + offset + default: + return 0, errors.New("nopClosingBytesReader: invalid whence") + } + if i < 0 { + return 0, errors.New("nopClosingBytesReader: negative position") + } + r.i = i + return i, nil +} + +var _ shared.BytesSetter = (*nopClosingBytesReader)(nil) diff --git a/sdk/azcore/internal/exported/exported_test.go b/sdk/azcore/internal/exported/exported_test.go new file mode 100644 index 000000000000..dcb4aaafa781 --- /dev/null +++ b/sdk/azcore/internal/exported/exported_test.go @@ -0,0 +1,116 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exported + +import ( + "io" + "net/http" + "strings" + "testing" +) + +func TestNopCloser(t *testing.T) { + nc := NopCloser(strings.NewReader("foo")) + if err := nc.Close(); err != nil { + t.Fatal(err) + } +} + +func TestHasStatusCode(t *testing.T) { + if HasStatusCode(nil, http.StatusAccepted) { + t.Fatal("unexpected success") + } + if HasStatusCode(&http.Response{}) { + t.Fatal("unexpected success") + } + if HasStatusCode(&http.Response{StatusCode: http.StatusBadGateway}, http.StatusBadRequest) { + t.Fatal("unexpected success") + } + if !HasStatusCode(&http.Response{StatusCode: http.StatusOK}, http.StatusAccepted, http.StatusOK, http.StatusNoContent) { + t.Fatal("unexpected failure") + } +} + +func TestPayload(t *testing.T) { + const val = "payload" + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader(val)), + } + b, err := Payload(resp) + if err != nil { + t.Fatal(err) + } + if string(b) != val { + t.Fatalf("got %s, want %s", string(b), val) + } + b, err = Payload(resp) + if err != nil { + t.Fatal(err) + } + if string(b) != val { + t.Fatalf("got %s, want %s", string(b), val) + } +} + +func TestNopClosingBytesReader(t *testing.T) { + const val1 = "the data" + ncbr := &nopClosingBytesReader{s: []byte(val1)} + b, err := io.ReadAll(ncbr) + if err != nil { + t.Fatal(err) + } + if string(b) != val1 { + t.Fatalf("got %s, want %s", string(b), val1) + } + const val2 = "something else" + ncbr.Set([]byte(val2)) + b, err = io.ReadAll(ncbr) + if err != nil { + t.Fatal(err) + } + if string(b) != val2 { + t.Fatalf("got %s, want %s", string(b), val2) + } + if err = ncbr.Close(); err != nil { + t.Fatal(err) + } + // seek to beginning and read again + i, err := ncbr.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if i != 0 { + t.Fatalf("got %d, want %d", i, 0) + } + b, err = io.ReadAll(ncbr) + if err != nil { + t.Fatal(err) + } + if string(b) != val2 { + t.Fatalf("got %s, want %s", string(b), val2) + } + // seek to middle from the end + i, err = ncbr.Seek(-4, io.SeekEnd) + if err != nil { + t.Fatal(err) + } + if l := int64(len(val2)) - 4; i != l { + t.Fatalf("got %d, want %d", l, i) + } + b, err = io.ReadAll(ncbr) + if err != nil { + t.Fatal(err) + } + if string(b) != "else" { + t.Fatalf("got %s, want %s", string(b), "else") + } + // underflow + _, err = ncbr.Seek(-int64(len(val2)+1), io.SeekCurrent) + if err == nil { + t.Fatal("unexpected nil error") + } +} diff --git a/sdk/azcore/internal/pipeline/pipeline.go b/sdk/azcore/internal/exported/pipeline.go similarity index 93% rename from sdk/azcore/internal/pipeline/pipeline.go rename to sdk/azcore/internal/exported/pipeline.go index f44b53db5e42..c44efd6eff57 100644 --- a/sdk/azcore/internal/pipeline/pipeline.go +++ b/sdk/azcore/internal/exported/pipeline.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package pipeline +package exported import ( "errors" @@ -16,6 +16,7 @@ import ( // Policy represents an extensibility point for the Pipeline that can mutate the specified // Request and react to the received Response. +// Exported as policy.Policy. type Policy interface { // Do applies the policy to the specified Request. When implementing a Policy, mutate the // request before calling req.Next() to move on to the next policy, and respond to the result @@ -25,11 +26,13 @@ type Policy interface { // Pipeline represents a primitive for sending HTTP requests and receiving responses. // Its behavior can be extended by specifying policies during construction. +// Exported as runtime.Pipeline. type Pipeline struct { policies []Policy } // Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. +// Exported as policy.Transporter. type Transporter interface { // Do sends the HTTP request and returns the HTTP response or error. Do(req *http.Request) (*http.Response, error) @@ -56,6 +59,7 @@ func (tp transportPolicy) Do(req *Request) (*http.Response, error) { } // NewPipeline creates a new Pipeline object from the specified Policies. +// Not directly exported, but used as part of runtime.NewPipeline(). func NewPipeline(transport Transporter, policies ...Policy) Pipeline { // transport policy must always be the last in the slice policies = append(policies, transportPolicy{trans: transport}) diff --git a/sdk/azcore/internal/pipeline/pipeline_test.go b/sdk/azcore/internal/exported/pipeline_test.go similarity index 99% rename from sdk/azcore/internal/pipeline/pipeline_test.go rename to sdk/azcore/internal/exported/pipeline_test.go index ad836d3160f4..b048e870abfe 100644 --- a/sdk/azcore/internal/pipeline/pipeline_test.go +++ b/sdk/azcore/internal/exported/pipeline_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package pipeline +package exported import ( "context" diff --git a/sdk/azcore/internal/pipeline/request.go b/sdk/azcore/internal/exported/request.go similarity index 92% rename from sdk/azcore/internal/pipeline/request.go rename to sdk/azcore/internal/exported/request.go index 3c00320d6687..4aeec158937b 100644 --- a/sdk/azcore/internal/pipeline/request.go +++ b/sdk/azcore/internal/exported/request.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package pipeline +package exported import ( "context" @@ -18,17 +18,9 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) -// PolicyFunc is a type that implements the Policy interface. -// Use this type when implementing a stateless policy as a first-class function. -type PolicyFunc func(*Request) (*http.Response, error) - -// Do implements the Policy interface on PolicyFunc. -func (pf PolicyFunc) Do(req *Request) (*http.Response, error) { - return pf(req) -} - // Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. // Don't use this type directly, use NewRequest() instead. +// Exported as policy.Request. type Request struct { req *http.Request body io.ReadSeekCloser @@ -53,6 +45,7 @@ func (ov opValues) get(value interface{}) bool { } // NewRequest creates a new Request with the specified input. +// Exported as runtime.NewRequest(). func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Request, error) { req, err := http.NewRequestWithContext(ctx, httpMethod, endpoint, nil) if err != nil { diff --git a/sdk/azcore/internal/pipeline/request_test.go b/sdk/azcore/internal/exported/request_test.go similarity index 87% rename from sdk/azcore/internal/pipeline/request_test.go rename to sdk/azcore/internal/exported/request_test.go index a72c445e629d..73de17500f6e 100644 --- a/sdk/azcore/internal/pipeline/request_test.go +++ b/sdk/azcore/internal/exported/request_test.go @@ -4,15 +4,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package pipeline +package exported import ( "context" "net/http" "strings" "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) const testURL = "http://test.contoso.com/" @@ -36,6 +34,12 @@ func TestNewRequest(t *testing.T) { } } +type testPolicy struct{} + +func (testPolicy) Do(*Request) (*http.Response, error) { + return &http.Response{}, nil +} + func TestRequestPolicies(t *testing.T) { req, err := NewRequest(context.Background(), http.MethodPost, testURL) if err != nil { @@ -56,10 +60,7 @@ func TestRequestPolicies(t *testing.T) { if resp != nil { t.Fatal("expected nil response") } - testPolicy := func(*Request) (*http.Response, error) { - return &http.Response{}, nil - } - req.policies = []Policy{PolicyFunc(testPolicy)} + req.policies = []Policy{testPolicy{}} resp, err = req.Next() if err != nil { t.Fatal(err) @@ -80,7 +81,7 @@ func TestRequestBody(t *testing.T) { if err := req.Close(); err != nil { t.Fatal(err) } - if err := req.SetBody(shared.NopCloser(strings.NewReader("test")), "application/text"); err != nil { + if err := req.SetBody(NopCloser(strings.NewReader("test")), "application/text"); err != nil { t.Fatal(err) } if err := req.RewindBody(); err != nil { @@ -96,7 +97,7 @@ func TestRequestClone(t *testing.T) { if err != nil { t.Fatal(err) } - if err := req.SetBody(shared.NopCloser(strings.NewReader("test")), "application/text"); err != nil { + if err := req.SetBody(NopCloser(strings.NewReader("test")), "application/text"); err != nil { t.Fatal(err) } type ensureCloned struct { diff --git a/sdk/azcore/internal/shared/response_error.go b/sdk/azcore/internal/exported/response_error.go similarity index 97% rename from sdk/azcore/internal/shared/response_error.go rename to sdk/azcore/internal/exported/response_error.go index 419e08f06ad9..3db6acc83258 100644 --- a/sdk/azcore/internal/shared/response_error.go +++ b/sdk/azcore/internal/exported/response_error.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package shared +package exported import ( "bytes" @@ -15,6 +15,7 @@ import ( ) // NewResponseError creates a new *ResponseError from the provided HTTP response. +// Exported as runtime.NewResponseError(). func NewResponseError(resp *http.Response) error { respErr := &ResponseError{ StatusCode: resp.StatusCode, @@ -94,6 +95,7 @@ func extractErrorCodeXML(body []byte) string { // ResponseError is returned when a request is made to a service and // the service returns a non-success HTTP status code. // Use errors.As() to access this type in the error chain. +// Exported as azcore.ResponseError. type ResponseError struct { // ErrorCode is the error code returned by the resource provider if available. ErrorCode string diff --git a/sdk/azcore/internal/shared/response_error_test.go b/sdk/azcore/internal/exported/response_error_test.go similarity index 99% rename from sdk/azcore/internal/shared/response_error_test.go rename to sdk/azcore/internal/exported/response_error_test.go index 9cd9b2525ce2..7b4a44150ef1 100644 --- a/sdk/azcore/internal/shared/response_error_test.go +++ b/sdk/azcore/internal/exported/response_error_test.go @@ -4,7 +4,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package shared +package exported import ( "errors" diff --git a/sdk/azcore/internal/pollers/op/op.go b/sdk/azcore/internal/pollers/op/op.go index 7d9c0572000c..488eca1e19b5 100644 --- a/sdk/azcore/internal/pollers/op/op.go +++ b/sdk/azcore/internal/pollers/op/op.go @@ -52,7 +52,7 @@ func New(resp *http.Response, finalState pollers.FinalStateVia, pollerID string) // service sent us a status then use that instead. curState := pollers.StatusInProgress status, err := getValue(resp, "status") - if err != nil && !errors.Is(err, shared.ErrNoBody) { + if err != nil && !errors.Is(err, pollers.ErrNoBody) { return nil, err } if status != "" { @@ -99,7 +99,7 @@ func (p *Poller) Update(resp *http.Response) error { } // check for resourceLocation resLoc, err := getValue(resp, "resourceLocation") - if err != nil && !errors.Is(err, shared.ErrNoBody) { + if err != nil && !errors.Is(err, pollers.ErrNoBody) { return err } else if resLoc != "" { p.FinalGET = resLoc @@ -116,7 +116,7 @@ func (p *Poller) Status() string { } func getValue(resp *http.Response, val string) (string, error) { - jsonBody, err := shared.GetJSON(resp) + jsonBody, err := pollers.GetJSON(resp) if err != nil { return "", err } diff --git a/sdk/azcore/internal/pollers/poller.go b/sdk/azcore/internal/pollers/poller.go index e29f5a030f51..3b4118694935 100644 --- a/sdk/azcore/internal/pollers/poller.go +++ b/sdk/azcore/internal/pollers/poller.go @@ -15,7 +15,7 @@ import ( "reflect" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) @@ -71,14 +71,14 @@ func PollerType(p *Poller) reflect.Type { } // NewPoller creates a Poller from the specified input. -func NewPoller(lro Operation, resp *http.Response, pl pipeline.Pipeline) *Poller { +func NewPoller(lro Operation, resp *http.Response, pl exported.Pipeline) *Poller { return &Poller{lro: lro, pl: pl, resp: resp} } // Poller encapsulates state and logic for polling on long-running operations. type Poller struct { lro Operation - pl pipeline.Pipeline + pl exported.Pipeline resp *http.Response err error } @@ -100,7 +100,7 @@ func (l *Poller) Poll(ctx context.Context) (*http.Response, error) { } return nil, l.err } - req, err := pipeline.NewRequest(ctx, http.MethodGet, l.lro.URL()) + req, err := exported.NewRequest(ctx, http.MethodGet, l.lro.URL()) if err != nil { return nil, err } @@ -112,7 +112,7 @@ func (l *Poller) Poll(ctx context.Context) (*http.Response, error) { defer resp.Body.Close() if !StatusCodeValid(resp) { // the LRO failed. unmarshall the error and update state - l.err = shared.NewResponseError(resp) + l.err = exported.NewResponseError(resp) l.resp = nil return nil, l.err } @@ -122,7 +122,7 @@ func (l *Poller) Poll(ctx context.Context) (*http.Response, error) { l.resp = resp log.Writef(log.EventLRO, "Status %s", l.lro.Status()) if Failed(l.lro.Status()) { - l.err = shared.NewResponseError(resp) + l.err = exported.NewResponseError(resp) l.resp = nil return nil, l.err } @@ -150,7 +150,7 @@ func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http // update l.resp with the content from final GET if applicable if u := l.lro.FinalGetURL(); u != "" { log.Write(log.EventLRO, "Performing final GET.") - req, err := pipeline.NewRequest(ctx, http.MethodGet, u) + req, err := exported.NewRequest(ctx, http.MethodGet, u) if err != nil { return nil, err } @@ -159,7 +159,7 @@ func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http return nil, err } if !StatusCodeValid(resp) { - return nil, shared.NewResponseError(resp) + return nil, exported.NewResponseError(resp) } l.resp = resp } @@ -170,7 +170,7 @@ func (l *Poller) FinalResponse(ctx context.Context, respType interface{}) (*http log.Write(log.EventLRO, "final response specifies a response type but no payload was received") return l.resp, nil } - body, err := shared.Payload(l.resp) + body, err := exported.Payload(l.resp) if err != nil { return nil, err } diff --git a/sdk/azcore/internal/pollers/poller_test.go b/sdk/azcore/internal/pollers/poller_test.go index 8e2ea6da9098..2a4ed681a412 100644 --- a/sdk/azcore/internal/pollers/poller_test.go +++ b/sdk/azcore/internal/pollers/poller_test.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -105,7 +105,7 @@ func TestNewPoller(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) srv.AppendResponse(mock.WithStatusCode(http.StatusNoContent)) // terminal defer close() - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) firstResp := &http.Response{ StatusCode: http.StatusAccepted, Header: http.Header{}, @@ -158,7 +158,7 @@ func TestNewPollerWithFinalGET(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) // terminal srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{ "shape": "round" }`))) // final GET defer close() - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) firstResp := &http.Response{ StatusCode: http.StatusAccepted, } @@ -194,13 +194,13 @@ func TestNewPollerFail1(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) srv.AppendResponse(mock.WithStatusCode(http.StatusConflict)) // terminal defer close() - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) firstResp := &http.Response{ StatusCode: http.StatusAccepted, } p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl) resp, err := p.PollUntilDone(context.Background(), time.Second, nil) - var respErr *shared.ResponseError + var respErr *exported.ResponseError if !errors.As(err, &respErr) { t.Fatalf("unexpected error type %T", err) } @@ -217,13 +217,13 @@ func TestNewPollerFail2(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) srv.AppendResponse(mock.WithStatusCode(http.StatusCreated)) // terminal defer close() - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) firstResp := &http.Response{ StatusCode: http.StatusAccepted, } p := NewPoller(&fakePoller{Ep: srv.URL()}, firstResp, pl) resp, err := p.PollUntilDone(context.Background(), time.Second, nil) - var respErr *shared.ResponseError + var respErr *exported.ResponseError if !errors.As(err, &respErr) { t.Fatalf("unexpected error type %T", err) } @@ -240,7 +240,7 @@ func TestNewPollerError(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) srv.AppendError(errors.New("fatal")) defer close() - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) firstResp := &http.Response{ StatusCode: http.StatusAccepted, } diff --git a/sdk/azcore/internal/pollers/util.go b/sdk/azcore/internal/pollers/util.go index c13caaee12fe..e7012cc904cc 100644 --- a/sdk/azcore/internal/pollers/util.go +++ b/sdk/azcore/internal/pollers/util.go @@ -7,6 +7,7 @@ package pollers import ( + "encoding/json" "errors" "fmt" "net/http" @@ -14,6 +15,7 @@ import ( "reflect" "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) @@ -45,7 +47,7 @@ func Failed(s string) bool { // returns true if the LRO response contains a valid HTTP status code func StatusCodeValid(resp *http.Response) bool { - return shared.HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) + return exported.HasStatusCode(resp, http.StatusOK, http.StatusAccepted, http.StatusCreated, http.StatusNoContent) } // IsValidURL verifies that the URL is valid and absolute. @@ -93,6 +95,27 @@ func DecodeID(tk string) (string, string, error) { return parts[0], parts[1], nil } +// ErrNoBody is returned if the response didn't contain a body. +var ErrNoBody = errors.New("the response did not contain a body") + +// GetJSON reads the response body into a raw JSON object. +// It returns ErrNoBody if there was no content. +func GetJSON(resp *http.Response) (map[string]interface{}, error) { + body, err := exported.Payload(resp) + if err != nil { + return nil, err + } + if len(body) == 0 { + return nil, ErrNoBody + } + // unmarshall the body to get the value + var jsonBody map[string]interface{} + if err = json.Unmarshal(body, &jsonBody); err != nil { + return nil, err + } + return jsonBody, nil +} + // used if the operation synchronously completed type NopPoller struct{} diff --git a/sdk/azcore/internal/pollers/util_test.go b/sdk/azcore/internal/pollers/util_test.go index a9cb5df26541..5255b425c640 100644 --- a/sdk/azcore/internal/pollers/util_test.go +++ b/sdk/azcore/internal/pollers/util_test.go @@ -7,6 +7,8 @@ package pollers import ( + "errors" + "io/ioutil" "net/http" "strings" "testing" @@ -151,6 +153,23 @@ func TestFailed(t *testing.T) { } } +func TestGetJSON(t *testing.T) { + j, err := GetJSON(&http.Response{Body: http.NoBody}) + if !errors.Is(err, ErrNoBody) { + t.Fatal(err) + } + if j != nil { + t.Fatal("expected nil json") + } + j, err = GetJSON(&http.Response{Body: ioutil.NopCloser(strings.NewReader(`{ "foo": "bar" }`))}) + if err != nil { + t.Fatal(err) + } + if v := j["foo"]; v != "bar" { + t.Fatalf("unexpected value %s", v) + } +} + func TestNopPoller(t *testing.T) { np := NopPoller{} if !np.Done() { diff --git a/sdk/azcore/internal/shared/constants.go b/sdk/azcore/internal/shared/constants.go index fb1c509f1464..2b13c78349ec 100644 --- a/sdk/azcore/internal/shared/constants.go +++ b/sdk/azcore/internal/shared/constants.go @@ -21,11 +21,6 @@ const ( HeaderOperationLocation = "Operation-Location" HeaderRetryAfter = "Retry-After" HeaderUserAgent = "User-Agent" - HeaderXmsDate = "x-ms-date" -) - -const ( - DefaultMaxRetries = 3 ) const BearerTokenPrefix = "Bearer " diff --git a/sdk/azcore/internal/shared/shared.go b/sdk/azcore/internal/shared/shared.go index 0109cc79d460..5051299afbc2 100644 --- a/sdk/azcore/internal/shared/shared.go +++ b/sdk/azcore/internal/shared/shared.go @@ -8,37 +8,12 @@ package shared import ( "context" - "encoding/json" - "errors" - "io" - "io/ioutil" "net/http" "reflect" "strconv" "time" ) -// TokenRequestOptions contain specific parameter that may be used by credentials types when attempting to get a token. -type TokenRequestOptions struct { - // Scopes contains the list of permission scopes required for the token. - Scopes []string - // TenantID contains the tenant ID to use in a multi-tenant authentication scenario, if TenantID is set - // it will override the tenant ID that was added at credential creation time. - TenantID string -} - -// TokenCredential represents a credential capable of providing an OAuth token. -type TokenCredential interface { - // GetToken requests an access token for the specified set of scopes. - GetToken(ctx context.Context, options TokenRequestOptions) (*AccessToken, error) -} - -// AccessToken represents an Azure service bearer access token with expiry information. -type AccessToken struct { - Token string - ExpiresOn time.Time -} - // CtxWithHTTPHeaderKey is used as a context key for adding/retrieving http.Header. type CtxWithHTTPHeaderKey struct{} @@ -48,19 +23,6 @@ type CtxWithRetryOptionsKey struct{} // CtxIncludeResponseKey is used as a context key for retrieving the raw response. type CtxIncludeResponseKey struct{} -type nopCloser struct { - io.ReadSeeker -} - -func (n nopCloser) Close() error { - return nil -} - -// NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. -func NopCloser(rs io.ReadSeeker) io.ReadSeekCloser { - return nopCloser{rs} -} - // Delay waits for the duration to elapse or the context to be cancelled. func Delay(ctx context.Context, delay time.Duration) error { select { @@ -71,27 +33,6 @@ func Delay(ctx context.Context, delay time.Duration) error { } } -// ErrNoBody is returned if the response didn't contain a body. -var ErrNoBody = errors.New("the response did not contain a body") - -// GetJSON reads the response body into a raw JSON object. -// It returns ErrNoBody if there was no content. -func GetJSON(resp *http.Response) (map[string]interface{}, error) { - body, err := Payload(resp) - if err != nil { - return nil, err - } - if len(body) == 0 { - return nil, ErrNoBody - } - // unmarshall the body to get the value - var jsonBody map[string]interface{} - if err = json.Unmarshal(body, &jsonBody); err != nil { - return nil, err - } - return jsonBody, nil -} - // RetryAfter returns non-zero if the response contains a Retry-After header value. func RetryAfter(resp *http.Response) time.Duration { if resp == nil { @@ -111,97 +52,14 @@ func RetryAfter(resp *http.Response) time.Duration { return 0 } -// HasStatusCode returns true if the Response's status code is one of the specified values. -func HasStatusCode(resp *http.Response, statusCodes ...int) bool { - if resp == nil { - return false - } - for _, sc := range statusCodes { - if resp.StatusCode == sc { - return true - } - } - return false -} - -// Payload reads and returns the response body or an error. -// On a successful read, the response body is cached. -// Subsequent reads will access the cached value. -func Payload(resp *http.Response) ([]byte, error) { - // r.Body won't be a nopClosingBytesReader if downloading was skipped - if buf, ok := resp.Body.(*NopClosingBytesReader); ok { - return buf.Bytes(), nil - } - bytesBody, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return nil, err - } - resp.Body = &NopClosingBytesReader{s: bytesBody, i: 0} - return bytesBody, nil -} - -// NopClosingBytesReader is an io.ReadSeekCloser around a byte slice. -// It also provides direct access to the byte slice to avoid rereading. -type NopClosingBytesReader struct { - s []byte - i int64 -} - -// NewNopClosingBytesReader creates a new NopClosingBytesReader around the specified byte slice. -func NewNopClosingBytesReader(data []byte) *NopClosingBytesReader { - return &NopClosingBytesReader{s: data} -} - -// Bytes returns the underlying byte slice. -func (r *NopClosingBytesReader) Bytes() []byte { - return r.s -} - -// Close implements the io.Closer interface. -func (*NopClosingBytesReader) Close() error { - return nil -} - -// Read implements the io.Reader interface. -func (r *NopClosingBytesReader) Read(b []byte) (n int, err error) { - if r.i >= int64(len(r.s)) { - return 0, io.EOF - } - n = copy(b, r.s[r.i:]) - r.i += int64(n) - return -} - -// Set replaces the existing byte slice with the specified byte slice and resets the reader. -func (r *NopClosingBytesReader) Set(b []byte) { - r.s = b - r.i = 0 -} - -// Seek implements the io.Seeker interface. -func (r *NopClosingBytesReader) Seek(offset int64, whence int) (int64, error) { - var i int64 - switch whence { - case io.SeekStart: - i = offset - case io.SeekCurrent: - i = r.i + offset - case io.SeekEnd: - i = int64(len(r.s)) + offset - default: - return 0, errors.New("nopClosingBytesReader: invalid whence") - } - if i < 0 { - return 0, errors.New("nopClosingBytesReader: negative position") - } - r.i = i - return i, nil -} - // TypeOfT returns the type of the generic type param. func TypeOfT[T any]() reflect.Type { // you can't, at present, obtain the type of // a type parameter, so this is the trick return reflect.TypeOf((*T)(nil)).Elem() } + +// BytesSetter abstracts replacing a byte slice on some type. +type BytesSetter interface { + Set(b []byte) +} diff --git a/sdk/azcore/internal/shared/shared_test.go b/sdk/azcore/internal/shared/shared_test.go index 242d32cd1d54..5c724731af73 100644 --- a/sdk/azcore/internal/shared/shared_test.go +++ b/sdk/azcore/internal/shared/shared_test.go @@ -8,23 +8,12 @@ package shared import ( "context" - "errors" - "io" - "io/ioutil" "net/http" "reflect" - "strings" "testing" "time" ) -func TestNopCloser(t *testing.T) { - nc := NopCloser(strings.NewReader("foo")) - if err := nc.Close(); err != nil { - t.Fatal(err) - } -} - func TestDelay(t *testing.T) { if err := Delay(context.Background(), 5*time.Millisecond); err != nil { t.Fatal(err) @@ -36,23 +25,6 @@ func TestDelay(t *testing.T) { } } -func TestGetJSON(t *testing.T) { - j, err := GetJSON(&http.Response{Body: http.NoBody}) - if !errors.Is(err, ErrNoBody) { - t.Fatal(err) - } - if j != nil { - t.Fatal("expected nil json") - } - j, err = GetJSON(&http.Response{Body: ioutil.NopCloser(strings.NewReader(`{ "foo": "bar" }`))}) - if err != nil { - t.Fatal(err) - } - if v := j["foo"]; v != "bar" { - t.Fatalf("unexpected value %s", v) - } -} - func TestRetryAfter(t *testing.T) { if RetryAfter(nil) != 0 { t.Fatal("expected zero duration") @@ -83,101 +55,6 @@ func TestRetryAfter(t *testing.T) { } } -func TestHasStatusCode(t *testing.T) { - if HasStatusCode(nil, http.StatusAccepted) { - t.Fatal("unexpected success") - } - if HasStatusCode(&http.Response{}) { - t.Fatal("unexpected success") - } - if HasStatusCode(&http.Response{StatusCode: http.StatusBadGateway}, http.StatusBadRequest) { - t.Fatal("unexpected success") - } - if !HasStatusCode(&http.Response{StatusCode: http.StatusOK}, http.StatusAccepted, http.StatusOK, http.StatusNoContent) { - t.Fatal("unexpected failure") - } -} - -func TestPayload(t *testing.T) { - const val = "payload" - resp := &http.Response{ - Body: io.NopCloser(strings.NewReader(val)), - } - b, err := Payload(resp) - if err != nil { - t.Fatal(err) - } - if string(b) != val { - t.Fatalf("got %s, want %s", string(b), val) - } - b, err = Payload(resp) - if err != nil { - t.Fatal(err) - } - if string(b) != val { - t.Fatalf("got %s, want %s", string(b), val) - } -} - -func TestNopClosingBytesReader(t *testing.T) { - const val1 = "the data" - ncbr := NewNopClosingBytesReader([]byte(val1)) - b, err := io.ReadAll(ncbr) - if err != nil { - t.Fatal(err) - } - if string(b) != val1 { - t.Fatalf("got %s, want %s", string(b), val1) - } - const val2 = "something else" - ncbr.Set([]byte(val2)) - b, err = io.ReadAll(ncbr) - if err != nil { - t.Fatal(err) - } - if string(b) != val2 { - t.Fatalf("got %s, want %s", string(b), val2) - } - if err = ncbr.Close(); err != nil { - t.Fatal(err) - } - // seek to beginning and read again - i, err := ncbr.Seek(0, io.SeekStart) - if err != nil { - t.Fatal(err) - } - if i != 0 { - t.Fatalf("got %d, want %d", i, 0) - } - b, err = io.ReadAll(ncbr) - if err != nil { - t.Fatal(err) - } - if string(b) != val2 { - t.Fatalf("got %s, want %s", string(b), val2) - } - // seek to middle from the end - i, err = ncbr.Seek(-4, io.SeekEnd) - if err != nil { - t.Fatal(err) - } - if l := int64(len(val2)) - 4; i != l { - t.Fatalf("got %d, want %d", l, i) - } - b, err = io.ReadAll(ncbr) - if err != nil { - t.Fatal(err) - } - if string(b) != "else" { - t.Fatalf("got %s, want %s", string(b), "else") - } - // underflow - _, err = ncbr.Seek(-int64(len(val2)+1), io.SeekCurrent) - if err == nil { - t.Fatal("unexpected nil error") - } -} - func TestTypeOfT(t *testing.T) { if tt := TypeOfT[bool](); tt != reflect.TypeOf(true) { t.Fatalf("unexpected type %s", tt) diff --git a/sdk/azcore/policy/policy.go b/sdk/azcore/policy/policy.go index 50be89588225..a4bb5dd10d7a 100644 --- a/sdk/azcore/policy/policy.go +++ b/sdk/azcore/policy/policy.go @@ -10,20 +10,19 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" ) // Policy represents an extensibility point for the Pipeline that can mutate the specified // Request and react to the received Response. -type Policy = pipeline.Policy +type Policy = exported.Policy // Transporter represents an HTTP pipeline transport used to send HTTP requests and receive responses. -type Transporter = pipeline.Transporter +type Transporter = exported.Transporter // Request is an abstraction over the creation of an HTTP request as it passes through the pipeline. // Don't use this type directly, use runtime.NewRequest() instead. -type Request = pipeline.Request +type Request = exported.Request // ClientOptions contains optional settings for a client's pipeline. // All zero-value fields will be initialized with default values. @@ -109,7 +108,13 @@ type TelemetryOptions struct { } // TokenRequestOptions contain specific parameter that may be used by credentials types when attempting to get a token. -type TokenRequestOptions = shared.TokenRequestOptions +type TokenRequestOptions struct { + // Scopes contains the list of permission scopes required for the token. + Scopes []string + // TenantID contains the tenant ID to use in a multi-tenant authentication scenario, if TenantID is set + // it will override the tenant ID that was added at credential creation time. + TenantID string +} // BearerTokenOptions configures the bearer token policy's behavior. type BearerTokenOptions struct { diff --git a/sdk/azcore/runtime/errors.go b/sdk/azcore/runtime/errors.go index 02d1b8d7373d..6d03b291ebff 100644 --- a/sdk/azcore/runtime/errors.go +++ b/sdk/azcore/runtime/errors.go @@ -9,11 +9,11 @@ package runtime import ( "net/http" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" ) // NewResponseError creates an *azcore.ResponseError from the provided HTTP response. // Call this when a service request returns a non-successful status code. func NewResponseError(resp *http.Response) error { - return shared.NewResponseError(resp) + return exported.NewResponseError(resp) } diff --git a/sdk/azcore/runtime/pager_test.go b/sdk/azcore/runtime/pager_test.go index 81b7d89151b5..f8fc59b71d98 100644 --- a/sdk/azcore/runtime/pager_test.go +++ b/sdk/azcore/runtime/pager_test.go @@ -13,8 +13,7 @@ import ( "net/http" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/stretchr/testify/require" ) @@ -34,7 +33,7 @@ func pageResponseFetcher(ctx context.Context, pl Pipeline, endpoint string) (Pag return PageResponse{}, err } if !HasStatusCode(resp, http.StatusOK) { - return PageResponse{}, shared.NewResponseError(resp) + return PageResponse{}, NewResponseError(resp) } pr := PageResponse{} if err := UnmarshalAsJSON(resp, &pr); err != nil { @@ -47,7 +46,7 @@ func TestPagerSinglePage(t *testing.T) { srv, close := mock.NewServer() defer close() srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [1, 2, 3, 4, 5]}`))) - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) pager := NewPager(PageProcessor[PageResponse]{ More: func(current PageResponse) bool { @@ -79,7 +78,7 @@ func TestPagerMultiplePages(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [1, 2, 3, 4, 5], "next": true}`))) srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [6, 7, 8], "next": true}`))) srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [9, 0, 1, 2]}`))) - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) pageCount := 0 pager := NewPager(PageProcessor[PageResponse]{ @@ -123,7 +122,7 @@ func TestPagerLROMultiplePages(t *testing.T) { srv, close := mock.NewServer() defer close() srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [6, 7, 8]}`))) - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) pager := NewPager(PageProcessor[PageResponse]{ More: func(current PageResponse) bool { @@ -177,7 +176,7 @@ func TestPagerPipelineError(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetError(errors.New("pipeline failed")) - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) pager := NewPager(PageProcessor[PageResponse]{ More: func(current PageResponse) bool { @@ -199,7 +198,7 @@ func TestPagerSecondPageError(t *testing.T) { defer close() srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [1, 2, 3, 4, 5], "next": true}`))) srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest), mock.WithBody([]byte(`{"message": "didn't work", "code": "PageError"}`))) - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) pageCount := 0 pager := NewPager(PageProcessor[PageResponse]{ @@ -227,7 +226,7 @@ func TestPagerSecondPageError(t *testing.T) { require.True(t, page.NextPage) case 2: require.Error(t, err) - var respErr *shared.ResponseError + var respErr *exported.ResponseError require.True(t, errors.As(err, &respErr)) require.Equal(t, "PageError", respErr.ErrorCode) goto ExitLoop @@ -241,7 +240,7 @@ func TestPagerResponderError(t *testing.T) { srv, close := mock.NewServer() defer close() srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`incorrect JSON response`))) - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) pager := NewPager(PageProcessor[PageResponse]{ More: func(current PageResponse) bool { diff --git a/sdk/azcore/runtime/pipeline.go b/sdk/azcore/runtime/pipeline.go index a69da8c85eac..7affa38c9a06 100644 --- a/sdk/azcore/runtime/pipeline.go +++ b/sdk/azcore/runtime/pipeline.go @@ -7,7 +7,9 @@ package runtime import ( - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) @@ -17,6 +19,10 @@ type PipelineOptions struct { PerCall, PerRetry []policy.Policy } +// Pipeline represents a primitive for sending HTTP requests and receiving responses. +// Its behavior can be extended by specifying policies during construction. +type Pipeline = exported.Pipeline + // NewPipeline creates a pipeline from connection options, with any additional policies as specified. // module, version: used by the telemetry policy, when enabled // perCall: additional policies to invoke once per request @@ -40,7 +46,7 @@ func NewPipeline(module, version string, plOpts PipelineOptions, options *policy } // we put the includeResponsePolicy at the very beginning so that the raw response // is populated with the final response (some policies might mutate the response) - policies := []policy.Policy{pipeline.PolicyFunc(includeResponsePolicy)} + policies := []policy.Policy{policyFunc(includeResponsePolicy)} if !cp.Telemetry.Disabled { policies = append(policies, NewTelemetryPolicy(module, version, &cp.Telemetry)) } @@ -50,10 +56,19 @@ func NewPipeline(module, version string, plOpts PipelineOptions, options *policy policies = append(policies, cp.PerRetryPolicies...) policies = append(policies, plOpts.PerRetry...) policies = append(policies, NewLogPolicy(&cp.Logging)) - policies = append(policies, pipeline.PolicyFunc(httpHeaderPolicy), pipeline.PolicyFunc(bodyDownloadPolicy)) + policies = append(policies, policyFunc(httpHeaderPolicy), policyFunc(bodyDownloadPolicy)) transport := cp.Transport if transport == nil { transport = defaultHTTPClient } - return pipeline.NewPipeline(transport, policies...) + return exported.NewPipeline(transport, policies...) +} + +// policyFunc is a type that implements the Policy interface. +// Use this type when implementing a stateless policy as a first-class function. +type policyFunc func(*policy.Request) (*http.Response, error) + +// Do implements the Policy interface on policyFunc. +func (pf policyFunc) Do(req *policy.Request) (*http.Response, error) { + return pf(req) } diff --git a/sdk/azcore/runtime/pipeline_test.go b/sdk/azcore/runtime/pipeline_test.go index f118a9ef0b22..42b6928eef0a 100644 --- a/sdk/azcore/runtime/pipeline_test.go +++ b/sdk/azcore/runtime/pipeline_test.go @@ -14,7 +14,6 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -97,7 +96,7 @@ func TestNewPipelineCustomPolicies(t *testing.T) { perRetryPolicy := &countingPolicy{} pl := NewPipeline("", "", - PipelineOptions{PerCall: []pipeline.Policy{perCallPolicy}, PerRetry: []pipeline.Policy{perRetryPolicy}}, + PipelineOptions{PerCall: []policy.Policy{perCallPolicy}, PerRetry: []policy.Policy{perRetryPolicy}}, &opts, ) _, err = pl.Do(req) diff --git a/sdk/azcore/runtime/policy_bearer_token.go b/sdk/azcore/runtime/policy_bearer_token.go index 75a23f035322..3cfbff363029 100644 --- a/sdk/azcore/runtime/policy_bearer_token.go +++ b/sdk/azcore/runtime/policy_bearer_token.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) @@ -14,9 +15,9 @@ import ( // BearerTokenPolicy authorizes requests with bearer tokens acquired from a TokenCredential. type BearerTokenPolicy struct { // mainResource is the resource to be retreived using the tenant specified in the credential - mainResource *shared.ExpiringResource[*shared.AccessToken, acquiringResourceState] + mainResource *shared.ExpiringResource[*azcore.AccessToken, acquiringResourceState] // the following fields are read-only - cred shared.TokenCredential + cred azcore.TokenCredential scopes []string } @@ -27,8 +28,8 @@ type acquiringResourceState struct { // acquire acquires or updates the resource; only one // thread/goroutine at a time ever calls this function -func acquire(state acquiringResourceState) (newResource *shared.AccessToken, newExpiration time.Time, err error) { - tk, err := state.p.cred.GetToken(state.req.Raw().Context(), shared.TokenRequestOptions{Scopes: state.p.scopes}) +func acquire(state acquiringResourceState) (newResource *azcore.AccessToken, newExpiration time.Time, err error) { + tk, err := state.p.cred.GetToken(state.req.Raw().Context(), policy.TokenRequestOptions{Scopes: state.p.scopes}) if err != nil { return nil, time.Time{}, err } @@ -39,7 +40,7 @@ func acquire(state acquiringResourceState) (newResource *shared.AccessToken, new // cred: an azcore.TokenCredential implementation such as a credential object from azidentity // scopes: the list of permission scopes required for the token. // opts: optional settings. Pass nil to accept default values; this is the same as passing a zero-value options. -func NewBearerTokenPolicy(cred shared.TokenCredential, scopes []string, opts *policy.BearerTokenOptions) *BearerTokenPolicy { +func NewBearerTokenPolicy(cred azcore.TokenCredential, scopes []string, opts *policy.BearerTokenOptions) *BearerTokenPolicy { return &BearerTokenPolicy{ cred: cred, scopes: scopes, diff --git a/sdk/azcore/runtime/policy_bearer_token_test.go b/sdk/azcore/runtime/policy_bearer_token_test.go index 17c19df46c87..78000a6f4d77 100644 --- a/sdk/azcore/runtime/policy_bearer_token_test.go +++ b/sdk/azcore/runtime/policy_bearer_token_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" @@ -25,14 +25,14 @@ const ( ) type mockCredential struct { - getTokenImpl func(ctx context.Context, options shared.TokenRequestOptions) (*shared.AccessToken, error) + getTokenImpl func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) } -func (mc mockCredential) GetToken(ctx context.Context, options shared.TokenRequestOptions) (*shared.AccessToken, error) { +func (mc mockCredential) GetToken(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { if mc.getTokenImpl != nil { return mc.getTokenImpl(ctx, options) } - return &shared.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil + return &azcore.AccessToken{Token: "***", ExpiresOn: time.Now().Add(time.Hour)}, nil } func (mc mockCredential) NewAuthenticationPolicy() policy.Policy { @@ -82,7 +82,7 @@ func TestBearerPolicy_CredentialFailGetToken(t *testing.T) { defer close() expectedErr := errors.New("oops") failCredential := mockCredential{} - failCredential.getTokenImpl = func(ctx context.Context, options shared.TokenRequestOptions) (*shared.AccessToken, error) { + failCredential.getTokenImpl = func(ctx context.Context, options policy.TokenRequestOptions) (*azcore.AccessToken, error) { return nil, expectedErr } b := NewBearerTokenPolicy(failCredential, nil, nil) @@ -138,7 +138,7 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { MaxRetries: 3, } b := NewBearerTokenPolicy(mockCredential{}, nil, nil) - pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []pipeline.Policy{b}}) + pipeline := newTestPipeline(&policy.ClientOptions{Transport: srv, Retry: retryOpts, PerRetryPolicies: []policy.Policy{b}}) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatal(err) diff --git a/sdk/azcore/runtime/policy_body_download.go b/sdk/azcore/runtime/policy_body_download.go index e83f9562b425..02d621ee89e2 100644 --- a/sdk/azcore/runtime/policy_body_download.go +++ b/sdk/azcore/runtime/policy_body_download.go @@ -11,7 +11,7 @@ import ( "net/http" "strings" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" ) @@ -29,7 +29,7 @@ func bodyDownloadPolicy(req *policy.Request) (*http.Response, error) { } // Either bodyDownloadPolicyOpValues was not specified (so skip is false) // or it was specified and skip is false: don't skip downloading the body - _, err = shared.Payload(resp) + _, err = exported.Payload(resp) if err != nil { return resp, newBodyDownloadError(err, req) } diff --git a/sdk/azcore/runtime/policy_body_download_test.go b/sdk/azcore/runtime/policy_body_download_test.go index 92deff3fe3e7..917a4b6f7cc3 100644 --- a/sdk/azcore/runtime/policy_body_download_test.go +++ b/sdk/azcore/runtime/policy_body_download_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -308,7 +307,7 @@ func TestReadBodyAfterSeek(t *testing.T) { if string(payload) != message { t.Fatal("incorrect payload") } - nb, ok := resp.Body.(*shared.NopClosingBytesReader) + nb, ok := resp.Body.(io.ReadSeekCloser) if !ok { t.Fatalf("unexpected body type: %t", resp.Body) } diff --git a/sdk/azcore/runtime/policy_logging_test.go b/sdk/azcore/runtime/policy_logging_test.go index 63b7230b6758..18c3a43c610e 100644 --- a/sdk/azcore/runtime/policy_logging_test.go +++ b/sdk/azcore/runtime/policy_logging_test.go @@ -14,7 +14,7 @@ import ( "strings" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -27,7 +27,7 @@ func TestPolicyLoggingSuccess(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse() - pl := pipeline.NewPipeline(srv, NewLogPolicy(nil)) + pl := exported.NewPipeline(srv, NewLogPolicy(nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -83,7 +83,7 @@ func TestPolicyLoggingError(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetError(errors.New("bogus error")) - pl := pipeline.NewPipeline(srv, NewLogPolicy(nil)) + pl := exported.NewPipeline(srv, NewLogPolicy(nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/sdk/azcore/runtime/policy_retry.go b/sdk/azcore/runtime/policy_retry.go index 5f525eaaa976..d4f2978bde5f 100644 --- a/sdk/azcore/runtime/policy_retry.go +++ b/sdk/azcore/runtime/policy_retry.go @@ -21,9 +21,13 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/internal/log" ) +const ( + defaultMaxRetries = 3 +) + func setDefaults(o *policy.RetryOptions) { if o.MaxRetries == 0 { - o.MaxRetries = shared.DefaultMaxRetries + o.MaxRetries = defaultMaxRetries } else if o.MaxRetries < 0 { o.MaxRetries = 0 } diff --git a/sdk/azcore/runtime/policy_retry_test.go b/sdk/azcore/runtime/policy_retry_test.go index 7bd7e87ae625..c1843d3eeef0 100644 --- a/sdk/azcore/runtime/policy_retry_test.go +++ b/sdk/azcore/runtime/policy_retry_test.go @@ -18,7 +18,7 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" @@ -35,7 +35,7 @@ func TestRetryPolicySuccess(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithStatusCode(http.StatusOK)) - pl := pipeline.NewPipeline(srv, NewRetryPolicy(nil)) + pl := exported.NewPipeline(srv, NewRetryPolicy(nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -63,7 +63,7 @@ func TestRetryPolicyFailOnStatusCode(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithStatusCode(http.StatusInternalServerError)) - pl := pipeline.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) + pl := exported.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -79,10 +79,10 @@ func TestRetryPolicyFailOnStatusCode(t *testing.T) { if resp.StatusCode != http.StatusInternalServerError { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) + if r := srv.Requests(); r != defaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) } - if body.rcount != shared.DefaultMaxRetries { + if body.rcount != defaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -97,7 +97,7 @@ func TestRetryPolicyFailOnStatusCodeRespBodyPreserved(t *testing.T) { srv.SetResponse(mock.WithStatusCode(http.StatusInternalServerError), mock.WithBody([]byte(respBody))) // add a per-request policy that reads and restores the request body. // this is to simulate how something like httputil.DumpRequest works. - pl := pipeline.NewPipeline(srv, pipeline.PolicyFunc(func(r *policy.Request) (*http.Response, error) { + pl := exported.NewPipeline(srv, policyFunc(func(r *policy.Request) (*http.Response, error) { b, err := ioutil.ReadAll(r.Raw().Body) if err != nil { t.Fatal(err) @@ -120,10 +120,10 @@ func TestRetryPolicyFailOnStatusCodeRespBodyPreserved(t *testing.T) { if resp.StatusCode != http.StatusInternalServerError { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) + if r := srv.Requests(); r != defaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) } - if body.rcount != shared.DefaultMaxRetries { + if body.rcount != defaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -145,7 +145,7 @@ func TestRetryPolicySuccessWithRetry(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout)) srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) srv.AppendResponse() - pl := pipeline.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) + pl := exported.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -182,7 +182,7 @@ func TestRetryPolicySuccessRetryWithNilResponse(t *testing.T) { t: srv, r: []int{2}, // send a nil on the second request } - pl := pipeline.NewPipeline(nilInjector, NewRetryPolicy(testRetryOptions())) + pl := exported.NewPipeline(nilInjector, NewRetryPolicy(testRetryOptions())) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -215,7 +215,7 @@ func TestRetryPolicyNoRetries(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout)) srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) srv.AppendResponse() - pl := pipeline.NewPipeline(srv, NewRetryPolicy(&policy.RetryOptions{MaxRetries: -1})) + pl := exported.NewPipeline(srv, NewRetryPolicy(&policy.RetryOptions{MaxRetries: -1})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -240,7 +240,7 @@ func TestRetryPolicyUnlimitedRetryDelay(t *testing.T) { srv.AppendResponse() opt := testRetryOptions() opt.MaxRetryDelay = -1 - pl := pipeline.NewPipeline(srv, NewRetryPolicy(opt)) + pl := exported.NewPipeline(srv, NewRetryPolicy(opt)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -262,7 +262,7 @@ func TestRetryPolicyFailOnError(t *testing.T) { defer close() fakeErr := errors.New("bogus error") srv.SetError(fakeErr) - pl := pipeline.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) + pl := exported.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) req, err := NewRequest(context.Background(), http.MethodPost, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -278,10 +278,10 @@ func TestRetryPolicyFailOnError(t *testing.T) { if resp != nil { t.Fatal("unexpected response") } - if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) + if r := srv.Requests(); r != defaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) } - if body.rcount != shared.DefaultMaxRetries { + if body.rcount != defaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -296,7 +296,7 @@ func TestRetryPolicySuccessWithRetryComplex(t *testing.T) { srv.AppendError(errors.New("bogus error")) srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError)) srv.AppendResponse(mock.WithStatusCode(http.StatusAccepted)) - pl := pipeline.NewPipeline(srv, pipeline.PolicyFunc(includeResponsePolicy), NewRetryPolicy(testRetryOptions())) + pl := exported.NewPipeline(srv, policyFunc(includeResponsePolicy), NewRetryPolicy(testRetryOptions())) var respFromCtx *http.Response ctxWithResp := WithCaptureResponse(context.Background(), &respFromCtx) req, err := NewRequest(ctxWithResp, http.MethodGet, srv.URL()) @@ -317,10 +317,10 @@ func TestRetryPolicySuccessWithRetryComplex(t *testing.T) { if resp.StatusCode != http.StatusAccepted { t.Fatalf("unexpected status code: %d", resp.StatusCode) } - if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) + if r := srv.Requests(); r != defaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) } - if body.rcount != shared.DefaultMaxRetries { + if body.rcount != defaultMaxRetries { t.Fatalf("unexpected rewind count: %d", body.rcount) } if !body.closed { @@ -332,7 +332,7 @@ func TestRetryPolicyRequestTimedOut(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetError(errors.New("bogus error")) - pl := pipeline.NewPipeline(srv, NewRetryPolicy(nil)) + pl := exported.NewPipeline(srv, NewRetryPolicy(nil)) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() req, err := NewRequest(ctx, http.MethodPost, srv.URL()) @@ -378,7 +378,7 @@ func TestRetryPolicyIsNotRetriable(t *testing.T) { defer close() srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout)) srv.AppendError(theErr) - pl := pipeline.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) + pl := exported.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -418,7 +418,7 @@ func TestWithRetryOptionsE2E(t *testing.T) { srv.RepeatResponse(9, mock.WithStatusCode(http.StatusRequestTimeout)) srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) defaultOptions := testRetryOptions() - pl := pipeline.NewPipeline(srv, NewRetryPolicy(defaultOptions)) + pl := exported.NewPipeline(srv, NewRetryPolicy(defaultOptions)) customOptions := *defaultOptions customOptions.MaxRetries = 10 customOptions.MaxRetryDelay = 200 * time.Millisecond @@ -451,7 +451,7 @@ func TestRetryPolicyFailOnErrorNoDownload(t *testing.T) { defer close() fakeErr := errors.New("bogus error") srv.SetError(fakeErr) - pl := pipeline.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) + pl := exported.NewPipeline(srv, NewRetryPolicy(testRetryOptions())) req, err := NewRequest(context.Background(), http.MethodPost, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -464,8 +464,8 @@ func TestRetryPolicyFailOnErrorNoDownload(t *testing.T) { if resp != nil { t.Fatal("unexpected response") } - if r := srv.Requests(); r != shared.DefaultMaxRetries+1 { - t.Fatalf("wrong request count, got %d expected %d", r, shared.DefaultMaxRetries+1) + if r := srv.Requests(); r != defaultMaxRetries+1 { + t.Fatalf("wrong request count, got %d expected %d", r, defaultMaxRetries+1) } } @@ -473,7 +473,7 @@ func TestRetryPolicySuccessNoDownload(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte("response body"))) - pl := pipeline.NewPipeline(srv, NewRetryPolicy(nil)) + pl := exported.NewPipeline(srv, NewRetryPolicy(nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -493,7 +493,7 @@ func TestRetryPolicySuccessNoDownloadNoBody(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithStatusCode(http.StatusOK)) - pl := pipeline.NewPipeline(srv, NewRetryPolicy(nil)) + pl := exported.NewPipeline(srv, NewRetryPolicy(nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -546,7 +546,7 @@ func TestRetryPolicyRequestTimedOutTooSlow(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithSlowResponse(5 * time.Second)) - pl := pipeline.NewPipeline(srv, NewRetryPolicy(nil)) + pl := exported.NewPipeline(srv, NewRetryPolicy(nil)) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() req, err := NewRequest(ctx, http.MethodPost, srv.URL()) @@ -579,7 +579,7 @@ func TestRetryPolicySuccessWithPerTryTimeout(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusOK)) opt := testRetryOptions() opt.TryTimeout = 1 * time.Second - pl := pipeline.NewPipeline(srv, NewRetryPolicy(opt)) + pl := exported.NewPipeline(srv, NewRetryPolicy(opt)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/sdk/azcore/runtime/policy_telemetry_test.go b/sdk/azcore/runtime/policy_telemetry_test.go index bdb56765786d..20d4a406c4bf 100644 --- a/sdk/azcore/runtime/policy_telemetry_test.go +++ b/sdk/azcore/runtime/policy_telemetry_test.go @@ -12,7 +12,7 @@ import ( "net/http" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" @@ -24,7 +24,7 @@ func TestPolicyTelemetryDefault(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse() - pl := pipeline.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", nil)) + pl := exported.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -42,7 +42,7 @@ func TestPolicyTelemetryPreserveExisting(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse() - pl := pipeline.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", nil)) + pl := exported.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", nil)) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -63,7 +63,7 @@ func TestPolicyTelemetryWithAppID(t *testing.T) { defer close() srv.SetResponse() const appID = "my_application" - pl := pipeline.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) + pl := exported.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -82,7 +82,7 @@ func TestPolicyTelemetryWithAppIDSanitized(t *testing.T) { defer close() srv.SetResponse() const appID = "This will get the spaces removed and truncated." - pl := pipeline.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) + pl := exported.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -102,7 +102,7 @@ func TestPolicyTelemetryPreserveExistingWithAppID(t *testing.T) { defer close() srv.SetResponse() const appID = "my_application" - pl := pipeline.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) + pl := exported.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -123,7 +123,7 @@ func TestPolicyTelemetryDisabled(t *testing.T) { defer close() srv.SetResponse() const appID = "my_application" - pl := pipeline.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID, Disabled: true})) + pl := exported.NewPipeline(srv, NewTelemetryPolicy("test", "v1.2.3", &policy.TelemetryOptions{ApplicationID: appID, Disabled: true})) req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/sdk/azcore/runtime/poller.go b/sdk/azcore/runtime/poller.go index daf866fd07d5..dfdb8b10a179 100644 --- a/sdk/azcore/runtime/poller.go +++ b/sdk/azcore/runtime/poller.go @@ -14,7 +14,7 @@ import ( "net/http" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op" @@ -49,7 +49,7 @@ type NewPollerOptions[T any] struct { } // NewPoller creates a Poller based on the provided initial response. -func NewPoller[T any](resp *http.Response, pl pipeline.Pipeline, options *NewPollerOptions[T]) (*Poller[T], error) { +func NewPoller[T any](resp *http.Response, pl exported.Pipeline, options *NewPollerOptions[T]) (*Poller[T], error) { if options == nil { options = &NewPollerOptions[T]{} } @@ -90,7 +90,7 @@ type NewPollerFromResumeTokenOptions[T any] struct { } // NewPollerFromResumeToken creates a Poller from a resume token string. -func NewPollerFromResumeToken[T any](token string, pl pipeline.Pipeline, options *NewPollerFromResumeTokenOptions[T]) (*Poller[T], error) { +func NewPollerFromResumeToken[T any](token string, pl exported.Pipeline, options *NewPollerFromResumeTokenOptions[T]) (*Poller[T], error) { if options == nil { options = &NewPollerFromResumeTokenOptions[T]{} } diff --git a/sdk/azcore/runtime/poller_test.go b/sdk/azcore/runtime/poller_test.go index 5e6582965093..757843f4ac41 100644 --- a/sdk/azcore/runtime/poller_test.go +++ b/sdk/azcore/runtime/poller_test.go @@ -16,8 +16,8 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -166,7 +166,7 @@ func TestLocPollerCancelled(t *testing.T) { if err == nil { t.Fatal("unexpected nil error") } - if _, ok := err.(*shared.ResponseError); !ok { + if _, ok := err.(*exported.ResponseError); !ok { t.Fatal("expected pollerError") } if w.Size != 0 { diff --git a/sdk/azcore/runtime/request.go b/sdk/azcore/runtime/request.go index b0805d432830..21e5c578d542 100644 --- a/sdk/azcore/runtime/request.go +++ b/sdk/azcore/runtime/request.go @@ -19,15 +19,11 @@ import ( "strings" "time" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) -// Pipeline represents a primitive for sending HTTP requests and receiving responses. -// Its behavior can be extended by specifying policies during construction. -type Pipeline = pipeline.Pipeline - // Base64Encoding is usesd to specify which base-64 encoder/decoder to use when // encoding/decoding a slice of bytes to/from a string. type Base64Encoding int @@ -41,8 +37,8 @@ const ( ) // NewRequest creates a new policy.Request with the specified input. -func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*pipeline.Request, error) { - return pipeline.NewRequest(ctx, httpMethod, endpoint) +func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*policy.Request, error) { + return exported.NewRequest(ctx, httpMethod, endpoint) } // JoinPaths concatenates multiple URL path segments into one path, @@ -87,7 +83,7 @@ func EncodeByteArray(v []byte, format Base64Encoding) string { func MarshalAsByteArray(req *policy.Request, v []byte, format Base64Encoding) error { // send as a JSON string encode := fmt.Sprintf("\"%s\"", EncodeByteArray(v, format)) - return req.SetBody(shared.NopCloser(strings.NewReader(encode)), shared.ContentTypeAppJSON) + return req.SetBody(exported.NopCloser(strings.NewReader(encode)), shared.ContentTypeAppJSON) } // MarshalAsJSON calls json.Marshal() to get the JSON encoding of v then calls SetBody. @@ -97,7 +93,7 @@ func MarshalAsJSON(req *policy.Request, v interface{}) error { if err != nil { return fmt.Errorf("error marshalling type %T: %s", v, err) } - return req.SetBody(shared.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppJSON) + return req.SetBody(exported.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppJSON) } // MarshalAsXML calls xml.Marshal() to get the XML encoding of v then calls SetBody. @@ -108,7 +104,7 @@ func MarshalAsXML(req *policy.Request, v interface{}) error { } // inclue the XML header as some services require it b = []byte(xml.Header + string(b)) - return req.SetBody(shared.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppXML) + return req.SetBody(exported.NopCloser(bytes.NewReader(b)), shared.ContentTypeAppXML) } // SetMultipartFormData writes the specified keys/values as multi-part form @@ -142,7 +138,7 @@ func SetMultipartFormData(req *policy.Request, formData map[string]interface{}) if err := writer.Close(); err != nil { return err } - return req.SetBody(shared.NopCloser(bytes.NewReader(body.Bytes())), writer.FormDataContentType()) + return req.SetBody(exported.NopCloser(bytes.NewReader(body.Bytes())), writer.FormDataContentType()) } // SkipBodyDownload will disable automatic downloading of the response body. diff --git a/sdk/azcore/runtime/request_test.go b/sdk/azcore/runtime/request_test.go index 5ee2e451d3f4..504fc190d9b8 100644 --- a/sdk/azcore/runtime/request_test.go +++ b/sdk/azcore/runtime/request_test.go @@ -22,7 +22,7 @@ import ( "time" "unsafe" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) @@ -535,7 +535,7 @@ func TestRequestSetBodyContentLengthHeader(t *testing.T) { for i := 0; i < buffLen; i++ { buff[i] = 1 } - err = req.SetBody(shared.NopCloser(bytes.NewReader(buff)), "application/octet-stream") + err = req.SetBody(exported.NopCloser(bytes.NewReader(buff)), "application/octet-stream") if err != nil { t.Fatal(err) } @@ -565,7 +565,7 @@ func TestRequestValidFail(t *testing.T) { t.Fatal(err) } req.Raw().Header.Add("inval d", "header") - p := pipeline.NewPipeline(nil) + p := exported.NewPipeline(nil) resp, err := p.Do(req) if err == nil { t.Fatal("unexpected nil error") @@ -593,7 +593,7 @@ func TestSetMultipartFormData(t *testing.T) { err = SetMultipartFormData(req, map[string]interface{}{ "string": "value", "int": 1, - "data": shared.NopCloser(strings.NewReader("some data")), + "data": exported.NopCloser(strings.NewReader("some data")), }) if err != nil { t.Fatal(err) diff --git a/sdk/azcore/runtime/response.go b/sdk/azcore/runtime/response.go index 086f7af57105..2322f0a201ba 100644 --- a/sdk/azcore/runtime/response.go +++ b/sdk/azcore/runtime/response.go @@ -16,6 +16,7 @@ import ( "io/ioutil" "net/http" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) @@ -23,12 +24,12 @@ import ( // On a successful read, the response body is cached. // Subsequent reads will access the cached value. func Payload(resp *http.Response) ([]byte, error) { - return shared.Payload(resp) + return exported.Payload(resp) } // HasStatusCode returns true if the Response's status code is one of the specified values. func HasStatusCode(resp *http.Response, statusCodes ...int) bool { - return shared.HasStatusCode(resp, statusCodes...) + return exported.HasStatusCode(resp, statusCodes...) } // UnmarshalAsByteArray will base-64 decode the received payload and place the result into the value pointed to by v. @@ -99,7 +100,7 @@ func removeBOM(resp *http.Response) error { // UTF8 trimmed := bytes.TrimPrefix(payload, []byte("\xef\xbb\xbf")) if len(trimmed) < len(payload) { - resp.Body.(*shared.NopClosingBytesReader).Set(trimmed) + resp.Body.(shared.BytesSetter).Set(trimmed) } return nil } diff --git a/sdk/azcore/streaming/progress.go b/sdk/azcore/streaming/progress.go index 50f9a0a8d877..8563375af07e 100644 --- a/sdk/azcore/streaming/progress.go +++ b/sdk/azcore/streaming/progress.go @@ -9,7 +9,7 @@ package streaming import ( "io" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" ) type progress struct { @@ -21,7 +21,7 @@ type progress struct { // NopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker. func NopCloser(rs io.ReadSeeker) io.ReadSeekCloser { - return shared.NopCloser(rs) + return exported.NopCloser(rs) } // NewRequestProgress adds progress reporting to an HTTP request's body stream. diff --git a/sdk/azcore/streaming/progress_test.go b/sdk/azcore/streaming/progress_test.go index 9e6ffd7a6c94..eb0f4380d00f 100644 --- a/sdk/azcore/streaming/progress_test.go +++ b/sdk/azcore/streaming/progress_test.go @@ -15,7 +15,7 @@ import ( "reflect" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" ) @@ -30,7 +30,7 @@ func TestProgressReporting(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithBody(content)) - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -79,7 +79,7 @@ func TestProgressReportingSeek(t *testing.T) { srv, close := mock.NewServer() defer close() srv.SetResponse(mock.WithBody(content)) - pl := pipeline.NewPipeline(srv) + pl := exported.NewPipeline(srv) req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { t.Fatalf("unexpected error: %v", err)