Skip to content

Commit

Permalink
fix(cba): Update credsNewAuth to support oauth2 over mTLS (#2610)
Browse files Browse the repository at this point in the history
this logic is ported over from "baseCreds" from the same file.
  • Loading branch information
andyrzhao authored Jun 12, 2024
1 parent ebc44d1 commit 953f728
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 13 deletions.
40 changes: 29 additions & 11 deletions internal/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@ func Creds(ctx context.Context, ds *DialSettings) (*google.Credentials, error) {
return creds, nil
}

// GetOAuth2Configuration determines configurations for the OAuth2 transport, which is separate from the API transport.
// The OAuth2 transport and endpoint will be configured for mTLS if applicable.
func GetOAuth2Configuration(ctx context.Context, settings *DialSettings) (string, *http.Client, error) {
clientCertSource, err := getClientCertificateSource(settings)
if err != nil {
return "", nil, err
}
tokenURL := oAuth2Endpoint(clientCertSource)
var oauth2Client *http.Client
if clientCertSource != nil {
tlsConfig := &tls.Config{
GetClientCertificate: clientCertSource,
}
oauth2Client = customHTTPClient(tlsConfig)
} else {
oauth2Client = oauth2.NewClient(ctx, nil)
}
return tokenURL, oauth2Client, nil
}

func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credentials, error) {
// Preserve old options behavior
if settings.InternalCredentials != nil {
Expand Down Expand Up @@ -80,13 +100,18 @@ func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credenti
aud = settings.DefaultAudience
}

tokenURL, oauth2Client, err := GetOAuth2Configuration(ctx, settings)
if err != nil {
return nil, err
}
creds, err := credentials.DetectDefault(&credentials.DetectOptions{
Scopes: scopes,
Audience: aud,
CredentialsFile: settings.CredentialsFile,
CredentialsJSON: settings.CredentialsJSON,
UseSelfSignedJWT: useSelfSignedJWT,
Client: oauth2.NewClient(ctx, nil),
TokenURL: tokenURL,
Client: oauth2Client,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -147,19 +172,12 @@ func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*g
var params google.CredentialsParams
params.Scopes = ds.GetScopes()

// Determine configurations for the OAuth2 transport, which is separate from the API transport.
// The OAuth2 transport and endpoint will be configured for mTLS if applicable.
clientCertSource, err := getClientCertificateSource(ds)
tokenURL, oauth2Client, err := GetOAuth2Configuration(ctx, ds)
if err != nil {
return nil, err
}
params.TokenURL = oAuth2Endpoint(clientCertSource)
if clientCertSource != nil {
tlsConfig := &tls.Config{
GetClientCertificate: clientCertSource,
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, customHTTPClient(tlsConfig))
}
params.TokenURL = tokenURL
ctx = context.WithValue(ctx, oauth2.HTTPClient, oauth2Client)

// By default, a standard OAuth 2.0 token source is created
cred, err := google.CredentialsFromJSONWithParams(ctx, data, params)
Expand Down
8 changes: 7 additions & 1 deletion transport/grpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna
defaultEndpointTemplate = ds.DefaultEndpoint
}

tokenURL, oauth2Client, err := internal.GetOAuth2Configuration(ctx, ds)
if err != nil {
return nil, err
}

pool, err := grpctransport.Dial(ctx, secure, &grpctransport.Options{
DisableTelemetry: ds.TelemetryDisabled,
DisableAuthentication: ds.NoAuth,
Expand All @@ -231,7 +236,8 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna
Audience: aud,
CredentialsFile: ds.CredentialsFile,
CredentialsJSON: ds.CredentialsJSON,
Client: oauth2.NewClient(ctx, nil),
TokenURL: tokenURL,
Client: oauth2Client,
},
InternalOptions: &grpctransport.InternalOptions{
EnableNonDefaultSAForDirectPath: ds.AllowNonDefaultServiceAccount,
Expand Down
7 changes: 6 additions & 1 deletion transport/http/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ func newClientNewAuth(ctx context.Context, base http.RoundTripper, ds *internal.
if ds.RequestReason != "" {
headers.Set("X-goog-request-reason", ds.RequestReason)
}
tokenURL, oauth2Client, err := internal.GetOAuth2Configuration(ctx, ds)
if err != nil {
return nil, err
}
client, err := httptransport.NewClient(&httptransport.Options{
DisableTelemetry: ds.TelemetryDisabled,
DisableAuthentication: ds.NoAuth,
Expand All @@ -121,7 +125,8 @@ func newClientNewAuth(ctx context.Context, base http.RoundTripper, ds *internal.
Audience: aud,
CredentialsFile: ds.CredentialsFile,
CredentialsJSON: ds.CredentialsJSON,
Client: oauth2.NewClient(ctx, nil),
TokenURL: tokenURL,
Client: oauth2Client,
},
InternalOptions: &httptransport.InternalOptions{
EnableJWTWithScope: ds.EnableJwtWithScope,
Expand Down

0 comments on commit 953f728

Please sign in to comment.