Skip to content

Commit

Permalink
refactor: move TLS config creation to ConnectionInfo
Browse files Browse the repository at this point in the history
This refactor moves all TLS configuration into the connection info type
and leaves room for altering the configuration based on the request
connection path, rather than hiding the configuration deep within the
code that retrieves the ephemeral certificate.
  • Loading branch information
enocom committed Apr 2, 2024
1 parent a45707d commit 3185637
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 214 deletions.
20 changes: 5 additions & 15 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
// The TLS handshake will not fail on an expired client certificate. It's
// not until the first read where the client cert error will be surfaced.
// So check that the certificate is valid before proceeding.
if !validClientCert(cn, d.logger, ci.Conf) {
if !validClientCert(cn, d.logger, ci.Expiration) {
d.logger.Debugf("[%v] Refreshing certificate now", cn.String())
i.ForceRefresh()
// Block on refreshed connection info
Expand Down Expand Up @@ -313,7 +313,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
}
}

tlsConn := tls.Client(conn, ci.Conf)
tlsConn := tls.Client(conn, ci.TLSConfig())
err = tlsConn.HandshakeContext(ctx)
if err != nil {
d.logger.Debugf("[%v] TLS handshake failed: %v", cn.String(), err)
Expand All @@ -339,28 +339,18 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
// validClientCert checks that the ephemeral client certificate retrieved from
// the cache is unexpired. The time comparisons strip the monotonic clock value
// to ensure an accurate result, even after laptop sleep.
func validClientCert(cn instance.ConnName, l debug.Logger, c *tls.Config) bool {
// The following conditions should be impossible (no certs, nil leaf), but
// just in case there's an unknown edge case, check assumptions before
// proceeding.
if len(c.Certificates) == 0 {
return false
}
if c.Certificates[0].Leaf == nil {
return false
}
func validClientCert(cn instance.ConnName, l debug.Logger, expiration time.Time) bool {
// Use UTC() to strip monotonic clock value to guard against inaccurate
// comparisons, especially after laptop sleep.
// See the comments on the monotonic clock in the Go documentation for
// details: https://pkg.go.dev/time#hdr-Monotonic_Clocks
now := time.Now().UTC()
expiration := c.Certificates[0].Leaf.NotAfter.UTC()
valid := expiration.After(now)
valid := expiration.UTC().After(now)
l.Debugf(
"[%v] Now = %v, Current cert expiration = %v",
cn.String(),
now.Format(time.RFC3339),
expiration.Format(time.RFC3339),
expiration.UTC().Format(time.RFC3339),
)
l.Debugf("[%v] Cert is valid = %v", cn.String(), valid)
return valid
Expand Down
26 changes: 10 additions & 16 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ package cloudsqlconn

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -630,8 +628,8 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
badCN, _ := instance.ParseConnName(badInstanceConnectionName)
spy := &spyConnectionInfoCache{
connectInfoCalls: []struct {
tls *tls.Config
err error
info cloudsql.ConnectionInfo
err error
}{{
err: errors.New("connect info failed"),
}},
Expand Down Expand Up @@ -671,18 +669,14 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
cn, _ := instance.ParseConnName(icn)
spy := &spyConnectionInfoCache{
connectInfoCalls: []struct {
tls *tls.Config
err error
info cloudsql.ConnectionInfo
err error
}{
// First call returns expired certificate
{
tls: &tls.Config{
Certificates: []tls.Certificate{{
Leaf: &x509.Certificate{
// Certificate expired 10 hours ago.
NotAfter: time.Now().Add(-10 * time.Hour),
},
}},
// Certificate expired 10 hours ago.
info: cloudsql.ConnectionInfo{
Expiration: time.Now().Add(-10 * time.Hour),
},
},
// Second call errors to validate error path
Expand Down Expand Up @@ -723,8 +717,8 @@ type spyConnectionInfoCache struct {
mu sync.Mutex
connectInfoIndex int
connectInfoCalls []struct {
tls *tls.Config
err error
info cloudsql.ConnectionInfo
err error
}
closeWasCalled bool
forceRefreshWasCalled bool
Expand All @@ -739,7 +733,7 @@ func (s *spyConnectionInfoCache) ConnectionInfo(
defer s.mu.Unlock()
res := s.connectInfoCalls[s.connectInfoIndex]
s.connectInfoIndex++
return cloudsql.ConnectionInfo{Conf: res.tls}, res.err
return res.info, res.err
}

func (s *spyConnectionInfoCache) ForceRefresh() {
Expand Down
87 changes: 77 additions & 10 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (r *refreshOperation) isValid() bool {
default:
return false
case <-r.ready:
if r.err != nil || time.Now().After(r.result.Expiry.Round(0)) {
if r.err != nil || time.Now().After(r.result.Expiration.Round(0)) {
return false
}
return true
Expand Down Expand Up @@ -175,11 +175,13 @@ func (i *RefreshAheadCache) Close() error {
// ConnectionInfo contains all necessary information to connect securely to the
// server-side Proxy running on a Cloud SQL instance.
type ConnectionInfo struct {
addrs map[string]string
ServerCaCert *x509.Certificate
DBVersion string
Conf *tls.Config
Expiry time.Time
ConnectionName instance.ConnName
ClientCertificate tls.Certificate
ServerCaCert *x509.Certificate
DBVersion string
Expiration time.Time

addrs map[string]string
}

// Addr returns the IP address or DNS name for the given IP type.
Expand All @@ -202,14 +204,79 @@ func (c ConnectionInfo) Addr(ipType string) (string, error) {
if !ok {
err := errtype.NewConfigError(
fmt.Sprintf("instance does not have IP of type %q", ipType),
// i.connName.String(),
"TODO",
c.ConnectionName.String(),
)
return "", err
}
return addr, nil
}

// TLSConfig constructs a TLS configuration for the given connection info.
func (c ConnectionInfo) TLSConfig() *tls.Config {
pool := x509.NewCertPool()
pool.AddCert(c.ServerCaCert)
return &tls.Config{
ServerName: c.ConnectionName.String(),
Certificates: []tls.Certificate{c.ClientCertificate},
RootCAs: pool,
// We need to set InsecureSkipVerify to true due to
// https://github.com/GoogleCloudPlatform/cloudsql-proxy/issues/194
// https://tip.golang.org/doc/go1.11#crypto/x509
//
// Since we have a secure channel to the Cloud SQL API which we use to
// retrieve the certificates, we instead need to implement our own
// VerifyPeerCertificate function that will verify that the certificate
// is OK.
InsecureSkipVerify: true,

Check failure

Code scanning / CodeQL

Disabled TLS certificate check High

InsecureSkipVerify should not be used in production code.
VerifyPeerCertificate: verifyPeerCertificateFunc(c.ConnectionName, pool),
MinVersion: tls.VersionTLS13,
}
}

// verifyPeerCertificateFunc creates a VerifyPeerCertificate func that
// verifies that the peer certificate is in the cert pool. We need to define
// our own because CloudSQL instances use the instance name (e.g.,
// my-project:my-instance) instead of a valid domain name for the certificate's
// Common Name.
func verifyPeerCertificateFunc(
cn instance.ConnName, pool *x509.CertPool,
) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return errtype.NewDialError(
"no certificate to verify", cn.String(), nil,
)
}

cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return errtype.NewDialError(
"failed to parse X.509 certificate", cn.String(), err,
)
}

opts := x509.VerifyOptions{Roots: pool}
if _, err = cert.Verify(opts); err != nil {
return errtype.NewDialError(
"failed to verify certificate", cn.String(), err,
)
}

certInstanceName := fmt.Sprintf("%s:%s", cn.Project(), cn.Name())
if cert.Subject.CommonName != certInstanceName {
return errtype.NewDialError(
fmt.Sprintf(
"certificate had CN %q, expected %q",
cert.Subject.CommonName, certInstanceName,
),
cn.String(),
nil,
)
}
return nil
}
}

// ConnectionInfo returns an IP address specified by ipType (i.e., public or
// private) and a TLS config that can be used to connect to a Cloud SQL
// instance.
Expand Down Expand Up @@ -353,7 +420,7 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
i.logger.Debugf(
"[%v] Current certificate expiration = %v",
i.connName.String(),
r.result.Expiry.UTC().Format(time.RFC3339),
r.result.Expiration.UTC().Format(time.RFC3339),
)
default:
i.logger.Debugf(
Expand Down Expand Up @@ -392,7 +459,7 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
// Update the current results, and schedule the next refresh in
// the future
i.cur = r
t := refreshDuration(time.Now(), i.cur.result.Expiry)
t := refreshDuration(time.Now(), i.cur.result.Expiration)
i.logger.Debugf(
"[%v] Connection info refresh operation scheduled at %v (now + %v)",
i.connName.String(),
Expand Down
119 changes: 113 additions & 6 deletions internal/cloudsql/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"testing"
"time"

"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/instance"
"cloud.google.com/go/cloudsqlconn/internal/mock"
)
Expand Down Expand Up @@ -87,10 +91,12 @@ func TestInstanceEngineVersion(t *testing.T) {
}
}

func TestConnectInfo(t *testing.T) {
func TestConnectionInfo(t *testing.T) {
ctx := context.Background()
wantAddr := "0.0.0.0"
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", mock.WithPublicIP(wantAddr))
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance", mock.WithPublicIP(wantAddr),
)
client, cleanup, err := mock.NewSQLAdminService(
ctx,
mock.InstanceGetSuccess(inst, 1),
Expand Down Expand Up @@ -125,14 +131,115 @@ func TestConnectInfo(t *testing.T) {
wantAddr, got,
)
}
}

func TestConnectionInfoTLSConfig(t *testing.T) {
cn := testInstanceConnName()
i := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
// Generate a client certificate with the client's public key and signed by
// the server's private key
cert, err := i.ClientCert(&RSAKey.PublicKey)
if err != nil {
t.Fatal(err)
}
// Now parse the bytes back out as structured data
// TODO: this should be done in the ClientCert method and not here.
b, _ := pem.Decode(cert)
clientCert, err := x509.ParseCertificate(b.Bytes)
if err != nil {
t.Fatal(err)
}

// Now self sign the server's cert
// TODO: this also should return structured data and handle the PEM
// encoding elsewhere
certBytes, err := mock.SelfSign(i.Cert, i.Key)
if err != nil {
t.Fatal(err)
}
b, _ = pem.Decode(certBytes)
serverCert, err := x509.ParseCertificate(b.Bytes)
if err != nil {
t.Fatal(err)
}

// Assemble a connection info with the raw and parsed client cert
// and the self-signed server certificate
ci := ConnectionInfo{
ConnectionName: cn,
ClientCertificate: tls.Certificate{
Certificate: [][]byte{clientCert.Raw},
PrivateKey: RSAKey,
Leaf: clientCert,
},
ServerCaCert: serverCert,
DBVersion: "doesn't matter here",
Expiration: clientCert.NotAfter,
}

got := ci.TLSConfig()
wantServerName := cn.String()
if got.ServerName != wantServerName {
t.Fatalf(
"ConnectInfo return unexpected server name in TLS Config, "+
"want = %v, got = %v",
wantServerName, got.ServerName,
)
}

if got.MinVersion != tls.VersionTLS13 {
t.Fatalf(
"want TLS 1.3, got = %v", got.MinVersion,
)
}

if got.Certificates[0].Leaf != ci.ClientCertificate.Leaf {
t.Fatal("leaf certificates do not match")
}

verifyPeerCert := got.VerifyPeerCertificate
err = verifyPeerCert([][]byte{serverCert.Raw}, nil)
if err != nil {
t.Fatalf("expected to verify peer cert, got error: %v", err)
}

err = verifyPeerCert(nil, nil)
var wantErr *errtype.DialError
if !errors.As(err, &wantErr) {
t.Fatalf(
"when verify peer cert fails, want = %T, got = %v", wantErr, err,
)
}

wantServerName := "my-project:my-region:my-instance"
if ci.Conf.ServerName != wantServerName {
// Ensure invalid certs result in an error
err = verifyPeerCert([][]byte{[]byte("not a cert")}, nil)
if !errors.As(err, &wantErr) {
t.Fatalf(
"ConnectInfo return unexpected server name in TLS Config, want = %v, got = %v",
wantServerName, ci.Conf.ServerName,
"when verify fails on invalid cert, want = %T, got = %v",
wantErr, err,
)
}

// Ensure the common name is verified againsts the expected name
badCert := mock.GenerateCertWithCommonName(i, "wrong:wrong")
err = verifyPeerCert([][]byte{badCert}, nil)
if !errors.As(err, &wantErr) {
t.Fatalf(
"when common names mismatch, want = %T, got = %v", wantErr, err,
)
}

// Verify an unreconigzed authority is rejected
other := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name())
cert, err = mock.SelfSign(other.Cert, other.Key)
if err != nil {
t.Fatalf("failed to sign certificate: %v", err)
}
b, _ = pem.Decode(cert)
err = verifyPeerCert([][]byte{b.Bytes}, nil)
if !errors.As(err, &wantErr) {
t.Fatalf("when certification fails, want = %T, got = %v", wantErr, err)
}
}

func TestConnectInfoAutoIP(t *testing.T) {
Expand Down
Loading

0 comments on commit 3185637

Please sign in to comment.