Skip to content

Commit

Permalink
allow MSI login with "mi_res_id" (#544)
Browse files Browse the repository at this point in the history
* allow login with resourceID

* test

* tweaks

* fix

* tested with cmd

* fix unittest

* add new test, remove debug trace

* fix unittest

* fix with url encode
  • Loading branch information
haitch authored Aug 6, 2020
1 parent dadf295 commit b8633a5
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 14 deletions.
84 changes: 78 additions & 6 deletions autorest/adal/cmd/adal.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ import (
)

const (
deviceMode = "device"
clientSecretMode = "secret"
clientCertMode = "cert"
refreshMode = "refresh"
deviceMode = "device"
clientSecretMode = "secret"
clientCertMode = "cert"
refreshMode = "refresh"
msiDefaultMode = "msiDefault"
msiClientIDMode = "msiClientID"
msiResourceIDMode = "msiResourceID"

activeDirectoryEndpoint = "https://login.microsoftonline.com/"
)
Expand All @@ -48,8 +51,9 @@ var (
mode string
resource string

tenantID string
applicationID string
tenantID string
applicationID string
identityResourceID string

applicationSecret string
certificatePath string
Expand Down Expand Up @@ -82,10 +86,28 @@ func init() {
flag.StringVar(&applicationSecret, "secret", "", "application secret")
flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/PFC application certificate")
flag.StringVar(&tokenCachePath, "tokenCachePath", defaultTokenCachePath(), "location of oath token cache")
flag.StringVar(&identityResourceID, "identityResourceID", "", "managedIdentity azure resource id")

flag.Parse()

switch mode = strings.TrimSpace(mode); mode {
case msiDefaultMode:
checkMandatoryOptions(msiDefaultMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
)
case msiClientIDMode:
checkMandatoryOptions(msiClientIDMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
option{name: "applicationId", value: applicationID},
)
case msiResourceIDMode:
checkMandatoryOptions(msiResourceIDMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
option{name: "identityResourceID", value: identityResourceID},
)
case clientSecretMode:
checkMandatoryOptions(clientSecretMode,
option{name: "resource", value: resource},
Expand Down Expand Up @@ -150,6 +172,42 @@ func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.Private
return certificate, rsaPrivateKey, nil
}

func acquireTokenMSIFlow(applicationID string,
identityResourceID string,
resource string,
callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {

// only one of them can be present:
if applicationID != "" && identityResourceID != "" {
return nil, fmt.Errorf("didn't expect applicationID and identityResourceID at same time")
}

msiEndpoint, _ := adal.GetMSIVMEndpoint()
var spt *adal.ServicePrincipalToken
var err error

// both can be empty, systemAssignedMSI scenario
if applicationID == "" && identityResourceID == "" {
spt, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource, callbacks...)
}

// msi login with clientID
if applicationID != "" {
spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, applicationID, callbacks...)
}

// msi login with resourceID
if identityResourceID != "" {
spt, err = adal.NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource, identityResourceID, callbacks...)
}

if err != nil {
return nil, err
}

return spt, spt.Refresh()
}

func acquireTokenClientCertFlow(oauthConfig adal.OAuthConfig,
applicationID string,
applicationCertPath string,
Expand Down Expand Up @@ -283,6 +341,20 @@ func main() {
if err == nil {
err = saveToken(spt.Token())
}
case msiResourceIDMode:
fallthrough
case msiClientIDMode:
fallthrough
case msiDefaultMode:
var spt *adal.ServicePrincipalToken
spt, err = acquireTokenMSIFlow(
applicationID,
identityResourceID,
resource,
callback)
if err == nil {
err = saveToken(spt.Token())
}
case refreshMode:
_, err = refreshToken(
*oauthConfig,
Expand Down
22 changes: 18 additions & 4 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,16 +678,22 @@ func GetMSIEndpoint() (string, error) {
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the system assigned identity when creating the token.
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, nil, callbacks...)
}

// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the specified user assigned identity when creating the token.
// It will use the clientID of specified user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, nil, callbacks...)
}

func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
// NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the azure resource id of user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, &identityResourceID, callbacks...)
}

func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, identityResourceID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
return nil, err
}
Expand All @@ -699,6 +705,11 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI
return nil, err
}
}
if identityResourceID != nil {
if err := validateStringParam(*identityResourceID, "identityResourceID"); err != nil {
return nil, err
}
}
// We set the oauth config token endpoint to be MSI's endpoint
msiEndpointURL, err := url.Parse(msiEndpoint)
if err != nil {
Expand All @@ -716,6 +727,9 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI
if userAssignedID != nil {
v.Set("client_id", *userAssignedID)
}
if identityResourceID != nil {
v.Set("mi_res_id", *identityResourceID)
}
msiEndpointURL.RawQuery = v.Encode()

spt := &ServicePrincipalToken{
Expand Down
43 changes: 39 additions & 4 deletions autorest/adal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
}

func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
resource := "https://resource"
const resource = "https://resource"
cb := func(token Token) error { return nil }

spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
Expand All @@ -717,8 +717,10 @@ func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
}

func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
resource := "https://resource"
userID := "abc123"
const (
resource = "https://resource"
userID = "abc123"
)
cb := func(token Token) error { return nil }

spt, err := NewServicePrincipalTokenFromMSIWithUserAssignedID("http://msiendpoint/", resource, userID, cb)
Expand All @@ -744,6 +746,39 @@ func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
}
}

func TestNewServicePrincipalTokenFromMSIWithIdentityResourceID(t *testing.T) {
const (
resource = "https://resource"
identityResourceID = "/subscriptions/testSub/resourceGroups/testGroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test-identity"
)
cb := func(token Token) error { return nil }

spt, err := NewServicePrincipalTokenFromMSIWithIdentityResourceID("http://msiendpoint/", resource, identityResourceID, cb)
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}

// check some of the SPT fields
if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
t.Fatal("SPT secret was not of MSI type")
}

if spt.inner.Resource != resource {
t.Fatal("SPT came back with incorrect resource")
}

if len(spt.refreshCallbacks) != 1 {
t.Fatal("SPT had incorrect refresh callbacks.")
}

urlPathParameter := url.Values{}
urlPathParameter.Set("mi_res_id", identityResourceID)

if !strings.Contains(spt.inner.OauthConfig.TokenEndpoint.RawQuery, urlPathParameter.Encode()) {
t.Fatal("SPT tokenEndpoint should contains mi_res_id")
}
}

func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) {
token := newToken()
secret := &ServicePrincipalAuthorizationCodeSecret{
Expand Down Expand Up @@ -895,7 +930,7 @@ func TestMarshalServicePrincipalCertificateSecret(t *testing.T) {
}

func TestMarshalServicePrincipalMSISecret(t *testing.T) {
spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil)
spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil, nil)
if err != nil {
t.Fatalf("failed to get MSI SPT: %+v", err)
}
Expand Down

0 comments on commit b8633a5

Please sign in to comment.