Skip to content

Commit

Permalink
refactor: move TLS config creation to ConnectionInfo
Browse files Browse the repository at this point in the history
This refactor moves all TLS configuration into the connection info type
and leaves room for altering the configuration based on the request
connection path, rather than hiding the configuration deep within the
code that retrieves the ephemeral certificate.
  • Loading branch information
enocom committed Apr 1, 2024
1 parent 942f2e3 commit db5defe
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 214 deletions.
20 changes: 5 additions & 15 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
// The TLS handshake will not fail on an expired client certificate. It's
// not until the first read where the client cert error will be surfaced.
// So check that the certificate is valid before proceeding.
if !validClientCert(cn, d.logger, ci.Conf) {
if !validClientCert(cn, d.logger, ci.Expiration) {
d.logger.Debugf("[%v] Refreshing certificate now", cn.String())
i.ForceRefresh()
// Block on refreshed connection info
Expand Down Expand Up @@ -313,7 +313,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
}
}

tlsConn := tls.Client(conn, ci.Conf)
tlsConn := tls.Client(conn, ci.TLSConfig())
err = tlsConn.HandshakeContext(ctx)
if err != nil {
d.logger.Debugf("[%v] TLS handshake failed: %v", cn.String(), err)
Expand All @@ -339,28 +339,18 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
// validClientCert checks that the ephemeral client certificate retrieved from
// the cache is unexpired. The time comparisons strip the monotonic clock value
// to ensure an accurate result, even after laptop sleep.
func validClientCert(cn instance.ConnName, l debug.Logger, c *tls.Config) bool {
// The following conditions should be impossible (no certs, nil leaf), but
// just in case there's an unknown edge case, check assumptions before
// proceeding.
if len(c.Certificates) == 0 {
return false
}
if c.Certificates[0].Leaf == nil {
return false
}
func validClientCert(cn instance.ConnName, l debug.Logger, expiration time.Time) bool {
// Use UTC() to strip monotonic clock value to guard against inaccurate
// comparisons, especially after laptop sleep.
// See the comments on the monotonic clock in the Go documentation for
// details: https://pkg.go.dev/time#hdr-Monotonic_Clocks
now := time.Now().UTC()
expiration := c.Certificates[0].Leaf.NotAfter.UTC()
valid := expiration.After(now)
valid := expiration.UTC().After(now)
l.Debugf(
"[%v] Now = %v, Current cert expiration = %v",
cn.String(),
now.Format(time.RFC3339),
expiration.Format(time.RFC3339),
expiration.UTC().Format(time.RFC3339),
)
l.Debugf("[%v] Cert is valid = %v", cn.String(), valid)
return valid
Expand Down
26 changes: 10 additions & 16 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ package cloudsqlconn

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -630,8 +628,8 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
badCN, _ := instance.ParseConnName(badInstanceConnectionName)
spy := &spyConnectionInfoCache{
connectInfoCalls: []struct {
tls *tls.Config
err error
info cloudsql.ConnectionInfo
err error
}{{
err: errors.New("connect info failed"),
}},
Expand Down Expand Up @@ -671,18 +669,14 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
cn, _ := instance.ParseConnName(icn)
spy := &spyConnectionInfoCache{
connectInfoCalls: []struct {
tls *tls.Config
err error
info cloudsql.ConnectionInfo
err error
}{
// First call returns expired certificate
{
tls: &tls.Config{
Certificates: []tls.Certificate{{
Leaf: &x509.Certificate{
// Certificate expired 10 hours ago.
NotAfter: time.Now().Add(-10 * time.Hour),
},
}},
// Certificate expired 10 hours ago.
info: cloudsql.ConnectionInfo{
Expiration: time.Now().Add(-10 * time.Hour),
},
},
// Second call errors to validate error path
Expand Down Expand Up @@ -723,8 +717,8 @@ type spyConnectionInfoCache struct {
mu sync.Mutex
connectInfoIndex int
connectInfoCalls []struct {
tls *tls.Config
err error
info cloudsql.ConnectionInfo
err error
}
closeWasCalled bool
forceRefreshWasCalled bool
Expand All @@ -739,7 +733,7 @@ func (s *spyConnectionInfoCache) ConnectionInfo(
defer s.mu.Unlock()
res := s.connectInfoCalls[s.connectInfoIndex]
s.connectInfoIndex++
return cloudsql.ConnectionInfo{Conf: res.tls}, res.err
return res.info, res.err
}

func (s *spyConnectionInfoCache) ForceRefresh() {
Expand Down
87 changes: 77 additions & 10 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (r *refreshOperation) isValid() bool {
default:
return false
case <-r.ready:
if r.err != nil || time.Now().After(r.result.Expiry.Round(0)) {
if r.err != nil || time.Now().After(r.result.Expiration.Round(0)) {
return false
}
return true
Expand Down Expand Up @@ -175,11 +175,13 @@ func (i *RefreshAheadCache) Close() error {
// ConnectionInfo contains all necessary information to connect securely to the
// server-side Proxy running on a Cloud SQL instance.
type ConnectionInfo struct {
addrs map[string]string
ServerCaCert *x509.Certificate
DBVersion string
Conf *tls.Config
Expiry time.Time
ConnectionName instance.ConnName
ClientCertificate tls.Certificate
ServerCaCert *x509.Certificate
DBVersion string
Expiration time.Time

addrs map[string]string
}

// Addr returns the IP address or DNS name for the given IP type.
Expand All @@ -202,14 +204,79 @@ func (c ConnectionInfo) Addr(ipType string) (string, error) {
if !ok {
err := errtype.NewConfigError(
fmt.Sprintf("instance does not have IP of type %q", ipType),
// i.connName.String(),
"TODO",
c.ConnectionName.String(),
)
return "", err
}
return addr, nil
}

// TLSConfig constructs a TLS configuration for the given connection info.
func (c ConnectionInfo) TLSConfig() *tls.Config {
pool := x509.NewCertPool()
pool.AddCert(c.ServerCaCert)
return &tls.Config{
ServerName: c.ConnectionName.String(),
Certificates: []tls.Certificate{c.ClientCertificate},
RootCAs: pool,
// We need to set InsecureSkipVerify to true due to
// https://github.com/GoogleCloudPlatform/cloudsql-proxy/issues/194
// https://tip.golang.org/doc/go1.11#crypto/x509
//
// Since we have a secure channel to the Cloud SQL API which we use to
// retrieve the certificates, we instead need to implement our own
// VerifyPeerCertificate function that will verify that the certificate
// is OK.
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyPeerCertificateFunc(c.ConnectionName, pool),
MinVersion: tls.VersionTLS13,
}
}

// verifyPeerCertificateFunc creates a VerifyPeerCertificate func that
// verifies that the peer certificate is in the cert pool. We need to define
// our own because CloudSQL instances use the instance name (e.g.,
// my-project:my-instance) instead of a valid domain name for the certificate's
// Common Name.
func verifyPeerCertificateFunc(
cn instance.ConnName, pool *x509.CertPool,
) func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return errtype.NewDialError(
"no certificate to verify", cn.String(), nil,
)
}

cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return errtype.NewDialError(
"failed to parse X.509 certificate", cn.String(), err,
)
}

opts := x509.VerifyOptions{Roots: pool}
if _, err = cert.Verify(opts); err != nil {
return errtype.NewDialError(
"failed to verify certificate", cn.String(), err,
)
}

certInstanceName := fmt.Sprintf("%s:%s", cn.Project(), cn.Name())
if cert.Subject.CommonName != certInstanceName {
return errtype.NewDialError(
fmt.Sprintf(
"certificate had CN %q, expected %q",
cert.Subject.CommonName, certInstanceName,
),
cn.String(),
nil,
)
}
return nil
}
}

// ConnectionInfo returns an IP address specified by ipType (i.e., public or
// private) and a TLS config that can be used to connect to a Cloud SQL
// instance.
Expand Down Expand Up @@ -353,7 +420,7 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
i.logger.Debugf(
"[%v] Current certificate expiration = %v",
i.connName.String(),
r.result.Expiry.UTC().Format(time.RFC3339),
r.result.Expiration.UTC().Format(time.RFC3339),
)
default:
i.logger.Debugf(
Expand Down Expand Up @@ -392,7 +459,7 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
// Update the current results, and schedule the next refresh in
// the future
i.cur = r
t := refreshDuration(time.Now(), i.cur.result.Expiry)
t := refreshDuration(time.Now(), i.cur.result.Expiration)
i.logger.Debugf(
"[%v] Connection info refresh operation scheduled at %v (now + %v)",
i.connName.String(),
Expand Down
Loading

0 comments on commit db5defe

Please sign in to comment.