Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow MSI login with "mi_res_id" #544

Merged
merged 9 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions autorest/adal/cmd/adal.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ import (
)

const (
deviceMode = "device"
clientSecretMode = "secret"
clientCertMode = "cert"
refreshMode = "refresh"
deviceMode = "device"
clientSecretMode = "secret"
clientCertMode = "cert"
refreshMode = "refresh"
msiDefaultMode = "msiDefault"
msiClientIDMode = "msiClientID"
msiResourceIDMode = "msiResourceID"

activeDirectoryEndpoint = "https://login.microsoftonline.com/"
)
Expand All @@ -48,8 +51,9 @@ var (
mode string
resource string

tenantID string
applicationID string
tenantID string
applicationID string
identityResourceID string

applicationSecret string
certificatePath string
Expand Down Expand Up @@ -82,10 +86,28 @@ func init() {
flag.StringVar(&applicationSecret, "secret", "", "application secret")
flag.StringVar(&certificatePath, "certificatePath", "", "path to pk12/PFC application certificate")
flag.StringVar(&tokenCachePath, "tokenCachePath", defaultTokenCachePath(), "location of oath token cache")
flag.StringVar(&identityResourceID, "identityResourceID", "", "managedIdentity azure resource id")

flag.Parse()

switch mode = strings.TrimSpace(mode); mode {
case msiDefaultMode:
checkMandatoryOptions(msiDefaultMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
)
case msiClientIDMode:
checkMandatoryOptions(msiClientIDMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
option{name: "applicationId", value: applicationID},
)
case msiResourceIDMode:
checkMandatoryOptions(msiResourceIDMode,
option{name: "resource", value: resource},
option{name: "tenantId", value: tenantID},
option{name: "identityResourceID", value: identityResourceID},
)
case clientSecretMode:
checkMandatoryOptions(clientSecretMode,
option{name: "resource", value: resource},
Expand Down Expand Up @@ -150,6 +172,42 @@ func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.Private
return certificate, rsaPrivateKey, nil
}

func acquireTokenMSIFlow(applicationID string,
identityResourceID string,
resource string,
callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) {

// only one of them can be present:
if applicationID != "" && identityResourceID != "" {
return nil, fmt.Errorf("didn't expect applicationID and identityResourceID at same time")
}

msiEndpoint, _ := adal.GetMSIVMEndpoint()
var spt *adal.ServicePrincipalToken
var err error

// both can be empty, systemAssignedMSI scenario
if applicationID == "" && identityResourceID == "" {
spt, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource, callbacks...)
}

// msi login with clientID
if applicationID != "" {
spt, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, applicationID, callbacks...)
}

// msi login with resourceID
if identityResourceID != "" {
spt, err = adal.NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource, identityResourceID, callbacks...)
}

if err != nil {
return nil, err
}

return spt, spt.Refresh()
}

func acquireTokenClientCertFlow(oauthConfig adal.OAuthConfig,
applicationID string,
applicationCertPath string,
Expand Down Expand Up @@ -283,6 +341,20 @@ func main() {
if err == nil {
err = saveToken(spt.Token())
}
case msiResourceIDMode:
fallthrough
case msiClientIDMode:
fallthrough
case msiDefaultMode:
var spt *adal.ServicePrincipalToken
spt, err = acquireTokenMSIFlow(
applicationID,
identityResourceID,
resource,
callback)
if err == nil {
err = saveToken(spt.Token())
}
case refreshMode:
_, err = refreshToken(
*oauthConfig,
Expand Down
22 changes: 18 additions & 4 deletions autorest/adal/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,16 +678,22 @@ func GetMSIEndpoint() (string, error) {
// NewServicePrincipalTokenFromMSI creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the system assigned identity when creating the token.
func NewServicePrincipalTokenFromMSI(msiEndpoint, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, callbacks...)
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, nil, callbacks...)
}

// NewServicePrincipalTokenFromMSIWithUserAssignedID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the specified user assigned identity when creating the token.
// It will use the clientID of specified user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource string, userAssignedID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, callbacks...)
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, &userAssignedID, nil, callbacks...)
}

