Skip to content

Commit

Permalink
azidentity credentials preserve MSAL headers (#22098)
Browse files Browse the repository at this point in the history
  • Loading branch information
handsomejack-42 authored Jan 2, 2024
1 parent 2d7b90a commit 4831055
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 1 deletion.
1 change: 1 addition & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
### Breaking Changes

### Bugs Fixed
* `azidentity.doForClient` method no longer removes headers from the incoming request

### Other Changes

Expand Down
11 changes: 11 additions & 0 deletions sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ func doForClient(client *azcore.Client, r *http.Request) (*http.Response, error)
return nil, err
}
}

// copy headers to the new request, ignoring any for which the new request has a value
h := req.Raw().Header
for key, vals := range r.Header {
if _, has := h[key]; !has {
for _, val := range vals {
h.Add(key, val)
}
}
}

resp, err := client.Pipeline().Do(req)
if err != nil {
return nil, err
Expand Down
117 changes: 117 additions & 0 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
package azidentity

import (
"bytes"
"context"
"crypto/x509"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
Expand All @@ -22,13 +25,15 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -1001,6 +1006,109 @@ func TestTokenCachePersistenceOptions(t *testing.T) {
}
}

func TestDoForClient(t *testing.T) {
var (
policyHeaderName = "PolicyHeader"
policyHeaderValue = "policyvalue"

reqBody = []byte(`{"request": "azidentity"}`)
respBody = []byte(`{"response": "golang"}`)
)

tests := map[string]struct {
method string
path string
body io.Reader
headers http.Header
}{
"happy path": {
method: http.MethodGet,
path: "/foo/bar",
body: bytes.NewBuffer(reqBody),
headers: http.Header{
"Header": []string{"value1", "value2"},
},
},
"no body": {
method: http.MethodGet,
path: "/",
body: http.NoBody,
},
"nil body": {
method: http.MethodGet,
path: "/",
body: nil,
},
"headers with empty value": {
method: http.MethodGet,
path: "/",
body: http.NoBody,
headers: http.Header{
"Header": nil,
},
},
}

client, err := azcore.NewClient(module, version, azruntime.PipelineOptions{
// add PerCall policy to ensure doForClient calls .Pipeline.Do()
PerCall: []policy.Policy{
policyFunc(func(req *policy.Request) (*http.Response, error) {
req.Raw().Header.Set(policyHeaderName, policyHeaderValue)
return req.Next()
}),
},
}, nil)
require.NoError(t, err)

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
assert.Equal(t, tt.method, req.Method)
assert.Equal(t, tt.path, req.URL.Path)

rb, err := io.ReadAll(req.Body)
assert.NoError(t, err)

if tt.body != nil && tt.body != http.NoBody {
assert.Equal(t, string(reqBody), string(rb))
} else {
assert.Empty(t, rb)
}

for k, v := range tt.headers {
assert.Equal(t, v, req.Header[k])
}

assert.Equal(t, policyHeaderValue, req.Header.Get(policyHeaderName))

rw.Header().Set("content-type", "application/json")
_, err = rw.Write(respBody)
assert.NoError(t, err)
}))
defer server.Close()

req, err := http.NewRequestWithContext(context.Background(), tt.method, server.URL+tt.path, tt.body)
require.NoError(t, err)

for k, vs := range tt.headers {
for _, v := range vs {
req.Header.Add(k, v)
}
}

resp, err := doForClient(client, req)
require.NoError(t, err)
defer resp.Body.Close()

require.Equal(t, http.StatusOK, resp.StatusCode)

b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, string(respBody), string(b))
})
}
}

// ==================================================================================================================================

type fakeConfidentialClient struct {
Expand Down Expand Up @@ -1098,3 +1206,12 @@ func (f fakePublicClient) AcquireTokenInteractive(ctx context.Context, scopes []
}

var _ msalPublicClient = (*fakePublicClient)(nil)

// ==================================================================================================================================

type policyFunc func(*policy.Request) (*http.Response, error)

// Do implements the Policy interface on policyFunc.
func (pf policyFunc) Do(req *policy.Request) (*http.Response, error) {
return pf(req)
}
15 changes: 14 additions & 1 deletion sdk/azidentity/live_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
)

Expand Down Expand Up @@ -146,7 +147,19 @@ func run(m *testing.M) int {
switch recording.GetRecordMode() {
case recording.PlaybackMode:
setFakeValues()
err := recording.SetBodilessMatcher(nil, nil)
err := recording.SetDefaultMatcher(nil, &recording.SetDefaultMatcherOptions{
CompareBodies: to.Ptr(false),
// ignore the presence/absence/value of these headers because
// MSAL sets them and they don't affect azidentity behavior
ExcludedHeaders: []string{
"Client-Request-Id",
"Return-Client-Request-Id",
"X-Client-Cpu",
"X-Client-Os",
"X-Client-Sku",
"X-Client-Ver",
},
})
if err != nil {
panic(err)
}
Expand Down

0 comments on commit 4831055

Please sign in to comment.