diff --git a/internal/creds.go b/internal/creds.go index b6dbace4c97..e6c4fe90d42 100644 --- a/internal/creds.go +++ b/internal/creds.go @@ -42,6 +42,26 @@ func Creds(ctx context.Context, ds *DialSettings) (*google.Credentials, error) { return creds, nil } +// GetOAuth2Configuration determines configurations for the OAuth2 transport, which is separate from the API transport. +// The OAuth2 transport and endpoint will be configured for mTLS if applicable. +func GetOAuth2Configuration(ctx context.Context, settings *DialSettings) (string, *http.Client, error) { + clientCertSource, err := getClientCertificateSource(settings) + if err != nil { + return "", nil, err + } + tokenURL := oAuth2Endpoint(clientCertSource) + var oauth2Client *http.Client + if clientCertSource != nil { + tlsConfig := &tls.Config{ + GetClientCertificate: clientCertSource, + } + oauth2Client = customHTTPClient(tlsConfig) + } else { + oauth2Client = oauth2.NewClient(ctx, nil) + } + return tokenURL, oauth2Client, nil +} + func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credentials, error) { // Preserve old options behavior if settings.InternalCredentials != nil { @@ -80,13 +100,18 @@ func credsNewAuth(ctx context.Context, settings *DialSettings) (*google.Credenti aud = settings.DefaultAudience } + tokenURL, oauth2Client, err := GetOAuth2Configuration(ctx, settings) + if err != nil { + return nil, err + } creds, err := credentials.DetectDefault(&credentials.DetectOptions{ Scopes: scopes, Audience: aud, CredentialsFile: settings.CredentialsFile, CredentialsJSON: settings.CredentialsJSON, UseSelfSignedJWT: useSelfSignedJWT, - Client: oauth2.NewClient(ctx, nil), + TokenURL: tokenURL, + Client: oauth2Client, }) if err != nil { return nil, err @@ -147,19 +172,12 @@ func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*g var params google.CredentialsParams params.Scopes = ds.GetScopes() - // Determine configurations for the OAuth2 transport, which is separate from the API transport. - // The OAuth2 transport and endpoint will be configured for mTLS if applicable. - clientCertSource, err := getClientCertificateSource(ds) + tokenURL, oauth2Client, err := GetOAuth2Configuration(ctx, ds) if err != nil { return nil, err } - params.TokenURL = oAuth2Endpoint(clientCertSource) - if clientCertSource != nil { - tlsConfig := &tls.Config{ - GetClientCertificate: clientCertSource, - } - ctx = context.WithValue(ctx, oauth2.HTTPClient, customHTTPClient(tlsConfig)) - } + params.TokenURL = tokenURL + ctx = context.WithValue(ctx, oauth2.HTTPClient, oauth2Client) // By default, a standard OAuth 2.0 token source is created cred, err := google.CredentialsFromJSONWithParams(ctx, data, params) diff --git a/transport/grpc/dial.go b/transport/grpc/dial.go index 2e66d02b378..2d4f90c9c1f 100644 --- a/transport/grpc/dial.go +++ b/transport/grpc/dial.go @@ -218,6 +218,11 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna defaultEndpointTemplate = ds.DefaultEndpoint } + tokenURL, oauth2Client, err := internal.GetOAuth2Configuration(ctx, ds) + if err != nil { + return nil, err + } + pool, err := grpctransport.Dial(ctx, secure, &grpctransport.Options{ DisableTelemetry: ds.TelemetryDisabled, DisableAuthentication: ds.NoAuth, @@ -231,7 +236,8 @@ func dialPoolNewAuth(ctx context.Context, secure bool, poolSize int, ds *interna Audience: aud, CredentialsFile: ds.CredentialsFile, CredentialsJSON: ds.CredentialsJSON, - Client: oauth2.NewClient(ctx, nil), + TokenURL: tokenURL, + Client: oauth2Client, }, InternalOptions: &grpctransport.InternalOptions{ EnableNonDefaultSAForDirectPath: ds.AllowNonDefaultServiceAccount, diff --git a/transport/http/dial.go b/transport/http/dial.go index d1cd83b62d9..a36e24315ba 100644 --- a/transport/http/dial.go +++ b/transport/http/dial.go @@ -107,6 +107,10 @@ func newClientNewAuth(ctx context.Context, base http.RoundTripper, ds *internal. if ds.RequestReason != "" { headers.Set("X-goog-request-reason", ds.RequestReason) } + tokenURL, oauth2Client, err := internal.GetOAuth2Configuration(ctx, ds) + if err != nil { + return nil, err + } client, err := httptransport.NewClient(&httptransport.Options{ DisableTelemetry: ds.TelemetryDisabled, DisableAuthentication: ds.NoAuth, @@ -121,7 +125,8 @@ func newClientNewAuth(ctx context.Context, base http.RoundTripper, ds *internal. Audience: aud, CredentialsFile: ds.CredentialsFile, CredentialsJSON: ds.CredentialsJSON, - Client: oauth2.NewClient(ctx, nil), + TokenURL: tokenURL, + Client: oauth2Client, }, InternalOptions: &httptransport.InternalOptions{ EnableJWTWithScope: ds.EnableJwtWithScope,