Skip to content

Commit

Permalink
fix(transport): relax universe checks (#2376)
Browse files Browse the repository at this point in the history
This PR serves to relax universe mismatch checks, as a user specifying credentials via the option.WithTokenSource has provided a form of credential where the universe information is not directly accessible.

In these cases, we no longer perform the universe mismatch checks.

This PR also augments the existing mismatch checks to exercise a more diverse set of client options.
  • Loading branch information
shollyman authored Jan 26, 2024
1 parent a8d9414 commit 55b0516
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 15 deletions.
8 changes: 6 additions & 2 deletions transport/grpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,12 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
if err != nil {
return nil, err
}
if o.GetUniverseDomain() != credsUniverseDomain {
return nil, internal.ErrUniverseNotMatch(o.GetUniverseDomain(), credsUniverseDomain)
if o.TokenSource == nil {
// We only validate non-tokensource creds, as TokenSource-based credentials
// don't propagate universe.
if o.GetUniverseDomain() != credsUniverseDomain {
return nil, internal.ErrUniverseNotMatch(o.GetUniverseDomain(), credsUniverseDomain)
}
}
grpcOpts = append(grpcOpts, grpc.WithPerRPCCredentials(grpcTokenSource{
TokenSource: oauth.TokenSource{TokenSource: creds.TokenSource},
Expand Down
8 changes: 6 additions & 2 deletions transport/http/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna
if err != nil {
return nil, err
}
if settings.GetUniverseDomain() != credsUniverseDomain {
return nil, internal.ErrUniverseNotMatch(settings.GetUniverseDomain(), credsUniverseDomain)
if settings.TokenSource == nil {
// We only validate non-tokensource creds, as TokenSource-based credentials
// don't propagate universe.
if settings.GetUniverseDomain() != credsUniverseDomain {
return nil, internal.ErrUniverseNotMatch(settings.GetUniverseDomain(), credsUniverseDomain)
}
}
paramTransport.quotaProject = internal.GetQuotaProject(creds, settings.QuotaProject)
ts := creds.TokenSource
Expand Down
94 changes: 83 additions & 11 deletions transport/http/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ package http
import (
"context"
"fmt"
"strings"
"testing"

"go.opencensus.io/plugin/ochttp"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/internal"
"google.golang.org/api/option"
)

Expand Down Expand Up @@ -41,16 +41,88 @@ func TestNewClient(t *testing.T) {
}
}

func TestNewClient_MismatchedUniverseDomainCreds(t *testing.T) {
func TestNewClient_MismatchedUniverseChecks(t *testing.T) {

rootTokenScope := "https://www.googleapis.com/auth/cloud-platform"
universeDomain := "example.com"
universeDomainDefault := "googleapis.com"
creds := &google.Credentials{} // universeDomainDefault
wantErr := internal.ErrUniverseNotMatch(universeDomain, universeDomainDefault)
_, _, err := NewClient(context.Background(), option.WithUniverseDomain(universeDomain),
option.WithCredentials(creds), option.WithScopes(rootTokenScope))

if err.Error() != wantErr.Error() {
t.Fatalf("got: %v, want: %v", err, wantErr)
otherUniverse := "example.com"
defaultUniverse := "googleapis.com"
fakeCreds := `
{"type": "service_account",
"project_id": "some-project",
"universe_domain": "UNIVERSE"}`

// utility function to make a fake credential quickly
makeFakeCredF := func(universe string) option.ClientOption {
data := []byte(strings.ReplaceAll(fakeCreds, "UNIVERSE", universe))
creds, _ := google.CredentialsFromJSON(context.Background(), data, rootTokenScope)
return option.WithCredentials(creds)
}

testCases := []struct {
description string
opts []option.ClientOption
wantErr bool
}{
{
description: "default creds and no universe",
opts: []option.ClientOption{
option.WithCredentials(&google.Credentials{}),
},
wantErr: false,
},
{
description: "default creds and default universe",
opts: []option.ClientOption{
option.WithCredentials(&google.Credentials{}),
option.WithUniverseDomain(defaultUniverse),
},
wantErr: false,
},
{
description: "default creds and mismatched universe",
opts: []option.ClientOption{
option.WithCredentials(&google.Credentials{}),
option.WithUniverseDomain(otherUniverse),
},
wantErr: true,
},
{
description: "foreign universe creds and default universe",
opts: []option.ClientOption{
makeFakeCredF(otherUniverse),
option.WithUniverseDomain(defaultUniverse),
},
wantErr: true,
},
{
description: "foreign universe creds and foreign universe",
opts: []option.ClientOption{
makeFakeCredF(otherUniverse),
option.WithUniverseDomain(otherUniverse),
},
wantErr: false,
},
{
description: "tokensource + mismatched universe",
opts: []option.ClientOption{
option.WithTokenSource(oauth2.StaticTokenSource(&oauth2.Token{})),
option.WithUniverseDomain(otherUniverse),
},
wantErr: false,
},
}

for _, tc := range testCases {
opts := []option.ClientOption{
option.WithScopes(rootTokenScope),
}
opts = append(opts, tc.opts...)
_, _, gotErr := NewClient(context.Background(), opts...)
if tc.wantErr && gotErr == nil {
t.Errorf("%q: wanted error, got none", tc.description)
}
if !tc.wantErr && gotErr != nil {
t.Errorf("%q: wanted success, got err: %v", tc.description, gotErr)
}
}
}

0 comments on commit 55b0516

Please sign in to comment.