From b8633a5d3c56af5782c641620899a6d96cf554a7 Mon Sep 17 00:00:00 2001 From: Haitao Chen Date: Thu, 6 Aug 2020 16:54:34 -0700 Subject: [PATCH] allow MSI login with "mi_res_id" (#544) * allow login with resourceID * test * tweaks * fix * tested with cmd * fix unittest * add new test, remove debug trace * fix unittest * fix with url encode --- autorest/adal/cmd/adal.go | 84 ++++++++++++++++++++++++++++++++++--- autorest/adal/token.go | 22 ++++++++-- autorest/adal/token_test.go | 43 +++++++++++++++++-- 3 files changed, 135 insertions(+), 14 deletions(-) diff --git a/autorest/adal/cmd/adal.go b/autorest/adal/cmd/adal.go index 7214dcabb..6c2aa1272 100644 --- a/autorest/adal/cmd/adal.go +++ b/autorest/adal/cmd/adal.go @@ -31,10 +31,13 @@ import ( ) const ( - deviceMode = "device" - clientSecretMode = "secret" - clientCertMode = "cert" - refreshMode = "refresh" + deviceMode = "device" + clientSecretMode = "secret" + clientCertMode = "cert" + refreshMode = "refresh" + msiDefaultMode = "msiDefault" + msiClientIDMode = "msiClientID" + msiResourceIDMode = "msiResourceID" activeDirectoryEndpoint = "https://login.microsoftonline.com/" ) @@ -48,8 +51,9 @@ var ( mode string resource string - tenantID string - applicationID string + tenantID string + applicationID string + identityResourceID string applicationSecret string certificatePath string @@ -82,10 +86,28 @@ func init() { flag.StringVar(&applicationSecret, "secret", "", "application secret") flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/PFC application certificate") flag.StringVar(&tokenCachePath, "tokenCachePath", defaultTokenCachePath(), "location of oath token cache") + flag.StringVar(&identityResourceID, "identityResourceID", "", "managedIdentity azure resource id") flag.Parse() switch mode = strings.TrimSpace(mode); mode { + case msiDefaultMode: + checkMandatoryOptions(msiDefaultMode, + option{name: "resource", value: resource}, + option{name: "tenantId", value: tenantID}, + ) + case msiClientIDMode: + checkMandatoryOptions(msiClientIDMode, + option{name: "resource", value: resource}, + option{name: "tenantId", value: tenantID}, + option{name: "applicationId", value: applicationID}, + ) + case msiResourceIDMode: + checkMandatoryOptions(msiResourceIDMode, + option{name: "resource", value: resource}, + option{name: "tenantId", value: tenantID}, + option{name: "identityResourceID", value: identityResourceID}, + ) case clientSecretMode: checkMandatoryOptions(clientSecretMode, option{name: "resource", value: resource}, @@ -150,6 +172,42 @@ func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.Private return certificate, rsaPrivateKey, nil } +func acquireTokenMSIFlow(applicationID string, + identityResourceID string, + resource string, + callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) { + + // only one of them can be present: + if applicationID != "" && identityResourceID != "" { + return nil, fmt.Errorf("didn't expect applicationID and identityResourceID at same time") + } + + msiEndpoint, _ := adal.GetMSIVMEndpoint() + var spt *adal.ServicePrincipalToken + var err error + + // both can be empty, systemAssignedMSI scenario + if applicationID == "" && identityResourceID == "" { + spt, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource, callbacks...) + } + + // msi login with clientID + if applicationID != "" { + spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, applicationID, callbacks...) + } + + // msi login with resourceID + if identityResourceID != "" { + spt, err = adal.NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource, identityResourceID, callbacks...) + } + + if err != nil { + return nil, err + } + + return spt, spt.Refresh() +} + func acquireTokenClientCertFlow(oauthConfig adal.OAuthConfig, applicationID string, applicationCertPath string, @@ -283,6 +341,20 @@ func main() { if err == nil { err = saveToken(spt.Token()) } + case msiResourceIDMode: + fallthrough + case msiClientIDMode: + fallthrough + case msiDefaultMode: + var spt *adal.ServicePrincipalToken + spt, err = acquireTokenMSIFlow( + applicationID, + identityResourceID, + resource, + callback) + if err == nil { + err = saveToken(spt.Token()) + } case refreshMode: _, err = refreshToken( *oauthConfig, diff --git a/autorest/adal/token.go b/autorest/adal/token.go index c026f7d12..d45a3fa57 100644 --- a/autorest/adal/token.go +++ b/autorest/adal/token.go @@ -678,16 +678,22 @@ func GetMSIEndpoint() (string, error) { // NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension. // It will use the system assigned identity when creating the token. func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { - return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...) + return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, nil, callbacks...) } // NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension. -// It will use the specified user assigned identity when creating the token. +// It will use the clientID of specified user assigned identity when creating the token. func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { - return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...) + return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, nil, callbacks...) } -func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { +// NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension. +// It will use the azure resource id of user assigned identity when creating the token. +func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { + return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, &identityResourceID, callbacks...) +} + +func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, identityResourceID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil { return nil, err } @@ -699,6 +705,11 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI return nil, err } } + if identityResourceID != nil { + if err := validateStringParam(*identityResourceID, "identityResourceID"); err != nil { + return nil, err + } + } // We set the oauth config token endpoint to be MSI's endpoint msiEndpointURL, err := url.Parse(msiEndpoint) if err != nil { @@ -716,6 +727,9 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI if userAssignedID != nil { v.Set("client_id", *userAssignedID) } + if identityResourceID != nil { + v.Set("mi_res_id", *identityResourceID) + } msiEndpointURL.RawQuery = v.Encode() spt := &ServicePrincipalToken{ diff --git a/autorest/adal/token_test.go b/autorest/adal/token_test.go index d123ac012..931504a1b 100644 --- a/autorest/adal/token_test.go +++ b/autorest/adal/token_test.go @@ -694,7 +694,7 @@ func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) { } func TestNewServicePrincipalTokenFromMSI(t *testing.T) { - resource := "https://resource" + const resource = "https://resource" cb := func(token Token) error { return nil } spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb) @@ -717,8 +717,10 @@ func TestNewServicePrincipalTokenFromMSI(t *testing.T) { } func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) { - resource := "https://resource" - userID := "abc123" + const ( + resource = "https://resource" + userID = "abc123" + ) cb := func(token Token) error { return nil } spt, err := NewServicePrincipalTokenFromMSIWithUserAssignedID("http://msiendpoint/", resource, userID, cb) @@ -744,6 +746,39 @@ func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) { } } +func TestNewServicePrincipalTokenFromMSIWithIdentityResourceID(t *testing.T) { + const ( + resource = "https://resource" + identityResourceID = "/subscriptions/testSub/resourceGroups/testGroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test-identity" + ) + cb := func(token Token) error { return nil } + + spt, err := NewServicePrincipalTokenFromMSIWithIdentityResourceID("http://msiendpoint/", resource, identityResourceID, cb) + if err != nil { + t.Fatalf("Failed to get MSI SPT: %v", err) + } + + // check some of the SPT fields + if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok { + t.Fatal("SPT secret was not of MSI type") + } + + if spt.inner.Resource != resource { + t.Fatal("SPT came back with incorrect resource") + } + + if len(spt.refreshCallbacks) != 1 { + t.Fatal("SPT had incorrect refresh callbacks.") + } + + urlPathParameter := url.Values{} + urlPathParameter.Set("mi_res_id", identityResourceID) + + if !strings.Contains(spt.inner.OauthConfig.TokenEndpoint.RawQuery, urlPathParameter.Encode()) { + t.Fatal("SPT tokenEndpoint should contains mi_res_id") + } +} + func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) { token := newToken() secret := &ServicePrincipalAuthorizationCodeSecret{ @@ -895,7 +930,7 @@ func TestMarshalServicePrincipalCertificateSecret(t *testing.T) { } func TestMarshalServicePrincipalMSISecret(t *testing.T) { - spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil) + spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil, nil) if err != nil { t.Fatalf("failed to get MSI SPT: %+v", err) }