Skip to content

Commit

Permalink
Update generic storage service to support validation and custom keys (#…
Browse files Browse the repository at this point in the history
…46319)

Adds two new optional configuration options to the generic.Service:
ValidateFunc and KeyFunc.

ValidateFunc is a custom function that is called prior to persisting
a resource in the backend. If the function returns an error, the
resource is not stored and the error is returned to users.

KeyFunc is a custom function used to derive the key for a particular
resource. By default the generic service uses the metadata name as
the key, however, in some scenarios more control over the key might
be desired. For instance, a singleton resource might want to enforce
that the key is also a static value instead of something that a
user may supply.
  • Loading branch information
rosstimothy committed Sep 10, 2024
1 parent c67492f commit c98c20f
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 51 deletions.
74 changes: 59 additions & 15 deletions lib/services/local/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,28 @@ type UnmarshalFunc[T any] func([]byte, ...services.MarshalOption) (T, error)

// ServiceConfig is the configuration for the service configuration.
type ServiceConfig[T Resource] struct {
Backend backend.Backend
ResourceKind string
PageLimit uint
// Backend used to persist the resource.
Backend backend.Backend
// ResourceKind is the friendly name of the resource.
ResourceKind string
// PageLimit
PageLimit uint
// BackendPrefix used when constructing the [backend.Item.Key].
BackendPrefix string
MarshalFunc MarshalFunc[T]
// MarshlFunc converts the resource to bytes for persistence.
MarshalFunc MarshalFunc[T]
// UnmarshalFunc converts the bytes read from the backend to the resource.
UnmarshalFunc UnmarshalFunc[T]
// ValidateFunc optionally validates the resource prior to persisting it. Any errors
// returned from the validation function will prevent writes to the backend.
ValidateFunc func(T) error
// RunWhileLockedRetryInterval is the interval to retry the RunWhileLocked function.
// If set to 0, the default interval of 250ms will be used.
// WARNING: If set to a negative value, the RunWhileLocked function will retry immediately.
RunWhileLockedRetryInterval time.Duration
// KeyFunc optionally allows resource to have a custom key. If not provided the
// name of the resource will be used.
KeyFunc func(T) string
}

func (c *ServiceConfig[T]) CheckAndSetDefaults() error {
Expand All @@ -77,6 +89,14 @@ func (c *ServiceConfig[T]) CheckAndSetDefaults() error {
return trace.BadParameter("unmarshal func is missing")
}

if c.ValidateFunc == nil {
c.ValidateFunc = func(t T) error { return nil }
}

if c.KeyFunc == nil {
c.KeyFunc = func(t T) string { return t.GetName() }
}

return nil
}

Expand All @@ -88,7 +108,9 @@ type Service[T Resource] struct {
backendPrefix string
marshalFunc MarshalFunc[T]
unmarshalFunc UnmarshalFunc[T]
validateFunc func(T) error
runWhileLockedRetryInterval time.Duration
keyFunc func(T) string
}

// NewService will return a new generic service with the given config. This will
Expand All @@ -105,7 +127,9 @@ func NewService[T Resource](cfg *ServiceConfig[T]) (*Service[T], error) {
backendPrefix: cfg.BackendPrefix,
marshalFunc: cfg.MarshalFunc,
unmarshalFunc: cfg.UnmarshalFunc,
validateFunc: cfg.ValidateFunc,
runWhileLockedRetryInterval: cfg.RunWhileLockedRetryInterval,
keyFunc: cfg.KeyFunc,
}, nil
}

Expand All @@ -116,12 +140,15 @@ func (s *Service[T]) WithPrefix(parts ...string) *Service[T] {
}

