Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

azidentity: add support for service fabric MSI environment #14783

Merged
merged 6 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 56 additions & 21 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ const (
)

const (
arcIMDSEndpoint = "IMDS_ENDPOINT"
identityEndpoint = "IDENTITY_ENDPOINT"
identityHeader = "IDENTITY_HEADER"
msiEndpoint = "MSI_ENDPOINT"
msiSecret = "MSI_SECRET"
imdsAPIVersion = "2018-02-01"
azureArcAPIVersion = "2019-08-15"
arcIMDSEndpoint = "IMDS_ENDPOINT"
identityEndpoint = "IDENTITY_ENDPOINT"
identityHeader = "IDENTITY_HEADER"
identityServerThumbprint = "IDENTITY_SERVER_THUMBPRINT"
msiEndpoint = "MSI_ENDPOINT"
msiSecret = "MSI_SECRET"
imdsAPIVersion = "2018-02-01"
azureArcAPIVersion = "2019-08-15"
serviceFabricAPIVersion = "2019-07-01-preview"
)

type msiType int
Expand All @@ -43,6 +45,7 @@ const (
msiTypeUnavailable msiType = 4
msiTypeAppServiceV20190801 msiType = 5
msiTypeAzureArc msiType = 6
msiTypeServiceFabric msiType = 7
)

// managedIdentityClient provides the base for authenticating in managed identity environments
Expand Down Expand Up @@ -109,7 +112,7 @@ func (c *managedIdentityClient) createAccessToken(res *azcore.Response) (*azcore
Token string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn wrappedNumber `json:"expires_in,omitempty"` // this field should always return the number of seconds for which a token is valid
ExpiresOn string `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string
ExpiresOn interface{} `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string
catalinaperalta marked this conversation as resolved.
Show resolved Hide resolved
}{}
if err := res.UnmarshalAsJSON(&value); err != nil {
return nil, fmt.Errorf("internal AccessToken: %w", err)
Expand All @@ -121,19 +124,29 @@ func (c *managedIdentityClient) createAccessToken(res *azcore.Response) (*azcore
}
return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Now().Add(time.Second * time.Duration(expiresIn)).UTC()}, nil
}
if expiresOn, err := strconv.Atoi(value.ExpiresOn); err == nil {
return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Now().Add(time.Second * time.Duration(expiresOn)).UTC()}, nil
}
// this is the case when expires_on is a time string
// this is the format of the string coming from the service
if expiresOn, err := time.Parse("1/2/2006 15:04:05 PM +00:00", value.ExpiresOn); err == nil { // the date string specified is for Windows OS
eo := expiresOn.UTC()
return &azcore.AccessToken{Token: value.Token, ExpiresOn: eo}, nil
} else if expiresOn, err := time.Parse("1/2/2006 15:04:05 +00:00", value.ExpiresOn); err == nil { // the date string specified is for Linux OS
eo := expiresOn.UTC()
return &azcore.AccessToken{Token: value.Token, ExpiresOn: eo}, nil
} else {
return nil, err
switch v := value.ExpiresOn.(type) {
case int:
// service fabric is one of the MSI environments that returns an int
return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(v), 0).UTC()}, nil
case float64:
return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(v), 0).UTC()}, nil
case string:
if expiresOn, err := strconv.Atoi(v); err == nil {
return &azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(expiresOn), 0).UTC()}, nil
}
// this is the case when expires_on is a time string
// this is the format of the string coming from the service
if expiresOn, err := time.Parse("1/2/2006 15:04:05 PM +00:00", v); err == nil { // the date string specified is for Windows OS
eo := expiresOn.UTC()
return &azcore.AccessToken{Token: value.Token, ExpiresOn: eo}, nil
} else if expiresOn, err := time.Parse("1/2/2006 15:04:05 +00:00", v); err == nil { // the date string specified is for Linux OS
eo := expiresOn.UTC()
return &azcore.AccessToken{Token: value.Token, ExpiresOn: eo}, nil
} else {
return nil, err
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
}
default:
return nil, &AuthenticationFailedError{msg: fmt.Sprintf("unsupported type received in expires_on: %T, %v", v, v)}
}
}

