Skip to content

Commit

Permalink
Implement LRU cache for storing SVIDs in SPIRE Agent (#3181)
Browse files Browse the repository at this point in the history
Signed-off-by: Prasad Borole <prasadb@uber.com>
Co-authored-by: Ryan Turner <turner@uber.com>
  • Loading branch information
prasadborole1 and Ryan Turner authored Sep 14, 2022
1 parent 6c6f43d commit 6689c36
Show file tree
Hide file tree
Showing 39 changed files with 2,965 additions and 144 deletions.
4 changes: 3 additions & 1 deletion cmd/spire-agent/cli/api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"google.golang.org/grpc/metadata"
)

const commandTimeout = 5 * time.Second

type workloadClient struct {
workload.SpiffeWorkloadAPIClient
timeout time.Duration
Expand Down Expand Up @@ -71,7 +73,7 @@ func adaptCommand(env *cli.Env, clientsMaker workloadClientMaker, cmd command) *
clientsMaker: clientsMaker,
cmd: cmd,
env: env,
timeout: cli.DurationFlag(time.Second),
timeout: cli.DurationFlag(commandTimeout),
}

fs := flag.NewFlagSet(cmd.name(), flag.ContinueOnError)
Expand Down
8 changes: 7 additions & 1 deletion cmd/spire-agent/cli/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ type experimentalConfig struct {

Flags fflag.RawConfig `hcl:"feature_flags"`

UnusedKeys []string `hcl:",unusedKeys"`
UnusedKeys []string `hcl:",unusedKeys"`
X509SVIDCacheMaxSize int `hcl:"x509_svid_cache_max_size"`
}

type Command struct {
Expand Down Expand Up @@ -394,6 +395,11 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool)
}
}

if c.Agent.Experimental.X509SVIDCacheMaxSize < 0 {
return nil, errors.New("x509_svid_cache_max_size should not be negative")
}
ac.X509SVIDCacheMaxSize = c.Agent.Experimental.X509SVIDCacheMaxSize

serverHostPort := net.JoinHostPort(c.Agent.ServerAddress, strconv.Itoa(c.Agent.ServerPort))
ac.ServerAddress = fmt.Sprintf("dns:///%s", serverHostPort)

