From 93ef8e85015eb41838a66b8203222907cbe204cd Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Thu, 7 Dec 2023 19:52:07 +0000 Subject: [PATCH 1/3] recorded tests ignore MSAL headers --- sdk/azidentity/live_test.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sdk/azidentity/live_test.go b/sdk/azidentity/live_test.go index 3b5211c144ec..775eae7db323 100644 --- a/sdk/azidentity/live_test.go +++ b/sdk/azidentity/live_test.go @@ -18,6 +18,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" ) @@ -146,7 +147,19 @@ func run(m *testing.M) int { switch recording.GetRecordMode() { case recording.PlaybackMode: setFakeValues() - err := recording.SetBodilessMatcher(nil, nil) + err := recording.SetDefaultMatcher(nil, &recording.SetDefaultMatcherOptions{ + CompareBodies: to.Ptr(false), + // ignore the presence/absence/value of these headers because + // MSAL sets them and they don't affect azidentity behavior + ExcludedHeaders: []string{ + "Client-Request-Id", + "Return-Client-Request-Id", + "X-Client-Cpu", + "X-Client-Os", + "X-Client-Sku", + "X-Client-Ver", + }, + }) if err != nil { panic(err) } From dd65c2587409dcc9b80469ed7d751cab84fba38f Mon Sep 17 00:00:00 2001 From: HandsomeJack Date: Tue, 5 Dec 2023 14:52:34 +0100 Subject: [PATCH 2/3] test(azidentity): add unit test for doForClient function The focus is on methods/behavior implemented by azidentity package. I didn't aim to hit 100% coverage, mainly because there are error paths determined by e.g. std http package etc. Signed-off-by: HandsomeJack --- sdk/azidentity/azidentity_test.go | 113 ++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index d2879a6aa667..a038f337a351 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -7,11 +7,14 @@ package azidentity import ( + "bytes" "context" "crypto/x509" "errors" "fmt" + "io" "net/http" + "net/http/httptest" "os" "path/filepath" "reflect" @@ -22,6 +25,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" @@ -29,6 +33,7 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public" "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1001,6 +1006,105 @@ func TestTokenCachePersistenceOptions(t *testing.T) { } } +func TestDoForClient(t *testing.T) { + var ( + policyHeaderName = "PolicyHeader" + policyHeaderValue = "policyvalue" + + reqBody = []byte(`{"request": "azidentity"}`) + respBody = []byte(`{"response": "golang"}`) + ) + + tests := map[string]struct { + method string + path string + body io.Reader + headers http.Header + }{ + "happy path": { + method: http.MethodGet, + path: "/foo/bar", + body: bytes.NewBuffer(reqBody), + headers: http.Header{ + "Header": []string{"value1", "value2"}, + }, + }, + "no body": { + method: http.MethodGet, + path: "/", + body: http.NoBody, + }, + "nil body": { + method: http.MethodGet, + path: "/", + body: nil, + }, + "headers with empty value": { + method: http.MethodGet, + path: "/", + body: http.NoBody, + headers: http.Header{ + "Header": nil, + }, + }, + } + + client, err := azcore.NewClient(module, version, azruntime.PipelineOptions{ + // add PerCall policy to ensure doForClient calls .Pipeline.Do() + PerCall: []policy.Policy{ + policyFunc(func(req *policy.Request) (*http.Response, error) { + req.Raw().Header.Set(policyHeaderName, policyHeaderValue) + return req.Next() + }), + }, + }, nil) + require.NoError(t, err) + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + assert.Equal(t, tt.method, req.Method) + assert.Equal(t, tt.path, req.URL.Path) + + rb, err := io.ReadAll(req.Body) + assert.NoError(t, err) + + if tt.body != nil && tt.body != http.NoBody { + assert.Equal(t, string(reqBody), string(rb)) + } else { + assert.Empty(t, rb) + } + + assert.Equal(t, policyHeaderValue, req.Header.Get(policyHeaderName)) + + rw.Header().Set("content-type", "application/json") + _, err = rw.Write(respBody) + assert.NoError(t, err) + })) + defer server.Close() + + req, err := http.NewRequestWithContext(context.Background(), tt.method, server.URL+tt.path, tt.body) + require.NoError(t, err) + + for k, vs := range tt.headers { + for _, v := range vs { + req.Header.Add(k, v) + } + } + + resp, err := doForClient(client, req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, string(respBody), string(b)) + }) + } +} + // ================================================================================================================================== type fakeConfidentialClient struct { @@ -1098,3 +1202,12 @@ func (f fakePublicClient) AcquireTokenInteractive(ctx context.Context, scopes [] } var _ msalPublicClient = (*fakePublicClient)(nil) + +// ================================================================================================================================== + +type policyFunc func(*policy.Request) (*http.Response, error) + +// Do implements the Policy interface on policyFunc. +func (pf policyFunc) Do(req *policy.Request) (*http.Response, error) { + return pf(req) +} From 49629314d58b6e17b6c3e0470fc38354600c7123 Mon Sep 17 00:00:00 2001 From: HandsomeJack Date: Fri, 8 Dec 2023 14:26:29 +0100 Subject: [PATCH 3/3] fix(azidentity): do not strip away request headers in doForClient Some authorities might require certain headers to be passed. For example, in our dSTS auth flow, the request form contains client_info, which needs to be accompanied by X-Client-SKU=MSAL.Go header, else the API call produces AADSTS501791: Client_info is only supported for MSAL/ADAL, please ensure that MSAL/ADAL custom headers are being sent. The `doForClient` function creates new `runtime.Request` from the incoming request, but it fails to propagate the respective headers. This commits is addressing that. Signed-off-by: HandsomeJack --- sdk/azidentity/CHANGELOG.md | 1 + sdk/azidentity/azidentity.go | 11 +++++++++++ sdk/azidentity/azidentity_test.go | 4 ++++ 3 files changed, 16 insertions(+) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 7c48853658cf..06209c7bca8e 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -8,6 +8,7 @@ ### Breaking Changes ### Bugs Fixed +* `azidentity.doForClient` method no longer removes headers from the incoming request ### Other Changes diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 67ff1cd2763f..b592fddc93cc 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -151,6 +151,17 @@ func doForClient(client *azcore.Client, r *http.Request) (*http.Response, error) return nil, err } } + + // copy headers to the new request, ignoring any for which the new request has a value + h := req.Raw().Header + for key, vals := range r.Header { + if _, has := h[key]; !has { + for _, val := range vals { + h.Add(key, val) + } + } + } + resp, err := client.Pipeline().Do(req) if err != nil { return nil, err diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index a038f337a351..ea07cac60a80 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -1075,6 +1075,10 @@ func TestDoForClient(t *testing.T) { assert.Empty(t, rb) } + for k, v := range tt.headers { + assert.Equal(t, v, req.Header[k]) + } + assert.Equal(t, policyHeaderValue, req.Header.Get(policyHeaderName)) rw.Header().Set("content-type", "application/json")