From 8812d1abba3551c12ed54d7ff9875d1be6a519fc Mon Sep 17 00:00:00 2001 From: catalinaperalta Date: Mon, 7 Jun 2021 14:19:43 -0700 Subject: [PATCH] azidentity: Adding Resource ID support in ManagedIdentityCredential (#14741) * enable resourceID for MSI cred * update const and rename * improve docs, add tests * review feedback * update type alias name --- sdk/azidentity/aad_identity_client.go | 1 + sdk/azidentity/managed_identity_client.go | 24 +++-- sdk/azidentity/managed_identity_credential.go | 33 +++++-- .../managed_identity_credential_test.go | 88 +++++++++++++++++++ 4 files changed, 130 insertions(+), 16 deletions(-) diff --git a/sdk/azidentity/aad_identity_client.go b/sdk/azidentity/aad_identity_client.go index f6479e5f5938..76ef981f2a5f 100644 --- a/sdk/azidentity/aad_identity_client.go +++ b/sdk/azidentity/aad_identity_client.go @@ -31,6 +31,7 @@ const ( qpPassword = "password" qpRedirectURI = "redirect_uri" qpRefreshToken = "refresh_token" + qpResID = "mi_res_id" qpResponseType = "response_type" qpScope = "scope" qpUsername = "username" diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index cf211bd94d69..933ab1740ab6 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -53,6 +53,7 @@ type managedIdentityClient struct { imdsAvailableTimeoutMS time.Duration msiType msiType endpoint string + id ManagedIdentityIDKind } type wrappedNumber json.Number @@ -72,6 +73,7 @@ func (n *wrappedNumber) UnmarshalJSON(b []byte) error { func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) *managedIdentityClient { logEnvVars() return &managedIdentityClient{ + id: options.ID, pipeline: newDefaultMSIPipeline(*options), // a pipeline that includes the specific requirements for MSI authentication, such as custom retry policy options imdsAPIVersion: imdsAPIVersion, // this field will be set to whatever value exists in the constant and is used when creating requests to IMDS imdsAvailableTimeoutMS: 500, // we allow a timeout of 500 ms since the endpoint might be slow to respond @@ -162,7 +164,7 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID } } -func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, clientID string, scopes []string) (*azcore.Request, error) { +func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id string, scopes []string) (*azcore.Request, error) { request, err := azcore.NewRequest(ctx, http.MethodGet, c.endpoint) if err != nil { return nil, err @@ -171,14 +173,16 @@ func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, clien q := request.URL.Query() q.Add("api-version", c.imdsAPIVersion) q.Add("resource", strings.Join(scopes, " ")) - if clientID != "" { - q.Add(qpClientID, clientID) + if c.id == ResourceID { + q.Add(qpResID, id) + } else if id != "" { + q.Add(qpClientID, id) } request.URL.RawQuery = q.Encode() return request, nil } -func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, clientID string, scopes []string) (*azcore.Request, error) { +func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, id string, scopes []string) (*azcore.Request, error) { request, err := azcore.NewRequest(ctx, http.MethodGet, c.endpoint) if err != nil { return nil, err @@ -188,16 +192,20 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, request.Header.Set("secret", os.Getenv(msiSecret)) q.Add("api-version", "2017-09-01") q.Add("resource", strings.Join(scopes, " ")) - if clientID != "" { + if c.id == ResourceID { + q.Add(qpResID, id) + } else if id != "" { // the legacy 2017 API version specifically specifies "clientid" and not "client_id" as a query param - q.Add("clientid", clientID) + q.Add("clientid", id) } } else if c.msiType == msiTypeAppServiceV20190801 { request.Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeader)) q.Add("api-version", "2019-08-01") q.Add("resource", scopes[0]) - if clientID != "" { - q.Add(qpClientID, clientID) + if c.id == ResourceID { + q.Add(qpResID, id) + } else if id != "" { + q.Add(qpClientID, id) } } diff --git a/sdk/azidentity/managed_identity_credential.go b/sdk/azidentity/managed_identity_credential.go index c4501323a807..48880536dc31 100644 --- a/sdk/azidentity/managed_identity_credential.go +++ b/sdk/azidentity/managed_identity_credential.go @@ -11,9 +11,25 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) +// ManagedIdentityIDKind is used to specify the type of identifier that is passed in for a user-assigned managed identity. +type ManagedIdentityIDKind int + +const ( + // ClientID is the default identifier for a user-assigned managed identity. + ClientID ManagedIdentityIDKind = 0 + // ResourceID is set when the resource ID of the user-assigned managed identity is to be used. + ResourceID ManagedIdentityIDKind = 1 +) + // ManagedIdentityCredentialOptions contains parameters that can be used to configure the pipeline used with Managed Identity Credential. // All zero-value fields will be initialized with their default values. type ManagedIdentityCredentialOptions struct { + // ID is used to configure an alternate identifier for a user-assigned identity. The default is client ID. + // Select the identifier to be used and pass the corresponding ID value in the string param in + // NewManagedIdentityCredential(). + // Hint: Choose from the list of allowed ManagedIdentityIDKind values. + ID ManagedIdentityIDKind + // HTTPClient sets the transport for making HTTP requests. // Leave this as nil to use the default HTTP transport. HTTPClient azcore.Transport @@ -29,16 +45,17 @@ type ManagedIdentityCredentialOptions struct { // managed identity environments such as Azure VMs, App Service, Azure Functions, Azure CloudShell, among others. More information about configuring managed identities can be found here: // https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview type ManagedIdentityCredential struct { - clientID string - client *managedIdentityClient + id string + client *managedIdentityClient } // NewManagedIdentityCredential creates an instance of the ManagedIdentityCredential capable of authenticating a resource that has a managed identity. -// clientID: The client ID to authenticate for a user assigned managed identity. +// id: The ID that corresponds to the user assigned managed identity. Defaults to the identity's client ID. To use another identifier, +// pass in the value for the identifier here AND choose the correct ID kind to be used in the request by setting ManagedIdentityIDKind in the options. // options: ManagedIdentityCredentialOptions that configure the pipeline for requests sent to Azure Active Directory. // More information on user assigned managed identities cam be found here: // https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview#how-a-user-assigned-managed-identity-works-with-an-azure-vm -func NewManagedIdentityCredential(clientID string, options *ManagedIdentityCredentialOptions) (*ManagedIdentityCredential, error) { +func NewManagedIdentityCredential(id string, options *ManagedIdentityCredentialOptions) (*ManagedIdentityCredential, error) { // Create a new Managed Identity Client with default options if options == nil { options = &ManagedIdentityCredentialOptions{} @@ -54,10 +71,10 @@ func NewManagedIdentityCredential(clientID string, options *ManagedIdentityCrede // Assign the msiType discovered onto the client client.msiType = msiType // check if no clientID is specified then check if it exists in an environment variable - if len(clientID) == 0 { - clientID = os.Getenv("AZURE_CLIENT_ID") + if len(id) == 0 { + id = os.Getenv("AZURE_CLIENT_ID") } - return &ManagedIdentityCredential{clientID: clientID, client: client}, nil + return &ManagedIdentityCredential{id: id, client: client}, nil } // GetToken obtains an AccessToken from the Managed Identity service if available. @@ -76,7 +93,7 @@ func (c *ManagedIdentityCredential) GetToken(ctx context.Context, opts azcore.To } // The following code will remove the /.default suffix from any scopes passed into the method since ManagedIdentityCredentials expect a resource string instead of a scope string opts.Scopes[0] = strings.TrimSuffix(opts.Scopes[0], defaultSuffix) - tk, err := c.client.authenticate(ctx, c.clientID, opts.Scopes) + tk, err := c.client.authenticate(ctx, c.id, opts.Scopes) if err != nil { addGetTokenFailureLogs("Managed Identity Credential", err, true) return nil, err diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index 31c8a48a3d68..1cafdfe0ac2a 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -532,3 +532,91 @@ func TestManagedIdentityCredential_GetTokenMultipleResources(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestManagedIdentityCredential_UseResourceID(t *testing.T) { + resetEnvironmentVarsForTest() + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(appServiceWindowsSuccessResp))) + _ = os.Setenv("MSI_ENDPOINT", srv.URL()) + _ = os.Setenv("MSI_SECRET", "secret") + defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET") + options := ManagedIdentityCredentialOptions{} + options.HTTPClient = srv + options.ID = ResourceID + cred, err := NewManagedIdentityCredential("sample/resource/id", &options) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tk, err := cred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}}) + if err != nil { + t.Fatal(err) + } + if tk.Token != "new_token" { + t.Fatalf("unexpected token returned. Expected: %s, Received: %s", "new_token", tk.Token) + } +} + +func TestManagedIdentityCredential_ResourceID_AppService(t *testing.T) { + // setting a dummy value for IDENTITY_ENDPOINT in order to be able to get a ManagedIdentityCredential type in order + // to test App Service authentication request creation. + _ = os.Setenv("IDENTITY_ENDPOINT", "somevalue") + _ = os.Setenv("IDENTITY_HEADER", "header") + defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER") + resID := "sample/resource/id" + cred, err := NewManagedIdentityCredential(resID, &ManagedIdentityCredentialOptions{ID: ResourceID}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cred.client.endpoint = imdsEndpoint + req, err := cred.client.createAuthRequest(context.Background(), resID, []string{msiScope}) + if err != nil { + t.Fatal(err) + } + if req.Request.Header.Get("X-IDENTITY-HEADER") != "header" { + t.Fatalf("Unexpected value for secret header") + } + reqQueryParams, err := url.ParseQuery(req.URL.RawQuery) + if err != nil { + t.Fatalf("Unable to parse App Service request query params: %v", err) + } + if reqQueryParams["api-version"][0] != "2019-08-01" { + t.Fatalf("Unexpected App Service API version") + } + if reqQueryParams["resource"][0] != msiScope { + t.Fatalf("Unexpected resource in resource query param") + } + if reqQueryParams[qpResID][0] != resID { + t.Fatalf("Unexpected resource ID in resource query param") + } +} + +func TestManagedIdentityCredential_ResourceID_IMDS(t *testing.T) { + // setting a dummy value for MSI_ENDPOINT in order to avoid failure in the constructor + _ = os.Setenv("MSI_ENDPOINT", "http://foo.com/") + defer clearEnvVars("MSI_ENDPOINT") + resID := "sample/resource/id" + cred, err := NewManagedIdentityCredential(resID, &ManagedIdentityCredentialOptions{ID: ResourceID}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cred.client.msiType = msiTypeIMDS + cred.client.endpoint = imdsEndpoint + req, err := cred.client.createAuthRequest(context.Background(), resID, []string{msiScope}) + if err != nil { + t.Fatal(err) + } + reqQueryParams, err := url.ParseQuery(req.URL.RawQuery) + if err != nil { + t.Fatalf("Unable to parse App Service request query params: %v", err) + } + if reqQueryParams["api-version"][0] != "2018-02-01" { + t.Fatalf("Unexpected App Service API version") + } + if reqQueryParams["resource"][0] != msiScope { + t.Fatalf("Unexpected resource in resource query param") + } + if reqQueryParams[qpResID][0] != resID { + t.Fatalf("Unexpected resource ID in resource query param") + } +}