diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index f8628605..0b703282 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -18,6 +18,8 @@ import ( "encoding/pem" "errors" "fmt" + "os" + "strings" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" @@ -315,16 +317,21 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client if err != nil { return Client{}, err } - + region := os.Getenv("MSAL_FORCE_REGION") opts := clientOptions{ authority: authority, // if the caller specified a token provider, it will handle all details of authentication, using Client only as a token cache disableInstanceDiscovery: cred.tokenProvider != nil, httpClient: shared.DefaultClient, + azureRegion: region, } for _, o := range options { o(&opts) } + if strings.EqualFold(opts.azureRegion, "DisableMsalForceRegion") { + opts.azureRegion = "" + } + baseOpts := []base.Option{ base.WithCacheAccessor(opts.accessor), base.WithClientCapabilities(opts.capabilities), diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 28bad83e..7c635266 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -164,6 +164,76 @@ func TestAcquireTokenByCredential(t *testing.T) { } } +func TestRegionAutoEnable(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + tests := []struct { + region string + envRegion string + }{ + { + region: "", + envRegion: "envRegion", + }, + { + region: "region", + envRegion: "envRegion", + }, + { + region: "DisableMsalForceRegion", + envRegion: "envRegion", + }, + } + + for _, test := range tests { + lmo := "login.microsoftonline.com" + tenant := "tenant" + mockClient := mock.Client{} + if test.envRegion != "" { + err := os.Setenv("MSAL_FORCE_REGION", test.envRegion) + if err != nil { + t.Fatal(err) + } + } + var client Client + if test.region != "" { + client, err = New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(test.region)) + if err != nil { + t.Fatal(err) + } + } else { + client, err = New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + } + + t.Cleanup(func() { + os.Unsetenv("MSAL_FORCE_REGION") + }) + if test.region == "" { + if test.envRegion != "" { + if client.base.AuthParams.AuthorityInfo.Region != test.envRegion { + t.Fatalf("wanted %q, got %q", test.envRegion, client.base.AuthParams.AuthorityInfo.Region) + } + } + } else { + if test.region == "DisableMsalForceRegion" { + if client.base.AuthParams.AuthorityInfo.Region != "" { + t.Fatalf("wanted empty, got %q", client.base.AuthParams.AuthorityInfo.Region) + } + } else { + + if client.base.AuthParams.AuthorityInfo.Region != test.region { + t.Fatalf("wanted %q, got %q", test.region, client.base.AuthParams.AuthorityInfo.Region) + } + } + } + } +} + func TestAcquireTokenOnBehalfOf(t *testing.T) { // this test is an offline version of TestOnBehalfOf in integration_test.go cred, err := NewCredFromSecret(fakeSecret)