Skip to content

Commit

Permalink
Revise azidentity errors (#15924)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Oct 27, 2021
1 parent 394eecb commit 47db7e2
Show file tree
Hide file tree
Showing 27 changed files with 157 additions and 173 deletions.
2 changes: 2 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
* `AzureCLICredential` no longer reads the environment variable `AZURE_CLI_PATH`
* `NewManagedIdentityCredential` no longer reads environment variables `AZURE_CLIENT_ID` and
`AZURE_RESOURCE_ID`. Use `ManagedIdentityCredentialOptions.ID` instead.
* Unexported `AuthenticationFailedError` and `CredentialUnavailableError` structs. In their place are two
interfaces having the same names.

### Bugs Fixed
* `AzureCLICredential.GetToken` no longer mutates its `opts.Scopes`
Expand Down
9 changes: 5 additions & 4 deletions sdk/azidentity/aad_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package azidentity
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -84,7 +85,7 @@ func getError(resp *http.Response) error {
} else {
msg = fmt.Sprintf("authentication failed: %s", authFailed.Message)
}
return &AuthenticationFailedError{msg: msg, resp: resp}
return newAuthenticationFailedError(errors.New(msg), resp)
}

// refreshAccessToken creates a refresh token request and returns the resulting Access Token or
Expand Down Expand Up @@ -169,7 +170,7 @@ func (c *aadIdentityClient) createAccessToken(res *http.Response) (*azcore.Acces
ExpiresOn string `json:"expires_on"`
}{}
if err := runtime.UnmarshalAsJSON(res, &value); err != nil {
return nil, fmt.Errorf("internal AccessToken: %w", err)
return nil, fmt.Errorf("internal AccessToken: %v", err)
}
t, err := value.ExpiresIn.Int64()
if err != nil {
Expand All @@ -191,7 +192,7 @@ func (c *aadIdentityClient) createRefreshAccessToken(res *http.Response) (*token
ExpiresOn string `json:"expires_on"`
}{}
if err := runtime.UnmarshalAsJSON(res, &value); err != nil {
return nil, fmt.Errorf("internal AccessToken: %w", err)
return nil, fmt.Errorf("internal AccessToken: %v", err)
}
t, err := value.ExpiresIn.Int64()
if err != nil {
Expand Down Expand Up @@ -319,7 +320,7 @@ func (c *aadIdentityClient) createUsernamePasswordAuthRequest(ctx context.Contex
func createDeviceCodeResult(res *http.Response) (*deviceCodeResult, error) {
value := &deviceCodeResult{}
if err := runtime.UnmarshalAsJSON(res, &value); err != nil {
return nil, fmt.Errorf("DeviceCodeResult: %w", err)
return nil, fmt.Errorf("DeviceCodeResult: %v", err)
}
return value, nil
}
Expand Down
3 changes: 2 additions & 1 deletion sdk/azidentity/authorization_code_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package azidentity

import (
"context"
"errors"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand Down Expand Up @@ -42,7 +43,7 @@ type AuthorizationCodeCredential struct {
// options: Manage the configuration of the requests sent to Azure Active Directory, they can also include a client secret for web app authentication.
func NewAuthorizationCodeCredential(tenantID string, clientID string, authCode string, redirectURL string, options *AuthorizationCodeCredentialOptions) (*AuthorizationCodeCredential, error) {
if !validTenantID(tenantID) {
return nil, &CredentialUnavailableError{credentialType: "Authorization Code Credential", message: tenantIDValidationErr}
return nil, errors.New(tenantIDValidationErr)
}
cp := AuthorizationCodeCredentialOptions{}
if options != nil {
Expand Down
6 changes: 1 addition & 5 deletions sdk/azidentity/authorization_code_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ func TestAuthorizationCodeCredential_InvalidTenantID(t *testing.T) {
if cred != nil {
t.Fatalf("Expected a nil credential value. Received: %v", cred)
}
var errType *CredentialUnavailableError
if !errors.As(err, &errType) {
t.Fatalf("Did not receive a CredentialUnavailableError. Received: %t", err)
}
}

func TestAuthorizationCodeCredential_CreateAuthRequestSuccess(t *testing.T) {
Expand Down Expand Up @@ -109,7 +105,7 @@ func TestAuthorizationCodeCredential_GetTokenInvalidCredentials(t *testing.T) {
if err == nil {
t.Fatalf("Expected an error but did not receive one.")
}
var authFailed *AuthenticationFailedError
var authFailed AuthenticationFailedError
if !errors.As(err, &authFailed) {
t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err)
}
Expand Down
50 changes: 0 additions & 50 deletions sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"regexp"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
)

// AuthorityHost is the base URL for Azure Active Directory
Expand Down Expand Up @@ -53,55 +52,6 @@ type tokenResponse struct {
refreshToken string
}

// AuthenticationFailedError is returned when the authentication request has failed.
type AuthenticationFailedError struct {
inner error
msg string
resp *http.Response
}

// Unwrap method on AuthenticationFailedError provides access to the inner error if available.
func (e *AuthenticationFailedError) Unwrap() error {
return e.inner
}

// NonRetriable indicates that this error should not be retried.
func (e *AuthenticationFailedError) NonRetriable() {
// marker method
}

func (e *AuthenticationFailedError) Error() string {
return e.msg
}

// RawResponse returns the HTTP response motivating the error, if available
func (e *AuthenticationFailedError) RawResponse() *http.Response {
return e.resp
}

var _ azcore.HTTPResponse = (*AuthenticationFailedError)(nil)
var _ errorinfo.NonRetriable = (*AuthenticationFailedError)(nil)

// CredentialUnavailableError is the error type returned when the conditions required to
// create a credential do not exist or are unavailable.
type CredentialUnavailableError struct {
// CredentialType holds the name of the credential that is unavailable
credentialType string
// Message contains the reason why the credential is unavailable
message string
}

func (e *CredentialUnavailableError) Error() string {
return e.credentialType + ": " + e.message
}

// NonRetriable indicates that this error should not be retried.
func (e *CredentialUnavailableError) NonRetriable() {
// marker method
}

var _ errorinfo.NonRetriable = (*CredentialUnavailableError)(nil)

// setAuthorityHost initializes the authority host for credentials.
func setAuthorityHost(authorityHost AuthorityHost) (string, error) {
host := string(authorityHost)
Expand Down
1 change: 1 addition & 0 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const (
accessTokenRespError = `{"error": "invalid_client","error_description": "Invalid client secret is provided.","error_codes": [0],"timestamp": "2019-12-01 19:00:00Z","trace_id": "2d091b0","correlation_id": "a999","error_uri": "https://login.contoso.com/error?code=0"}`
accessTokenRespSuccess = `{"access_token": "` + tokenValue + `", "expires_in": 3600}`
accessTokenRespMalformed = `{"access_token": 0, "expires_in": 3600}`
tokenValue = "new_token"
)

func defaultTestPipeline(srv policy.Transporter, cred azcore.TokenCredential, scope string) runtime.Pipeline {
Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/azure_cli_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func defaultTokenProvider() func(ctx context.Context, resource string, tenantID
// if there's no output in stderr report the error message instead
msg = err.Error()
}
return nil, &CredentialUnavailableError{credentialType: "Azure CLI Credential", message: msg}
return nil, newCredentialUnavailableError("Azure CLI Credential", msg)
}

return output, nil
Expand Down
24 changes: 11 additions & 13 deletions sdk/azidentity/chained_token_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package azidentity
import (
"context"
"errors"
"fmt"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand All @@ -25,15 +26,11 @@ type ChainedTokenCredential struct {
// NewChainedTokenCredential creates an instance of ChainedTokenCredential with the specified TokenCredential sources.
func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) {
if len(sources) == 0 {
credErr := &CredentialUnavailableError{credentialType: "Chained Token Credential", message: "Length of sources cannot be 0"}
logCredentialError(credErr.credentialType, credErr)
return nil, credErr
return nil, errors.New("sources must contain at least one TokenCredential")
}
for _, source := range sources {
if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil
credErr := &CredentialUnavailableError{credentialType: "Chained Token Credential", message: "Sources cannot contain a nil TokenCredential"}
logCredentialError(credErr.credentialType, credErr)
return nil, credErr
return nil, errors.New("sources cannot contain nil")
}
}
cp := make([]azcore.TokenCredential, len(sources))
Expand All @@ -43,22 +40,23 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine

// GetToken sequentially calls TokenCredential.GetToken on all the specified sources, returning the token from the first successful call to GetToken().
func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) {
var errList []*CredentialUnavailableError
var errList []CredentialUnavailableError
// loop through all of the credentials provided in sources
for _, cred := range c.sources {
// make a GetToken request for the current credential in the loop
token, err = cred.GetToken(ctx, opts)
// check if we received a CredentialUnavailableError
var credErr *CredentialUnavailableError
var credErr CredentialUnavailableError
if errors.As(err, &credErr) {
// if we did receive a CredentialUnavailableError then we append it to our error slice and continue looping for a good credential
errList = append(errList, credErr)
} else if err != nil {
// if we receive some other type of error then we must stop looping and process the error accordingly
var authenticationFailed *AuthenticationFailedError
if errors.As(err, &authenticationFailed) {
var authFailed AuthenticationFailedError
if errors.As(err, &authFailed) {
// if the error is an AuthenticationFailedError we return the error related to the invalid credential and append all of the other error messages received prior to this point
authErr := &AuthenticationFailedError{msg: "Received an AuthenticationFailedError, there is an invalid credential in the chain. " + createChainedErrorMessage(errList), inner: err}
err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err)
authErr := newAuthenticationFailedError(err, authFailed.RawResponse())
return nil, authErr
}
// if we receive some other error type this is unexpected and we simple return the unexpected error
Expand All @@ -70,14 +68,14 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token
}
}
// if we reach this point it means that all of the credentials in the chain returned CredentialUnavailableErrors
credErr := &CredentialUnavailableError{credentialType: "Chained Token Credential", message: createChainedErrorMessage(errList)}
credErr := newCredentialUnavailableError("Chained Token Credential", createChainedErrorMessage(errList))
// skip adding the stack trace here as it was already logged by other calls to GetToken()
addGetTokenFailureLogs("Chained Token Credential", credErr, false)
return nil, credErr
}

// helper function used to chain the error messages of the CredentialUnavailableError slice
func createChainedErrorMessage(errList []*CredentialUnavailableError) string {
func createChainedErrorMessage(errList []CredentialUnavailableError) string {
msg := ""
for _, err := range errList {
msg += err.Error()
Expand Down
13 changes: 3 additions & 10 deletions sdk/azidentity/chained_token_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,10 @@ func TestChainedTokenCredential_InstantiateFailure(t *testing.T) {
if err == nil {
t.Fatalf("Expected an error for sending a nil credential in the chain")
}
var credErr *CredentialUnavailableError
if !errors.As(err, &credErr) {
t.Fatalf("Expected a CredentialUnavailableError, but received: %T", credErr)
}
_, err = NewChainedTokenCredential([]azcore.TokenCredential{}, nil)
if err == nil {
t.Fatalf("Expected an error for not sending any credential sources")
}
if !errors.As(err, &credErr) {
t.Fatalf("Expected a CredentialUnavailableError, but received: %T", credErr)
}
}

func TestChainedTokenCredential_GetTokenSuccess(t *testing.T) {
Expand Down Expand Up @@ -118,9 +111,9 @@ func TestChainedTokenCredential_GetTokenFail(t *testing.T) {
if err == nil {
t.Fatalf("Expected an error but did not receive one")
}
var authErr *AuthenticationFailedError
var authErr AuthenticationFailedError
if !errors.As(err, &authErr) {
t.Fatalf("Expected Error Type: AuthenticationFailedError, ReceivedErrorType: %T", err)
t.Fatalf("Expected AuthenticationFailedError, received %T", err)
}
if len(err.Error()) == 0 {
t.Fatalf("Did not create an appropriate error message")
Expand All @@ -130,7 +123,7 @@ func TestChainedTokenCredential_GetTokenFail(t *testing.T) {
func TestChainedTokenCredential_GetTokenWithUnavailableCredentialInChain(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.AppendError(&CredentialUnavailableError{credentialType: "MockCredential", message: "Mocking a credential unavailable error"})
srv.AppendError(newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error"))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientSecretCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
Expand Down
7 changes: 3 additions & 4 deletions sdk/azidentity/client_certificate_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type ClientCertificateCredential struct {
// options: ClientCertificateCredentialOptions that can be used to provide additional configurations for the credential, such as the certificate password.
func NewClientCertificateCredential(tenantID string, clientID string, certData []byte, options *ClientCertificateCredentialOptions) (*ClientCertificateCredential, error) {
if !validTenantID(tenantID) {
return nil, &CredentialUnavailableError{credentialType: "Client Certificate Credential", message: tenantIDValidationErr}
return nil, errors.New(tenantIDValidationErr)
}
cp := ClientCertificateCredentialOptions{}
if options != nil {
Expand All @@ -60,9 +60,8 @@ func NewClientCertificateCredential(tenantID string, clientID string, certData [
cert, err = loadPKCS12Cert(certData, cp.Password, cp.SendCertificateChain)
}
if err != nil {
credErr := &CredentialUnavailableError{credentialType: "Client Certificate Credential", message: err.Error()}
logCredentialError(credErr.credentialType, credErr)
return nil, credErr
logCredentialError("Client Certificate Credential", err)
return nil, err
}
authorityHost, err := setAuthorityHost(cp.AuthorityHost)
if err != nil {
Expand Down
6 changes: 1 addition & 5 deletions sdk/azidentity/client_certificate_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ func TestClientCertificateCredential_InvalidTenantID(t *testing.T) {
if cred != nil {
t.Fatalf("Expected a nil credential value. Received: %v", cred)
}
var errType *CredentialUnavailableError
if !errors.As(err, &errType) {
t.Fatalf("Did not receive a CredentialUnavailableError. Received: %t", err)
}
}

func TestClientCertificateCredential_CreateAuthRequestSuccess(t *testing.T) {
Expand Down Expand Up @@ -221,7 +217,7 @@ func TestClientCertificateCredential_GetTokenInvalidCredentials(t *testing.T) {
if err == nil {
t.Fatalf("Expected to receive a nil error, but received: %v", err)
}
var authFailed *AuthenticationFailedError
var authFailed AuthenticationFailedError
if !errors.As(err, &authFailed) {
t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err)
}
Expand Down
3 changes: 2 additions & 1 deletion sdk/azidentity/client_secret_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package azidentity

import (
"context"
"errors"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand Down Expand Up @@ -37,7 +38,7 @@ type ClientSecretCredential struct {
// options: allow to configure the management of the requests sent to Azure Active Directory.
func NewClientSecretCredential(tenantID string, clientID string, clientSecret string, options *ClientSecretCredentialOptions) (*ClientSecretCredential, error) {
if !validTenantID(tenantID) {
return nil, &CredentialUnavailableError{credentialType: "Client Secret Credential", message: tenantIDValidationErr}
return nil, errors.New(tenantIDValidationErr)
}
cp := ClientSecretCredentialOptions{}
if options != nil {
Expand Down
7 changes: 1 addition & 6 deletions sdk/azidentity/client_secret_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ const (
clientID = "expected-client-id"
secret = "secret"
wrongSecret = "wrong_secret"
tokenValue = "new_token"
scope = "https://storage.azure.com/.default"
defaultTestAuthorityHost = "login.microsoftonline.com"
)
Expand All @@ -34,10 +33,6 @@ func TestClientSecretCredential_InvalidTenantID(t *testing.T) {
if cred != nil {
t.Fatalf("Expected a nil credential value. Received: %v", cred)
}
var errType *CredentialUnavailableError
if !errors.As(err, &errType) {
t.Fatalf("Did not receive a CredentialUnavailableError. Received: %t", err)
}
}

func TestClientSecretCredential_CreateAuthRequestSuccess(t *testing.T) {
Expand Down Expand Up @@ -110,7 +105,7 @@ func TestClientSecretCredential_GetTokenInvalidCredentials(t *testing.T) {
if err == nil {
t.Fatalf("Expected an error but did not receive one.")
}
var authFailed *AuthenticationFailedError
var authFailed AuthenticationFailedError
if !errors.As(err, &authFailed) {
t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err)
}
Expand Down
8 changes: 4 additions & 4 deletions sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package azidentity

import (
"context"
"errors"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand Down Expand Up @@ -65,17 +66,16 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
errMsg += err.Error()
}

cliCred, err := NewAzureCLICredential(&AzureCLICredentialOptions{TenantID: options.TenantID})
cliCred, err := NewAzureCLICredential(&AzureCLICredentialOptions{TenantID: cp.TenantID})
if err == nil {
creds = append(creds, cliCred)
} else {
errMsg += err.Error()
}

// if no credentials are added to the slice of TokenCredentials then return a CredentialUnavailableError
if len(creds) == 0 {
err := &CredentialUnavailableError{credentialType: "Default Azure Credential", message: errMsg}
logCredentialError(err.credentialType, err)
err := errors.New(errMsg)
logCredentialError("Default Azure Credential", err)
return nil, err
}
chain, err := NewChainedTokenCredential(creds, nil)
Expand Down
Loading

0 comments on commit 47db7e2

Please sign in to comment.