diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 7c48853658cf..06209c7bca8e 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -8,6 +8,7 @@ ### Breaking Changes ### Bugs Fixed +* `azidentity.doForClient` method no longer removes headers from the incoming request ### Other Changes diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 67ff1cd2763f..b592fddc93cc 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -151,6 +151,17 @@ func doForClient(client *azcore.Client, r *http.Request) (*http.Response, error) return nil, err } } + + // copy headers to the new request, ignoring any for which the new request has a value + h := req.Raw().Header + for key, vals := range r.Header { + if _, has := h[key]; !has { + for _, val := range vals { + h.Add(key, val) + } + } + } + resp, err := client.Pipeline().Do(req) if err != nil { return nil, err diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index d2879a6aa667..ea07cac60a80 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -7,11 +7,14 @@ package azidentity import ( + "bytes" "context" "crypto/x509" "errors" "fmt" + "io" "net/http" + "net/http/httptest" "os" "path/filepath" "reflect" @@ -22,6 +25,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" @@ -29,6 +33,7 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1001,6 +1006,109 @@ func TestTokenCachePersistenceOptions(t *testing.T) { } } +func TestDoForClient(t *testing.T) { + var ( + policyHeaderName = "PolicyHeader" + policyHeaderValue = "policyvalue" + + reqBody = []byte(`{"request": "azidentity"}`) + respBody = []byte(`{"response": "golang"}`) + ) + + tests := map[string]struct { + method string + path string + body io.Reader + headers http.Header + }{ + "happy path": { + method: http.MethodGet, + path: "/foo/bar", + body: bytes.NewBuffer(reqBody), + headers: http.Header{ + "Header": []string{"value1", "value2"}, + }, + }, + "no body": { + method: http.MethodGet, + path: "/", + body: http.NoBody, + }, + "nil body": { + method: http.MethodGet, + path: "/", + body: nil, + }, + "headers with empty value": { + method: http.MethodGet, + path: "/", + body: http.NoBody, + headers: http.Header{ + "Header": nil, + }, + }, + } + + client, err := azcore.NewClient(module, version, azruntime.PipelineOptions{ + // add PerCall policy to ensure doForClient calls .Pipeline.Do() + PerCall: []policy.Policy{ + policyFunc(func(req *policy.Request) (*http.Response, error) { + req.Raw().Header.Set(policyHeaderName, policyHeaderValue) + return req.Next() + }), + }, + }, nil) + require.NoError(t, err) + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, tt.method, req.Method) + assert.Equal(t, tt.path, req.URL.Path) + + rb, err := io.ReadAll(req.Body) + assert.NoError(t, err) + + if tt.body != nil && tt.body != http.NoBody { + assert.Equal(t, string(reqBody), string(rb)) + } else { + assert.Empty(t, rb) + } + + for k, v := range tt.headers { + assert.Equal(t, v, req.Header[k]) + } + + assert.Equal(t, policyHeaderValue, req.Header.Get(policyHeaderName)) + + rw.Header().Set("content-type", "application/json") + _, err = rw.Write(respBody) + assert.NoError(t, err) + })) + defer server.Close() + + req, err := http.NewRequestWithContext(context.Background(), tt.method, server.URL+tt.path, tt.body) + require.NoError(t, err) + + for k, vs := range tt.headers { + for _, v := range vs { + req.Header.Add(k, v) + } + } + + resp, err := doForClient(client, req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, string(respBody), string(b)) + }) + } +} + // ================================================================================================================================== type fakeConfidentialClient struct { @@ -1098,3 +1206,12 @@ func (f fakePublicClient) AcquireTokenInteractive(ctx context.Context, scopes [] } var _ msalPublicClient = (*fakePublicClient)(nil) + +// ================================================================================================================================== + +type policyFunc func(*policy.Request) (*http.Response, error) + +// Do implements the Policy interface on policyFunc. +func (pf policyFunc) Do(req *policy.Request) (*http.Response, error) { + return pf(req) +} diff --git a/sdk/azidentity/live_test.go b/sdk/azidentity/live_test.go index 3b5211c144ec..775eae7db323 100644 --- a/sdk/azidentity/live_test.go +++ b/sdk/azidentity/live_test.go @@ -18,6 +18,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" ) @@ -146,7 +147,19 @@ func run(m *testing.M) int { switch recording.GetRecordMode() { case recording.PlaybackMode: setFakeValues() - err := recording.SetBodilessMatcher(nil, nil) + err := recording.SetDefaultMatcher(nil, &recording.SetDefaultMatcherOptions{ + CompareBodies: to.Ptr(false), + // ignore the presence/absence/value of these headers because + // MSAL sets them and they don't affect azidentity behavior + ExcludedHeaders: []string{ + "Client-Request-Id", + "Return-Client-Request-Id", + "X-Client-Cpu", + "X-Client-Os", + "X-Client-Sku", + "X-Client-Ver", + }, + }) if err != nil { panic(err) }