Skip to content

Commit

Permalink
feat: add additional checks before using S2A (#2103)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmenxk authored Aug 8, 2023
1 parent 8029f73 commit c62e5c6
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 11 deletions.
40 changes: 29 additions & 11 deletions internal/cba.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,10 @@ func getTransportConfig(settings *DialSettings) (*transportConfig, error) {
s2aMTLSEndpoint: "",
}

// Check the env to determine whether to use S2A.
if !isGoogleS2AEnabled() {
if !shouldUseS2A(clientCertSource, settings) {
return &defaultTransportConfig, nil
}

// If client cert is found, use that over S2A.
// If MTLS is not enabled for the endpoint, skip S2A.
if clientCertSource != nil || !mtlsEndpointEnabledForS2A() {
return &defaultTransportConfig, nil
}
s2aMTLSEndpoint := settings.DefaultMTLSEndpoint
// If there is endpoint override, honor it.
if settings.Endpoint != "" {
Expand All @@ -118,10 +112,6 @@ func getTransportConfig(settings *DialSettings) (*transportConfig, error) {
}, nil
}

func isGoogleS2AEnabled() bool {
return strings.ToLower(os.Getenv(googleAPIUseS2AEnv)) == "true"
}

// getClientCertificateSource returns a default client certificate source, if
// not provided by the user.
//
Expand Down Expand Up @@ -275,8 +265,36 @@ func GetHTTPTransportConfigAndEndpoint(settings *DialSettings) (cert.Source, fun
return nil, dialTLSContextFunc, config.s2aMTLSEndpoint, nil
}

func shouldUseS2A(clientCertSource cert.Source, settings *DialSettings) bool {
// If client cert is found, use that over S2A.
if clientCertSource != nil {
return false
}
// If EXPERIMENTAL_GOOGLE_API_USE_S2A is not set to true, skip S2A.
if !isGoogleS2AEnabled() {
return false
}
// If DefaultMTLSEndpoint is not set, skip S2A.
if settings.DefaultMTLSEndpoint == "" {
return false
}
// If MTLS is not enabled for this endpoint, skip S2A.
if !mtlsEndpointEnabledForS2A() {
return false
}
// If custom HTTP client is provided, skip S2A.
if settings.HTTPClient != nil {
return false
}
return true
}

// mtlsEndpointEnabledForS2A checks if the endpoint is indeed MTLS-enabled, so that we can use S2A for MTLS connection.
var mtlsEndpointEnabledForS2A = func() bool {
// TODO(xmenxk): determine this via discovery config.
return true
}

func isGoogleS2AEnabled() bool {
return strings.ToLower(os.Getenv(googleAPIUseS2AEnv)) == "true"
}
24 changes: 24 additions & 0 deletions internal/cba_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package internal

import (
"crypto/tls"
"net/http"
"os"
"testing"
"time"
Expand Down Expand Up @@ -278,6 +279,29 @@ func TestGetHTTPTransportConfigAndEndpoint(t *testing.T) {
testOverrideEndpoint,
false,
},
{
"no client cert, S2A address not empty, but DefaultMTLSEndpoint is not set",
&DialSettings{
DefaultMTLSEndpoint: "",
DefaultEndpoint: testRegularEndpoint,
},
validConfigResp,
func() bool { return true },
testRegularEndpoint,
true,
},
{
"no client cert, endpoint is MTLS enabled, S2A address not empty, custom HTTP client",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
HTTPClient: http.DefaultClient,
},
validConfigResp,
func() bool { return true },
testRegularEndpoint,
true,
},
}

defer setupTest()()
Expand Down

0 comments on commit c62e5c6

Please sign in to comment.