diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index ec382cd83165..dcde20193d9d 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -51,6 +51,12 @@ * Credential options structs now embed `azcore.ClientOptions`. In addition to changing literal initialization syntax, this change renames `HTTPClient` fields to `Transport`. * Renamed `LogCredential` to `EventCredential` +* `AzureCLICredential` no longer reads the environment variable `AZURE_CLI_PATH` +* `NewManagedIdentityCredential` no longer reads environment variables `AZURE_CLIENT_ID` and + `AZURE_RESOURCE_ID`. Use `ManagedIdentityCredentialOptions.ID` instead. + +### Bugs Fixed +* `AzureCLICredential.GetToken` no longer mutates its `opts.Scopes` ### Features Added * Added connection configuration options to `DefaultAzureCredentialOptions` diff --git a/sdk/azidentity/azure_cli_credential.go b/sdk/azidentity/azure_cli_credential.go index 044ad9017166..6947bc7468c0 100644 --- a/sdk/azidentity/azure_cli_credential.go +++ b/sdk/azidentity/azure_cli_credential.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "os" "os/exec" @@ -63,9 +64,12 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent // opts: TokenRequestOptions contains the list of scopes for which the token will have access. // Returns an AccessToken which can be used to authenticate service client calls. func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (*azcore.AccessToken, error) { - // The following code will remove the /.default suffix from the scope passed into the method since AzureCLI expect a resource string instead of a scope string - opts.Scopes[0] = strings.TrimSuffix(opts.Scopes[0], defaultSuffix) - at, err := c.authenticate(ctx, opts.Scopes[0]) + if len(opts.Scopes) != 1 { + return nil, errors.New("this credential requires exactly one scope per token request") + } + // CLI expects an AAD v1 resource, not a v2 scope + scope := strings.TrimSuffix(opts.Scopes[0], defaultSuffix) + at, err := c.authenticate(ctx, scope) if err != nil { addGetTokenFailureLogs("Azure CLI Credential", err, true) return nil, err @@ -74,7 +78,7 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ return at, nil } -const timeoutCLIRequest = 10000 * time.Millisecond +const timeoutCLIRequest = 10 * time.Second // authenticate creates a client secret authentication request and returns the resulting Access Token or // an error in case of authentication failure. @@ -91,44 +95,34 @@ func (c *AzureCLICredential) authenticate(ctx context.Context, resource string) func defaultTokenProvider() func(ctx context.Context, resource string, tenantID string) ([]byte, error) { return func(ctx context.Context, resource string, tenantID string) ([]byte, error) { - // This is the path that a developer can set to tell this class what the install path for Azure CLI is. - const azureCLIPath = "AZURE_CLI_PATH" - - // The default install paths are used to find Azure CLI. This is for security, so that any path in the calling program's Path environment is not used to execute Azure CLI. - azureCLIDefaultPathWindows := fmt.Sprintf("%s\\Microsoft SDKs\\Azure\\CLI2\\wbin; %s\\Microsoft SDKs\\Azure\\CLI2\\wbin", os.Getenv("ProgramFiles(x86)"), os.Getenv("ProgramFiles")) - - // Default path for non-Windows. - const azureCLIDefaultPath = "/bin:/sbin:/usr/bin:/usr/local/bin" - - // Validate resource, since it gets sent as a command line argument to Azure CLI - const invalidResourceErrorTemplate = "resource %s is not in expected format. Only alphanumeric characters, [dot], [colon], [hyphen], and [forward slash] are allowed" match, err := regexp.MatchString("^[0-9a-zA-Z-.:/]+$", resource) if err != nil { return nil, err } if !match { - return nil, fmt.Errorf(invalidResourceErrorTemplate, resource) + return nil, fmt.Errorf(`unexpected scope "%s". Only alphanumeric characters and ".", ";", "-", and "/" are allowed`, resource) } ctx, cancel := context.WithTimeout(ctx, timeoutCLIRequest) defer cancel() - // Execute Azure CLI to get token + commandLine := "az account get-access-token -o json --resource " + resource + if tenantID != "" { + commandLine += " --tenant " + tenantID + } var cliCmd *exec.Cmd if runtime.GOOS == "windows" { - cliCmd = exec.CommandContext(ctx, fmt.Sprintf("%s\\system32\\cmd.exe", os.Getenv("windir"))) - cliCmd.Env = os.Environ() - cliCmd.Env = append(cliCmd.Env, fmt.Sprintf("PATH=%s;%s", os.Getenv(azureCLIPath), azureCLIDefaultPathWindows)) - cliCmd.Args = append(cliCmd.Args, "/c", "az") + dir := os.Getenv("SYSTEMROOT") + if dir == "" { + return nil, errors.New("environment variable 'SYSTEMROOT' has no value") + } + cliCmd = exec.CommandContext(ctx, "cmd.exe", "/c", commandLine) + cliCmd.Dir = dir } else { - cliCmd = exec.CommandContext(ctx, "az") - cliCmd.Env = os.Environ() - cliCmd.Env = append(cliCmd.Env, fmt.Sprintf("PATH=%s:%s", os.Getenv(azureCLIPath), azureCLIDefaultPath)) - } - cliCmd.Args = append(cliCmd.Args, "account", "get-access-token", "-o", "json", "--resource", resource) - if tenantID != "" { - cliCmd.Args = append(cliCmd.Args, "--tenant", tenantID) + cliCmd = exec.CommandContext(ctx, "/bin/sh", "-c", commandLine) + cliCmd.Dir = "/bin" } + cliCmd.Env = os.Environ() var stderr bytes.Buffer cliCmd.Stderr = &stderr diff --git a/sdk/azidentity/logging.go b/sdk/azidentity/logging.go index 2a2613ac79fd..0003a20d53ca 100644 --- a/sdk/azidentity/logging.go +++ b/sdk/azidentity/logging.go @@ -38,9 +38,6 @@ func logEnvVars() { if envCheck := os.Getenv("AZURE_AUTHORITY_HOST"); len(envCheck) > 0 { envVars = append(envVars, "AZURE_AUTHORITY_HOST") } - if envCheck := os.Getenv("AZURE_CLI_PATH"); len(envCheck) > 0 { - envVars = append(envVars, "AZURE_CLI_PATH") - } if len(envVars) > 0 { log.Writef(EventCredential, "Azure Identity => Found the following environment variables:\n\t%s", strings.Join(envVars, ", ")) } diff --git a/sdk/azidentity/managed_identity_credential.go b/sdk/azidentity/managed_identity_credential.go index ad4159f616ea..fc9e2a5f17d7 100644 --- a/sdk/azidentity/managed_identity_credential.go +++ b/sdk/azidentity/managed_identity_credential.go @@ -6,7 +6,6 @@ package azidentity import ( "context" "fmt" - "os" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -89,20 +88,7 @@ func NewManagedIdentityCredential(options *ManagedIdentityCredentialOptions) (*M } // Assign the msiType discovered onto the client client.msiType = msiType - // check if no clientID is specified then check if it exists in an environment variable - id := cp.ID - if id == nil { - cID := os.Getenv("AZURE_CLIENT_ID") - if cID != "" { - id = ClientID(cID) - } else { - rID := os.Getenv("AZURE_RESOURCE_ID") - if rID != "" { - id = ResourceID(rID) - } - } - } - return &ManagedIdentityCredential{id: id, client: client}, nil + return &ManagedIdentityCredential{id: cp.ID, client: client}, nil } // GetToken obtains an AccessToken from the Managed Identity service if available. diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index 01a9cbb07bdd..4257cc37f6f7 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -704,34 +704,3 @@ func TestManagedIdentityCredential_CreateAccessTokenExpiresOnFail(t *testing.T) t.Fatalf("expected to receive an error but received none") } } - -func TestManagedIdentityCredential_ResourceID_envVar(t *testing.T) { - // setting a dummy value for IDENTITY_ENDPOINT in order to be able to get a ManagedIdentityCredential type - _ = os.Setenv("IDENTITY_ENDPOINT", "somevalue") - _ = os.Setenv("IDENTITY_HEADER", "header") - _ = os.Setenv("AZURE_RESOURCE_ID", "resource_id") - defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER", "AZURE_CLIENT_ID", "AZURE_RESOURCE_ID") - cred, err := NewManagedIdentityCredential(nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if cred.id != ResourceID("resource_id") { - t.Fatal("unexpected id value stored") - } - _ = os.Setenv("AZURE_RESOURCE_ID", "") - _ = os.Setenv("AZURE_CLIENT_ID", "client_id") - cred, err = NewManagedIdentityCredential(nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if cred.id != ClientID("client_id") { - t.Fatal("unexpected id value stored") - } - cred, err = NewManagedIdentityCredential(nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if cred.id != ClientID("client_id") { - t.Fatal("unexpected id value stored") - } -}