Expand All @@ -150,6 +163,8 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID
return nil, &AuthenticationFailedError{inner: err, msg: "Failed to retreive secret key from the identity endpoint."}
}
return c.createAzureArcAuthRequest(ctx, key, scopes)
case msiTypeServiceFabric:
return c.createServiceFabricAuthRequest(ctx, clientID, scopes)
case msiTypeCloudShell:
return c.createCloudShellAuthRequest(ctx, clientID, scopes)
default:
Expand Down Expand Up @@ -213,6 +228,23 @@ func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context,
return request, nil
}

func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id string, scopes []string) (*azcore.Request, error) {
request, err := azcore.NewRequest(ctx, http.MethodGet, c.endpoint)
if err != nil {
return nil, err
}
q := request.URL.Query()
request.Header.Set("Accept", "application/json")
request.Header.Set("Secret", os.Getenv(identityHeader))
q.Add("api-version", serviceFabricAPIVersion)
q.Add("resource", strings.Join(scopes, " "))
if id != "" {
q.Add(qpClientID, id)
}
request.URL.RawQuery = q.Encode()
return request, nil
}

func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resources []string) (string, error) {
// create the request to retreive the secret key challenge provided by the HIMDS service
request, err := azcore.NewRequest(ctx, http.MethodGet, c.endpoint)
Expand Down Expand Up @@ -296,6 +328,9 @@ func (c *managedIdentityClient) getMSIType() (msiType, error) {
c.endpoint = endpointEnvVar
if header := os.Getenv(identityHeader); header != "" { // if BOTH the env vars IDENTITY_ENDPOINT and IDENTITY_HEADER are set the msiType is AppService
c.msiType = msiTypeAppServiceV20190801
if thumbprint := os.Getenv(identityServerThumbprint); thumbprint != "" { // if IDENTITY_SERVER_THUMBPRINT is set the environment is Service Fabric
c.msiType = msiTypeServiceFabric
}
} else if arcIMDS := os.Getenv(arcIMDSEndpoint); arcIMDS != "" {
c.msiType = msiTypeAzureArc
} else {
Expand Down
23 changes: 22 additions & 1 deletion sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
appServiceWindowsSuccessResp = `{"access_token": "new_token", "expires_on": "9/14/2017 00:00:00 PM +00:00", "resource": "https://vault.azure.net", "token_type": "Bearer"}`
appServiceLinuxSuccessResp = `{"access_token": "new_token", "expires_on": "09/14/2017 00:00:00 +00:00", "resource": "https://vault.azure.net", "token_type": "Bearer"}`
expiresOnIntResp = `{"access_token": "new_token", "refresh_token": "", "expires_in": "", "expires_on": "1560974028", "not_before": "1560970130", "resource": "https://vault.azure.net", "token_type": "Bearer"}`
expiresOnNonStringIntResp = `{"access_token": "new_token", "refresh_token": "", "expires_in": "", "expires_on": 1560974028, "not_before": "1560970130", "resource": "https://vault.azure.net", "token_type": "Bearer"}`
)

func clearEnvVars(envVars ...string) {
Expand Down Expand Up @@ -294,7 +295,7 @@ func TestManagedIdentityCredential_CreateAppServiceAuthRequestV20170901(t *testi
}
}

func TestManagedIdentityCredential_CreateAccessTokenExpiresOnInt(t *testing.T) {
func TestManagedIdentityCredential_CreateAccessTokenExpiresOnStringInt(t *testing.T) {
resetEnvironmentVarsForTest()
srv, close := mock.NewServer()
defer close()
Expand Down Expand Up @@ -620,3 +621,23 @@ func TestManagedIdentityCredential_ResourceID_IMDS(t *testing.T) {
t.Fatalf("Unexpected resource ID in resource query param")
}
}

func TestManagedIdentityCredential_CreateAccessTokenExpiresOnInt(t *testing.T) {
resetEnvironmentVarsForTest()
srv, close := mock.NewServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(expiresOnNonStringIntResp)))
_ = os.Setenv("MSI_ENDPOINT", srv.URL())
_ = os.Setenv("MSI_SECRET", "secret")
defer clearEnvVars("MSI_ENDPOINT", "MSI_SECRET")
options := ManagedIdentityCredentialOptions{}
options.HTTPClient = srv
msiCred, err := NewManagedIdentityCredential(clientID, &options)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err = msiCred.GetToken(context.Background(), azcore.TokenRequestOptions{Scopes: []string{msiScope}})
if err != nil {
t.Fatalf("Received an error when attempting to retrieve a token")
}
}