Skip to content

Commit

Permalink
Added Region auto enable
Browse files Browse the repository at this point in the history
Added Region auto enable in confidential client and its test
  • Loading branch information
4gust committed Oct 30, 2024
1 parent bf74752 commit 462d177
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
9 changes: 8 additions & 1 deletion apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"encoding/pem"
"errors"
"fmt"
"os"
"strings"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
Expand Down Expand Up @@ -315,16 +317,21 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client
if err != nil {
return Client{}, err
}

region := os.Getenv("MSAL_FORCE_REGION")
opts := clientOptions{
authority: authority,
// if the caller specified a token provider, it will handle all details of authentication, using Client only as a token cache
disableInstanceDiscovery: cred.tokenProvider != nil,
httpClient: shared.DefaultClient,
azureRegion: region,
}
for _, o := range options {
o(&opts)
}
if strings.EqualFold(opts.azureRegion, "DisableMsalForceRegion") {
opts.azureRegion = ""
}

baseOpts := []base.Option{
base.WithCacheAccessor(opts.accessor),
base.WithClientCapabilities(opts.capabilities),
Expand Down
70 changes: 70 additions & 0 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,76 @@ func TestAcquireTokenByCredential(t *testing.T) {
}
}

func TestRegionAutoEnable(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
tests := []struct {
region string
envRegion string
}{
{
region: "",
envRegion: "envRegion",
},
{
region: "region",
envRegion: "envRegion",
},
{
region: "DisableMsalForceRegion",
envRegion: "envRegion",
},
}

for _, test := range tests {
lmo := "login.microsoftonline.com"
tenant := "tenant"
mockClient := mock.Client{}
if test.envRegion != "" {
err := os.Setenv("MSAL_FORCE_REGION", test.envRegion)
if err != nil {
t.Fatal(err)
}
}
var client Client
if test.region != "" {
client, err = New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(test.region))
if err != nil {
t.Fatal(err)
}
} else {
client, err = New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
}

t.Cleanup(func() {
os.Unsetenv("MSAL_FORCE_REGION")
})
if test.region == "" {
if test.envRegion != "" {
if client.base.AuthParams.AuthorityInfo.Region != test.envRegion {
t.Fatalf("wanted %q, got %q", test.envRegion, client.base.AuthParams.AuthorityInfo.Region)
}
}
} else {
if test.region == "DisableMsalForceRegion" {
if client.base.AuthParams.AuthorityInfo.Region != "" {
t.Fatalf("wanted empty, got %q", client.base.AuthParams.AuthorityInfo.Region)
}
} else {

if client.base.AuthParams.AuthorityInfo.Region != test.region {
t.Fatalf("wanted %q, got %q", test.region, client.base.AuthParams.AuthorityInfo.Region)
}
}
}
}
}

func TestAcquireTokenOnBehalfOf(t *testing.T) {
// this test is an offline version of TestOnBehalfOf in integration_test.go
cred, err := NewCredFromSecret(fakeSecret)
Expand Down

0 comments on commit 462d177

Please sign in to comment.