Skip to content

Commit

Permalink
Implement LRU cache for storing SVIDs in SPIRE Agent
Browse files Browse the repository at this point in the history
  • Loading branch information
prasadborole1 committed Jun 22, 2022
1 parent 2cec901 commit 49f442a
Show file tree
Hide file tree
Showing 30 changed files with 1,321 additions and 260 deletions.
2 changes: 1 addition & 1 deletion cmd/spire-agent/cli/api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func adaptCommand(env *cli.Env, clientsMaker workloadClientMaker, cmd command) *
clientsMaker: clientsMaker,
cmd: cmd,
env: env,
timeout: cli.DurationFlag(time.Second),
timeout: cli.DurationFlag(2 * time.Second),
}

fs := flag.NewFlagSet(cmd.name(), flag.ContinueOnError)
Expand Down
15 changes: 14 additions & 1 deletion cmd/spire-agent/cli/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ type experimentalConfig struct {
SyncInterval string `hcl:"sync_interval"`
TCPSocketPort int `hcl:"tcp_socket_port"`

UnusedKeys []string `hcl:",unusedKeys"`
UnusedKeys []string `hcl:",unusedKeys"`
MaxSvidCacheSize int `hcl:"max_svid_cache_size"`
SVIDCacheExpiryPeriod string `hcl:"svid_cache_expiry_interval"`
}

type Command struct {
Expand Down Expand Up @@ -400,6 +402,17 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool)
}
}

if c.Agent.Experimental.MaxSvidCacheSize != 0 {
ac.MaxSvidCacheSize = c.Agent.Experimental.MaxSvidCacheSize
}
if c.Agent.Experimental.SVIDCacheExpiryPeriod != "" {
var err error
ac.SVIDCacheExpiryPeriod, err = time.ParseDuration(c.Agent.Experimental.SVIDCacheExpiryPeriod)
if err != nil {
return nil, fmt.Errorf("could not parse svid cache expiry interval: %w", err)
}
}

serverHostPort := net.JoinHostPort(c.Agent.ServerAddress, strconv.Itoa(c.Agent.ServerPort))
ac.ServerAddress = fmt.Sprintf("dns:///%s", serverHostPort)

