Skip to content

Commit

Permalink
azidentity: Adding Resource ID support in ManagedIdentityCredential (#…
Browse files Browse the repository at this point in the history
…14741)

* enable resourceID for MSI cred

* update const and rename

* improve docs, add tests

* review feedback

* update type alias name
  • Loading branch information
catalinaperalta authored Jun 7, 2021
1 parent 8cce9bd commit 8812d1a
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 16 deletions.
1 change: 1 addition & 0 deletions sdk/azidentity/aad_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (
qpPassword = "password"
qpRedirectURI = "redirect_uri"
qpRefreshToken = "refresh_token"
qpResID = "mi_res_id"
qpResponseType = "response_type"
qpScope = "scope"
qpUsername = "username"
Expand Down
24 changes: 16 additions & 8 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type managedIdentityClient struct {
imdsAvailableTimeoutMS time.Duration
msiType msiType
endpoint string
id ManagedIdentityIDKind
}

type wrappedNumber json.Number
Expand All @@ -72,6 +73,7 @@ func (n *wrappedNumber) UnmarshalJSON(b []byte) error {
func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) *managedIdentityClient {
logEnvVars()
return &managedIdentityClient{
id: options.ID,
pipeline: newDefaultMSIPipeline(*options), // a pipeline that includes the specific requirements for MSI authentication, such as custom retry policy options
imdsAPIVersion: imdsAPIVersion, // this field will be set to whatever value exists in the constant and is used when creating requests to IMDS
imdsAvailableTimeoutMS: 500, // we allow a timeout of 500 ms since the endpoint might be slow to respond
Expand Down Expand Up @@ -162,7 +164,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID
}
}

func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, clientID string, scopes []string) (*azcore.Request, error) {
func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id string, scopes []string) (*azcore.Request, error) {
request, err := azcore.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
return nil, err
Expand All @@ -171,14 +173,16 @@ func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, clien
q := request.URL.Query()
q.Add("api-version", c.imdsAPIVersion)
q.Add("resource", strings.Join(scopes, " "))
if clientID != "" {
q.Add(qpClientID, clientID)
if c.id == ResourceID {
q.Add(qpResID, id)
} else if id != "" {
q.Add(qpClientID, id)
}
request.URL.RawQuery = q.Encode()
return request, nil
}

func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, clientID string, scopes []string) (*azcore.Request, error) {
func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, id string, scopes []string) (*azcore.Request, error) {
request, err := azcore.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
return nil, err
Expand All @@ -188,16 +192,20 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context,
request.Header.Set("secret", os.Getenv(msiSecret))
q.Add("api-version", "2017-09-01")
q.Add("resource", strings.Join(scopes, " "))
if clientID != "" {
if c.id == ResourceID {
q.Add(qpResID, id)
} else if id != "" {
// the legacy 2017 API version specifically specifies "clientid" and not "client_id" as a query param
q.Add("clientid", clientID)
q.Add("clientid", id)
}
} else if c.msiType == msiTypeAppServiceV20190801 {
request.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeader))
q.Add("api-version", "2019-08-01")
q.Add("resource", scopes[0])
if clientID != "" {
q.Add(qpClientID, clientID)
if c.id == ResourceID {
q.Add(qpResID, id)
} else if id != "" {
q.Add(qpClientID, id)
}
}

Expand Down
33 changes: 25 additions & 8 deletions sdk/azidentity/managed_identity_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,25 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
)

// ManagedIdentityIDKind is used to specify the type of identifier that is passed in for a user-assigned managed identity.
type ManagedIdentityIDKind int

const (
// ClientID is the default identifier for a user-assigned managed identity.
ClientID ManagedIdentityIDKind = 0
// ResourceID is set when the resource ID of the user-assigned managed identity is to be used.
ResourceID ManagedIdentityIDKind = 1
)