Expand Down
36 changes: 36 additions & 0 deletions cmd/spire-agent/cli/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,42 @@ func TestNewAgentConfig(t *testing.T) {
require.Nil(t, c)
},
},
{
msg: "x509_svid_cache_max_size is set",
input: func(c *Config) {
c.Agent.Experimental.X509SVIDCacheMaxSize = 100
},
test: func(t *testing.T, c *agent.Config) {
require.EqualValues(t, 100, c.X509SVIDCacheMaxSize)
},
},
{
msg: "x509_svid_cache_max_size is not set",
input: func(c *Config) {
},
test: func(t *testing.T, c *agent.Config) {
require.EqualValues(t, 0, c.X509SVIDCacheMaxSize)
},
},
{
msg: "x509_svid_cache_max_size is zero",
input: func(c *Config) {
c.Agent.Experimental.X509SVIDCacheMaxSize = 0
},
test: func(t *testing.T, c *agent.Config) {
require.EqualValues(t, 0, c.X509SVIDCacheMaxSize)
},
},
{
msg: "x509_svid_cache_max_size is negative",
expectError: true,
input: func(c *Config) {
c.Agent.Experimental.X509SVIDCacheMaxSize = -10
},
test: func(t *testing.T, c *agent.Config) {
require.Nil(t, c)
},
},
{
msg: "allowed_foreign_jwt_claims provided",
input: func(c *Config) {
Expand Down
28 changes: 14 additions & 14 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,20 @@ func (a *Agent) attest(ctx context.Context, sto storage.Storage, cat catalog.Cat

func (a *Agent) newManager(ctx context.Context, sto storage.Storage, cat catalog.Catalog, metrics telemetry.Metrics, as *node_attestor.AttestationResult, cache *storecache.Cache, na nodeattestor.NodeAttestor) (manager.Manager, error) {
config := &manager.Config{
SVID: as.SVID,
SVIDKey: as.Key,
Bundle: as.Bundle,
Reattestable: as.Reattestable,
Catalog: cat,
TrustDomain: a.c.TrustDomain,
ServerAddr: a.c.ServerAddress,
Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager),
Metrics: metrics,
WorkloadKeyType: a.c.WorkloadKeyType,
Storage: sto,
SyncInterval: a.c.SyncInterval,
SVIDStoreCache: cache,
NodeAttestor: na,
SVID: as.SVID,
SVIDKey: as.Key,
Bundle: as.Bundle,
Catalog: cat,
TrustDomain: a.c.TrustDomain,
ServerAddr: a.c.ServerAddress,
Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager),
Metrics: metrics,
WorkloadKeyType: a.c.WorkloadKeyType,
Storage: sto,
SyncInterval: a.c.SyncInterval,
SVIDCacheMaxSize: a.c.X509SVIDCacheMaxSize,
SVIDStoreCache: cache,
NodeAttestor: na,
}

mgr := manager.New(config)
Expand Down
28 changes: 16 additions & 12 deletions pkg/agent/api/delegatedidentity/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,24 @@ func (s *Service) isCallerAuthorized(ctx context.Context, log logrus.FieldLogger
}
}

identities := s.manager.MatchingIdentities(callerSelectors)
numRegisteredIDs := len(identities)
entries := s.manager.MatchingRegistrationEntries(callerSelectors)
numRegisteredEntries := len(entries)

if numRegisteredIDs == 0 {
if numRegisteredEntries == 0 {
log.Error("no identity issued")
return nil, status.Error(codes.PermissionDenied, "no identity issued")
}

for _, identity := range identities {
if _, ok := s.authorizedDelegates[identity.Entry.SpiffeId]; ok {
for _, entry := range entries {
if _, ok := s.authorizedDelegates[entry.SpiffeId]; ok {
return callerSelectors, nil
}
}

// caller has identity associeted with but none is authorized
log.WithFields(logrus.Fields{
"num_registered_ids": numRegisteredIDs,
"default_id": identities[0].Entry.SpiffeId,
"num_registered_entries": numRegisteredEntries,
"default_id": entries[0].SpiffeId,
}).Error("Permission denied; caller not configured as an authorized delegate.")

return nil, status.Error(codes.PermissionDenied, "caller not configured as an authorized delegate")
Expand All @@ -120,7 +120,11 @@ func (s *Service) SubscribeToX509SVIDs(req *delegatedidentityv1.SubscribeToX509S
return status.Error(codes.InvalidArgument, "could not parse provided selectors")
}

subscriber := s.manager.SubscribeToCacheChanges(selectors)
subscriber, err := s.manager.SubscribeToCacheChanges(ctx, selectors)
if err != nil {
log.WithError(err).Error("Subscribe to cache changes failed")
return err
}
defer subscriber.Finish()

for {
Expand Down Expand Up @@ -268,11 +272,11 @@ func (s *Service) FetchJWTSVIDs(ctx context.Context, req *delegatedidentityv1.Fe
}
var spiffeIDs []spiffeid.ID

identities := s.manager.MatchingIdentities(selectors)
for _, identity := range identities {
spiffeID, err := spiffeid.FromString(identity.Entry.SpiffeId)
entries := s.manager.MatchingRegistrationEntries(selectors)
for _, entry := range entries {
spiffeID, err := spiffeid.FromString(entry.SpiffeId)
if err != nil {
log.WithField(telemetry.SPIFFEID, identity.Entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID")
log.WithField(telemetry.SPIFFEID, entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID")
return nil, status.Errorf(codes.InvalidArgument, "invalid requested SPIFFE ID: %v", err)
}

Expand Down
29 changes: 23 additions & 6 deletions pkg/agent/api/delegatedidentity/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ func TestSubscribeToX509SVIDs(t *testing.T) {
expectCode: codes.PermissionDenied,
expectMsg: "caller not configured as an authorized delegate",
},
{
testName: "subscribe to cache changes error",
authSpiffeID: []string{"spiffe://example.org/one"},
identities: []cache.Identity{
identityFromX509SVID(x509SVID1),
},
managerErr: errors.New("err"),
expectCode: codes.Unknown,
expectMsg: "err",
},
{
testName: "workload update with one identity",
authSpiffeID: []string{"spiffe://example.org/one"},
Expand Down Expand Up @@ -653,10 +663,6 @@ func (fa FakeAttestor) Attest(ctx context.Context) ([]*common.Selector, error) {
return fa.selectors, fa.err
}

func (m *FakeManager) MatchingIdentities(selectors []*common.Selector) []cache.Identity {
return m.identities
}

type FakeManager struct {
manager.Manager

Expand All @@ -677,9 +683,12 @@ func (m *FakeManager) subscriberDone() {
atomic.AddInt32(&m.subscribers, -1)
}

func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber {
func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) {
if m.err != nil {
return nil, m.err
}
atomic.AddInt32(&m.subscribers, 1)
return newFakeSubscriber(m, m.updates)
return newFakeSubscriber(m, m.updates), nil
}

func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error) {
Expand All @@ -692,6 +701,14 @@ func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, au
}, nil
}

func (m *FakeManager) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry {
out := make([]*common.RegistrationEntry, 0, len(m.identities))
for _, identity := range m.identities {
out = append(out, identity.Entry)
}
return out
}

type fakeSubscriber struct {
m *FakeManager
ch chan *cache.WorkloadUpdate
Expand Down
3 changes: 3 additions & 0 deletions pkg/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ type Config struct {
// SyncInterval controls how often the agent sync synchronizer waits
SyncInterval time.Duration

// X509SVIDCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache
X509SVIDCacheMaxSize int

// Trust domain and associated CA bundle
TrustDomain spiffeid.TrustDomain
TrustBundle []*x509.Certificate
Expand Down
8 changes: 6 additions & 2 deletions pkg/agent/endpoints/sdsv2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Attestor interface {
}

type Manager interface {
SubscribeToCacheChanges(key cache.Selectors) cache.Subscriber
SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error)
FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate
}

Expand Down Expand Up @@ -64,7 +64,11 @@ func (h *Handler) StreamSecrets(stream discovery_v2.SecretDiscoveryService_Strea
return err
}

sub := h.c.Manager.SubscribeToCacheChanges(selectors)
sub, err := h.c.Manager.SubscribeToCacheChanges(stream.Context(), selectors)
if err != nil {
log.WithError(err).Error("Subscribe to cache changes failed")
return err
}
defer sub.Finish()

updch := sub.Updates()
Expand Down
4 changes: 2 additions & 2 deletions pkg/agent/endpoints/sdsv2/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ func NewFakeManager(t *testing.T) *FakeManager {
}
}

func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber {
func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) {
require.Equal(m.t, workloadSelectors, selectors)

updch := make(chan *cache.WorkloadUpdate, 1)
Expand All @@ -568,7 +568,7 @@ func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.S
return NewFakeSubscriber(updch, func() {
delete(m.subs, key)
close(updch)
})
}), nil
}

func (m *FakeManager) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate {
Expand Down
8 changes: 6 additions & 2 deletions pkg/agent/endpoints/sdsv3/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Attestor interface {
}

type Manager interface {
SubscribeToCacheChanges(key cache.Selectors) cache.Subscriber
SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error)
FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate
}

Expand Down Expand Up @@ -74,7 +74,11 @@ func (h *Handler) StreamSecrets(stream secret_v3.SecretDiscoveryService_StreamSe
return err
}

sub := h.c.Manager.SubscribeToCacheChanges(selectors)
sub, err := h.c.Manager.SubscribeToCacheChanges(stream.Context(), selectors)
if err != nil {
log.WithError(err).Error("Subscribe to cache changes failed")
return err
}
defer sub.Finish()

updch := sub.Updates()
Expand Down
37 changes: 33 additions & 4 deletions pkg/agent/endpoints/sdsv3/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,21 @@ func TestStreamSecretsBadNonce(t *testing.T) {
requireSecrets(t, resp, workloadTLSCertificate2)
}

func TestStreamSecretsErrInSubscribeToCacheChanges(t *testing.T) {
test := setupErrTest(t)
defer test.server.Stop()

stream, err := test.handler.StreamSecrets(context.Background())
require.NoError(t, err)
defer func() {
require.NoError(t, stream.CloseSend())
}()

resp, err := stream.Recv()
require.Error(t, err)
require.Nil(t, resp)
}

func TestFetchSecrets(t *testing.T) {
for _, tt := range []struct {
name string
Expand Down Expand Up @@ -1174,11 +1189,16 @@ func DeltaSecretsTest(t *testing.T) {
}

func setupTest(t *testing.T) *handlerTest {
return setupTestWithConfig(t, Config{})
return setupTestWithManager(t, Config{}, NewFakeManager(t))
}

func setupTestWithConfig(t *testing.T, c Config) *handlerTest {
func setupErrTest(t *testing.T) *handlerTest {
manager := NewFakeManager(t)
manager.err = errors.New("bad-error")
return setupTestWithManager(t, Config{}, manager)
}

func setupTestWithManager(t *testing.T, c Config, manager *FakeManager) *handlerTest {
defaultConfig := Config{
Manager: manager,
Attestor: FakeAttestor(workloadSelectors),
Expand Down Expand Up @@ -1220,6 +1240,11 @@ func setupTestWithConfig(t *testing.T, c Config) *handlerTest {
return test
}

func setupTestWithConfig(t *testing.T, c Config) *handlerTest {
manager := NewFakeManager(t)
return setupTestWithManager(t, c, manager)
}

type handlerTest struct {
t *testing.T

Expand Down Expand Up @@ -1279,6 +1304,7 @@ type FakeManager struct {
upd *cache.WorkloadUpdate
next int
subs map[int]chan *cache.WorkloadUpdate
err error
}

func NewFakeManager(t *testing.T) *FakeManager {
Expand All @@ -1288,7 +1314,10 @@ func NewFakeManager(t *testing.T) *FakeManager {
}
}

func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber {
func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) {
if m.err != nil {
return nil, m.err
}
require.Equal(m.t, workloadSelectors, selectors)

updch := make(chan *cache.WorkloadUpdate, 1)
Expand All @@ -1304,7 +1333,7 @@ func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.S
return NewFakeSubscriber(updch, func() {
delete(m.subs, key)
close(updch)
})
}), nil
}

func (m *FakeManager) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate {
Expand Down
Loading

0 comments on commit 6689c36

Please sign in to comment.