diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index d2879a6aa667..c801a999e14c 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -7,11 +7,14 @@ package azidentity import ( + "bytes" "context" "crypto/x509" "errors" "fmt" + "io" "net/http" + "net/http/httptest" "os" "path/filepath" "reflect" @@ -22,6 +25,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" @@ -29,6 +33,7 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1001,6 +1006,107 @@ func TestTokenCachePersistenceOptions(t *testing.T) { } } +func TestDoForClient(t *testing.T) { + var ( + policyHeaderName = "PolicyHeader" + policyHeaderValue = "policyvalue" + + reqBody = []byte(`{"request": "azidentity"}`) + respBody = []byte(`{"response": "golang"}`) + ) + + tests := map[string]struct { + method string + path string + body io.Reader + headers http.Header + }{ + "happy path": { + method: http.MethodGet, + path: "/foo/bar", + body: bytes.NewBuffer(reqBody), + headers: http.Header{ + "Header": []string{"value1", "value2"}, + }, + }, + "no body": { + method: http.MethodGet, + path: "/", + body: http.NoBody, + }, + "nil body": { + method: http.MethodGet, + path: "/", + body: nil, + }, + "headers with empty value": { + method: http.MethodGet, + path: "/", + body: http.NoBody, + headers: http.Header{ + "Header": nil, + }, + }, + } + + client, err := azcore.NewClient(module, version, azruntime.PipelineOptions{ + // add PerCall policy to ensure doForClient calls .Pipeline.Do() + PerCall: []policy.Policy{ + policyFunc(func(req *policy.Request) (*http.Response, error) { + req.Raw().Header.Set(policyHeaderName, policyHeaderValue) + return req.Next() + }), + }, + }, nil) + require.NoError(t, err) + + for name, tt := range tests { + tt := tt + + t.Run(name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, tt.method, req.Method) + assert.Equal(t, tt.path, req.URL.Path) + + rb, err := io.ReadAll(req.Body) + assert.NoError(t, err) + + if tt.body != nil && tt.body != http.NoBody { + assert.Equal(t, string(reqBody), string(rb)) + } else { + assert.Empty(t, rb) + } + + assert.Equal(t, policyHeaderValue, req.Header.Get(policyHeaderName)) + + rw.Header().Set("content-type", "application/json") + _, err = rw.Write(respBody) + assert.NoError(t, err) + })) + defer server.Close() + + req, err := http.NewRequestWithContext(context.Background(), tt.method, server.URL+tt.path, tt.body) + require.NoError(t, err) + + for k, vs := range tt.headers { + for _, v := range vs { + req.Header.Add(k, v) + } + } + + resp, err := doForClient(client, req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, string(respBody), string(b)) + }) + } +} + // ================================================================================================================================== type fakeConfidentialClient struct { @@ -1098,3 +1204,12 @@ func (f fakePublicClient) AcquireTokenInteractive(ctx context.Context, scopes [] } var _ msalPublicClient = (*fakePublicClient)(nil) + +// ================================================================================================================================== + +type policyFunc func(*policy.Request) (*http.Response, error) + +// Do implements the Policy interface on policyFunc. +func (pf policyFunc) Do(req *policy.Request) (*http.Response, error) { + return pf(req) +}