func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
// NewServicePrincipalTokenFromMSIWithIdentityResourceID creates a ServicePrincipalToken via the MSI VM Extension.
// It will use the azure resource id of user assigned identity when creating the token.
func NewServicePrincipalTokenFromMSIWithIdentityResourceID(msiEndpoint, resource string, identityResourceID string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
return newServicePrincipalTokenFromMSI(msiEndpoint, resource, nil, &identityResourceID, callbacks...)
}

func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedID *string, identityResourceID *string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) {
if err := validateStringParam(msiEndpoint, "msiEndpoint"); err != nil {
return nil, err
}
Expand All @@ -699,6 +705,11 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI
return nil, err
}
}
if identityResourceID != nil {
if err := validateStringParam(*identityResourceID, "identityResourceID"); err != nil {
return nil, err
}
}
// We set the oauth config token endpoint to be MSI's endpoint
msiEndpointURL, err := url.Parse(msiEndpoint)
if err != nil {
Expand All @@ -716,6 +727,9 @@ func newServicePrincipalTokenFromMSI(msiEndpoint, resource string, userAssignedI
if userAssignedID != nil {
v.Set("client_id", *userAssignedID)
}
if identityResourceID != nil {
v.Set("mi_res_id", *identityResourceID)
}
msiEndpointURL.RawQuery = v.Encode()

spt := &ServicePrincipalToken{
Expand Down
43 changes: 39 additions & 4 deletions autorest/adal/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ func TestServicePrincipalTokenManualRefreshFailsWithoutRefresh(t *testing.T) {
}

func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
resource := "https://resource"
const resource = "https://resource"
cb := func(token Token) error { return nil }

spt, err := NewServicePrincipalTokenFromMSI("http://msiendpoint/", resource, cb)
Expand All @@ -717,8 +717,10 @@ func TestNewServicePrincipalTokenFromMSI(t *testing.T) {
}

func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
resource := "https://resource"
userID := "abc123"
const (
resource = "https://resource"
userID = "abc123"
)
cb := func(token Token) error { return nil }

spt, err := NewServicePrincipalTokenFromMSIWithUserAssignedID("http://msiendpoint/", resource, userID, cb)
Expand All @@ -744,6 +746,39 @@ func TestNewServicePrincipalTokenFromMSIWithUserAssignedID(t *testing.T) {
}
}

func TestNewServicePrincipalTokenFromMSIWithIdentityResourceID(t *testing.T) {
const (
resource = "https://resource"
identityResourceID = "/subscriptions/testSub/resourceGroups/testGroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test-identity"
)
cb := func(token Token) error { return nil }

spt, err := NewServicePrincipalTokenFromMSIWithIdentityResourceID("http://msiendpoint/", resource, identityResourceID, cb)
if err != nil {
t.Fatalf("Failed to get MSI SPT: %v", err)
}

// check some of the SPT fields
if _, ok := spt.inner.Secret.(*ServicePrincipalMSISecret); !ok {
t.Fatal("SPT secret was not of MSI type")
}

if spt.inner.Resource != resource {
t.Fatal("SPT came back with incorrect resource")
}

if len(spt.refreshCallbacks) != 1 {
t.Fatal("SPT had incorrect refresh callbacks.")
}

urlPathParameter := url.Values{}
urlPathParameter.Set("mi_res_id", identityResourceID)

if !strings.Contains(spt.inner.OauthConfig.TokenEndpoint.RawQuery, urlPathParameter.Encode()) {
t.Fatal("SPT tokenEndpoint should contains mi_res_id")
}
}

func TestNewServicePrincipalTokenFromManualTokenSecret(t *testing.T) {
token := newToken()
secret := &ServicePrincipalAuthorizationCodeSecret{
Expand Down Expand Up @@ -895,7 +930,7 @@ func TestMarshalServicePrincipalCertificateSecret(t *testing.T) {
}

func TestMarshalServicePrincipalMSISecret(t *testing.T) {
spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil)
spt, err := newServicePrincipalTokenFromMSI("http://msiendpoint/", "https://resource", nil, nil)
if err != nil {
t.Fatalf("failed to get MSI SPT: %+v", err)
}
Expand Down