Skip to content

Commit

Permalink
refactor: simplify internal interface
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom committed Apr 1, 2024
1 parent 2dae0e1 commit 942f2e3
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 137 deletions.
16 changes: 10 additions & 6 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func getDefaultKeys() (*rsa.PrivateKey, error) {
type connectionInfoCache interface {
OpenConns() *uint64

ConnectInfo(context.Context, string) (string, *tls.Config, error)
ConnectionInfo(context.Context) (cloudsql.ConnectionInfo, error)
InstanceEngineVersion(context.Context) (string, error)
UpdateRefresh(*bool)
ForceRefresh()
Expand Down Expand Up @@ -250,7 +250,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
var endInfo trace.EndSpanFunc
ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.InstanceInfo")
i := d.instance(cn, &cfg.useIAMAuthN)
addr, tlsConfig, err := i.ConnectInfo(ctx, cfg.ipType)
ci, err := i.ConnectionInfo(ctx)
if err != nil {
d.lock.Lock()
defer d.lock.Unlock()
Expand All @@ -268,11 +268,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
// The TLS handshake will not fail on an expired client certificate. It's
// not until the first read where the client cert error will be surfaced.
// So check that the certificate is valid before proceeding.
if !validClientCert(cn, d.logger, tlsConfig) {
if !validClientCert(cn, d.logger, ci.Conf) {
d.logger.Debugf("[%v] Refreshing certificate now", cn.String())
i.ForceRefresh()
// Block on refreshed connection info
addr, tlsConfig, err = i.ConnectInfo(ctx, cfg.ipType)
ci, err = i.ConnectionInfo(ctx)
if err != nil {
d.lock.Lock()
defer d.lock.Unlock()
Expand All @@ -287,6 +287,10 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
var connectEnd trace.EndSpanFunc
ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect")
defer func() { connectEnd(err) }()
addr, err := ci.Addr(cfg.ipType)
if err != nil {
return nil, err
}
addr = net.JoinHostPort(addr, serverProxyPort)
f := d.dialFunc
if cfg.dialFunc != nil {
Expand All @@ -309,7 +313,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
}
}

tlsConn := tls.Client(conn, tlsConfig)
tlsConn := tls.Client(conn, ci.Conf)
err = tlsConn.HandshakeContext(ctx)
if err != nil {
d.logger.Debugf("[%v] TLS handshake failed: %v", cn.String(), err)
Expand Down Expand Up @@ -452,7 +456,7 @@ func (d *Dialer) instance(cn instance.ConnName, useIAMAuthN *bool) connectionInf
useIAMAuthNDial = *useIAMAuthN
}
d.logger.Debugf("[%v] Connection info added to cache", cn.String())
i = cloudsql.NewInstance(
i = cloudsql.NewRefreshAheadCache(
cn,
d.logger,
d.sqladmin, d.key,
Expand Down
10 changes: 6 additions & 4 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/instance"
"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
"cloud.google.com/go/cloudsqlconn/internal/mock"
"golang.org/x/oauth2"
)
Expand Down Expand Up @@ -109,8 +110,7 @@ func TestDialWithAdminAPIErrors(t *testing.T) {
}

func TestDialWithConfigurationErrors(t *testing.T) {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance",
mock.WithCertExpiry(time.Now().Add(-time.Hour)))
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")

svc, cleanup, err := mock.NewSQLAdminService(
context.Background(),
Expand Down Expand Up @@ -732,12 +732,14 @@ type spyConnectionInfoCache struct {
connectionInfoCache
}

func (s *spyConnectionInfoCache) ConnectInfo(_ context.Context, _ string) (string, *tls.Config, error) {
func (s *spyConnectionInfoCache) ConnectionInfo(
context.Context,
) (cloudsql.ConnectionInfo, error) {
s.mu.Lock()
defer s.mu.Unlock()
res := s.connectInfoCalls[s.connectInfoIndex]
s.connectInfoIndex++
return "unused", res.tls, res.err
return cloudsql.ConnectionInfo{Conf: res.tls}, res.err
}

func (s *spyConnectionInfoCache) ForceRefresh() {
Expand Down
95 changes: 56 additions & 39 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -56,13 +57,13 @@ type refreshOperation struct {
ready chan struct{}
// timer that triggers refresh, can be used to cancel.
timer *time.Timer
result refreshResult
result ConnectionInfo
err error
}

// cancel prevents the instanceInfo from starting, if it hasn't already
// started. Returns true if timer was stopped successfully, or false if it has
// already started.
// cancel prevents the the refresh operation from starting, if it hasn't
// already started. Returns true if timer was stopped successfully, or false if
// it has already started.
func (r *refreshOperation) cancel() bool {
return r.timer.Stop()
}
Expand All @@ -75,18 +76,18 @@ func (r *refreshOperation) isValid() bool {
default:
return false
case <-r.ready:
if r.err != nil || time.Now().After(r.result.expiry.Round(0)) {
if r.err != nil || time.Now().After(r.result.Expiry.Round(0)) {
return false
}
return true
}
}

// Instance manages the information used to connect to the Cloud SQL instance
// by periodically calling the Cloud SQL Admin API. It automatically refreshes
// the required information approximately 4 minutes before the previous
// certificate expires (every ~56 minutes).
type Instance struct {
// RefreshAheadCache manages the information used to connect to the Cloud SQL
// instance by periodically calling the Cloud SQL Admin API. It automatically
// refreshes the required information approximately 4 minutes before the
// previous certificate expires (every ~56 minutes).
type RefreshAheadCache struct {
// openConns is the number of open connections to the instance.
openConns uint64

Expand Down Expand Up @@ -117,8 +118,8 @@ type Instance struct {
cancel context.CancelFunc
}

// NewInstance initializes a new Instance given an instance connection name
func NewInstance(
// NewRefreshAheadCache initializes a new Instance given an instance connection name
func NewRefreshAheadCache(
cn instance.ConnName,
l debug.Logger,
client *sqladmin.Service,
Expand All @@ -127,9 +128,9 @@ func NewInstance(
ts oauth2.TokenSource,
dialerID string,
useIAMAuthNDial bool,
) *Instance {
) *RefreshAheadCache {
ctx, cancel := context.WithCancel(context.Background())
i := &Instance{
i := &RefreshAheadCache{
connName: cn,
logger: l,
key: key,
Expand All @@ -156,13 +157,13 @@ func NewInstance(

// OpenConns returns a pointer to the number of open connections to
// faciliate changing the value using atomics.
func (i *Instance) OpenConns() *uint64 {
func (i *RefreshAheadCache) OpenConns() *uint64 {
return &i.openConns
}

// Close closes the instance; it stops the refresh cycle and prevents it from
// making additional calls to the Cloud SQL Admin API.
func (i *Instance) Close() error {
func (i *RefreshAheadCache) Close() error {
i.mu.Lock()
defer i.mu.Unlock()
i.cancel()
Expand All @@ -171,54 +172,70 @@ func (i *Instance) Close() error {
return nil
}

// ConnectInfo returns an IP address specified by ipType (i.e., public or
// private) and a TLS config that can be used to connect to a Cloud SQL
// instance.
func (i *Instance) ConnectInfo(ctx context.Context, ipType string) (string, *tls.Config, error) {
op, err := i.refreshOperation(ctx)
if err != nil {
return "", nil, err
}
// ConnectionInfo contains all necessary information to connect securely to the
// server-side Proxy running on a Cloud SQL instance.
type ConnectionInfo struct {
addrs map[string]string
ServerCaCert *x509.Certificate
DBVersion string
Conf *tls.Config
Expiry time.Time
}

// Addr returns the IP address or DNS name for the given IP type.
func (c ConnectionInfo) Addr(ipType string) (string, error) {
var (
addr string
ok bool
)
switch ipType {
case AutoIP:
// Try Public first
addr, ok = op.result.ipAddrs[PublicIP]
addr, ok = c.addrs[PublicIP]
if !ok {
// Try Private second
addr, ok = op.result.ipAddrs[PrivateIP]
addr, ok = c.addrs[PrivateIP]
}
default:
addr, ok = op.result.ipAddrs[ipType]
addr, ok = c.addrs[ipType]
}
if !ok {
err := errtype.NewConfigError(
fmt.Sprintf("instance does not have IP of type %q", ipType),
i.connName.String(),
// i.connName.String(),
"TODO",
)
return "", nil, err
return "", err
}
return addr, nil
}

// ConnectionInfo returns an IP address specified by ipType (i.e., public or
// private) and a TLS config that can be used to connect to a Cloud SQL
// instance.
func (i *RefreshAheadCache) ConnectionInfo(ctx context.Context) (ConnectionInfo, error) {
op, err := i.refreshOperation(ctx)
if err != nil {
return ConnectionInfo{}, err
}
return addr, op.result.conf, nil
return op.result, nil
}

// InstanceEngineVersion returns the engine type and version for the instance.
// The value corresponds to one of the following types for the instance:
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
func (i *Instance) InstanceEngineVersion(ctx context.Context) (string, error) {
func (i *RefreshAheadCache) InstanceEngineVersion(ctx context.Context) (string, error) {
op, err := i.refreshOperation(ctx)
if err != nil {
return "", err
}
return op.result.version, nil
return op.result.DBVersion, nil
}

// UpdateRefresh cancels all existing refresh attempts and schedules new
// attempts with the provided config only if it differs from the current
// configuration.
func (i *Instance) UpdateRefresh(useIAMAuthNDial *bool) {
func (i *RefreshAheadCache) UpdateRefresh(useIAMAuthNDial *bool) {
i.mu.Lock()
defer i.mu.Unlock()
if useIAMAuthNDial != nil && *useIAMAuthNDial != i.useIAMAuthNDial {
Expand All @@ -236,7 +253,7 @@ func (i *Instance) UpdateRefresh(useIAMAuthNDial *bool) {
// ForceRefresh triggers an immediate refresh operation to be scheduled and
// used for future connection attempts. Until the refresh completes, the
// existing connection info will be available for use if valid.
func (i *Instance) ForceRefresh() {
func (i *RefreshAheadCache) ForceRefresh() {
i.mu.Lock()
defer i.mu.Unlock()
// If the next refresh hasn't started yet, we can cancel it and start an
Expand All @@ -253,7 +270,7 @@ func (i *Instance) ForceRefresh() {

// refreshOperation returns the most recent refresh operation
// waiting for it to complete if necessary
func (i *Instance) refreshOperation(ctx context.Context) (*refreshOperation, error) {
func (i *RefreshAheadCache) refreshOperation(ctx context.Context) (*refreshOperation, error) {
i.mu.RLock()
cur := i.cur
i.mu.RUnlock()
Expand Down Expand Up @@ -292,7 +309,7 @@ func refreshDuration(now, certExpiry time.Time) time.Duration {
// scheduleRefresh schedules a refresh operation to be triggered after a given
// duration. The returned refreshOperation can be used to either Cancel or Wait
// for the operation's completion.
func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
r := &refreshOperation{}
r.ready = make(chan struct{})
r.timer = time.AfterFunc(d, func() {
Expand Down Expand Up @@ -324,7 +341,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
nil,
)
} else {
r.result, r.err = i.r.performRefresh(
r.result, r.err = i.r.ConnectionInfo(
ctx, i.connName, i.key, i.useIAMAuthNDial)
}
switch r.err {
Expand All @@ -336,7 +353,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
i.logger.Debugf(
"[%v] Current certificate expiration = %v",
i.connName.String(),
r.result.expiry.UTC().Format(time.RFC3339),
r.result.Expiry.UTC().Format(time.RFC3339),
)
default:
i.logger.Debugf(
Expand Down Expand Up @@ -375,7 +392,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
// Update the current results, and schedule the next refresh in
// the future
i.cur = r
t := refreshDuration(time.Now(), i.cur.result.expiry)
t := refreshDuration(time.Now(), i.cur.result.Expiry)
i.logger.Debugf(
"[%v] Connection info refresh operation scheduled at %v (now + %v)",
i.connName.String(),
Expand Down
Loading

0 comments on commit 942f2e3

Please sign in to comment.