diff --git a/dialer.go b/dialer.go index 00091521..0d3c08b7 100644 --- a/dialer.go +++ b/dialer.go @@ -74,7 +74,7 @@ func getDefaultKeys() (*rsa.PrivateKey, error) { type connectionInfoCache interface { OpenConns() *uint64 - ConnectInfo(context.Context, string) (string, *tls.Config, error) + ConnectionInfo(context.Context) (cloudsql.ConnectionInfo, error) InstanceEngineVersion(context.Context) (string, error) UpdateRefresh(*bool) ForceRefresh() @@ -250,7 +250,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn var endInfo trace.EndSpanFunc ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.InstanceInfo") i := d.instance(cn, &cfg.useIAMAuthN) - addr, tlsConfig, err := i.ConnectInfo(ctx, cfg.ipType) + ci, err := i.ConnectionInfo(ctx) if err != nil { d.lock.Lock() defer d.lock.Unlock() @@ -268,11 +268,11 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn // The TLS handshake will not fail on an expired client certificate. It's // not until the first read where the client cert error will be surfaced. // So check that the certificate is valid before proceeding. - if !validClientCert(cn, d.logger, tlsConfig) { + if !validClientCert(cn, d.logger, ci.Conf) { d.logger.Debugf("[%v] Refreshing certificate now", cn.String()) i.ForceRefresh() // Block on refreshed connection info - addr, tlsConfig, err = i.ConnectInfo(ctx, cfg.ipType) + ci, err = i.ConnectionInfo(ctx) if err != nil { d.lock.Lock() defer d.lock.Unlock() @@ -287,6 +287,10 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn var connectEnd trace.EndSpanFunc ctx, connectEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.Connect") defer func() { connectEnd(err) }() + addr, err := ci.Addr(cfg.ipType) + if err != nil { + return nil, err + } addr = net.JoinHostPort(addr, serverProxyPort) f := d.dialFunc if cfg.dialFunc != nil { @@ -309,7 +313,7 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn } } - tlsConn := tls.Client(conn, tlsConfig) + tlsConn := tls.Client(conn, ci.Conf) err = tlsConn.HandshakeContext(ctx) if err != nil { d.logger.Debugf("[%v] TLS handshake failed: %v", cn.String(), err) @@ -452,7 +456,7 @@ func (d *Dialer) instance(cn instance.ConnName, useIAMAuthN *bool) connectionInf useIAMAuthNDial = *useIAMAuthN } d.logger.Debugf("[%v] Connection info added to cache", cn.String()) - i = cloudsql.NewInstance( + i = cloudsql.NewRefreshAheadCache( cn, d.logger, d.sqladmin, d.key, diff --git a/dialer_test.go b/dialer_test.go index e1ee43a1..86e4da92 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -30,6 +30,7 @@ import ( "cloud.google.com/go/cloudsqlconn/errtype" "cloud.google.com/go/cloudsqlconn/instance" + "cloud.google.com/go/cloudsqlconn/internal/cloudsql" "cloud.google.com/go/cloudsqlconn/internal/mock" "golang.org/x/oauth2" ) @@ -109,8 +110,7 @@ func TestDialWithAdminAPIErrors(t *testing.T) { } func TestDialWithConfigurationErrors(t *testing.T) { - inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance", - mock.WithCertExpiry(time.Now().Add(-time.Hour))) + inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance") svc, cleanup, err := mock.NewSQLAdminService( context.Background(), @@ -732,12 +732,14 @@ type spyConnectionInfoCache struct { connectionInfoCache } -func (s *spyConnectionInfoCache) ConnectInfo(_ context.Context, _ string) (string, *tls.Config, error) { +func (s *spyConnectionInfoCache) ConnectionInfo( + context.Context, +) (cloudsql.ConnectionInfo, error) { s.mu.Lock() defer s.mu.Unlock() res := s.connectInfoCalls[s.connectInfoIndex] s.connectInfoIndex++ - return "unused", res.tls, res.err + return cloudsql.ConnectionInfo{Conf: res.tls}, res.err } func (s *spyConnectionInfoCache) ForceRefresh() { diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index b3f7cb2b..9e66432d 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -18,6 +18,7 @@ import ( "context" "crypto/rsa" "crypto/tls" + "crypto/x509" "fmt" "sync" "time" @@ -56,13 +57,13 @@ type refreshOperation struct { ready chan struct{} // timer that triggers refresh, can be used to cancel. timer *time.Timer - result refreshResult + result ConnectionInfo err error } -// cancel prevents the instanceInfo from starting, if it hasn't already -// started. Returns true if timer was stopped successfully, or false if it has -// already started. +// cancel prevents the the refresh operation from starting, if it hasn't +// already started. Returns true if timer was stopped successfully, or false if +// it has already started. func (r *refreshOperation) cancel() bool { return r.timer.Stop() } @@ -75,18 +76,18 @@ func (r *refreshOperation) isValid() bool { default: return false case <-r.ready: - if r.err != nil || time.Now().After(r.result.expiry.Round(0)) { + if r.err != nil || time.Now().After(r.result.Expiry.Round(0)) { return false } return true } } -// Instance manages the information used to connect to the Cloud SQL instance -// by periodically calling the Cloud SQL Admin API. It automatically refreshes -// the required information approximately 4 minutes before the previous -// certificate expires (every ~56 minutes). -type Instance struct { +// RefreshAheadCache manages the information used to connect to the Cloud SQL +// instance by periodically calling the Cloud SQL Admin API. It automatically +// refreshes the required information approximately 4 minutes before the +// previous certificate expires (every ~56 minutes). +type RefreshAheadCache struct { // openConns is the number of open connections to the instance. openConns uint64 @@ -117,8 +118,8 @@ type Instance struct { cancel context.CancelFunc } -// NewInstance initializes a new Instance given an instance connection name -func NewInstance( +// NewRefreshAheadCache initializes a new Instance given an instance connection name +func NewRefreshAheadCache( cn instance.ConnName, l debug.Logger, client *sqladmin.Service, @@ -127,9 +128,9 @@ func NewInstance( ts oauth2.TokenSource, dialerID string, useIAMAuthNDial bool, -) *Instance { +) *RefreshAheadCache { ctx, cancel := context.WithCancel(context.Background()) - i := &Instance{ + i := &RefreshAheadCache{ connName: cn, logger: l, key: key, @@ -156,13 +157,13 @@ func NewInstance( // OpenConns returns a pointer to the number of open connections to // faciliate changing the value using atomics. -func (i *Instance) OpenConns() *uint64 { +func (i *RefreshAheadCache) OpenConns() *uint64 { return &i.openConns } // Close closes the instance; it stops the refresh cycle and prevents it from // making additional calls to the Cloud SQL Admin API. -func (i *Instance) Close() error { +func (i *RefreshAheadCache) Close() error { i.mu.Lock() defer i.mu.Unlock() i.cancel() @@ -171,14 +172,18 @@ func (i *Instance) Close() error { return nil } -// ConnectInfo returns an IP address specified by ipType (i.e., public or -// private) and a TLS config that can be used to connect to a Cloud SQL -// instance. -func (i *Instance) ConnectInfo(ctx context.Context, ipType string) (string, *tls.Config, error) { - op, err := i.refreshOperation(ctx) - if err != nil { - return "", nil, err - } +// ConnectionInfo contains all necessary information to connect securely to the +// server-side Proxy running on a Cloud SQL instance. +type ConnectionInfo struct { + addrs map[string]string + ServerCaCert *x509.Certificate + DBVersion string + Conf *tls.Config + Expiry time.Time +} + +// Addr returns the IP address or DNS name for the given IP type. +func (c ConnectionInfo) Addr(ipType string) (string, error) { var ( addr string ok bool @@ -186,39 +191,51 @@ func (i *Instance) ConnectInfo(ctx context.Context, ipType string) (string, *tls switch ipType { case AutoIP: // Try Public first - addr, ok = op.result.ipAddrs[PublicIP] + addr, ok = c.addrs[PublicIP] if !ok { // Try Private second - addr, ok = op.result.ipAddrs[PrivateIP] + addr, ok = c.addrs[PrivateIP] } default: - addr, ok = op.result.ipAddrs[ipType] + addr, ok = c.addrs[ipType] } if !ok { err := errtype.NewConfigError( fmt.Sprintf("instance does not have IP of type %q", ipType), - i.connName.String(), + // i.connName.String(), + "TODO", ) - return "", nil, err + return "", err + } + return addr, nil +} + +// ConnectionInfo returns an IP address specified by ipType (i.e., public or +// private) and a TLS config that can be used to connect to a Cloud SQL +// instance. +func (i *RefreshAheadCache) ConnectionInfo(ctx context.Context) (ConnectionInfo, error) { + op, err := i.refreshOperation(ctx) + if err != nil { + return ConnectionInfo{}, err } - return addr, op.result.conf, nil + return op.result, nil } // InstanceEngineVersion returns the engine type and version for the instance. // The value corresponds to one of the following types for the instance: // https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion -func (i *Instance) InstanceEngineVersion(ctx context.Context) (string, error) { +func (i *RefreshAheadCache) InstanceEngineVersion(ctx context.Context) (string, error) { op, err := i.refreshOperation(ctx) if err != nil { return "", err } - return op.result.version, nil + return op.result.DBVersion, nil } // UpdateRefresh cancels all existing refresh attempts and schedules new // attempts with the provided config only if it differs from the current // configuration. -func (i *Instance) UpdateRefresh(useIAMAuthNDial *bool) { +func (i *RefreshAheadCache) UpdateRefresh(useIAMAuthNDial *bool) { i.mu.Lock() defer i.mu.Unlock() if useIAMAuthNDial != nil && *useIAMAuthNDial != i.useIAMAuthNDial { @@ -236,7 +253,7 @@ func (i *Instance) UpdateRefresh(useIAMAuthNDial *bool) { // ForceRefresh triggers an immediate refresh operation to be scheduled and // used for future connection attempts. Until the refresh completes, the // existing connection info will be available for use if valid. -func (i *Instance) ForceRefresh() { +func (i *RefreshAheadCache) ForceRefresh() { i.mu.Lock() defer i.mu.Unlock() // If the next refresh hasn't started yet, we can cancel it and start an @@ -253,7 +270,7 @@ func (i *Instance) ForceRefresh() { // refreshOperation returns the most recent refresh operation // waiting for it to complete if necessary -func (i *Instance) refreshOperation(ctx context.Context) (*refreshOperation, error) { +func (i *RefreshAheadCache) refreshOperation(ctx context.Context) (*refreshOperation, error) { i.mu.RLock() cur := i.cur i.mu.RUnlock() @@ -292,7 +309,7 @@ func refreshDuration(now, certExpiry time.Time) time.Duration { // scheduleRefresh schedules a refresh operation to be triggered after a given // duration. The returned refreshOperation can be used to either Cancel or Wait // for the operation's completion. -func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { +func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation { r := &refreshOperation{} r.ready = make(chan struct{}) r.timer = time.AfterFunc(d, func() { @@ -324,7 +341,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { nil, ) } else { - r.result, r.err = i.r.performRefresh( + r.result, r.err = i.r.ConnectionInfo( ctx, i.connName, i.key, i.useIAMAuthNDial) } switch r.err { @@ -336,7 +353,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { i.logger.Debugf( "[%v] Current certificate expiration = %v", i.connName.String(), - r.result.expiry.UTC().Format(time.RFC3339), + r.result.Expiry.UTC().Format(time.RFC3339), ) default: i.logger.Debugf( @@ -375,7 +392,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation { // Update the current results, and schedule the next refresh in // the future i.cur = r - t := refreshDuration(time.Now(), i.cur.result.expiry) + t := refreshDuration(time.Now(), i.cur.result.Expiry) i.logger.Debugf( "[%v] Connection info refresh operation scheduled at %v (now + %v)", i.connName.String(), diff --git a/internal/cloudsql/instance_test.go b/internal/cloudsql/instance_test.go index d18f048a..e3cdc12f 100644 --- a/internal/cloudsql/instance_test.go +++ b/internal/cloudsql/instance_test.go @@ -22,7 +22,6 @@ import ( "testing" "time" - "cloud.google.com/go/cloudsqlconn/errtype" "cloud.google.com/go/cloudsqlconn/instance" "cloud.google.com/go/cloudsqlconn/internal/mock" ) @@ -69,7 +68,7 @@ func TestInstanceEngineVersion(t *testing.T) { t.Fatalf("%v", err) } }() - i := NewInstance( + i := NewRefreshAheadCache( testInstanceConnName(), nullLogger{}, client, RSAKey, 30*time.Second, nil, "", false, ) @@ -106,28 +105,32 @@ func TestConnectInfo(t *testing.T) { } }() - i := NewInstance( + i := NewRefreshAheadCache( testInstanceConnName(), nullLogger{}, client, RSAKey, 30*time.Second, nil, "", false, ) - gotAddr, gotTLSCfg, err := i.ConnectInfo(ctx, PublicIP) + ci, err := i.ConnectionInfo(ctx) if err != nil { t.Fatalf("failed to retrieve connect info: %v", err) } - if gotAddr != wantAddr { + got, err := ci.Addr(PublicIP) + if err != nil { + t.Fatal(err) + } + if got != wantAddr { t.Fatalf( "ConnectInfo returned unexpected IP address, want = %v, got = %v", - wantAddr, gotAddr, + wantAddr, got, ) } wantServerName := "my-project:my-region:my-instance" - if gotTLSCfg.ServerName != wantServerName { + if ci.Conf.ServerName != wantServerName { t.Fatalf( "ConnectInfo return unexpected server name in TLS Config, want = %v, got = %v", - wantServerName, gotTLSCfg.ServerName, + wantServerName, ci.Conf.ServerName, ) } } @@ -174,7 +177,7 @@ func TestConnectInfoAutoIP(t *testing.T) { } }() - i := NewInstance( + i := NewRefreshAheadCache( testInstanceConnName(), nullLogger{}, client, RSAKey, 30*time.Second, nil, "", false, ) @@ -182,11 +185,15 @@ func TestConnectInfoAutoIP(t *testing.T) { t.Fatalf("failed to create mock instance: %v", err) } - got, _, err := i.ConnectInfo(context.Background(), AutoIP) + ci, err := i.ConnectionInfo(context.Background()) if err != nil { t.Fatalf("failed to retrieve connect info: %v", err) } + got, err := ci.Addr(AutoIP) + if err != nil { + t.Fatal(err) + } if got != tc.wantIP { t.Fatalf( "ConnectInfo returned unexpected IP address, want = %v, got = %v", @@ -196,32 +203,9 @@ func TestConnectInfoAutoIP(t *testing.T) { } } -func TestConnectInfoErrors(t *testing.T) { - ctx := context.Background() - - client, cleanup, err := mock.NewSQLAdminService(ctx) - if err != nil { - t.Fatalf("%s", err) - } - defer cleanup() - - // Use a timeout that should fail instantly - i := NewInstance( - testInstanceConnName(), nullLogger{}, client, - RSAKey, 0, nil, "", false, - ) - - _, _, err = i.ConnectInfo(ctx, PublicIP) - var wantErr *errtype.DialError - if !errors.As(err, &wantErr) { - t.Fatalf("when connect info fails, want = %T, got = %v", wantErr, err) - } - - // when client asks for wrong IP address type - gotAddr, _, err := i.ConnectInfo(ctx, PrivateIP) - if err == nil { - t.Fatalf("expected ConnectInfo to fail but returned IP address = %v", gotAddr) - } +func TestConnectionInfoErrors(t *testing.T) { + // ConnectionInfo{} + // TODO unit tests for happy path and errors } func TestClose(t *testing.T) { @@ -234,13 +218,13 @@ func TestClose(t *testing.T) { defer cleanup() // Set up an instance and then close it immediately - i := NewInstance( + i := NewRefreshAheadCache( testInstanceConnName(), nullLogger{}, client, RSAKey, 30*time.Second, nil, "", false, ) i.Close() - _, _, err = i.ConnectInfo(ctx, PublicIP) + _, err = i.ConnectionInfo(ctx) if !errors.Is(err, context.Canceled) { t.Fatalf("failed to retrieve connect info: %v", err) } @@ -300,7 +284,7 @@ func TestContextCancelled(t *testing.T) { defer cleanup() // Set up an instance and then close it immediately - i := NewInstance( + i := NewRefreshAheadCache( testInstanceConnName(), nullLogger{}, client, RSAKey, 30*time.Second, nil, "", false, ) diff --git a/internal/cloudsql/refresh.go b/internal/cloudsql/refresh.go index 3e1fefae..8fc56c17 100644 --- a/internal/cloudsql/refresh.go +++ b/internal/cloudsql/refresh.go @@ -297,15 +297,6 @@ func newRefresher( } } -// refreshResult contains all the resulting data from the refresh operation. -type refreshResult struct { - ipAddrs map[string]string - serverCaCert *x509.Certificate - version string - conf *tls.Config - expiry time.Time -} - // refresher manages the SQL Admin API access to instance metadata and to // ephemeral certificates. type refresher struct { @@ -317,11 +308,11 @@ type refresher struct { ts oauth2.TokenSource } -// performRefresh immediately performs a full refresh operation using the Cloud +// ConnectionInfo immediately performs a full refresh operation using the Cloud // SQL Admin API. -func (r refresher) performRefresh( +func (r refresher) ConnectionInfo( ctx context.Context, cn instance.ConnName, k *rsa.PrivateKey, iamAuthNDial bool, -) (rr refreshResult, err error) { +) (ci ConnectionInfo, err error) { var refreshEnd trace.EndSpanFunc ctx, refreshEnd = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.RefreshConnection", @@ -365,15 +356,15 @@ func (r refresher) performRefresh( select { case r := <-mdC: if r.err != nil { - return refreshResult{}, fmt.Errorf("failed to get instance: %w", r.err) + return ConnectionInfo{}, fmt.Errorf("failed to get instance: %w", r.err) } md = r.md case <-ctx.Done(): - return rr, fmt.Errorf("refresh failed: %w", ctx.Err()) + return ci, fmt.Errorf("refresh failed: %w", ctx.Err()) } if iamAuthNDial { if vErr := supportsAutoIAMAuthN(md.version); vErr != nil { - return refreshResult{}, vErr + return ConnectionInfo{}, vErr } } @@ -381,11 +372,11 @@ func (r refresher) performRefresh( select { case r := <-ecC: if r.err != nil { - return refreshResult{}, fmt.Errorf("fetch ephemeral cert failed: %w", r.err) + return ConnectionInfo{}, fmt.Errorf("fetch ephemeral cert failed: %w", r.err) } ec = r.ec case <-ctx.Done(): - return refreshResult{}, fmt.Errorf("refresh failed: %w", ctx.Err()) + return ConnectionInfo{}, fmt.Errorf("refresh failed: %w", ctx.Err()) } c := createTLSConfig(cn, md, ec) @@ -394,12 +385,12 @@ func (r refresher) performRefresh( if len(c.Certificates) > 0 { expiry = c.Certificates[0].Leaf.NotAfter } - return refreshResult{ - ipAddrs: md.ipAddrs, - serverCaCert: md.serverCaCert, - version: md.version, - conf: c, - expiry: expiry, + return ConnectionInfo{ + addrs: md.ipAddrs, + ServerCaCert: md.serverCaCert, + DBVersion: md.version, + Conf: c, + Expiry: expiry, }, nil } diff --git a/internal/cloudsql/refresh_test.go b/internal/cloudsql/refresh_test.go index 3e5266bb..96b260a4 100644 --- a/internal/cloudsql/refresh_test.go +++ b/internal/cloudsql/refresh_test.go @@ -61,37 +61,37 @@ func TestRefresh(t *testing.T) { }() r := newRefresher(nullLogger{}, client, nil, testDialerID) - rr, err := r.performRefresh(context.Background(), cn, RSAKey, false) + rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, false) if err != nil { t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err) } - gotIP, ok := rr.ipAddrs[PublicIP] + gotIP, ok := rr.addrs[PublicIP] if !ok { t.Fatal("metadata IP addresses did not include public address") } if wantPublicIP != gotIP { t.Fatalf("metadata IP mismatch, want = %v, got = %v", wantPublicIP, gotIP) } - gotIP, ok = rr.ipAddrs[PrivateIP] + gotIP, ok = rr.addrs[PrivateIP] if !ok { t.Fatal("metadata IP addresses did not include private address") } if wantPrivateIP != gotIP { t.Fatalf("metadata IP mismatch, want = %v, got = %v", wantPrivateIP, gotIP) } - gotPSC, ok := rr.ipAddrs[PSC] + gotPSC, ok := rr.addrs[PSC] if !ok { t.Fatal("metadata IP addresses did not include PSC endpoint") } if wantPSC != gotPSC { t.Fatalf("metadata IP mismatch, want = %v. got = %v", wantPSC, gotPSC) } - if wantExpiry != rr.expiry { - t.Fatalf("expiry mismatch, want = %v, got = %v", wantExpiry, rr.expiry) + if wantExpiry != rr.Expiry { + t.Fatalf("expiry mismatch, want = %v, got = %v", wantExpiry, rr.Expiry) } - if wantConnName != rr.conf.ServerName { - t.Fatalf("server name mismatch, want = %v, got = %v", wantConnName, rr.conf.ServerName) + if wantConnName != rr.Conf.ServerName { + t.Fatalf("server name mismatch, want = %v, got = %v", wantConnName, rr.Conf.ServerName) } } @@ -109,7 +109,7 @@ func TestRefreshFailsFast(t *testing.T) { defer cleanup() r := newRefresher(nullLogger{}, client, nil, testDialerID) - _, err = r.performRefresh(context.Background(), cn, RSAKey, false) + _, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false) if err != nil { t.Fatalf("expected no error, got = %v", err) } @@ -117,7 +117,7 @@ func TestRefreshFailsFast(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // context is canceled - _, err = r.performRefresh(ctx, cn, RSAKey, false) + _, err = r.ConnectionInfo(ctx, cn, RSAKey, false) if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled error, got = %v", err) } @@ -191,12 +191,12 @@ func TestRefreshAdjustsCertExpiry(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { ts := &fakeTokenSource{responses: tc.resps} r := newRefresher(nullLogger{}, client, ts, testDialerID) - rr, err := r.performRefresh(context.Background(), cn, RSAKey, true) + rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true) if err != nil { t.Fatalf("want no error, got = %v", err) } - if tc.wantExpiry != rr.expiry { - t.Fatalf("want = %v, got = %v", tc.wantExpiry, rr.expiry) + if tc.wantExpiry != rr.Expiry { + t.Fatalf("want = %v, got = %v", tc.wantExpiry, rr.Expiry) } }) } @@ -237,7 +237,7 @@ func TestRefreshWithIAMAuthErrors(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { ts := &fakeTokenSource{responses: tc.resps} r := newRefresher(nullLogger{}, client, ts, testDialerID) - _, err := r.performRefresh(context.Background(), cn, RSAKey, true) + _, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true) if err == nil { t.Fatalf("expected get failed error, got = %v", err) } @@ -297,7 +297,7 @@ func TestRefreshMetadataConfigError(t *testing.T) { defer cleanup() r := newRefresher(nullLogger{}, client, nil, testDialerID) - _, err = r.performRefresh(context.Background(), cn, RSAKey, false) + _, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false) if !errors.As(err, &tc.wantErr) { t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err) } @@ -362,7 +362,7 @@ func TestRefreshMetadataRefreshError(t *testing.T) { defer cleanup() r := newRefresher(nullLogger{}, client, nil, testDialerID) - _, err = r.performRefresh(context.Background(), cn, RSAKey, false) + _, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false) if !errors.As(err, &tc.wantErr) { t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err) } @@ -427,7 +427,7 @@ func TestRefreshWithFailedEphemeralCertCall(t *testing.T) { defer cleanup() r := newRefresher(nullLogger{}, client, nil, testDialerID) - _, err = r.performRefresh(context.Background(), cn, RSAKey, false) + _, err = r.ConnectionInfo(context.Background(), cn, RSAKey, false) if !errors.As(err, &tc.wantErr) { t.Errorf("[%v] PerformRefresh failed with unexpected error, want = %T, got = %v", i, tc.wantErr, err) @@ -454,39 +454,39 @@ func TestRefreshBuildsTLSConfig(t *testing.T) { defer cleanup() r := newRefresher(nullLogger{}, client, nil, testDialerID) - rr, err := r.performRefresh(context.Background(), cn, RSAKey, false) + rr, err := r.ConnectionInfo(context.Background(), cn, RSAKey, false) if err != nil { t.Fatalf("expected no error, got = %v", err) } - if wantServerName != rr.conf.ServerName { + if wantServerName != rr.Conf.ServerName { t.Fatalf( "TLS config has incorrect server name, want = %v, got = %v", - wantServerName, rr.conf.ServerName, + wantServerName, rr.Conf.ServerName, ) } wantCertLen := 1 - if wantCertLen != len(rr.conf.Certificates) { + if wantCertLen != len(rr.Conf.Certificates) { t.Fatalf( "TLS config has unexpected number of certificates, want = %v, got = %v", - wantCertLen, len(rr.conf.Certificates), + wantCertLen, len(rr.Conf.Certificates), ) } wantInsecure := true - if wantInsecure != rr.conf.InsecureSkipVerify { + if wantInsecure != rr.Conf.InsecureSkipVerify { t.Fatalf( "TLS config should skip verification, want = %v, got = %v", - wantInsecure, rr.conf.InsecureSkipVerify, + wantInsecure, rr.Conf.InsecureSkipVerify, ) } - if rr.conf.RootCAs == nil { + if rr.Conf.RootCAs == nil { t.Fatal("TLS config should include RootCA, got nil") } - verifyPeerCert := rr.conf.VerifyPeerCertificate + verifyPeerCert := rr.Conf.VerifyPeerCertificate b, _ := pem.Decode(certBytes) err = verifyPeerCert([][]byte{b.Bytes}, nil) if err != nil {