return &Service[T]{
backend: s.backend,
resourceKind: s.resourceKind,
pageLimit: s.pageLimit,
backendPrefix: strings.Join(append([]string{s.backendPrefix}, parts...), string(backend.Separator)),
marshalFunc: s.marshalFunc,
unmarshalFunc: s.unmarshalFunc,
backend: s.backend,
resourceKind: s.resourceKind,
pageLimit: s.pageLimit,
backendPrefix: strings.Join(append([]string{s.backendPrefix}, parts...), string(backend.Separator)),
marshalFunc: s.marshalFunc,
unmarshalFunc: s.unmarshalFunc,
validateFunc: s.validateFunc,
runWhileLockedRetryInterval: s.runWhileLockedRetryInterval,
keyFunc: s.keyFunc,
}
}

Expand Down Expand Up @@ -228,7 +255,11 @@ func (s *Service[T]) GetResource(ctx context.Context, name string) (resource T,
// CreateResource creates a new resource.
func (s *Service[T]) CreateResource(ctx context.Context, resource T) (T, error) {
var t T
item, err := s.MakeBackendItem(resource, resource.GetName())
if err := s.validateFunc(resource); err != nil {
return t, trace.Wrap(err)
}

item, err := s.MakeBackendItem(resource, s.keyFunc(resource))
if err != nil {
return t, trace.Wrap(err)
}
Expand All @@ -248,7 +279,12 @@ func (s *Service[T]) CreateResource(ctx context.Context, resource T) (T, error)
// UpdateResource updates an existing resource.
func (s *Service[T]) UpdateResource(ctx context.Context, resource T) (T, error) {
var t T
item, err := s.MakeBackendItem(resource, resource.GetName())

if err := s.validateFunc(resource); err != nil {
return t, trace.Wrap(err)
}

item, err := s.MakeBackendItem(resource, s.keyFunc(resource))
if err != nil {
return t, trace.Wrap(err)
}
Expand All @@ -268,7 +304,12 @@ func (s *Service[T]) UpdateResource(ctx context.Context, resource T) (T, error)
// UpsertResource upserts a resource.
func (s *Service[T]) UpsertResource(ctx context.Context, resource T) (T, error) {
var t T
item, err := s.MakeBackendItem(resource, resource.GetName())

if err := s.validateFunc(resource); err != nil {
return t, trace.Wrap(err)
}

item, err := s.MakeBackendItem(resource, s.keyFunc(resource))
if err != nil {
return t, trace.Wrap(err)
}
Expand Down Expand Up @@ -317,8 +358,11 @@ func (s *Service[T]) UpdateAndSwapResource(ctx context.Context, name string, mod
return t, trace.Wrap(err)
}

err = modify(resource)
if err != nil {
if err := modify(resource); err != nil {
return t, trace.Wrap(err)
}

if err := s.validateFunc(resource); err != nil {
return t, trace.Wrap(err)
}

Expand Down
118 changes: 118 additions & 0 deletions lib/services/local/generic/generic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,121 @@ func TestGenericListResourcesReturnNextResource(t *testing.T) {
))
require.Nil(t, next)
}


Check failure on line 330 in lib/services/local/generic/generic_test.go

View workflow job for this annotation

GitHub Actions / Lint (Go)

File is not `gci`-ed with --skip-generated -s standard -s default -s prefix(github.com/gravitational/teleport) --custom-order (gci)

Check failure on line 330 in lib/services/local/generic/generic_test.go

View workflow job for this annotation

GitHub Actions / Lint (Go)

File is not `goimports`-ed (goimports)
func TestGenericValidation(t *testing.T) {

ctx := context.Background()

memBackend, err := memory.New(memory.Config{
Context: ctx,
Clock: clockwork.NewFakeClock(),
})
require.NoError(t, err)

validationErr := trace.BadParameter("invalid test resource")
service, err := NewService(&ServiceConfig[*testResource]{
Backend: memBackend,
ResourceKind: "generic resource",
PageLimit: 200,
BackendPrefix: "generic_prefix",
UnmarshalFunc: unmarshalResource,
MarshalFunc: marshalResource,
ValidateFunc: func(tr *testResource) error { return validationErr },
})
require.NoError(t, err)

r1 := newTestResource("r1")

_, err = service.CreateResource(ctx, r1)
require.ErrorIs(t, err, validationErr)

_, err = service.UpdateResource(ctx, r1)
require.ErrorIs(t, err, validationErr)

_, err = service.UpsertResource(ctx, r1)
require.ErrorIs(t, err, validationErr)

}

func TestGenericKeyOverride(t *testing.T) {
ctx := context.Background()
memBackend, err := memory.New(memory.Config{
Context: ctx,
Clock: clockwork.NewFakeClock(),
})
require.NoError(t, err)

service, err := NewService(&ServiceConfig[*testResource]{
Backend: memBackend,
ResourceKind: "generic resource",
PageLimit: 200,
BackendPrefix: "generic_prefix",
UnmarshalFunc: unmarshalResource,
MarshalFunc: marshalResource,
KeyFunc: func(tr *testResource) string { return "llama" },
})
require.NoError(t, err)

r1 := newTestResource("r1")

// Create the test resource
created, err := service.CreateResource(ctx, r1)
require.NoError(t, err)
require.Empty(t, cmp.Diff(r1, created, cmpopts.IgnoreFields(types.Metadata{}, "Revision")))

// Validate that the resource is stored under the custom key
item, err := memBackend.Get(ctx, backend.NewKey("generic_prefix", "llama"))
require.NoError(t, err)
require.NotNil(t, item)

// Validate that the default key does not exist
item, err = memBackend.Get(ctx, backend.NewKey("generic_prefix", r1.GetName()))
require.Error(t, err)
require.Nil(t, item)

// Update the test resource
updated, err := service.UpdateResource(ctx, created)
require.NoError(t, err)
require.Empty(t, cmp.Diff(r1, updated, cmpopts.IgnoreFields(types.Metadata{}, "Revision")))

// Validate that the resource is stored under the custom key
item, err = memBackend.Get(ctx, backend.NewKey("generic_prefix", "llama"))
require.NoError(t, err)
require.NotNil(t, item)

// Validate that the default key still does not exist
item, err = memBackend.Get(ctx, backend.NewKey("generic_prefix", r1.GetName()))
require.Error(t, err)
require.Nil(t, item)

// Upsert the test resource
upserted, err := service.UpsertResource(ctx, updated)
require.NoError(t, err)
require.Empty(t, cmp.Diff(r1, upserted, cmpopts.IgnoreFields(types.Metadata{}, "Revision")))

// Validate that the resource is stored under the custom key
item, err = memBackend.Get(ctx, backend.NewKey("generic_prefix", "llama"))
require.NoError(t, err)
require.NotNil(t, item)

// Validate that the default key still does not exist
item, err = memBackend.Get(ctx, backend.NewKey("generic_prefix", r1.GetName()))
require.Error(t, err)
require.Nil(t, item)

// Compare and swap the resource
swapped, err := service.UpdateAndSwapResource(ctx, "llama", func(tr *testResource) error { return nil })
require.NoError(t, err)
require.Empty(t, cmp.Diff(r1, swapped, cmpopts.IgnoreFields(types.Metadata{}, "Revision")))

// Validate that the resource is stored under the custom key
item, err = memBackend.Get(ctx, backend.NewKey("generic_prefix", "llama"))
require.NoError(t, err)
require.NotNil(t, item)

// Validate that the default key still does not exist
item, err = memBackend.Get(ctx, backend.NewKey("generic_prefix", r1.GetName()))
require.Error(t, err)
require.Nil(t, item)
}
79 changes: 59 additions & 20 deletions lib/services/local/generic/generic_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package generic
import (
"context"
"strings"
"time"

"github.com/gravitational/trace"

Expand All @@ -27,27 +28,62 @@ import (
"github.com/gravitational/teleport/lib/services"
)

// ServiceWrapperConfig is the configuration for the service wrapper.
type ServiceWrapperConfig[T types.ResourceMetadata] struct {
// Backend used to persist the resource.
Backend backend.Backend
// ResourceKind is the friendly name of the resource.
ResourceKind string
// PageLimit
PageLimit uint
// BackendPrefix used when constructing the [backend.Item.Key].
BackendPrefix string
// MarshlFunc converts the resource to bytes for persistence.
MarshalFunc MarshalFunc[T]
// UnmarshalFunc converts the bytes read from the backend to the resource.
UnmarshalFunc UnmarshalFunc[T]
// ValidateFunc optionally validates the resource prior to persisting it. Any errors
// returned from the validation function will prevent writes to the backend.
ValidateFunc func(T) error
// RunWhileLockedRetryInterval is the interval to retry the RunWhileLocked function.
// If set to 0, the default interval of 250ms will be used.
// WARNING: If set to a negative value, the RunWhileLocked function will retry immediately.
RunWhileLockedRetryInterval time.Duration
// KeyFunc optionally allows resource to have a custom key. If not provided the
// name of the resource will be used.
KeyFunc func(T) string
}

// NewServiceWrapper will return a new generic service wrapper. It is compatible with resources aligned with RFD 153.
func NewServiceWrapper[T types.ResourceMetadata](
backend backend.Backend,
resourceKind string,
backendPrefix string,
marshalFunc MarshalFunc[T],
unmarshalFunc UnmarshalFunc[T]) (*ServiceWrapper[T], error) {

cfg := &ServiceConfig[resourceMetadataAdapter[T]]{
Backend: backend,
ResourceKind: resourceKind,
BackendPrefix: backendPrefix,
func NewServiceWrapper[T types.ResourceMetadata](cfg ServiceWrapperConfig[T]) (*ServiceWrapper[T], error) {
serviceConfig := &ServiceConfig[resourceMetadataAdapter[T]]{
Backend: cfg.Backend,
ResourceKind: cfg.ResourceKind,
PageLimit: cfg.PageLimit,
BackendPrefix: cfg.BackendPrefix,
MarshalFunc: func(w resourceMetadataAdapter[T], option ...services.MarshalOption) ([]byte, error) {
return marshalFunc(w.resource, option...)
return cfg.MarshalFunc(w.resource, option...)
},
UnmarshalFunc: func(bytes []byte, option ...services.MarshalOption) (resourceMetadataAdapter[T], error) {
r, err := unmarshalFunc(bytes, option...)
r, err := cfg.UnmarshalFunc(bytes, option...)
return newResourceMetadataAdapter(r), trace.Wrap(err)
},
RunWhileLockedRetryInterval: cfg.RunWhileLockedRetryInterval,
}

if cfg.ValidateFunc != nil {
serviceConfig.ValidateFunc = func(rma resourceMetadataAdapter[T]) error {
return cfg.ValidateFunc(rma.resource)
}
}
service, err := NewService[resourceMetadataAdapter[T]](cfg)

if cfg.KeyFunc != nil {
serviceConfig.KeyFunc = func(rma resourceMetadataAdapter[T]) string {
return cfg.KeyFunc(rma.resource)
}
}

service, err := NewService[resourceMetadataAdapter[T]](serviceConfig)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -71,12 +107,15 @@ func (s ServiceWrapper[T]) WithPrefix(parts ...string) *ServiceWrapper[T] {

return &ServiceWrapper[T]{
service: &Service[resourceMetadataAdapter[T]]{
backend: s.service.backend,
resourceKind: s.service.resourceKind,
pageLimit: s.service.pageLimit,
backendPrefix: strings.Join(append([]string{s.service.backendPrefix}, parts...), string(backend.Separator)),
marshalFunc: s.service.marshalFunc,
unmarshalFunc: s.service.unmarshalFunc,
backend: s.service.backend,
resourceKind: s.service.resourceKind,
pageLimit: s.service.pageLimit,
backendPrefix: strings.Join(append([]string{s.service.backendPrefix}, parts...), string(backend.Separator)),
marshalFunc: s.service.marshalFunc,
unmarshalFunc: s.service.unmarshalFunc,
validateFunc: s.service.validateFunc,
keyFunc: s.service.keyFunc,
runWhileLockedRetryInterval: s.service.runWhileLockedRetryInterval,
},
}
}
Expand Down
Loading

0 comments on commit c98c20f

Please sign in to comment.