// ManagedIdentityCredentialOptions contains parameters that can be used to configure the pipeline used with Managed Identity Credential.
// All zero-value fields will be initialized with their default values.
type ManagedIdentityCredentialOptions struct {
// ID is used to configure an alternate identifier for a user-assigned identity. The default is client ID.
// Select the identifier to be used and pass the corresponding ID value in the string param in
// NewManagedIdentityCredential().
// Hint: Choose from the list of allowed ManagedIdentityIDKind values.
ID ManagedIdentityIDKind

// HTTPClient sets the transport for making HTTP requests.
// Leave this as nil to use the default HTTP transport.
HTTPClient azcore.Transport
Expand All @@ -29,16 +45,17 @@ type ManagedIdentityCredentialOptions struct {
// managed identity environments such as Azure VMs, App Service, Azure Functions, Azure CloudShell, among others. More information about configuring managed identities can be found here:
// https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview
type ManagedIdentityCredential struct {
clientID string
client *managedIdentityClient
id string
client *managedIdentityClient
}

// NewManagedIdentityCredential creates an instance of the ManagedIdentityCredential capable of authenticating a resource that has a managed identity.
// clientID: The client ID to authenticate for a user assigned managed identity.
// id: The ID that corresponds to the user assigned managed identity. Defaults to the identity's client ID. To use another identifier,
// pass in the value for the identifier here AND choose the correct ID kind to be used in the request by setting ManagedIdentityIDKind in the options.
// options: ManagedIdentityCredentialOptions that configure the pipeline for requests sent to Azure Active Directory.
// More information on user assigned managed identities cam be found here:
// https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview#how-a-user-assigned-managed-identity-works-with-an-azure-vm
func NewManagedIdentityCredential(clientID string, options *ManagedIdentityCredentialOptions) (*ManagedIdentityCredential, error) {
func NewManagedIdentityCredential(id string, options *ManagedIdentityCredentialOptions) (*ManagedIdentityCredential, error) {
// Create a new Managed Identity Client with default options
if options == nil {
options = &ManagedIdentityCredentialOptions{}
Expand All @@ -54,10 +71,10 @@ func NewManagedIdentityCredential(clientID string, options *ManagedIdentityCrede
// Assign the msiType discovered onto the client
client.msiType = msiType
// check if no clientID is specified then check if it exists in an environment variable
if len(clientID) == 0 {
clientID = os.Getenv("AZURE_CLIENT_ID")
if len(id) == 0 {
id = os.Getenv("AZURE_CLIENT_ID")
}
return &ManagedIdentityCredential{clientID: clientID, client: client}, nil
return &ManagedIdentityCredential{id: id, client: client}, nil
}

// GetToken obtains an AccessToken from the Managed Identity service if available.
Expand All @@ -76,7 +93,7 @@ func (c *ManagedIdentityCredential) GetToken(ctx context.Context, opts azcore.To
}
// The following code will remove the /.default suffix from any scopes passed into the method since ManagedIdentityCredentials expect a resource string instead of a scope string
opts.Scopes[0] = strings.TrimSuffix(opts.Scopes[0], defaultSuffix)
tk, err := c.client.authenticate(ctx, c.clientID, opts.Scopes)
tk, err := c.client.authenticate(ctx, c.id, opts.Scopes)
if err != nil {
addGetTokenFailureLogs("Managed Identity Credential", err, true)
return nil, err
Expand Down
88 changes: 88 additions & 0 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,91 @@ func TestManagedIdentityCredential_GetTokenMultipleResources(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
}

func TestManagedIdentityCredential_UseResourceID(t *testing.T) {
resetEnvironmentVarsForTest()
srv, close := mock.NewServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(appServiceWindowsSuccessResp)))
_ = os.Setenv("MSI_ENDPOINT", srv.URL())
_ = os.Setenv("MSI_SECRET", "secret")
defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET")
options := ManagedIdentityCredentialOptions{}
options.HTTPClient = srv
options.ID = ResourceID
cred, err := NewManagedIdentityCredential("sample/resource/id", &options)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tk, err := cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token != "new_token" {
t.Fatalf("unexpected token returned. Expected: %s, Received: %s", "new_token", tk.Token)
}
}

func TestManagedIdentityCredential_ResourceID_AppService(t *testing.T) {
// setting a dummy value for IDENTITY_ENDPOINT in order to be able to get a ManagedIdentityCredential type in order
// to test App Service authentication request creation.
_ = os.Setenv("IDENTITY_ENDPOINT", "somevalue")
_ = os.Setenv("IDENTITY_HEADER", "header")
defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER")
resID := "sample/resource/id"
cred, err := NewManagedIdentityCredential(resID, &ManagedIdentityCredentialOptions{ID: ResourceID})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
cred.client.endpoint = imdsEndpoint
req, err := cred.client.createAuthRequest(context.Background(), resID, []string{msiScope})
if err != nil {
t.Fatal(err)
}
if req.Request.Header.Get("X-IDENTITY-HEADER") != "header" {
t.Fatalf("Unexpected value for secret header")
}
reqQueryParams, err := url.ParseQuery(req.URL.RawQuery)
if err != nil {
t.Fatalf("Unable to parse App Service request query params: %v", err)
}
if reqQueryParams["api-version"][0] != "2019-08-01" {
t.Fatalf("Unexpected App Service API version")
}
if reqQueryParams["resource"][0] != msiScope {
t.Fatalf("Unexpected resource in resource query param")
}
if reqQueryParams[qpResID][0] != resID {
t.Fatalf("Unexpected resource ID in resource query param")
}
}

func TestManagedIdentityCredential_ResourceID_IMDS(t *testing.T) {
// setting a dummy value for MSI_ENDPOINT in order to avoid failure in the constructor
_ = os.Setenv("MSI_ENDPOINT", "http://foo.com/")
defer clearEnvVars("MSI_ENDPOINT")
resID := "sample/resource/id"
cred, err := NewManagedIdentityCredential(resID, &ManagedIdentityCredentialOptions{ID: ResourceID})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
cred.client.msiType = msiTypeIMDS
cred.client.endpoint = imdsEndpoint
req, err := cred.client.createAuthRequest(context.Background(), resID, []string{msiScope})
if err != nil {
t.Fatal(err)
}
reqQueryParams, err := url.ParseQuery(req.URL.RawQuery)
if err != nil {
t.Fatalf("Unable to parse App Service request query params: %v", err)
}
if reqQueryParams["api-version"][0] != "2018-02-01" {
t.Fatalf("Unexpected App Service API version")
}
if reqQueryParams["resource"][0] != msiScope {
t.Fatalf("Unexpected resource in resource query param")
}
if reqQueryParams[qpResID][0] != resID {
t.Fatalf("Unexpected resource ID in resource query param")
}
}

0 comments on commit 8812d1a

Please sign in to comment.