diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index dcde20193d9d..76ece3ac0a64 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -54,6 +54,8 @@ * `AzureCLICredential` no longer reads the environment variable `AZURE_CLI_PATH` * `NewManagedIdentityCredential` no longer reads environment variables `AZURE_CLIENT_ID` and `AZURE_RESOURCE_ID`. Use `ManagedIdentityCredentialOptions.ID` instead. +* Unexported `AuthenticationFailedError` and `CredentialUnavailableError` structs. In their place are two + interfaces having the same names. ### Bugs Fixed * `AzureCLICredential.GetToken` no longer mutates its `opts.Scopes` diff --git a/sdk/azidentity/aad_identity_client.go b/sdk/azidentity/aad_identity_client.go index 10eb09bbe45e..e0e17874bea4 100644 --- a/sdk/azidentity/aad_identity_client.go +++ b/sdk/azidentity/aad_identity_client.go @@ -6,6 +6,7 @@ package azidentity import ( "context" "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -84,7 +85,7 @@ func getError(resp *http.Response) error { } else { msg = fmt.Sprintf("authentication failed: %s", authFailed.Message) } - return &AuthenticationFailedError{msg: msg, resp: resp} + return newAuthenticationFailedError(errors.New(msg), resp) } // refreshAccessToken creates a refresh token request and returns the resulting Access Token or @@ -169,7 +170,7 @@ func (c *aadIdentityClient) createAccessToken(res *http.Response) (*azcore.Acces ExpiresOn string `json:"expires_on"` }{} if err := runtime.UnmarshalAsJSON(res, &value); err != nil { - return nil, fmt.Errorf("internal AccessToken: %w", err) + return nil, fmt.Errorf("internal AccessToken: %v", err) } t, err := value.ExpiresIn.Int64() if err != nil { @@ -191,7 +192,7 @@ func (c *aadIdentityClient) createRefreshAccessToken(res *http.Response) (*token ExpiresOn string `json:"expires_on"` }{} if err := runtime.UnmarshalAsJSON(res, &value); err != nil { - return nil, fmt.Errorf("internal AccessToken: %w", err) + return nil, fmt.Errorf("internal AccessToken: %v", err) } t, err := value.ExpiresIn.Int64() if err != nil { @@ -319,7 +320,7 @@ func (c *aadIdentityClient) createUsernamePasswordAuthRequest(ctx context.Contex func createDeviceCodeResult(res *http.Response) (*deviceCodeResult, error) { value := &deviceCodeResult{} if err := runtime.UnmarshalAsJSON(res, &value); err != nil { - return nil, fmt.Errorf("DeviceCodeResult: %w", err) + return nil, fmt.Errorf("DeviceCodeResult: %v", err) } return value, nil } diff --git a/sdk/azidentity/authorization_code_credential.go b/sdk/azidentity/authorization_code_credential.go index abf70cfe57f2..e62d9d7a52bf 100644 --- a/sdk/azidentity/authorization_code_credential.go +++ b/sdk/azidentity/authorization_code_credential.go @@ -5,6 +5,7 @@ package azidentity import ( "context" + "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -42,7 +43,7 @@ type AuthorizationCodeCredential struct { // options: Manage the configuration of the requests sent to Azure Active Directory, they can also include a client secret for web app authentication. func NewAuthorizationCodeCredential(tenantID string, clientID string, authCode string, redirectURL string, options *AuthorizationCodeCredentialOptions) (*AuthorizationCodeCredential, error) { if !validTenantID(tenantID) { - return nil, &CredentialUnavailableError{credentialType: "Authorization Code Credential", message: tenantIDValidationErr} + return nil, errors.New(tenantIDValidationErr) } cp := AuthorizationCodeCredentialOptions{} if options != nil { diff --git a/sdk/azidentity/authorization_code_credential_test.go b/sdk/azidentity/authorization_code_credential_test.go index 477bda1dfa65..3c3a8f626e7c 100644 --- a/sdk/azidentity/authorization_code_credential_test.go +++ b/sdk/azidentity/authorization_code_credential_test.go @@ -28,10 +28,6 @@ func TestAuthorizationCodeCredential_InvalidTenantID(t *testing.T) { if cred != nil { t.Fatalf("Expected a nil credential value. Received: %v", cred) } - var errType *CredentialUnavailableError - if !errors.As(err, &errType) { - t.Fatalf("Did not receive a CredentialUnavailableError. Received: %t", err) - } } func TestAuthorizationCodeCredential_CreateAuthRequestSuccess(t *testing.T) { @@ -109,7 +105,7 @@ func TestAuthorizationCodeCredential_GetTokenInvalidCredentials(t *testing.T) { if err == nil { t.Fatalf("Expected an error but did not receive one.") } - var authFailed *AuthenticationFailedError + var authFailed AuthenticationFailedError if !errors.As(err, &authFailed) { t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err) } diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index aa63afd11675..908611ab91b4 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -12,7 +12,6 @@ import ( "regexp" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" ) // AuthorityHost is the base URL for Azure Active Directory @@ -53,55 +52,6 @@ type tokenResponse struct { refreshToken string } -// AuthenticationFailedError is returned when the authentication request has failed. -type AuthenticationFailedError struct { - inner error - msg string - resp *http.Response -} - -// Unwrap method on AuthenticationFailedError provides access to the inner error if available. -func (e *AuthenticationFailedError) Unwrap() error { - return e.inner -} - -// NonRetriable indicates that this error should not be retried. -func (e *AuthenticationFailedError) NonRetriable() { - // marker method -} - -func (e *AuthenticationFailedError) Error() string { - return e.msg -} - -// RawResponse returns the HTTP response motivating the error, if available -func (e *AuthenticationFailedError) RawResponse() *http.Response { - return e.resp -} - -var _ azcore.HTTPResponse = (*AuthenticationFailedError)(nil) -var _ errorinfo.NonRetriable = (*AuthenticationFailedError)(nil) - -// CredentialUnavailableError is the error type returned when the conditions required to -// create a credential do not exist or are unavailable. -type CredentialUnavailableError struct { - // CredentialType holds the name of the credential that is unavailable - credentialType string - // Message contains the reason why the credential is unavailable - message string -} - -func (e *CredentialUnavailableError) Error() string { - return e.credentialType + ": " + e.message -} - -// NonRetriable indicates that this error should not be retried. -func (e *CredentialUnavailableError) NonRetriable() { - // marker method -} - -var _ errorinfo.NonRetriable = (*CredentialUnavailableError)(nil) - // setAuthorityHost initializes the authority host for credentials. func setAuthorityHost(authorityHost AuthorityHost) (string, error) { host := string(authorityHost) diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 84a54280db19..c7595d6c7a2b 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -18,6 +18,7 @@ const ( accessTokenRespError = `{"error": "invalid_client","error_description": "Invalid client secret is provided.","error_codes": [0],"timestamp": "2019-12-01 19:00:00Z","trace_id": "2d091b0","correlation_id": "a999","error_uri": "https://login.contoso.com/error?code=0"}` accessTokenRespSuccess = `{"access_token": "` + tokenValue + `", "expires_in": 3600}` accessTokenRespMalformed = `{"access_token": 0, "expires_in": 3600}` + tokenValue = "new_token" ) func defaultTestPipeline(srv policy.Transporter, cred azcore.TokenCredential, scope string) runtime.Pipeline { diff --git a/sdk/azidentity/azure_cli_credential.go b/sdk/azidentity/azure_cli_credential.go index 6947bc7468c0..5713c8f30d53 100644 --- a/sdk/azidentity/azure_cli_credential.go +++ b/sdk/azidentity/azure_cli_credential.go @@ -133,7 +133,7 @@ func defaultTokenProvider() func(ctx context.Context, resource string, tenantID // if there's no output in stderr report the error message instead msg = err.Error() } - return nil, &CredentialUnavailableError{credentialType: "Azure CLI Credential", message: msg} + return nil, newCredentialUnavailableError("Azure CLI Credential", msg) } return output, nil diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 128b65d882a3..f2eea3f271e9 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -6,6 +6,7 @@ package azidentity import ( "context" "errors" + "fmt" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -25,15 +26,11 @@ type ChainedTokenCredential struct { // NewChainedTokenCredential creates an instance of ChainedTokenCredential with the specified TokenCredential sources. func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) { if len(sources) == 0 { - credErr := &CredentialUnavailableError{credentialType: "Chained Token Credential", message: "Length of sources cannot be 0"} - logCredentialError(credErr.credentialType, credErr) - return nil, credErr + return nil, errors.New("sources must contain at least one TokenCredential") } for _, source := range sources { if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil - credErr := &CredentialUnavailableError{credentialType: "Chained Token Credential", message: "Sources cannot contain a nil TokenCredential"} - logCredentialError(credErr.credentialType, credErr) - return nil, credErr + return nil, errors.New("sources cannot contain nil") } } cp := make([]azcore.TokenCredential, len(sources)) @@ -43,22 +40,23 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine // GetToken sequentially calls TokenCredential.GetToken on all the specified sources, returning the token from the first successful call to GetToken(). func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { - var errList []*CredentialUnavailableError + var errList []CredentialUnavailableError // loop through all of the credentials provided in sources for _, cred := range c.sources { // make a GetToken request for the current credential in the loop token, err = cred.GetToken(ctx, opts) // check if we received a CredentialUnavailableError - var credErr *CredentialUnavailableError + var credErr CredentialUnavailableError if errors.As(err, &credErr) { // if we did receive a CredentialUnavailableError then we append it to our error slice and continue looping for a good credential errList = append(errList, credErr) } else if err != nil { // if we receive some other type of error then we must stop looping and process the error accordingly - var authenticationFailed *AuthenticationFailedError - if errors.As(err, &authenticationFailed) { + var authFailed AuthenticationFailedError + if errors.As(err, &authFailed) { // if the error is an AuthenticationFailedError we return the error related to the invalid credential and append all of the other error messages received prior to this point - authErr := &AuthenticationFailedError{msg: "Received an AuthenticationFailedError, there is an invalid credential in the chain. " + createChainedErrorMessage(errList), inner: err} + err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err) + authErr := newAuthenticationFailedError(err, authFailed.RawResponse()) return nil, authErr } // if we receive some other error type this is unexpected and we simple return the unexpected error @@ -70,14 +68,14 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token } } // if we reach this point it means that all of the credentials in the chain returned CredentialUnavailableErrors - credErr := &CredentialUnavailableError{credentialType: "Chained Token Credential", message: createChainedErrorMessage(errList)} + credErr := newCredentialUnavailableError("Chained Token Credential", createChainedErrorMessage(errList)) // skip adding the stack trace here as it was already logged by other calls to GetToken() addGetTokenFailureLogs("Chained Token Credential", credErr, false) return nil, credErr } // helper function used to chain the error messages of the CredentialUnavailableError slice -func createChainedErrorMessage(errList []*CredentialUnavailableError) string { +func createChainedErrorMessage(errList []CredentialUnavailableError) string { msg := "" for _, err := range errList { msg += err.Error() diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index b4f6ddb31e39..4d32c0b05bc1 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -48,17 +48,10 @@ func TestChainedTokenCredential_InstantiateFailure(t *testing.T) { if err == nil { t.Fatalf("Expected an error for sending a nil credential in the chain") } - var credErr *CredentialUnavailableError - if !errors.As(err, &credErr) { - t.Fatalf("Expected a CredentialUnavailableError, but received: %T", credErr) - } _, err = NewChainedTokenCredential([]azcore.TokenCredential{}, nil) if err == nil { t.Fatalf("Expected an error for not sending any credential sources") } - if !errors.As(err, &credErr) { - t.Fatalf("Expected a CredentialUnavailableError, but received: %T", credErr) - } } func TestChainedTokenCredential_GetTokenSuccess(t *testing.T) { @@ -118,9 +111,9 @@ func TestChainedTokenCredential_GetTokenFail(t *testing.T) { if err == nil { t.Fatalf("Expected an error but did not receive one") } - var authErr *AuthenticationFailedError + var authErr AuthenticationFailedError if !errors.As(err, &authErr) { - t.Fatalf("Expected Error Type: AuthenticationFailedError, ReceivedErrorType: %T", err) + t.Fatalf("Expected AuthenticationFailedError, received %T", err) } if len(err.Error()) == 0 { t.Fatalf("Did not create an appropriate error message") @@ -130,7 +123,7 @@ func TestChainedTokenCredential_GetTokenFail(t *testing.T) { func TestChainedTokenCredential_GetTokenWithUnavailableCredentialInChain(t *testing.T) { srv, close := mock.NewTLSServer() defer close() - srv.AppendError(&CredentialUnavailableError{credentialType: "MockCredential", message: "Mocking a credential unavailable error"}) + srv.AppendError(newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")) srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) options := ClientSecretCredentialOptions{} options.AuthorityHost = AuthorityHost(srv.URL()) diff --git a/sdk/azidentity/client_certificate_credential.go b/sdk/azidentity/client_certificate_credential.go index 4efadb1970b4..d31ac88258ba 100644 --- a/sdk/azidentity/client_certificate_credential.go +++ b/sdk/azidentity/client_certificate_credential.go @@ -49,7 +49,7 @@ type ClientCertificateCredential struct { // options: ClientCertificateCredentialOptions that can be used to provide additional configurations for the credential, such as the certificate password. func NewClientCertificateCredential(tenantID string, clientID string, certData []byte, options *ClientCertificateCredentialOptions) (*ClientCertificateCredential, error) { if !validTenantID(tenantID) { - return nil, &CredentialUnavailableError{credentialType: "Client Certificate Credential", message: tenantIDValidationErr} + return nil, errors.New(tenantIDValidationErr) } cp := ClientCertificateCredentialOptions{} if options != nil { @@ -60,9 +60,8 @@ func NewClientCertificateCredential(tenantID string, clientID string, certData [ cert, err = loadPKCS12Cert(certData, cp.Password, cp.SendCertificateChain) } if err != nil { - credErr := &CredentialUnavailableError{credentialType: "Client Certificate Credential", message: err.Error()} - logCredentialError(credErr.credentialType, credErr) - return nil, credErr + logCredentialError("Client Certificate Credential", err) + return nil, err } authorityHost, err := setAuthorityHost(cp.AuthorityHost) if err != nil { diff --git a/sdk/azidentity/client_certificate_credential_test.go b/sdk/azidentity/client_certificate_credential_test.go index ea5327b8786a..4a23449d3c4d 100644 --- a/sdk/azidentity/client_certificate_credential_test.go +++ b/sdk/azidentity/client_certificate_credential_test.go @@ -40,10 +40,6 @@ func TestClientCertificateCredential_InvalidTenantID(t *testing.T) { if cred != nil { t.Fatalf("Expected a nil credential value. Received: %v", cred) } - var errType *CredentialUnavailableError - if !errors.As(err, &errType) { - t.Fatalf("Did not receive a CredentialUnavailableError. Received: %t", err) - } } func TestClientCertificateCredential_CreateAuthRequestSuccess(t *testing.T) { @@ -221,7 +217,7 @@ func TestClientCertificateCredential_GetTokenInvalidCredentials(t *testing.T) { if err == nil { t.Fatalf("Expected to receive a nil error, but received: %v", err) } - var authFailed *AuthenticationFailedError + var authFailed AuthenticationFailedError if !errors.As(err, &authFailed) { t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err) } diff --git a/sdk/azidentity/client_secret_credential.go b/sdk/azidentity/client_secret_credential.go index fba9fad57404..adcfa4d48255 100644 --- a/sdk/azidentity/client_secret_credential.go +++ b/sdk/azidentity/client_secret_credential.go @@ -5,6 +5,7 @@ package azidentity import ( "context" + "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -37,7 +38,7 @@ type ClientSecretCredential struct { // options: allow to configure the management of the requests sent to Azure Active Directory. func NewClientSecretCredential(tenantID string, clientID string, clientSecret string, options *ClientSecretCredentialOptions) (*ClientSecretCredential, error) { if !validTenantID(tenantID) { - return nil, &CredentialUnavailableError{credentialType: "Client Secret Credential", message: tenantIDValidationErr} + return nil, errors.New(tenantIDValidationErr) } cp := ClientSecretCredentialOptions{} if options != nil { diff --git a/sdk/azidentity/client_secret_credential_test.go b/sdk/azidentity/client_secret_credential_test.go index 9814b6ce5a7a..d12a2485d76f 100644 --- a/sdk/azidentity/client_secret_credential_test.go +++ b/sdk/azidentity/client_secret_credential_test.go @@ -21,7 +21,6 @@ const ( clientID = "expected-client-id" secret = "secret" wrongSecret = "wrong_secret" - tokenValue = "new_token" scope = "https://storage.azure.com/.default" defaultTestAuthorityHost = "login.microsoftonline.com" ) @@ -34,10 +33,6 @@ func TestClientSecretCredential_InvalidTenantID(t *testing.T) { if cred != nil { t.Fatalf("Expected a nil credential value. Received: %v", cred) } - var errType *CredentialUnavailableError - if !errors.As(err, &errType) { - t.Fatalf("Did not receive a CredentialUnavailableError. Received: %t", err) - } } func TestClientSecretCredential_CreateAuthRequestSuccess(t *testing.T) { @@ -110,7 +105,7 @@ func TestClientSecretCredential_GetTokenInvalidCredentials(t *testing.T) { if err == nil { t.Fatalf("Expected an error but did not receive one.") } - var authFailed *AuthenticationFailedError + var authFailed AuthenticationFailedError if !errors.As(err, &authFailed) { t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err) } diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go index 29bf79f4966d..51957159c4f2 100644 --- a/sdk/azidentity/default_azure_credential.go +++ b/sdk/azidentity/default_azure_credential.go @@ -5,6 +5,7 @@ package azidentity import ( "context" + "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -65,17 +66,16 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default errMsg += err.Error() } - cliCred, err := NewAzureCLICredential(&AzureCLICredentialOptions{TenantID: options.TenantID}) + cliCred, err := NewAzureCLICredential(&AzureCLICredentialOptions{TenantID: cp.TenantID}) if err == nil { creds = append(creds, cliCred) } else { errMsg += err.Error() } - // if no credentials are added to the slice of TokenCredentials then return a CredentialUnavailableError if len(creds) == 0 { - err := &CredentialUnavailableError{credentialType: "Default Azure Credential", message: errMsg} - logCredentialError(err.credentialType, err) + err := errors.New(errMsg) + logCredentialError("Default Azure Credential", err) return nil, err } chain, err := NewChainedTokenCredential(creds, nil) diff --git a/sdk/azidentity/device_code_credential.go b/sdk/azidentity/device_code_credential.go index 1498d459b390..90ee87e8766b 100644 --- a/sdk/azidentity/device_code_credential.go +++ b/sdk/azidentity/device_code_credential.go @@ -88,7 +88,7 @@ func NewDeviceCodeCredential(options *DeviceCodeCredentialOptions) (*DeviceCodeC } cp.init() if !validTenantID(cp.TenantID) { - return nil, &CredentialUnavailableError{credentialType: "Device Code Credential", message: tenantIDValidationErr} + return nil, errors.New(tenantIDValidationErr) } authorityHost, err := setAuthorityHost(cp.AuthorityHost) if err != nil { @@ -132,8 +132,9 @@ func (c *DeviceCodeCredential) GetToken(ctx context.Context, opts policy.TokenRe // make initial request to the device code endpoint for a device code and instructions for authentication dc, err := c.client.requestNewDeviceCode(ctx, c.tenantID, c.clientID, opts.Scopes) if err != nil { - addGetTokenFailureLogs("Device Code Credential", err, true) - return nil, err // TODO check what error type to return here + authErr := newAuthenticationFailedError(err, nil) + addGetTokenFailureLogs("Device Code Credential", authErr, true) + return nil, authErr } // send authentication flow instructions back to the user to log in and authorize the device @@ -156,8 +157,8 @@ func (c *DeviceCodeCredential) GetToken(ctx context.Context, opts policy.TokenRe } // if there is an error, check for an AADAuthenticationFailedError in order to check the status for token retrieval // if the error is not an AADAuthenticationFailedError, then fail here since something unexpected occurred - var authFailed *AuthenticationFailedError - if errors.As(err, &authFailed) && strings.Contains(authFailed.msg, "authorization_pending") { + var authFailed AuthenticationFailedError + if errors.As(err, &authFailed) && strings.Contains(authFailed.Error(), "authorization_pending") { // wait for the interval specified from the initial device code endpoint and then poll for the token again time.Sleep(time.Duration(dc.Interval) * time.Second) } else { diff --git a/sdk/azidentity/device_code_credential_test.go b/sdk/azidentity/device_code_credential_test.go index 6296edc4bba4..d0c116b9e9dd 100644 --- a/sdk/azidentity/device_code_credential_test.go +++ b/sdk/azidentity/device_code_credential_test.go @@ -35,10 +35,6 @@ func TestDeviceCodeCredential_InvalidTenantID(t *testing.T) { if cred != nil { t.Fatalf("Expected a nil credential value. Received: %v", cred) } - var errType *CredentialUnavailableError - if !errors.As(err, &errType) { - t.Fatalf("Did not receive a CredentialUnavailableError. Received: %t", err) - } } func TestDeviceCodeCredential_CreateAuthRequestSuccess(t *testing.T) { @@ -273,7 +269,7 @@ func TestDeviceCodeCredential_GetTokenWithRefreshTokenFailure(t *testing.T) { if err == nil { t.Fatalf("Expected an error but did not receive one") } - var authFailed *AuthenticationFailedError + var authFailed AuthenticationFailedError if !errors.As(err, &authFailed) { t.Fatalf("Expected AuthenticationFailedError, got %T", err) } diff --git a/sdk/azidentity/environment_credential.go b/sdk/azidentity/environment_credential.go index fcc46641ff54..403fb80ebcbc 100644 --- a/sdk/azidentity/environment_credential.go +++ b/sdk/azidentity/environment_credential.go @@ -5,6 +5,8 @@ package azidentity import ( "context" + "errors" + "fmt" "os" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -46,15 +48,11 @@ func NewEnvironmentCredential(options *EnvironmentCredentialOptions) (*Environme } tenantID := os.Getenv("AZURE_TENANT_ID") if tenantID == "" { - err := &CredentialUnavailableError{credentialType: "Environment Credential", message: "Missing environment variable AZURE_TENANT_ID"} - logCredentialError(err.credentialType, err) - return nil, err + return nil, errors.New("Missing environment variable AZURE_TENANT_ID") } clientID := os.Getenv("AZURE_CLIENT_ID") if clientID == "" { - err := &CredentialUnavailableError{credentialType: "Environment Credential", message: "Missing environment variable AZURE_CLIENT_ID"} - logCredentialError(err.credentialType, err) - return nil, err + return nil, errors.New("Missing environment variable AZURE_CLIENT_ID") } if clientSecret := os.Getenv("AZURE_CLIENT_SECRET"); clientSecret != "" { log.Write(EventCredential, "Azure Identity => NewEnvironmentCredential() invoking ClientSecretCredential") @@ -68,7 +66,7 @@ func NewEnvironmentCredential(options *EnvironmentCredentialOptions) (*Environme log.Write(EventCredential, "Azure Identity => NewEnvironmentCredential() invoking ClientCertificateCredential") certData, err := os.ReadFile(certPath) if err != nil { - return nil, &CredentialUnavailableError{credentialType: "Environment Credential", message: "Failed to read certificate file: " + err.Error()} + return nil, fmt.Errorf("Failed to read certificate file: %v", err) } cred, err := NewClientCertificateCredential(tenantID, clientID, certData, &ClientCertificateCredentialOptions{AuthorityHost: cp.AuthorityHost, ClientOptions: cp.ClientOptions}) if err != nil { @@ -86,9 +84,7 @@ func NewEnvironmentCredential(options *EnvironmentCredentialOptions) (*Environme return &EnvironmentCredential{cred: cred}, nil } } - err := &CredentialUnavailableError{credentialType: "Environment Credential", message: "Missing environment variable AZURE_CLIENT_SECRET or AZURE_CLIENT_CERTIFICATE_PATH or AZURE_USERNAME and AZURE_PASSWORD"} - logCredentialError(err.credentialType, err) - return nil, err + return nil, errors.New("Missing environment variable AZURE_CLIENT_SECRET or AZURE_CLIENT_CERTIFICATE_PATH or AZURE_USERNAME and AZURE_PASSWORD") } // GetToken obtains a token from Azure Active Directory, using the underlying credential's GetToken method. diff --git a/sdk/azidentity/environment_credential_test.go b/sdk/azidentity/environment_credential_test.go index 427665df73d2..9d4caaacb314 100644 --- a/sdk/azidentity/environment_credential_test.go +++ b/sdk/azidentity/environment_credential_test.go @@ -4,7 +4,6 @@ package azidentity import ( - "errors" "os" "testing" ) @@ -43,10 +42,6 @@ func TestEnvironmentCredential_TenantIDNotSet(t *testing.T) { if err == nil { t.Fatalf("Expected an error but received nil") } - var credentialUnavailable *CredentialUnavailableError - if !errors.As(err, &credentialUnavailable) { - t.Fatalf("Expected a credential unavailable error, instead received: %T", err) - } } func TestEnvironmentCredential_ClientIDNotSet(t *testing.T) { @@ -63,10 +58,6 @@ func TestEnvironmentCredential_ClientIDNotSet(t *testing.T) { if err == nil { t.Fatalf("Expected an error but received nil") } - var credentialUnavailable *CredentialUnavailableError - if !errors.As(err, &credentialUnavailable) { - t.Fatalf("Expected a credential unavailable error, instead received: %T", err) - } } func TestEnvironmentCredential_ClientSecretNotSet(t *testing.T) { @@ -83,10 +74,6 @@ func TestEnvironmentCredential_ClientSecretNotSet(t *testing.T) { if err == nil { t.Fatalf("Expected an error but received nil") } - var credentialUnavailable *CredentialUnavailableError - if !errors.As(err, &credentialUnavailable) { - t.Fatalf("Expected a credential unavailable error, instead received: %T", err) - } } func TestEnvironmentCredential_ClientSecretSet(t *testing.T) { @@ -153,10 +140,6 @@ func TestEnvironmentCredential_UsernameOnlySet(t *testing.T) { if err == nil { t.Fatalf("Expected an error but received nil") } - var credentialUnavailable *CredentialUnavailableError - if !errors.As(err, &credentialUnavailable) { - t.Fatalf("Expected a credential unavailable error, instead received: %T", err) - } } func TestEnvironmentCredential_UsernamePasswordSet(t *testing.T) { diff --git a/sdk/azidentity/errors.go b/sdk/azidentity/errors.go new file mode 100644 index 000000000000..9c859fc01c41 --- /dev/null +++ b/sdk/azidentity/errors.go @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azidentity + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" +) + +// AuthenticationFailedError indicates an authentication request has failed. +type AuthenticationFailedError interface { + azcore.HTTPResponse + errorinfo.NonRetriable + AuthenticationFailed() +} + +type authenticationFailedError struct { + error + resp *http.Response +} + +func newAuthenticationFailedError(err error, resp *http.Response) AuthenticationFailedError { + return authenticationFailedError{err, resp} +} + +// NonRetriable indicates that this error should not be retried. +func (authenticationFailedError) NonRetriable() { + // marker method +} + +// AuthenticationFailed indicates that an authentication attempt failed +func (authenticationFailedError) AuthenticationFailed() { + // marker method +} + +// RawResponse returns the HTTP response motivating the error, if available. +func (e authenticationFailedError) RawResponse() *http.Response { + return e.resp +} + +var _ AuthenticationFailedError = (*authenticationFailedError)(nil) +var _ azcore.HTTPResponse = (*authenticationFailedError)(nil) +var _ errorinfo.NonRetriable = (*authenticationFailedError)(nil) + +// CredentialUnavailableError indicates a credential can't attempt authenticate +// because it lacks required data or state. +type CredentialUnavailableError interface { + errorinfo.NonRetriable + CredentialUnavailable() +} + +type credentialUnavailableError struct { + credType string + message string +} + +func newCredentialUnavailableError(credType, message string) CredentialUnavailableError { + return credentialUnavailableError{credType: credType, message: message} +} + +func (e credentialUnavailableError) Error() string { + return e.credType + ": " + e.message +} + +// NonRetriable indicates that this error should not be retried. +func (e credentialUnavailableError) NonRetriable() { + // marker method +} + +// CredentialUnavailable indicates that the credential didn't attempt to authenticate +func (e credentialUnavailableError) CredentialUnavailable() { + // marker method +} + +var _ CredentialUnavailableError = (*credentialUnavailableError)(nil) +var _ errorinfo.NonRetriable = (*credentialUnavailableError)(nil) diff --git a/sdk/azidentity/interactive_browser_credential.go b/sdk/azidentity/interactive_browser_credential.go index f56c615db2f0..5c5c00c5e907 100644 --- a/sdk/azidentity/interactive_browser_credential.go +++ b/sdk/azidentity/interactive_browser_credential.go @@ -7,6 +7,7 @@ import ( "context" "crypto/sha256" "encoding/base64" + "errors" "math/rand" "net/url" "path" @@ -64,7 +65,7 @@ func NewInteractiveBrowserCredential(options *InteractiveBrowserCredentialOption } cp.init() if !validTenantID(cp.TenantID) { - return nil, &CredentialUnavailableError{credentialType: "Interactive Browser Credential", message: tenantIDValidationErr} + return nil, errors.New(tenantIDValidationErr) } authorityHost, err := setAuthorityHost(cp.AuthorityHost) if err != nil { diff --git a/sdk/azidentity/interactive_browser_credential_test.go b/sdk/azidentity/interactive_browser_credential_test.go index c1234f3e1715..25cf7e146144 100644 --- a/sdk/azidentity/interactive_browser_credential_test.go +++ b/sdk/azidentity/interactive_browser_credential_test.go @@ -24,10 +24,6 @@ func TestInteractiveBrowserCredential_InvalidTenantID(t *testing.T) { if cred != nil { t.Fatalf("Expected a nil credential value. Received: %v", cred) } - var errType *CredentialUnavailableError - if !errors.As(err, &errType) { - t.Fatalf("Did not receive a CredentialUnavailableError. Received: %t", err) - } } func TestInteractiveBrowserCredential_CreateWithNilOptions(t *testing.T) { @@ -105,7 +101,7 @@ func TestInteractiveBrowserCredential_GetTokenInvalidCredentials(t *testing.T) { if err == nil { t.Fatalf("Expected an error but did not receive one.") } - var authFailed *AuthenticationFailedError + var authFailed AuthenticationFailedError if !errors.As(err, &authFailed) { t.Fatalf("Expected: AuthenticationFailedError, Received: %T", err) } diff --git a/sdk/azidentity/jwt.go b/sdk/azidentity/jwt.go index 88a815337f05..84f8c93aa1eb 100644 --- a/sdk/azidentity/jwt.go +++ b/sdk/azidentity/jwt.go @@ -48,7 +48,7 @@ func createClientAssertionJWT(clientID string, audience string, cert *certConten headerJSON, err := json.Marshal(headerData) if err != nil { - return "", fmt.Errorf("marshal headerJWT: %w", err) + return "", fmt.Errorf("marshal headerJWT: %v", err) } header := base64.RawURLEncoding.EncodeToString(headerJSON) jti, err := uuid.New() @@ -66,7 +66,7 @@ func createClientAssertionJWT(clientID string, audience string, cert *certConten payloadJSON, err := json.Marshal(payloadData) if err != nil { - return "", fmt.Errorf("marshal payloadJWT: %w", err) + return "", fmt.Errorf("marshal payloadJWT: %v", err) } payload := base64.RawURLEncoding.EncodeToString(payloadJSON) result := header + "." + payload diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index b5ef1286c394..591a25bc6130 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -136,7 +136,7 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) *manage // scopes: The scopes required for the token. func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKind, scopes []string) (*azcore.AccessToken, error) { if len(c.unavailableMessage) > 0 { - return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage} + return nil, newCredentialUnavailableError("Managed Identity Credential", c.unavailableMessage) } msg, err := c.createAuthRequest(ctx, id, scopes) @@ -155,13 +155,13 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKi if c.msiType == msiTypeIMDS && resp.StatusCode == 400 { if id != nil { - return nil, &AuthenticationFailedError{msg: "The requested identity isn't assigned to this resource."} + return nil, newAuthenticationFailedError(errors.New("The requested identity isn't assigned to this resource."), resp) } c.unavailableMessage = "No default identity is assigned to this resource." - return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage} + return nil, newCredentialUnavailableError("Managed Identity Credential", c.unavailableMessage) } - return nil, &AuthenticationFailedError{resp: resp, msg: "authentication failed"} + return nil, newAuthenticationFailedError(errors.New("authentication failed"), resp) } func (c *managedIdentityClient) createAccessToken(res *http.Response) (*azcore.AccessToken, error) { @@ -173,7 +173,7 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (*azcore.A ExpiresOn interface{} `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string }{} if err := runtime.UnmarshalAsJSON(res, &value); err != nil { - return nil, fmt.Errorf("internal AccessToken: %w", err) + return nil, fmt.Errorf("internal AccessToken: %v", err) } if value.ExpiresIn != "" { expiresIn, err := json.Number(value.ExpiresIn).Int64() @@ -201,7 +201,8 @@ func (c *managedIdentityClient) createAccessToken(res *http.Response) (*azcore.A return nil, err } default: - return nil, &AuthenticationFailedError{msg: fmt.Sprintf("unsupported type received in expires_on: %T, %v", v, v)} + err := fmt.Errorf("unsupported type received in expires_on: %T, %v", v, v) + return nil, newAuthenticationFailedError(err, res) } } @@ -215,7 +216,8 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage // need to perform preliminary request to retreive the secret key challenge provided by the HIMDS service key, err := c.getAzureArcSecretKey(ctx, scopes) if err != nil { - return nil, &AuthenticationFailedError{msg: "failed to retreive secret key from the identity endpoint"} + msg := fmt.Errorf("failed to retreive secret key from the identity endpoint: %v", err) + return nil, newAuthenticationFailedError(msg, nil) } return c.createAzureArcAuthRequest(ctx, key, scopes) case msiTypeServiceFabric: @@ -231,7 +233,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id Manage errorMsg = "unknown" } c.unavailableMessage = "managed identity support is " + errorMsg - return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage} + return nil, newCredentialUnavailableError("Managed Identity Credential", c.unavailableMessage) } } @@ -326,7 +328,8 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour // the endpoint is expected to return a 401 with the WWW-Authenticate header set to the location // of the secret key file. Any other status code indicates an error in the request. if response.StatusCode != 401 { - return "", &AuthenticationFailedError{resp: response, msg: fmt.Sprintf("expected a 401 response, received %d", response.StatusCode)} + err := fmt.Errorf("expected a 401 response, received %d", response.StatusCode) + return "", newAuthenticationFailedError(err, response) } header := response.Header.Get("WWW-Authenticate") if len(header) == 0 { @@ -339,7 +342,7 @@ func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resour } key, err := ioutil.ReadFile(header[pos+1:]) if err != nil { - return "", fmt.Errorf("could not read file (%s) contents: %w", header[pos+1:], err) + return "", fmt.Errorf("could not read file (%s) contents: %v", header[pos+1:], err) } return string(key), nil } @@ -397,14 +400,14 @@ func (c *managedIdentityClient) getMSIType() (msiType, error) { c.msiType = msiTypeAzureArc } else { c.msiType = msiTypeUnavailable - return c.msiType, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: "this environment is not supported yet"} + return c.msiType, newCredentialUnavailableError("Managed Identity Credential", "this environment is not supported") } } else if c.imdsAvailable() { // if MSI_ENDPOINT is NOT set AND the IMDS endpoint is available the msiType is IMDS. This will timeout after 500 milliseconds c.endpoint = imdsEndpoint c.msiType = msiTypeIMDS } else { // if MSI_ENDPOINT is NOT set and IMDS endpoint is not available Managed Identity is not available c.msiType = msiTypeUnavailable - return c.msiType, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: "no managed identity endpoint is available"} + return c.msiType, newCredentialUnavailableError("Managed Identity Credential", "no managed identity endpoint is available") } } return c.msiType, nil diff --git a/sdk/azidentity/managed_identity_credential.go b/sdk/azidentity/managed_identity_credential.go index fc9e2a5f17d7..0619c49d922e 100644 --- a/sdk/azidentity/managed_identity_credential.go +++ b/sdk/azidentity/managed_identity_credential.go @@ -5,6 +5,7 @@ package azidentity import ( "context" + "errors" "fmt" "strings" @@ -82,9 +83,8 @@ func NewManagedIdentityCredential(options *ManagedIdentityCredentialOptions) (*M msiType, err := client.getMSIType() // If there is an error that means that the code is not running in a Managed Identity environment if err != nil { - credErr := &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: "Please make sure you are running in a managed identity environment, such as a VM, Azure Functions, Cloud Shell, etc..."} - logCredentialError(credErr.credentialType, credErr) - return nil, credErr + logCredentialError("Managed Identity Credential", err) + return nil, err } // Assign the msiType discovered onto the client client.msiType = msiType @@ -96,12 +96,12 @@ func NewManagedIdentityCredential(options *ManagedIdentityCredentialOptions) (*M // Returns an AccessToken which can be used to authenticate service client calls. func (c *ManagedIdentityCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (*azcore.AccessToken, error) { if opts.Scopes == nil { - err := &AuthenticationFailedError{msg: "must specify a resource in order to authenticate"} + err := errors.New("must specify a resource in order to authenticate") addGetTokenFailureLogs("Managed Identity Credential", err, true) return nil, err } if len(opts.Scopes) != 1 { - err := &AuthenticationFailedError{msg: "can only specify one resource to authenticate with ManagedIdentityCredential"} + err := errors.New("can only specify one resource to authenticate with ManagedIdentityCredential") addGetTokenFailureLogs("Managed Identity Credential", err, true) return nil, err } diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index 4257cc37f6f7..bfb57fc2ea5b 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -374,7 +374,7 @@ func TestManagedIdentityCredential_GetTokenIMDS400(t *testing.T) { } // cred should return CredentialUnavailableError when IMDS responds 400 to a token request. // Also, it shouldn't send another token request (mockIMDS will appropriately panic if it does). - var expected *CredentialUnavailableError + var expected CredentialUnavailableError for i := 0; i < 3; i++ { _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{msiScope}}) if !errors.As(err, &expected) { diff --git a/sdk/azidentity/username_password_credential.go b/sdk/azidentity/username_password_credential.go index 60bae9dbb754..b0f526c9073e 100644 --- a/sdk/azidentity/username_password_credential.go +++ b/sdk/azidentity/username_password_credential.go @@ -5,6 +5,7 @@ package azidentity import ( "context" + "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -41,7 +42,7 @@ type UsernamePasswordCredential struct { // options: UsernamePasswordCredentialOptions used to configure the pipeline for the requests sent to Azure Active Directory. func NewUsernamePasswordCredential(tenantID string, clientID string, username string, password string, options *UsernamePasswordCredentialOptions) (*UsernamePasswordCredential, error) { if !validTenantID(tenantID) { - return nil, &CredentialUnavailableError{credentialType: "Username Password Credential", message: tenantIDValidationErr} + return nil, errors.New(tenantIDValidationErr) } cp := UsernamePasswordCredentialOptions{} if options != nil { diff --git a/sdk/azidentity/username_password_credential_test.go b/sdk/azidentity/username_password_credential_test.go index ae92d41df39c..501100175db1 100644 --- a/sdk/azidentity/username_password_credential_test.go +++ b/sdk/azidentity/username_password_credential_test.go @@ -5,7 +5,6 @@ package azidentity import ( "context" - "errors" "io/ioutil" "net/http" "net/url" @@ -24,10 +23,6 @@ func TestUsernamePasswordCredential_InvalidTenantID(t *testing.T) { if cred != nil { t.Fatalf("Expected a nil credential value. Received: %v", cred) } - var errType *CredentialUnavailableError - if !errors.As(err, &errType) { - t.Fatalf("Did not receive a CredentialUnavailableError. Received: %t", err) - } } func TestUsernamePasswordCredential_CreateAuthRequestSuccess(t *testing.T) {