Skip to content

Commit

Permalink
Modified challenge policy test (#20554)
Browse files Browse the repository at this point in the history
  • Loading branch information
gapra-msft authored Apr 5, 2023
1 parent 5ab558f commit 4bdfb89
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 46 deletions.
4 changes: 0 additions & 4 deletions sdk/storage/azblob/internal/shared/challenge_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ func NewStorageChallengePolicy(cred azcore.TokenCredential) policy.Policy {
}

func (s *storageAuthorizer) onRequest(req *policy.Request, authNZ func(policy.TokenRequestOptions) error) error {
if len(s.scopes) == 0 || s.tenantID == "" {
// returning nil indicates the bearer token policy should send the request
return nil
}
return authNZ(policy.TokenRequestOptions{Scopes: s.scopes})
}

Expand Down
100 changes: 58 additions & 42 deletions sdk/storage/azblob/internal/shared/challenge_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/stretchr/testify/require"
"net/http"
"strings"
"testing"
"time"
Expand All @@ -25,52 +24,69 @@ func (cf credentialFunc) GetToken(ctx context.Context, options policy.TokenReque
return cf(ctx, options)
}

func TestChallengePolicy(t *testing.T) {
func TestChallengePolicyStorage(t *testing.T) {
accessToken := "***"
storageResource := "https://storage.azure.com"
storageScope := "https://storage.azure.com/.default"
challenge := `Bearer authorization_uri="https://login.microsoftonline.com/{tenant}", resource_id="{storageResource}"`

srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithStatusCode(200),
)
authenticated := false
cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
authenticated = true
require.Equal(t, []string{storageScope}, tro.Scopes)
return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil
})
p := NewStorageChallengePolicy(cred)
pl := runtime.NewPipeline("", "",
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
&policy.ClientOptions{Transport: srv},
)
req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost")
require.NoError(t, err)
_, err = pl.Do(req)
require.NoError(t, err)
require.True(t, authenticated, "policy should have authenticated")
}

func TestChallengePolicyDisk(t *testing.T) {
accessToken := "***"
diskResource := "https://disk.azure.com/"
diskScope := "https://disk.azure.com//.default"
challenge := `Bearer authorization_uri="https://login.microsoftonline.com/{tenant}", resource_id="{storageResource}"`

for _, test := range []struct {
expectedScope, format, resource string
}{
{format: challenge, resource: storageResource, expectedScope: storageScope},
{format: challenge, resource: diskResource, expectedScope: diskScope},
} {
t.Run("", func(t *testing.T) {
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", strings.ReplaceAll(test.format, "{storageResource}", test.resource)),
mock.WithStatusCode(401),
)
srv.AppendResponse(mock.WithPredicate(func(r *http.Request) bool {
if authz := r.Header.Values("Authorization"); len(authz) != 1 || authz[0] != "Bearer "+accessToken {
t.Errorf(`unexpected Authorization "%s"`, authz)
}
return true
}))
srv.AppendResponse()
authenticated := false
cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
authenticated = true
require.Equal(t, []string{test.expectedScope}, tro.Scopes)
return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil
})
p := NewStorageChallengePolicy(cred)
pl := runtime.NewPipeline("", "",
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
&policy.ClientOptions{Transport: srv},
)
req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost")
require.NoError(t, err)
_, err = pl.Do(req)
require.NoError(t, err)
require.True(t, authenticated, "policy should have authenticated")
})
}
srv, close := mock.NewServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.AppendResponse(
mock.WithHeader("WWW-Authenticate", strings.ReplaceAll(challenge, "{storageResource}", diskResource)),
mock.WithStatusCode(401),
)
srv.AppendResponse(
mock.WithStatusCode(200),
)
attemptedAuthentication := false
authenticated := false
cred := credentialFunc(func(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
if attemptedAuthentication {
authenticated = true
require.Equal(t, []string{diskScope}, tro.Scopes)
return azcore.AccessToken{Token: accessToken, ExpiresOn: time.Now().Add(time.Hour)}, nil
}
attemptedAuthentication = true
return azcore.AccessToken{}, nil
})
p := NewStorageChallengePolicy(cred)
pl := runtime.NewPipeline("", "",
runtime.PipelineOptions{PerRetry: []policy.Policy{p}},
&policy.ClientOptions{Transport: srv},
)
req, err := runtime.NewRequest(context.Background(), "GET", "https://localhost")
require.NoError(t, err)
_, err = pl.Do(req)
require.NoError(t, err)
require.True(t, authenticated, "policy should have authenticated")
}

func TestParseTenant(t *testing.T) {
Expand Down

0 comments on commit 4bdfb89

Please sign in to comment.