diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 6069c1ae1cb43..ad02b10d35933 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -2524,9 +2524,10 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t testing.TB, p a // Create test database auth tokens generator. testAuth, err := newTestAuth(common.AuthConfig{ - AuthClient: c.authClient, - Clients: &clients.TestCloudClients{}, - Clock: c.clock, + AuthClient: c.authClient, + AccessPoint: c.authClient, + Clients: &clients.TestCloudClients{}, + Clock: c.clock, }) require.NoError(t, err) diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index 5cee3d170d991..5367abdf6a5b9 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -116,7 +116,11 @@ type AuthClient interface { // GenerateDatabaseCert generates client certificate used by a database // service to authenticate with the database instance. GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) - // GetAuthPreference returns the cluster authentication config. +} + +// AccessPoint is an interface that defines a subset of +// authclient.DatabaseAccessPoint that are required for database auth. +type AccessPoint interface { GetAuthPreference(ctx context.Context) (types.AuthPreference, error) } @@ -124,6 +128,8 @@ type AuthClient interface { type AuthConfig struct { // AuthClient is the cluster auth client. AuthClient AuthClient + // AccessPoint is a caching client connected to the Auth Server. + AccessPoint AccessPoint // Clients provides interface for obtaining cloud provider clients. Clients cloud.Clients // Clock is the clock implementation. @@ -137,6 +143,9 @@ func (c *AuthConfig) CheckAndSetDefaults() error { if c.AuthClient == nil { return trace.BadParameter("missing AuthClient") } + if c.AccessPoint == nil { + return trace.BadParameter("missing AccessPoint") + } if c.Clients == nil { return trace.BadParameter("missing Clients") } @@ -151,10 +160,11 @@ func (c *AuthConfig) CheckAndSetDefaults() error { func (c *AuthConfig) withLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) AuthConfig { return AuthConfig{ - AuthClient: c.AuthClient, - Clients: c.Clients, - Clock: c.Clock, - Log: getUpdatedLogger(c.Log), + AuthClient: c.AuthClient, + AccessPoint: c.AccessPoint, + Clients: c.Clients, + Clock: c.Clock, + Log: getUpdatedLogger(c.Log), } } @@ -1030,7 +1040,7 @@ func (a *dbAuth) GenerateDatabaseClientKey(ctx context.Context) (*keys.PrivateKe // GetAuthPreference returns the cluster authentication config. func (a *dbAuth) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { - return a.cfg.AuthClient.GetAuthPreference(ctx) + return a.cfg.AccessPoint.GetAuthPreference(ctx) } // GetAzureIdentityResourceID returns the Azure identity resource ID attached to diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index 63d6247e4d646..bda8d1d425c28 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -49,7 +49,8 @@ func TestAuthGetAzureCacheForRedisToken(t *testing.T) { t.Parallel() auth, err := NewAuth(AuthConfig{ - AuthClient: new(authClientMock), + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), Clients: &cloud.TestCloudClients{ AzureRedis: libcloudazure.NewRedisClientByAPI(&libcloudazure.ARMRedisMock{ Token: "azure-redis-token", @@ -108,8 +109,9 @@ func TestAuthGetRedshiftServerlessAuthToken(t *testing.T) { stsMock := &mocks.STSMock{} clock := clockwork.NewFakeClock() auth, err := NewAuth(AuthConfig{ - Clock: clock, - AuthClient: new(authClientMock), + Clock: clock, + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), Clients: &cloud.TestCloudClients{ STS: stsMock, RedshiftServerless: &mocks.RedshiftServerlessMock{ @@ -135,8 +137,9 @@ func TestAuthGetTLSConfig(t *testing.T) { t.Parallel() auth, err := NewAuth(AuthConfig{ - AuthClient: new(authClientMock), - Clients: &cloud.TestCloudClients{}, + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), + Clients: &cloud.TestCloudClients{}, }) require.NoError(t, err) @@ -343,8 +346,9 @@ func TestGetAzureIdentityResourceID(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { auth, err := NewAuth(AuthConfig{ - AuthClient: new(authClientMock), - Clients: tc.clients, + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), + Clients: tc.clients, }) require.NoError(t, err) @@ -365,8 +369,9 @@ func TestGetAzureIdentityResourceIDCache(t *testing.T) { clock := clockwork.NewFakeClock() auth, err := NewAuth(AuthConfig{ - Clock: clock, - AuthClient: new(authClientMock), + Clock: clock, + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), Clients: &cloud.TestCloudClients{ InstanceMetadata: &imdsMock{ id: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.compute/virtualmachines/vm", @@ -584,8 +589,9 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { stsMock := &mocks.STSMock{} clock := clockwork.NewFakeClockAt(time.Date(2001, time.February, 3, 0, 0, 0, 0, time.UTC)) auth, err := NewAuth(AuthConfig{ - Clock: clock, - AuthClient: new(authClientMock), + Clock: clock, + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), Clients: &cloud.TestCloudClients{ STS: stsMock, RDS: &mocks.RDSMock{}, @@ -673,8 +679,9 @@ func TestGetAWSIAMCreds(t *testing.T) { } { t.Run(name, func(t *testing.T) { auth, err := NewAuth(AuthConfig{ - Clock: clock, - AuthClient: new(authClientMock), + Clock: clock, + AuthClient: new(authClientMock), + AccessPoint: new(accessPointMock), Clients: &cloud.TestCloudClients{ STS: tt.stsMock, }, @@ -970,8 +977,11 @@ func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.Da }, nil } +type accessPointMock struct { +} + // GetAuthPreference always returns types.DefaultAuthPreference(). -func (m *authClientMock) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { +func (m accessPointMock) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { return types.DefaultAuthPreference(), nil } diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index e52642bb1dc7b..90b93a688ad52 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -194,9 +194,10 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { } if c.Auth == nil { c.Auth, err = common.NewAuth(common.AuthConfig{ - AuthClient: c.AuthClient, - Clock: c.Clock, - Clients: c.CloudClients, + AuthClient: c.AuthClient, + AccessPoint: c.AccessPoint, + Clock: c.Clock, + Clients: c.CloudClients, }) if err != nil { return trace.Wrap(err)