Expand Down
44 changes: 44 additions & 0 deletions cmd/spire-agent/cli/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,50 @@ func TestNewAgentConfig(t *testing.T) {
require.Nil(t, c)
},
},
{
msg: "svid_cache_expiry_interval parses a duration",
input: func(c *Config) {
c.Agent.Experimental.SVIDCacheExpiryPeriod = "1s50ms"
},
test: func(t *testing.T, c *agent.Config) {
require.EqualValues(t, 1050000000, c.SVIDCacheExpiryPeriod)
},
},
{
msg: "invalid svid_cache_expiry_interval returns an error",
expectError: true,
input: func(c *Config) {
c.Agent.Experimental.SVIDCacheExpiryPeriod = "moo"
},
test: func(t *testing.T, c *agent.Config) {
require.Nil(t, c)
},
},
{
msg: "svid_cache_expiry_interval is not set",
input: func(c *Config) {
},
test: func(t *testing.T, c *agent.Config) {
require.EqualValues(t, 0, c.SVIDCacheExpiryPeriod)
},
},
{
msg: "max_svid_cache_size is set",
input: func(c *Config) {
c.Agent.Experimental.MaxSvidCacheSize = 100
},
test: func(t *testing.T, c *agent.Config) {
require.EqualValues(t, 100, c.MaxSvidCacheSize)
},
},
{
msg: "max_svid_cache_size is not set",
input: func(c *Config) {
},
test: func(t *testing.T, c *agent.Config) {
require.EqualValues(t, 0, c.MaxSvidCacheSize)
},
},
{
msg: "admin_socket_path not provided",
input: func(c *Config) {
Expand Down
24 changes: 13 additions & 11 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,19 @@ func (a *Agent) attest(ctx context.Context, cat catalog.Catalog, metrics telemet
func (a *Agent) newManager(ctx context.Context, cat catalog.Catalog, metrics telemetry.Metrics, as *node_attestor.AttestationResult, cache *storecache.Cache) (manager.Manager, error) {
config := &manager.Config{
SVID: as.SVID,
SVIDKey: as.Key,
Bundle: as.Bundle,
Catalog: cat,
TrustDomain: a.c.TrustDomain,
ServerAddr: a.c.ServerAddress,
Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager),
Metrics: metrics,
BundleCachePath: a.bundleCachePath(),
SVIDCachePath: a.agentSVIDPath(),
SyncInterval: a.c.SyncInterval,
SVIDStoreCache: cache,
SVIDKey: as.Key,
Bundle: as.Bundle,
Catalog: cat,
TrustDomain: a.c.TrustDomain,
ServerAddr: a.c.ServerAddress,
Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager),
Metrics: metrics,
BundleCachePath: a.bundleCachePath(),
SVIDCachePath: a.agentSVIDPath(),
SyncInterval: a.c.SyncInterval,
MaxSvidCacheSize: a.c.MaxSvidCacheSize,
SVIDCacheExpiryPeriod: a.c.SVIDCacheExpiryPeriod,
SVIDStoreCache: cache,
}

mgr := manager.New(config)
Expand Down
18 changes: 9 additions & 9 deletions pkg/agent/api/delegatedidentity/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,24 @@ func (s *Service) isCallerAuthorized(ctx context.Context, log logrus.FieldLogger
}
}

identities := s.manager.MatchingIdentities(callerSelectors)
numRegisteredIDs := len(identities)
entries := s.manager.MatchingRegistrationEntries(callerSelectors)
numRegisteredIDs := len(entries)

if numRegisteredIDs == 0 {
log.Error("no identity issued")
return nil, status.Error(codes.PermissionDenied, "no identity issued")
}

for _, identity := range identities {
if _, ok := s.authorizedDelegates[identity.Entry.SpiffeId]; ok {
for _, entry := range entries {
if _, ok := s.authorizedDelegates[entry.SpiffeId]; ok {
return callerSelectors, nil
}
}

// caller has identity associeted with but none is authorized
log.WithFields(logrus.Fields{
"num_registered_ids": numRegisteredIDs,
"default_id": identities[0].Entry.SpiffeId,
"default_id": entries[0].SpiffeId,
}).Error("Permission denied; caller not configured as an authorized delegate.")

return nil, status.Error(codes.PermissionDenied, "caller not configured as an authorized delegate")
Expand Down Expand Up @@ -268,11 +268,11 @@ func (s *Service) FetchJWTSVIDs(ctx context.Context, req *delegatedidentityv1.Fe
}
var spiffeIDs []spiffeid.ID

identities := s.manager.MatchingIdentities(selectors)
for _, identity := range identities {
spiffeID, err := spiffeid.FromString(identity.Entry.SpiffeId)
entries := s.manager.MatchingRegistrationEntries(selectors)
for _, entry := range entries {
spiffeID, err := spiffeid.FromString(entry.SpiffeId)
if err != nil {
log.WithField(telemetry.SPIFFEID, identity.Entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID")
log.WithField(telemetry.SPIFFEID, entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID")
return nil, status.Errorf(codes.InvalidArgument, "invalid requested SPIFFE ID: %v", err)
}

Expand Down
15 changes: 10 additions & 5 deletions pkg/agent/api/delegatedidentity/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/andres-erbsen/clock"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/spiffe/go-spiffe/v2/bundle/spiffebundle"
Expand Down Expand Up @@ -653,10 +654,6 @@ func (fa FakeAttestor) Attest(ctx context.Context) ([]*common.Selector, error) {
return fa.selectors, fa.err
}

func (m *FakeManager) MatchingIdentities(selectors []*common.Selector) []cache.Identity {
return m.identities
}

type FakeManager struct {
manager.Manager

Expand Down Expand Up @@ -692,6 +689,14 @@ func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, au
}, nil
}

func (m *FakeManager) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry {
out := make([]*common.RegistrationEntry, 0, len(m.identities))
for _, identity := range m.identities {
out = append(out, identity.Entry)
}
return out
}

type fakeSubscriber struct {
m *FakeManager
ch chan *cache.WorkloadUpdate
Expand Down Expand Up @@ -794,5 +799,5 @@ func (m *FakeManager) SubscribeToBundleChanges() *cache.BundleStream {

func newTestCache() *cache.Cache {
log, _ := test.NewNullLogger()
return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{})
return cache.New(log, trustDomain1, bundle1, telemetry.Blackhole{}, 0, 0, clock.NewMock())
}
6 changes: 6 additions & 0 deletions pkg/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ type Config struct {
// SyncInterval controls how often the agent sync synchronizer waits
SyncInterval time.Duration

// MaxSvidCacheSize is a soft limit of max number of SVIDs that would be stored in cache
MaxSvidCacheSize int

// SVIDCacheExpiryPeriod is a period after which svids that don't have subscribers will be removed from cache
SVIDCacheExpiryPeriod time.Duration

// Trust domain and associated CA bundle
TrustDomain spiffeid.TrustDomain
TrustBundle []*x509.Certificate
Expand Down
12 changes: 6 additions & 6 deletions pkg/agent/endpoints/workload/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (

type Manager interface {
SubscribeToCacheChanges(cache.Selectors) cache.Subscriber
MatchingIdentities([]*common.Selector) []cache.Identity
MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry
FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error)
FetchWorkloadUpdate([]*common.Selector) *cache.WorkloadUpdate
}
Expand Down Expand Up @@ -84,15 +84,15 @@ func (h *Handler) FetchJWTSVID(ctx context.Context, req *workload.JWTSVIDRequest

log = log.WithField(telemetry.Registered, true)

identities := h.c.Manager.MatchingIdentities(selectors)
for _, identity := range identities {
if req.SpiffeId != "" && identity.Entry.SpiffeId != req.SpiffeId {
entries := h.c.Manager.MatchingRegistrationEntries(selectors)
for _, entry := range entries {
if req.SpiffeId != "" && entry.SpiffeId != req.SpiffeId {
continue
}

spiffeID, err := spiffeid.FromString(identity.Entry.SpiffeId)
spiffeID, err := spiffeid.FromString(entry.SpiffeId)
if err != nil {
log.WithField(telemetry.SPIFFEID, identity.Entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID")
log.WithField(telemetry.SPIFFEID, entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID")
return nil, status.Errorf(codes.InvalidArgument, "invalid requested SPIFFE ID: %v", err)
}

Expand Down
8 changes: 6 additions & 2 deletions pkg/agent/endpoints/workload/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1014,8 +1014,12 @@ type FakeManager struct {
err error
}

func (m *FakeManager) MatchingIdentities(selectors []*common.Selector) []cache.Identity {
return m.identities
func (m *FakeManager) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry {
out := make([]*common.RegistrationEntry, 0, len(m.identities))
for _, identity := range m.identities {
out = append(out, identity.Entry)
}
return out
}

func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error) {
Expand Down
Loading

0 comments on commit 49f442a

Please sign in to comment.