Skip to content

Commit

Permalink
DB to use cache for auth preference (#46672)
Browse files Browse the repository at this point in the history
  • Loading branch information
greedy52 authored Sep 18, 2024
1 parent 9fba457 commit 0d80b7d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 26 deletions.
7 changes: 4 additions & 3 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2524,9 +2524,10 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t testing.TB, p a

// Create test database auth tokens generator.
testAuth, err := newTestAuth(common.AuthConfig{
AuthClient: c.authClient,
Clients: &clients.TestCloudClients{},
Clock: c.clock,
AuthClient: c.authClient,
AccessPoint: c.authClient,
Clients: &clients.TestCloudClients{},
Clock: c.clock,
})
require.NoError(t, err)

Expand Down
22 changes: 16 additions & 6 deletions lib/srv/db/common/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,20 @@ type AuthClient interface {
// GenerateDatabaseCert generates client certificate used by a database
// service to authenticate with the database instance.
GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error)
// GetAuthPreference returns the cluster authentication config.
}

// AccessPoint is an interface that defines a subset of
// authclient.DatabaseAccessPoint that are required for database auth.
type AccessPoint interface {
GetAuthPreference(ctx context.Context) (types.AuthPreference, error)
}

// AuthConfig is the database access authenticator configuration.
type AuthConfig struct {
// AuthClient is the cluster auth client.
AuthClient AuthClient
// AccessPoint is a caching client connected to the Auth Server.
AccessPoint AccessPoint
// Clients provides interface for obtaining cloud provider clients.
Clients cloud.Clients
// Clock is the clock implementation.
Expand All @@ -137,6 +143,9 @@ func (c *AuthConfig) CheckAndSetDefaults() error {
if c.AuthClient == nil {
return trace.BadParameter("missing AuthClient")
}
if c.AccessPoint == nil {
return trace.BadParameter("missing AccessPoint")
}
if c.Clients == nil {
return trace.BadParameter("missing Clients")
}
Expand All @@ -151,10 +160,11 @@ func (c *AuthConfig) CheckAndSetDefaults() error {

func (c *AuthConfig) withLogger(getUpdatedLogger func(logrus.FieldLogger) logrus.FieldLogger) AuthConfig {
return AuthConfig{
AuthClient: c.AuthClient,
Clients: c.Clients,
Clock: c.Clock,
Log: getUpdatedLogger(c.Log),
AuthClient: c.AuthClient,
AccessPoint: c.AccessPoint,
Clients: c.Clients,
Clock: c.Clock,
Log: getUpdatedLogger(c.Log),
}
}

Expand Down Expand Up @@ -1030,7 +1040,7 @@ func (a *dbAuth) GenerateDatabaseClientKey(ctx context.Context) (*keys.PrivateKe

// GetAuthPreference returns the cluster authentication config.
func (a *dbAuth) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
return a.cfg.AuthClient.GetAuthPreference(ctx)
return a.cfg.AccessPoint.GetAuthPreference(ctx)
}

// GetAzureIdentityResourceID returns the Azure identity resource ID attached to
Expand Down
38 changes: 24 additions & 14 deletions lib/srv/db/common/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func TestAuthGetAzureCacheForRedisToken(t *testing.T) {
t.Parallel()

auth, err := NewAuth(AuthConfig{
AuthClient: new(authClientMock),
AuthClient: new(authClientMock),
AccessPoint: new(accessPointMock),
Clients: &cloud.TestCloudClients{
AzureRedis: libcloudazure.NewRedisClientByAPI(&libcloudazure.ARMRedisMock{
Token: "azure-redis-token",
Expand Down Expand Up @@ -108,8 +109,9 @@ func TestAuthGetRedshiftServerlessAuthToken(t *testing.T) {
stsMock := &mocks.STSMock{}
clock := clockwork.NewFakeClock()
auth, err := NewAuth(AuthConfig{
Clock: clock,
AuthClient: new(authClientMock),
Clock: clock,
AuthClient: new(authClientMock),
AccessPoint: new(accessPointMock),
Clients: &cloud.TestCloudClients{
STS: stsMock,
RedshiftServerless: &mocks.RedshiftServerlessMock{
Expand All @@ -135,8 +137,9 @@ func TestAuthGetTLSConfig(t *testing.T) {
t.Parallel()

auth, err := NewAuth(AuthConfig{
AuthClient: new(authClientMock),
Clients: &cloud.TestCloudClients{},
AuthClient: new(authClientMock),
AccessPoint: new(accessPointMock),
Clients: &cloud.TestCloudClients{},
})
require.NoError(t, err)

Expand Down Expand Up @@ -343,8 +346,9 @@ func TestGetAzureIdentityResourceID(t *testing.T) {
} {
t.Run(tc.desc, func(t *testing.T) {
auth, err := NewAuth(AuthConfig{
AuthClient: new(authClientMock),
Clients: tc.clients,
AuthClient: new(authClientMock),
AccessPoint: new(accessPointMock),
Clients: tc.clients,
})
require.NoError(t, err)

Expand All @@ -365,8 +369,9 @@ func TestGetAzureIdentityResourceIDCache(t *testing.T) {
clock := clockwork.NewFakeClock()

auth, err := NewAuth(AuthConfig{
Clock: clock,
AuthClient: new(authClientMock),
Clock: clock,
AuthClient: new(authClientMock),
AccessPoint: new(accessPointMock),
Clients: &cloud.TestCloudClients{
InstanceMetadata: &imdsMock{
id: "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/rg/providers/microsoft.compute/virtualmachines/vm",
Expand Down Expand Up @@ -584,8 +589,9 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) {
stsMock := &mocks.STSMock{}
clock := clockwork.NewFakeClockAt(time.Date(2001, time.February, 3, 0, 0, 0, 0, time.UTC))
auth, err := NewAuth(AuthConfig{
Clock: clock,
AuthClient: new(authClientMock),
Clock: clock,
AuthClient: new(authClientMock),
AccessPoint: new(accessPointMock),
Clients: &cloud.TestCloudClients{
STS: stsMock,
RDS: &mocks.RDSMock{},
Expand Down Expand Up @@ -673,8 +679,9 @@ func TestGetAWSIAMCreds(t *testing.T) {
} {
t.Run(name, func(t *testing.T) {
auth, err := NewAuth(AuthConfig{
Clock: clock,
AuthClient: new(authClientMock),
Clock: clock,
AuthClient: new(authClientMock),
AccessPoint: new(accessPointMock),
Clients: &cloud.TestCloudClients{
STS: tt.stsMock,
},
Expand Down Expand Up @@ -970,8 +977,11 @@ func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.Da
}, nil
}

type accessPointMock struct {
}

// GetAuthPreference always returns types.DefaultAuthPreference().
func (m *authClientMock) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
func (m accessPointMock) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
return types.DefaultAuthPreference(), nil
}

Expand Down
7 changes: 4 additions & 3 deletions lib/srv/db/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,10 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) {
}
if c.Auth == nil {
c.Auth, err = common.NewAuth(common.AuthConfig{
AuthClient: c.AuthClient,
Clock: c.Clock,
Clients: c.CloudClients,
AuthClient: c.AuthClient,
AccessPoint: c.AccessPoint,
Clock: c.Clock,
Clients: c.CloudClients,
})
if err != nil {
return trace.Wrap(err)
Expand Down

0 comments on commit 0d80b7d

Please sign in to comment.