Skip to content

Commit

Permalink
Align azidentity environment variables with other SDKs (Azure#15928)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored and jhendrixMSFT committed Jan 12, 2022
1 parent 3683a5f commit e1c1c6c
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 77 deletions.
6 changes: 6 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@
* Credential options structs now embed `azcore.ClientOptions`. In addition to changing literal initialization
syntax, this change renames `HTTPClient` fields to `Transport`.
* Renamed `LogCredential` to `EventCredential`
* `AzureCLICredential` no longer reads the environment variable `AZURE_CLI_PATH`
* `NewManagedIdentityCredential` no longer reads environment variables `AZURE_CLIENT_ID` and
`AZURE_RESOURCE_ID`. Use `ManagedIdentityCredentialOptions.ID` instead.

### Bugs Fixed
* `AzureCLICredential.GetToken` no longer mutates its `opts.Scopes`

### Features Added
* Added connection configuration options to `DefaultAzureCredentialOptions`
Expand Down
50 changes: 22 additions & 28 deletions sdk/azidentity/azure_cli_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"os"
"os/exec"
Expand Down Expand Up @@ -63,9 +64,12 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent
// opts: TokenRequestOptions contains the list of scopes for which the token will have access.
// Returns an AccessToken which can be used to authenticate service client calls.
func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (*azcore.AccessToken, error) {
// The following code will remove the /.default suffix from the scope passed into the method since AzureCLI expect a resource string instead of a scope string
opts.Scopes[0] = strings.TrimSuffix(opts.Scopes[0], defaultSuffix)
at, err := c.authenticate(ctx, opts.Scopes[0])
if len(opts.Scopes) != 1 {
return nil, errors.New("this credential requires exactly one scope per token request")
}
// CLI expects an AAD v1 resource, not a v2 scope
scope := strings.TrimSuffix(opts.Scopes[0], defaultSuffix)
at, err := c.authenticate(ctx, scope)
if err != nil {
addGetTokenFailureLogs("Azure CLI Credential", err, true)
return nil, err
Expand All @@ -74,7 +78,7 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ
return at, nil
}

const timeoutCLIRequest = 10000 * time.Millisecond
const timeoutCLIRequest = 10 * time.Second

// authenticate creates a client secret authentication request and returns the resulting Access Token or
// an error in case of authentication failure.
Expand All @@ -91,44 +95,34 @@ func (c *AzureCLICredential) authenticate(ctx context.Context, resource string)

func defaultTokenProvider() func(ctx context.Context, resource string, tenantID string) ([]byte, error) {
return func(ctx context.Context, resource string, tenantID string) ([]byte, error) {
// This is the path that a developer can set to tell this class what the install path for Azure CLI is.
const azureCLIPath = "AZURE_CLI_PATH"

// The default install paths are used to find Azure CLI. This is for security, so that any path in the calling program's Path environment is not used to execute Azure CLI.
azureCLIDefaultPathWindows := fmt.Sprintf("%s\\Microsoft SDKs\\Azure\\CLI2\\wbin; %s\\Microsoft SDKs\\Azure\\CLI2\\wbin", os.Getenv("ProgramFiles(x86)"), os.Getenv("ProgramFiles"))

// Default path for non-Windows.
const azureCLIDefaultPath = "/bin:/sbin:/usr/bin:/usr/local/bin"

// Validate resource, since it gets sent as a command line argument to Azure CLI
const invalidResourceErrorTemplate = "resource %s is not in expected format. Only alphanumeric characters, [dot], [colon], [hyphen], and [forward slash] are allowed"
match, err := regexp.MatchString("^[0-9a-zA-Z-.:/]+$", resource)
if err != nil {
return nil, err
}
if !match {
return nil, fmt.Errorf(invalidResourceErrorTemplate, resource)
return nil, fmt.Errorf(`unexpected scope "%s". Only alphanumeric characters and ".", ";", "-", and "/" are allowed`, resource)
}

ctx, cancel := context.WithTimeout(ctx, timeoutCLIRequest)
defer cancel()

// Execute Azure CLI to get token
commandLine := "az account get-access-token -o json --resource " + resource
if tenantID != "" {
commandLine += " --tenant " + tenantID
}
var cliCmd *exec.Cmd
if runtime.GOOS == "windows" {
cliCmd = exec.CommandContext(ctx, fmt.Sprintf("%s\\system32\\cmd.exe", os.Getenv("windir")))
cliCmd.Env = os.Environ()
cliCmd.Env = append(cliCmd.Env, fmt.Sprintf("PATH=%s;%s", os.Getenv(azureCLIPath), azureCLIDefaultPathWindows))
cliCmd.Args = append(cliCmd.Args, "/c", "az")
dir := os.Getenv("SYSTEMROOT")
if dir == "" {
return nil, errors.New("environment variable 'SYSTEMROOT' has no value")
}
cliCmd = exec.CommandContext(ctx, "cmd.exe", "/c", commandLine)
cliCmd.Dir = dir
} else {
cliCmd = exec.CommandContext(ctx, "az")
cliCmd.Env = os.Environ()
cliCmd.Env = append(cliCmd.Env, fmt.Sprintf("PATH=%s:%s", os.Getenv(azureCLIPath), azureCLIDefaultPath))
}
cliCmd.Args = append(cliCmd.Args, "account", "get-access-token", "-o", "json", "--resource", resource)
if tenantID != "" {
cliCmd.Args = append(cliCmd.Args, "--tenant", tenantID)
cliCmd = exec.CommandContext(ctx, "/bin/sh", "-c", commandLine)
cliCmd.Dir = "/bin"
}
cliCmd.Env = os.Environ()
var stderr bytes.Buffer
cliCmd.Stderr = &stderr

Expand Down
3 changes: 0 additions & 3 deletions sdk/azidentity/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ func logEnvVars() {
if envCheck := os.Getenv("AZURE_AUTHORITY_HOST"); len(envCheck) > 0 {
envVars = append(envVars, "AZURE_AUTHORITY_HOST")
}
if envCheck := os.Getenv("AZURE_CLI_PATH"); len(envCheck) > 0 {
envVars = append(envVars, "AZURE_CLI_PATH")
}
if len(envVars) > 0 {
log.Writef(EventCredential, "Azure Identity => Found the following environment variables:\n\t%s", strings.Join(envVars, ", "))
}
Expand Down
16 changes: 1 addition & 15 deletions sdk/azidentity/managed_identity_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package azidentity
import (
"context"
"fmt"
"os"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
Expand Down Expand Up @@ -89,20 +88,7 @@ func NewManagedIdentityCredential(options *ManagedIdentityCredentialOptions) (*M
}
// Assign the msiType discovered onto the client
client.msiType = msiType
// check if no clientID is specified then check if it exists in an environment variable
id := cp.ID
if id == nil {
cID := os.Getenv("AZURE_CLIENT_ID")
if cID != "" {
id = ClientID(cID)
} else {
rID := os.Getenv("AZURE_RESOURCE_ID")
if rID != "" {
id = ResourceID(rID)
}
}
}
return &ManagedIdentityCredential{id: id, client: client}, nil
return &ManagedIdentityCredential{id: cp.ID, client: client}, nil
}

// GetToken obtains an AccessToken from the Managed Identity service if available.
Expand Down
31 changes: 0 additions & 31 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -704,34 +704,3 @@ func TestManagedIdentityCredential_CreateAccessTokenExpiresOnFail(t *testing.T)
t.Fatalf("expected to receive an error but received none")
}
}

func TestManagedIdentityCredential_ResourceID_envVar(t *testing.T) {
// setting a dummy value for IDENTITY_ENDPOINT in order to be able to get a ManagedIdentityCredential type
_ = os.Setenv("IDENTITY_ENDPOINT", "somevalue")
_ = os.Setenv("IDENTITY_HEADER", "header")
_ = os.Setenv("AZURE_RESOURCE_ID", "resource_id")
defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER", "AZURE_CLIENT_ID", "AZURE_RESOURCE_ID")
cred, err := NewManagedIdentityCredential(nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cred.id != ResourceID("resource_id") {
t.Fatal("unexpected id value stored")
}
_ = os.Setenv("AZURE_RESOURCE_ID", "")
_ = os.Setenv("AZURE_CLIENT_ID", "client_id")
cred, err = NewManagedIdentityCredential(nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cred.id != ClientID("client_id") {
t.Fatal("unexpected id value stored")
}
cred, err = NewManagedIdentityCredential(nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cred.id != ClientID("client_id") {
t.Fatal("unexpected id value stored")
}
}

0 comments on commit e1c1c6c

Please sign in to comment.