From 7f385cc0585ef9509961bed95c8944e0a0776fd5 Mon Sep 17 00:00:00 2001 From: Catalina Peralta Date: Fri, 23 Jul 2021 12:47:59 -0700 Subject: [PATCH 1/2] expiring resource rework --- sdk/azidentity/bearer_token_policy.go | 197 ++++++++++++++------- sdk/azidentity/bearer_token_policy_test.go | 49 +++++ 2 files changed, 179 insertions(+), 67 deletions(-) diff --git a/sdk/azidentity/bearer_token_policy.go b/sdk/azidentity/bearer_token_policy.go index d363c7cb81a6..50b13cde5317 100644 --- a/sdk/azidentity/bearer_token_policy.go +++ b/sdk/azidentity/bearer_token_policy.go @@ -4,7 +4,9 @@ package azidentity import ( + "fmt" "net/http" + "strings" "sync" "time" @@ -16,97 +18,158 @@ const ( ) type bearerTokenPolicy struct { - // cond is used to synchronize token refresh. the locker - // must be locked when updating the following shared state. + // mainResource is the resource to be retreived using the tenant specified in the credential + mainResource *expiringResource + // auxResources are additional resources that are required for cross-tenant applications + auxResources map[string]*expiringResource + // the following fields are read-only + creds azcore.TokenCredential + options azcore.TokenRequestOptions +} + +type expiringResource struct { + // cond is used to synchronize access to the shared resource embodied by the remaining fields cond *sync.Cond - // renewing indicates that the token is in the process of being refreshed - renewing bool + // acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource + acquiring bool - // header contains the authorization header value - header string + // resource contains the value of the shared resource + resource interface{} - // expiresOn is when the token will expire - expiresOn time.Time + // expiration indicates when the shared resource expires; it is 0 if the resource was never acquired + expiration time.Time - // the following fields are read-only - creds azcore.TokenCredential - options azcore.TokenRequestOptions + // acquireResource is the callback function that actually acquires the resource + acquireResource acquireResource } -func newBearerTokenPolicy(creds azcore.TokenCredential, opts azcore.AuthenticationOptions) *bearerTokenPolicy { - return &bearerTokenPolicy{ - cond: sync.NewCond(&sync.Mutex{}), - creds: creds, - options: opts.TokenRequest, - } +type acquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) + +type acquiringResourceState struct { + req *azcore.Request + p bearerTokenPolicy } -func (b *bearerTokenPolicy) Do(req *azcore.Request) (*azcore.Response, error) { - if req.URL.Scheme != "https" { - // HTTPS must be used, otherwise the tokens are at the risk of being exposed - return nil, &AuthenticationFailedError{msg: "token credentials require a URL using the HTTPS protocol scheme"} +// acquire acquires or updates the resource; only one +// thread/goroutine at a time ever calls this function +func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { + s := state.(acquiringResourceState) + tk, err := s.p.creds.GetToken(s.req.Context(), s.p.options) + if err != nil { + return nil, time.Time{}, err } - // create a "refresh window" before the token's real expiration date. - // this allows callers to continue to use the old token while the - // refresh is in progress. - const window = 2 * time.Minute - now, getToken, header := time.Now(), false, "" + return tk, tk.ExpiresOn, nil +} + +func newExpiringResource(ar acquireResource) *expiringResource { + return &expiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar} +} + +func (er *expiringResource) GetResource(state interface{}) (interface{}, error) { + // If the resource is expiring within this time window, update it eagerly. + // This allows other threads/goroutines to keep running by using the not-yet-expired + // resource value while one thread/goroutine updates the resource. + const window = 2 * time.Minute // This example updates the resource 2 minutes prior to expiration + + now, acquire, resource := time.Now(), false, er.resource // acquire exclusive lock - b.cond.L.Lock() + er.cond.L.Lock() for { - if b.expiresOn.IsZero() || b.expiresOn.Before(now) { - // token was never obtained or has expired - if !b.renewing { - // another go routine isn't refreshing the token so this one will - b.renewing = true - getToken = true + if er.expiration.IsZero() || er.expiration.Before(now) { + // The resource was never acquired or has expired + if !er.acquiring { + // If another thread/goroutine is not acquiring/updating the resource, this thread/goroutine will do it + er.acquiring, acquire = true, true break } - // getting here means this go routine will wait for the token to refresh - } else if b.expiresOn.Add(-window).Before(now) { - // token is within the expiration window - if !b.renewing { - // another go routine isn't refreshing the token so this one will - b.renewing = true - getToken = true + // Getting here means that this thread/goroutine will wait for the updated resource + } else if er.expiration.Add(-window).Before(now) { + // The resource is valid but is expiring within the time window + if !er.acquiring { + // If another thread/goroutine is not acquiring/renewing the resource, this thread/goroutine will do it + er.acquiring, acquire = true, true break } - // this go routine will use the existing token while another refreshes it - header = b.header + // This thread/goroutine will use the existing resource value while another updates it + resource = er.resource break } else { - // token is not expiring yet so use it as-is - header = b.header + // The resource is not close to expiring, this thread/goroutine should use its current value + resource = er.resource break } - // wait for the token to refresh - b.cond.Wait() + // If we get here, wait for the new resource value to be acquired/updated + er.cond.Wait() } - b.cond.L.Unlock() - if getToken { - // this go routine has been elected to refresh the token - tk, err := b.creds.GetToken(req.Context(), b.options) - // update shared state - b.cond.L.Lock() - // to avoid a deadlock if GetToken() fails we MUST reset b.renewing to false before returning - b.renewing = false + er.cond.L.Unlock() // Release the lock so no threads/goroutines are blocked + + var err error + if acquire { + // This thread/goroutine has been selected to acquire/update the resource + var expiration time.Time + resource, expiration, err = er.acquireResource(state) + + // Atomically, update the shared resource's new value & expiration. + er.cond.L.Lock() + if err == nil { + // No error, update resource & expiration + er.resource, er.expiration = resource, expiration + } + er.acquiring = false // Indicate that no thread/goroutine is currently acquiring the resrouce + + // Wake up any waiting threads/goroutines since there is a resource they can ALL use + er.cond.L.Unlock() + er.cond.Broadcast() + } + return resource, err // Return the resource this thread/goroutine can use +} + +func newBearerTokenPolicy(creds azcore.TokenCredential, opts azcore.AuthenticationOptions) *bearerTokenPolicy { + p := &bearerTokenPolicy{ + creds: creds, + options: opts.TokenRequest, + mainResource: newExpiringResource(acquire), + } + if len(opts.AuxiliaryTenants) > 0 { + p.auxResources = map[string]*expiringResource{} + } + for _, t := range opts.AuxiliaryTenants { + p.auxResources[t] = newExpiringResource(acquire) + + } + return p +} + +func (b *bearerTokenPolicy) Do(req *azcore.Request) (*azcore.Response, error) { + as := acquiringResourceState{ + p: *b, + req: req, + } + tk, err := b.mainResource.GetResource(as) + if err != nil { + return nil, err + } + if token, ok := tk.(*azcore.AccessToken); ok { + req.Request.Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat)) + req.Request.Header.Set(headerAuthorization, fmt.Sprintf("Bearer %s", token.Token)) + } + auxTokens := []string{} + for tenant, er := range b.auxResources { + bCopy := *b + bCopy.options.TenantID = tenant + auxAS := acquiringResourceState{ + p: bCopy, + req: req, + } + auxTk, err := er.GetResource(auxAS) if err != nil { - b.unlock() return nil, err } - header = bearerTokenPrefix + tk.Token - b.header = header - b.expiresOn = tk.ExpiresOn - b.unlock() + auxTokens = append(auxTokens, fmt.Sprintf("%s%s", bearerTokenPrefix, auxTk.(*azcore.AccessToken).Token)) + } + if len(auxTokens) > 0 { + req.Request.Header.Set(headerAuxiliaryAuthorization, strings.Join(auxTokens, ", ")) } - req.Request.Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat)) - req.Request.Header.Set(headerAuthorization, header) return req.Next() } - -// signal any waiters that the token has been refreshed -func (b *bearerTokenPolicy) unlock() { - b.cond.Broadcast() - b.cond.L.Unlock() -} diff --git a/sdk/azidentity/bearer_token_policy_test.go b/sdk/azidentity/bearer_token_policy_test.go index b11c8512d937..0ce05cd47b0d 100644 --- a/sdk/azidentity/bearer_token_policy_test.go +++ b/sdk/azidentity/bearer_token_policy_test.go @@ -197,3 +197,52 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { t.Fatal("expected nil response") } } + +func TestBearerTokenWithAuxiliaryTenants(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + headerResult := "Bearer new_token, Bearer new_token, Bearer new_token" + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse() + options := ClientSecretCredentialOptions{ + AuthorityHost: srv.URL(), + HTTPClient: srv, + } + cred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + retryOpts := azcore.RetryOptions{ + MaxRetryDelay: 500 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewRetryPolicy(&retryOpts), + cred.NewAuthenticationPolicy( + azcore.AuthenticationOptions{ + TokenRequest: azcore.TokenRequestOptions{ + Scopes: []string{scope}, + }, + AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"}, + }), + azcore.NewLogPolicy(nil)) + + req, err := azcore.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + resp, err := pipeline.Do(req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + if auxH := resp.Request.Header.Get(headerAuxiliaryAuthorization); auxH != headerResult { + t.Fatalf("unexpected auxiliary authorization header %s", auxH) + } +} From 6f2b81121212f6a2faca52cf4835404b4eb46201 Mon Sep 17 00:00:00 2001 From: Catalina Peralta Date: Wed, 28 Jul 2021 10:07:25 -0700 Subject: [PATCH 2/2] add auxiliary auth header --- sdk/azidentity/azidentity.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index d3d21f7e8bb9..4563cc745688 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -29,12 +29,13 @@ const ( ) const ( - headerXmsDate = "x-ms-date" - headerUserAgent = "User-Agent" - headerURLEncoded = "application/x-www-form-urlencoded" - headerAuthorization = "Authorization" - headerMetadata = "Metadata" - headerContentType = "Content-Type" + headerXmsDate = "x-ms-date" + headerUserAgent = "User-Agent" + headerURLEncoded = "application/x-www-form-urlencoded" + headerAuthorization = "Authorization" + headerAuxiliaryAuthorization = "x-ms-authorization-auxiliary" + headerMetadata = "Metadata" + headerContentType = "Content-Type" ) const tenantIDValidationErr = "Invalid tenantID provided. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names."