From 2edaa678a89b66ebf78e9195be4244e70bf92cd1 Mon Sep 17 00:00:00 2001 From: Qiang Zhao Date: Mon, 24 Jun 2024 18:09:07 +0800 Subject: [PATCH 1/8] auth: support oidc authentication provider --- cmd/server/cmd.go | 4 + common/container/container.go | 38 +++++- coordinator/coordinator_rpc_server.go | 3 +- go.mod | 3 + go.sum | 6 + server/auth/authentication.go | 45 ++++++++ server/auth/interceptor.go | 90 +++++++++++++++ server/auth/oidc.go | 159 ++++++++++++++++++++++++++ server/internal_rpc_server.go | 3 +- server/public_rpc_server.go | 5 +- server/server.go | 10 +- server/standalone.go | 3 +- 12 files changed, 355 insertions(+), 14 deletions(-) create mode 100644 server/auth/authentication.go create mode 100644 server/auth/interceptor.go create mode 100644 server/auth/oidc.go diff --git a/cmd/server/cmd.go b/cmd/server/cmd.go index aab02b2d..72488ad5 100644 --- a/cmd/server/cmd.go +++ b/cmd/server/cmd.go @@ -43,6 +43,8 @@ var ( ) func init() { + Cmd.Flags().SortFlags = false + flag.PublicAddr(Cmd, &conf.PublicServiceAddr) flag.InternalAddr(Cmd, &conf.InternalServiceAddr) flag.MetricsAddr(Cmd, &conf.MetricsServiceAddr) @@ -52,6 +54,8 @@ func init() { Cmd.Flags().BoolVar(&conf.WalSyncData, "wal-sync-data", true, "Whether to sync data in write-ahead-log") Cmd.Flags().Int64Var(&conf.DbBlockCacheMB, "db-cache-size-mb", kv.DefaultFactoryOptions.CacheSizeMB, "Max size of the shared DB cache") + Cmd.Flags().StringVar(&conf.AuthOptions.ProviderName, "auth-provider-name", "", "Authentication provider name. supported: oidc") + Cmd.Flags().StringVar(&conf.AuthOptions.ProviderParams, "auth-provider-params", "", "Authentication provider params. \n oidc: "+"{\"allowedIssueURLs\":\"required1,required2\",\"allowedAudiences\":\"required1,required2\",\"userNameClaim\":\"optional(default:sub)\"}") // server TLS section Cmd.Flags().StringVar(&serverTLS.CertFile, "tls-cert-file", "", "Tls certificate file") diff --git a/common/container/container.go b/common/container/container.go index d85a3619..782f3685 100644 --- a/common/container/container.go +++ b/common/container/container.go @@ -17,6 +17,7 @@ package container import ( "context" "crypto/tls" + "github.com/streamnative/oxia/server/auth" "io" "log/slog" "net" @@ -44,7 +45,7 @@ type GrpcServer interface { } type GrpcProvider interface { - StartGrpcServer(name, bindAddress string, registerFunc func(grpc.ServiceRegistrar), tlsConf *tls.Config) (GrpcServer, error) + StartGrpcServer(name, bindAddress string, registerFunc func(grpc.ServiceRegistrar), tlsConf *tls.Config, options *auth.Options) (GrpcServer, error) } var Default = &defaultProvider{} @@ -52,8 +53,8 @@ var Default = &defaultProvider{} type defaultProvider struct { } -func (*defaultProvider) StartGrpcServer(name, bindAddress string, registerFunc func(grpc.ServiceRegistrar), tlsConf *tls.Config) (GrpcServer, error) { - return newDefaultGrpcProvider(name, bindAddress, registerFunc, tlsConf) +func (*defaultProvider) StartGrpcServer(name, bindAddress string, registerFunc func(grpc.ServiceRegistrar), tlsConf *tls.Config, options *auth.Options) (GrpcServer, error) { + return newDefaultGrpcProvider(name, bindAddress, registerFunc, tlsConf, options) } type defaultGrpcServer struct { @@ -64,16 +65,41 @@ type defaultGrpcServer struct { } func newDefaultGrpcProvider(name, bindAddress string, registerFunc func(grpc.ServiceRegistrar), - tlsConf *tls.Config) (GrpcServer, error) { + tlsConf *tls.Config, authOptions *auth.Options) (GrpcServer, error) { tcs := insecure.NewCredentials() if tlsConf != nil { tcs = credentials.NewTLS(tlsConf) } + streamInterceptors := []grpc.StreamServerInterceptor{ + grpcprometheus.StreamServerInterceptor, + } + unaryInterceptors := []grpc.UnaryServerInterceptor{ + grpcprometheus.UnaryServerInterceptor, + } + if authOptions.IsEnabled() { + provider, err := auth.NewAuthenticationProvider(context.Background(), *authOptions) + if err != nil { + slog.Error("Failed to init authentication provider", + slog.Any("authOptions", *authOptions), + slog.Any("error", err)) + return nil, err + } + delegator, err := auth.NewGrpcAuthenticationDelegator(provider) + if err != nil { + slog.Error("Failed to init grpc authentication delegator", + slog.Any("authOptions", *authOptions), + slog.Any("error", err)) + return nil, err + } + unaryInterceptors = append(unaryInterceptors, delegator.GetUnaryInterceptor()) + streamInterceptors = append(streamInterceptors, delegator.GetStreamInterceptor()) + } + c := &defaultGrpcServer{ server: grpc.NewServer( grpc.Creds(tcs), - grpc.ChainStreamInterceptor(grpcprometheus.StreamServerInterceptor), - grpc.ChainUnaryInterceptor(grpcprometheus.UnaryServerInterceptor), + grpc.ChainStreamInterceptor(streamInterceptors...), + grpc.ChainUnaryInterceptor(unaryInterceptors...), grpc.MaxRecvMsgSize(maxGrpcFrameSize), ), } diff --git a/coordinator/coordinator_rpc_server.go b/coordinator/coordinator_rpc_server.go index feed201e..ed2747be 100644 --- a/coordinator/coordinator_rpc_server.go +++ b/coordinator/coordinator_rpc_server.go @@ -16,6 +16,7 @@ package coordinator import ( "crypto/tls" + "github.com/streamnative/oxia/server/auth" "google.golang.org/grpc" "google.golang.org/grpc/health" @@ -37,7 +38,7 @@ func newRpcServer(bindAddress string, tlsConf *tls.Config) (*rpcServer, error) { var err error server.grpcServer, err = container.Default.StartGrpcServer("coordinator", bindAddress, func(registrar grpc.ServiceRegistrar) { grpc_health_v1.RegisterHealthServer(registrar, server.healthServer) - }, tlsConf) + }, tlsConf, &auth.Disabled) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index 977fda75..64eaa2df 100644 --- a/go.mod +++ b/go.mod @@ -66,10 +66,12 @@ require ( github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect github.com/cockroachdb/redact v1.1.5 // indirect github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 // indirect + github.com/coreos/go-oidc/v3 v3.10.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/evanphx/json-patch v5.6.0+incompatible // indirect github.com/getsentry/sentry-go v0.21.0 // indirect + github.com/go-jose/go-jose/v4 v4.0.1 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.19.6 // indirect @@ -116,6 +118,7 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect go.opentelemetry.io/otel/sdk v1.27.0 // indirect go.opentelemetry.io/otel/trace v1.27.0 // indirect + golang.org/x/crypto v0.23.0 // indirect golang.org/x/oauth2 v0.19.0 // indirect golang.org/x/term v0.20.0 // indirect golang.org/x/text v0.16.0 // indirect diff --git a/go.sum b/go.sum index a12f049a..914eca3d 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/cockroachdb/redact v1.1.5 h1:u1PMllDkdFfPWaNGMyLD1+so+aq3uUItthCFqzwP github.com/cockroachdb/redact v1.1.5/go.mod h1:BVNblN9mBWFyMyqK1k3AAiSxhvhfK2oOZZ2lK+dpvRg= github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 h1:zuQyyAKVxetITBuuhv3BI9cMrmStnpT18zmgmTxunpo= github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06/go.mod h1:7nc4anLGjupUW/PeY5qiNYsdNXj7zopG+eqsS7To5IQ= +github.com/coreos/go-oidc/v3 v3.10.0 h1:tDnXHnLyiTVyT/2zLDGj09pFPkhND8Gl8lnTRhoEaJU= +github.com/coreos/go-oidc/v3 v3.10.0/go.mod h1:5j11xcw0D3+SGxn6Z/WFADsgcWVMyNAlSQupk0KK3ac= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -51,6 +53,8 @@ github.com/getsentry/sentry-go v0.21.0 h1:c9l5F1nPF30JIppulk4veau90PK6Smu3abgVtV github.com/getsentry/sentry-go v0.21.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWqS6U= +github.com/go-jose/go-jose/v4 v4.0.1/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -227,6 +231,8 @@ go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN8 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc h1:O9NuF4s+E/PvMIy+9IUZB9znFwUIXEWSstNjek6VpVg= golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= diff --git a/server/auth/authentication.go b/server/auth/authentication.go new file mode 100644 index 00000000..966f8c96 --- /dev/null +++ b/server/auth/authentication.go @@ -0,0 +1,45 @@ +package auth + +import ( + "context" + "github.com/pkg/errors" +) + +const ( + ProviderOIDC = "oidc" + + ProviderParamTypeToken = "token" +) + +var ( + ErrUnsupportedProvider = errors.New("Unsupported authentication provider.") + ErrUnMatchedAuthenticationParamType = errors.New("Unmatched authentication parameter type.") + ErrEmptyToken = errors.New("Empty token") + ErrMalformedToken = errors.New("Malformed token") +) + +var Disabled = Options{} + +type Options struct { + ProviderName string + ProviderParams string +} + +func (op *Options) IsEnabled() bool { + return op.ProviderName != "" +} + +// todo: add metrics +type AuthenticationProvider interface { + AcceptParamType() string + Authenticate(ctx context.Context, param interface{}) (string, error) +} + +func NewAuthenticationProvider(ctx context.Context, options Options) (AuthenticationProvider, error) { + switch options.ProviderName { + case ProviderOIDC: + return NewOIDCProvider(ctx, options.ProviderParams) + default: + return nil, ErrUnsupportedProvider + } +} diff --git a/server/auth/interceptor.go b/server/auth/interceptor.go new file mode 100644 index 00000000..67a9d3a3 --- /dev/null +++ b/server/auth/interceptor.go @@ -0,0 +1,90 @@ +package auth + +import ( + "context" + "errors" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" + "log/slog" + "strings" +) + +const ( + MetadataAuthorizationKey = "authorization" + TokenPrefix = "Bearer " +) + +var ( + ErrMetadataFetchFailed = errors.New("metadata fetch failed") +) + +type GrpcAuthenticationDelegator struct { + provider AuthenticationProvider + + validate func(ctx context.Context, provider AuthenticationProvider) (string, error) +} + +func (delegator *GrpcAuthenticationDelegator) GetUnaryInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + _, err := delegator.validate(ctx, delegator.provider) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, err.Error()) + } + // todo: set username to metadata to support authorization + return handler(ctx, req) + } +} + +func (delegator *GrpcAuthenticationDelegator) GetStreamInterceptor() grpc.StreamServerInterceptor { + return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + _, err := delegator.validate(ss.Context(), delegator.provider) + if err != nil { + return status.Errorf(codes.Unauthenticated, err.Error()) + } + // todo: set username to metadata to support authorization + return handler(srv, ss) + } +} + +func NewGrpcAuthenticationDelegator(provider AuthenticationProvider) (*GrpcAuthenticationDelegator, error) { + delegator := &GrpcAuthenticationDelegator{ + provider: provider, + } + switch provider.AcceptParamType() { + case ProviderParamTypeToken: + delegator.validate = validateTokenWithContext + default: + return nil, ErrUnMatchedAuthenticationParamType + } + return delegator, nil +} + +func validateTokenWithContext(ctx context.Context, provider AuthenticationProvider) (string, error) { + meta, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", ErrMetadataFetchFailed + } + peerMeta, ok := peer.FromContext(ctx) + if !ok { + return "", ErrMetadataFetchFailed + } + val := meta.Get(MetadataAuthorizationKey) + if len(val) < 1 { + slog.Debug("Receive empty token from the client", + slog.String("peer", peerMeta.Addr.String())) + return "", ErrEmptyToken + } + token := strings.TrimPrefix(val[0], TokenPrefix) + var userName string + var err error + if userName, err = provider.Authenticate(ctx, token); err != nil { + slog.Debug("Failed to authenticate token", + slog.String("peer", peerMeta.Addr.String()), + slog.String("token", token)) + return "", err + } + return userName, nil +} diff --git a/server/auth/oidc.go b/server/auth/oidc.go new file mode 100644 index 00000000..2e8e5277 --- /dev/null +++ b/server/auth/oidc.go @@ -0,0 +1,159 @@ +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/pkg/errors" + "net/http" + "strings" + "time" +) + +const ( + DefaultUserNameCalm = "sub" + AllowedAudienceDefaultValue = "" +) + +var ( + ErrEmptyIssueURL = errors.New("empty issue URL") + ErrEmptyAllowedAudiences = errors.New("empty allowed audiences") + ErrUnknownIssuer = errors.New("unknown issuer") + ErrUserNameNotFound = errors.New("username not found") + ErrForbiddenAudience = errors.New("forbidden audience") +) + +type OIDCOptions struct { + AllowedIssueURLs string `json:"allowedIssueURLs,omitempty"` + AllowedAudiences string `json:"allowedAudiences,omitempty"` + UserNameClaim string `json:"userNameClaim,omitempty"` +} + +func (op *OIDCOptions) Validate() error { + if op.AllowedIssueURLs == "" { + return ErrEmptyIssueURL + } + if op.AllowedAudiences == "" { + return ErrEmptyAllowedAudiences + } + return nil +} + +func (op *OIDCOptions) withDefault() { + if op.UserNameClaim == "" { + op.UserNameClaim = DefaultUserNameCalm + } +} + +type ProviderWithVerifier struct { + provider *oidc.Provider + verifier *oidc.IDTokenVerifier +} + +type OIDCProvider struct { + userNameClaim string + allowedAudiences map[string]string + + providers map[string]*ProviderWithVerifier +} + +func (p *OIDCProvider) AcceptParamType() string { + return ProviderParamTypeToken +} + +func (p *OIDCProvider) Authenticate(ctx context.Context, param interface{}) (string, error) { + token, ok := param.(string) + if !ok { + return "", ErrUnMatchedAuthenticationParamType + } + tokenParts := strings.Split(token, ".") + if len(tokenParts) != 3 { + return "", ErrMalformedToken + } + payload, err := base64.RawURLEncoding.DecodeString(tokenParts[1]) + if err != nil { + return "", err + } + unsecureJwtPayload := &struct { + Issuer string `json:"iss"` + }{} + if err = json.Unmarshal(payload, unsecureJwtPayload); err != nil { + return "", err + } + issuer := unsecureJwtPayload.Issuer + oidcProvider, exist := p.providers[issuer] + if !exist { + return "", ErrUnknownIssuer + } + idToken, err := oidcProvider.verifier.Verify(ctx, token) + if err != nil { + return "", err + } + rawClaims := map[string]json.RawMessage{} + if err = idToken.Claims(&rawClaims); err != nil { + return "", err + } + rawMessage, ok := rawClaims[p.userNameClaim] + if !ok { + return "", ErrUserNameNotFound + } + var userName string + if err = json.Unmarshal(rawMessage, &userName); err != nil { + return "", err + } + + // any of the client audience in the allowed is passed + audienceAllowed := false + audienceArr := idToken.Audience + for _, audience := range audienceArr { + if _, ok := p.allowedAudiences[audience]; ok { + audienceAllowed = true + } + } + if !audienceAllowed { + return "", ErrForbiddenAudience + } + return userName, nil +} + +func NewOIDCProvider(ctx context.Context, jsonParam string) (AuthenticationProvider, error) { + oidcParams := &OIDCOptions{} + if err := json.Unmarshal([]byte(jsonParam), oidcParams); err != nil { + return nil, err + } + oidcParams.withDefault() + if err := oidcParams.Validate(); err != nil { + return nil, err + } + allowedAudienceMap := map[string]string{} + allowedAudienceArr := strings.Split(oidcParams.AllowedAudiences, ",") + for i := range allowedAudienceArr { + allowedAudience := allowedAudienceArr[i] + allowedAudienceMap[allowedAudience] = AllowedAudienceDefaultValue + } + oidcProvider := &OIDCProvider{ + userNameClaim: oidcParams.UserNameClaim, + allowedAudiences: allowedAudienceMap, + } + + ctx = oidc.ClientContext(ctx, &http.Client{Timeout: 30 * time.Second}) + urlArr := strings.Split(oidcParams.AllowedIssueURLs, ",") + for i := 0; i < len(urlArr); i++ { + issueURL := urlArr[i] + provider, err := oidc.NewProvider(ctx, issueURL) + if err != nil { + return nil, err + } + config := &oidc.Config{ + SkipClientIDCheck: true, + Now: time.Now, + } + verifier := provider.Verifier(config) + oidcProvider.providers[issueURL] = &ProviderWithVerifier{ + provider: provider, + verifier: verifier, + } + } + return oidcProvider, nil +} diff --git a/server/internal_rpc_server.go b/server/internal_rpc_server.go index 6d997441..2c51b64b 100644 --- a/server/internal_rpc_server.go +++ b/server/internal_rpc_server.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "fmt" + "github.com/streamnative/oxia/server/auth" "io" "log/slog" @@ -61,7 +62,7 @@ func newInternalRpcServer(grpcProvider container.GrpcProvider, bindAddress strin proto.RegisterOxiaCoordinationServer(registrar, server) proto.RegisterOxiaLogReplicationServer(registrar, server) grpc_health_v1.RegisterHealthServer(registrar, server.healthServer) - }, tlsConf) + }, tlsConf, &auth.Disabled) if err != nil { return nil, err } diff --git a/server/public_rpc_server.go b/server/public_rpc_server.go index 722c84a8..8b9d8da7 100644 --- a/server/public_rpc_server.go +++ b/server/public_rpc_server.go @@ -17,6 +17,7 @@ package server import ( "context" "crypto/tls" + "github.com/streamnative/oxia/server/auth" "log/slog" "github.com/pkg/errors" @@ -45,7 +46,7 @@ type publicRpcServer struct { } func newPublicRpcServer(provider container.GrpcProvider, bindAddress string, shardsDirector ShardsDirector, assignmentDispatcher ShardAssignmentsDispatcher, - tlsConf *tls.Config) (*publicRpcServer, error) { + tlsConf *tls.Config, options *auth.Options) (*publicRpcServer, error) { server := &publicRpcServer{ shardsDirector: shardsDirector, assignmentDispatcher: assignmentDispatcher, @@ -57,7 +58,7 @@ func newPublicRpcServer(provider container.GrpcProvider, bindAddress string, sha var err error server.grpcServer, err = provider.StartGrpcServer("public", bindAddress, func(registrar grpc.ServiceRegistrar) { proto.RegisterOxiaClientServer(registrar, server) - }, tlsConf) + }, tlsConf, options) if err != nil { return nil, err } diff --git a/server/server.go b/server/server.go index eb34a499..4d705586 100644 --- a/server/server.go +++ b/server/server.go @@ -16,6 +16,7 @@ package server import ( "crypto/tls" + "github.com/streamnative/oxia/server/auth" "log/slog" "time" @@ -35,8 +36,11 @@ type Config struct { ServerTLS *tls.Config InternalServerTLS *tls.Config MetricsServiceAddr string - DataDir string - WalDir string + + AuthOptions auth.Options + + DataDir string + WalDir string WalRetentionTime time.Duration WalSyncData bool @@ -99,7 +103,7 @@ func NewWithGrpcProvider(config Config, provider container.GrpcProvider, replica } s.publicRpcServer, err = newPublicRpcServer(provider, config.PublicServiceAddr, s.shardsDirector, - s.shardAssignmentDispatcher, config.ServerTLS) + s.shardAssignmentDispatcher, config.ServerTLS, &config.AuthOptions) if err != nil { return nil, err } diff --git a/server/standalone.go b/server/standalone.go index a38e0e97..f6edc6c5 100644 --- a/server/standalone.go +++ b/server/standalone.go @@ -16,6 +16,7 @@ package server import ( "context" + "github.com/streamnative/oxia/server/auth" "log/slog" "path/filepath" @@ -85,7 +86,7 @@ func NewStandalone(config StandaloneConfig) (*Standalone, error) { } s.rpc, err = newPublicRpcServer(container.Default, config.PublicServiceAddr, s.shardsDirector, - nil, config.ServerTLS) + nil, config.ServerTLS, &auth.Disabled) if err != nil { return nil, err } From 073b3c69af06d7b948772dea17102edb3dc9281b Mon Sep 17 00:00:00 2001 From: Qiang Zhao Date: Mon, 24 Jun 2024 21:53:29 +0800 Subject: [PATCH 2/8] fix lint --- .golangci.yaml | 2 +- cmd/health/cmd.go | 2 +- cmd/health/cmd_test.go | 2 +- common/client_pool.go | 22 +++++++++++----- common/container/container.go | 3 ++- coordinator/coordinator.go | 2 +- coordinator/coordinator_rpc_server.go | 1 + coordinator/impl/cluster_rebalance.go | 2 +- coordinator/impl/coordinator_e2e_test.go | 18 ++++++------- maelstrom/grpc_provider.go | 4 ++- oxia/async_client_impl.go | 2 +- oxia/auth/authentication.go | 7 +++++ oxia/auth/token.go | 31 +++++++++++++++++++++++ oxia/internal/shard_manager_test.go | 2 +- oxia/options_client.go | 14 ++++++++++ server/auth/authentication.go | 11 ++++---- server/auth/interceptor.go | 5 ++-- server/auth/oidc.go | 9 ++++--- server/internal_rpc_server.go | 3 ++- server/leader_controller.go | 2 +- server/public_rpc_server.go | 3 ++- server/rpc_provider.go | 2 +- server/server.go | 3 ++- server/standalone.go | 3 ++- tests/security/tls/tls_encryption_test.go | 12 ++++----- 25 files changed, 119 insertions(+), 48 deletions(-) create mode 100644 oxia/auth/authentication.go create mode 100644 oxia/auth/token.go diff --git a/.golangci.yaml b/.golangci.yaml index 41895441..1e91338e 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -92,7 +92,7 @@ linters-settings: severity: warning disabled: false arguments: - - 12 + - 15 - name: cyclomatic severity: warning disabled: false diff --git a/cmd/health/cmd.go b/cmd/health/cmd.go index 15798ff0..84d25149 100644 --- a/cmd/health/cmd.go +++ b/cmd/health/cmd.go @@ -63,7 +63,7 @@ func init() { } func exec(*cobra.Command, []string) error { - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) serverAddress := fmt.Sprintf("%s:%d", config.Host, config.Port) diff --git a/cmd/health/cmd_test.go b/cmd/health/cmd_test.go index bcf51ffa..8040a3d3 100644 --- a/cmd/health/cmd_test.go +++ b/cmd/health/cmd_test.go @@ -32,7 +32,7 @@ func TestHealthCmd(t *testing.T) { _health := health.NewServer() server, err := container.Default.StartGrpcServer("health", "localhost:0", func(registrar grpc.ServiceRegistrar) { grpc_health_v1.RegisterHealthServer(registrar, _health) - }, nil) + }, nil, nil) assert.NoError(t, err) defer func() { _ = server.Close() diff --git a/common/client_pool.go b/common/client_pool.go index 822e14e4..4c482af4 100644 --- a/common/client_pool.go +++ b/common/client_pool.go @@ -22,6 +22,8 @@ import ( "sync" "time" + "github.com/streamnative/oxia/oxia/auth" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -48,14 +50,16 @@ type clientPool struct { sync.RWMutex connections map[string]grpc.ClientConnInterface - tls *tls.Config - log *slog.Logger + tls *tls.Config + authentication auth.Authentication + log *slog.Logger } -func NewClientPool(tlsConf *tls.Config) ClientPool { +func NewClientPool(tlsConf *tls.Config, authentication auth.Authentication) ClientPool { return &clientPool{ - connections: make(map[string]grpc.ClientConnInterface), - tls: tlsConf, + connections: make(map[string]grpc.ClientConnInterface), + tls: tlsConf, + authentication: authentication, log: slog.With( slog.String("component", "client-pool"), ), @@ -142,11 +146,15 @@ func (cp *clientPool) getConnection(target string) (grpc.ClientConnInterface, er tcs = credentials.NewTLS(cp.tls) } - cnx, err := grpc.NewClient(target, + options := []grpc.DialOption{ grpc.WithTransportCredentials(tcs), grpc.WithStreamInterceptor(grpcprometheus.StreamClientInterceptor), grpc.WithUnaryInterceptor(grpcprometheus.UnaryClientInterceptor), - ) + } + if cp.authentication != nil { + options = append(options, grpc.WithPerRPCCredentials(cp.authentication)) + } + cnx, err := grpc.NewClient(target, options...) if err != nil { return nil, errors.Wrapf(err, "error connecting to %s", target) } diff --git a/common/container/container.go b/common/container/container.go index 782f3685..b40d775d 100644 --- a/common/container/container.go +++ b/common/container/container.go @@ -17,12 +17,13 @@ package container import ( "context" "crypto/tls" - "github.com/streamnative/oxia/server/auth" "io" "log/slog" "net" "os" + "github.com/streamnative/oxia/server/auth" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" diff --git a/coordinator/coordinator.go b/coordinator/coordinator.go index 8ad6d3c3..9b34e4b3 100644 --- a/coordinator/coordinator.go +++ b/coordinator/coordinator.go @@ -90,7 +90,7 @@ func New(config Config) (*Coordinator, error) { ) s := &Coordinator{ - clientPool: common.NewClientPool(config.PeerTLS), + clientPool: common.NewClientPool(config.PeerTLS, nil), } var metadataProvider impl.MetadataProvider diff --git a/coordinator/coordinator_rpc_server.go b/coordinator/coordinator_rpc_server.go index ed2747be..173e0fd6 100644 --- a/coordinator/coordinator_rpc_server.go +++ b/coordinator/coordinator_rpc_server.go @@ -16,6 +16,7 @@ package coordinator import ( "crypto/tls" + "github.com/streamnative/oxia/server/auth" "google.golang.org/grpc" diff --git a/coordinator/impl/cluster_rebalance.go b/coordinator/impl/cluster_rebalance.go index 76563eaf..38e0b8df 100644 --- a/coordinator/impl/cluster_rebalance.go +++ b/coordinator/impl/cluster_rebalance.go @@ -131,7 +131,7 @@ outer: return res } -func getShardsPerServer(servers []model.ServerAddress, currentStatus *model.ClusterStatus) ( //nolint:revive +func getShardsPerServer(servers []model.ServerAddress, currentStatus *model.ClusterStatus) ( existingServers map[model.ServerAddress]common.Set[int64], deletedServers map[model.ServerAddress]common.Set[int64]) { existingServers = map[model.ServerAddress]common.Set[int64]{} diff --git a/coordinator/impl/coordinator_e2e_test.go b/coordinator/impl/coordinator_e2e_test.go index 46635dc8..b62f6ecc 100644 --- a/coordinator/impl/coordinator_e2e_test.go +++ b/coordinator/impl/coordinator_e2e_test.go @@ -68,7 +68,7 @@ func TestCoordinatorE2E(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) coordinator, err := NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, NewRpcProvider(clientPool)) @@ -106,7 +106,7 @@ func TestCoordinatorE2E_ShardsRanges(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) coordinator, err := NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, NewRpcProvider(clientPool)) assert.NoError(t, err) @@ -159,7 +159,7 @@ func TestCoordinator_LeaderFailover(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) coordinator, err := NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, NewRpcProvider(clientPool)) assert.NoError(t, err) @@ -263,7 +263,7 @@ func TestCoordinator_MultipleNamespaces(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) coordinator, err := NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, NewRpcProvider(clientPool)) assert.NoError(t, err) @@ -354,7 +354,7 @@ func TestCoordinator_DeleteNamespace(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) coordinator, err := NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, NewRpcProvider(clientPool)) assert.NoError(t, err) @@ -436,7 +436,7 @@ func TestCoordinator_DynamicallAddNamespace(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) configChangesCh := make(chan any) configProvider := func() (model.ClusterConfig, error) { @@ -524,7 +524,7 @@ func TestCoordinator_RebalanceCluster(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) mutex := &sync.Mutex{} configProvider := func() (model.ClusterConfig, error) { @@ -622,7 +622,7 @@ func TestCoordinator_AddRemoveNodes(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) configProvider := func() (model.ClusterConfig, error) { return clusterConfig, nil @@ -684,7 +684,7 @@ func TestCoordinator_ShrinkCluster(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3, sa4}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) configProvider := func() (model.ClusterConfig, error) { return clusterConfig, nil diff --git a/maelstrom/grpc_provider.go b/maelstrom/grpc_provider.go index efb069e6..25ecbe0e 100644 --- a/maelstrom/grpc_provider.go +++ b/maelstrom/grpc_provider.go @@ -23,6 +23,8 @@ import ( "os" "sync" + "github.com/streamnative/oxia/server/auth" + "github.com/pkg/errors" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -53,7 +55,7 @@ func newMaelstromGrpcProvider() *maelstromGrpcProvider { } func (m *maelstromGrpcProvider) StartGrpcServer(name, _ string, registerFunc func(grpc.ServiceRegistrar), - _ *tls.Config) (container.GrpcServer, error) { + _ *tls.Config, _ *auth.Options) (container.GrpcServer, error) { slog.Info( "Start Grpc server", slog.String("name", name), diff --git a/oxia/async_client_impl.go b/oxia/async_client_impl.go index 1b04af0f..1f4446ca 100644 --- a/oxia/async_client_impl.go +++ b/oxia/async_client_impl.go @@ -65,7 +65,7 @@ func NewAsyncClient(serviceAddress string, opts ...ClientOption) (AsyncClient, e return nil, err } - clientPool := common.NewClientPool(options.tls) + clientPool := common.NewClientPool(options.tls, options.authentication) shardManager, err := internal.NewShardManager(internal.NewShardStrategy(), clientPool, serviceAddress, options.namespace, options.requestTimeout) diff --git a/oxia/auth/authentication.go b/oxia/auth/authentication.go new file mode 100644 index 00000000..657dac1d --- /dev/null +++ b/oxia/auth/authentication.go @@ -0,0 +1,7 @@ +package auth + +import "google.golang.org/grpc/credentials" + +type Authentication interface { + credentials.PerRPCCredentials +} diff --git a/oxia/auth/token.go b/oxia/auth/token.go new file mode 100644 index 00000000..90e2bd45 --- /dev/null +++ b/oxia/auth/token.go @@ -0,0 +1,31 @@ +package auth + +import ( + "context" +) + +type tokenAuthentication struct { + requireTransportSecurity bool + tokenGetFunc func() string +} + +func (tokenAuth *tokenAuthentication) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { + token := tokenAuth.tokenGetFunc() + return map[string]string{ + "authorization": "Bearer" + " " + token, + }, nil +} + +func (tokenAuth *tokenAuthentication) RequireTransportSecurity() bool { + return tokenAuth.requireTransportSecurity +} + +func NewTokenAuthenticationWithFunc(tokenGetFunc func() string, requireTransportSecurity bool) Authentication { + return &tokenAuthentication{tokenGetFunc: tokenGetFunc, requireTransportSecurity: requireTransportSecurity} +} + +func NewTokenAuthenticationWithToken(token string, requireTransportSecurity bool) Authentication { + return NewTokenAuthenticationWithFunc(func() string { + return token + }, requireTransportSecurity) +} diff --git a/oxia/internal/shard_manager_test.go b/oxia/internal/shard_manager_test.go index 1dc1b4c9..eb980701 100644 --- a/oxia/internal/shard_manager_test.go +++ b/oxia/internal/shard_manager_test.go @@ -38,7 +38,7 @@ func TestWithStandalone(t *testing.T) { standaloneServer, err := server.NewStandalone(server.NewTestConfig(t.TempDir())) assert.NoError(t, err) - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) serviceAddress := fmt.Sprintf("localhost:%d", standaloneServer.RpcPort()) shardManager, err := NewShardManager(&testShardStrategy{}, clientPool, serviceAddress, common.DefaultNamespace, 30*time.Second) assert.NoError(t, err) diff --git a/oxia/options_client.go b/oxia/options_client.go index 9e8ef59e..2ec7e1bb 100644 --- a/oxia/options_client.go +++ b/oxia/options_client.go @@ -18,6 +18,8 @@ import ( "crypto/tls" "time" + "github.com/streamnative/oxia/oxia/auth" + "github.com/google/uuid" "github.com/pkg/errors" "go.opentelemetry.io/otel" @@ -46,6 +48,7 @@ var ( ErrInvalidOptionIdentity = errors.New("Identity must be non-empty") ErrInvalidOptionNamespace = errors.New("Namespace cannot be empty") ErrInvalidOptionTLS = errors.New("Tls cannot be empty") + ErrInvalidOptionAuthentication = errors.New("Authentication cannot be empty") ) // clientOptions contains options for the Oxia client. @@ -60,6 +63,7 @@ type clientOptions struct { sessionTimeout time.Duration identity string tls *tls.Config + authentication auth.Authentication } func defaultIdentity() string { @@ -199,3 +203,13 @@ func WithTLS(tlsConf *tls.Config) ClientOption { return options, nil }) } + +func WithAuthentication(authentication auth.Authentication) ClientOption { + return clientOptionFunc(func(options clientOptions) (clientOptions, error) { + if authentication == nil { + return options, ErrInvalidOptionAuthentication + } + options.authentication = authentication + return options, nil + }) +} diff --git a/server/auth/authentication.go b/server/auth/authentication.go index 966f8c96..927f0654 100644 --- a/server/auth/authentication.go +++ b/server/auth/authentication.go @@ -2,6 +2,7 @@ package auth import ( "context" + "github.com/pkg/errors" ) @@ -12,10 +13,10 @@ const ( ) var ( - ErrUnsupportedProvider = errors.New("Unsupported authentication provider.") - ErrUnMatchedAuthenticationParamType = errors.New("Unmatched authentication parameter type.") - ErrEmptyToken = errors.New("Empty token") - ErrMalformedToken = errors.New("Malformed token") + ErrUnsupportedProvider = errors.New("unsupported authentication provider") + ErrUnMatchedAuthenticationParamType = errors.New("unmatched authentication parameter type") + ErrEmptyToken = errors.New("empty token") + ErrMalformedToken = errors.New("malformed token") ) var Disabled = Options{} @@ -32,7 +33,7 @@ func (op *Options) IsEnabled() bool { // todo: add metrics type AuthenticationProvider interface { AcceptParamType() string - Authenticate(ctx context.Context, param interface{}) (string, error) + Authenticate(ctx context.Context, param any) (string, error) } func NewAuthenticationProvider(ctx context.Context, options Options) (AuthenticationProvider, error) { diff --git a/server/auth/interceptor.go b/server/auth/interceptor.go index 67a9d3a3..d5cb73f4 100644 --- a/server/auth/interceptor.go +++ b/server/auth/interceptor.go @@ -3,13 +3,14 @@ package auth import ( "context" "errors" + "log/slog" + "strings" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - "log/slog" - "strings" ) const ( diff --git a/server/auth/oidc.go b/server/auth/oidc.go index 2e8e5277..17e8a3f4 100644 --- a/server/auth/oidc.go +++ b/server/auth/oidc.go @@ -4,11 +4,12 @@ import ( "context" "encoding/base64" "encoding/json" - "github.com/coreos/go-oidc/v3/oidc" - "github.com/pkg/errors" "net/http" "strings" "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/pkg/errors" ) const ( @@ -58,11 +59,11 @@ type OIDCProvider struct { providers map[string]*ProviderWithVerifier } -func (p *OIDCProvider) AcceptParamType() string { +func (*OIDCProvider) AcceptParamType() string { return ProviderParamTypeToken } -func (p *OIDCProvider) Authenticate(ctx context.Context, param interface{}) (string, error) { +func (p *OIDCProvider) Authenticate(ctx context.Context, param any) (string, error) { token, ok := param.(string) if !ok { return "", ErrUnMatchedAuthenticationParamType diff --git a/server/internal_rpc_server.go b/server/internal_rpc_server.go index 2c51b64b..7d485e04 100644 --- a/server/internal_rpc_server.go +++ b/server/internal_rpc_server.go @@ -18,10 +18,11 @@ import ( "context" "crypto/tls" "fmt" - "github.com/streamnative/oxia/server/auth" "io" "log/slog" + "github.com/streamnative/oxia/server/auth" + "github.com/pkg/errors" "google.golang.org/grpc" "google.golang.org/grpc/codes" diff --git a/server/leader_controller.go b/server/leader_controller.go index 6f507da8..25eae4b2 100644 --- a/server/leader_controller.go +++ b/server/leader_controller.go @@ -198,7 +198,7 @@ func (lc *leaderController) Term() int64 { // // Any existing follow cursors are destroyed as is any state // regarding reconfigurations. -func (lc *leaderController) NewTerm(req *proto.NewTermRequest) (*proto.NewTermResponse, error) { //nolint:revive +func (lc *leaderController) NewTerm(req *proto.NewTermRequest) (*proto.NewTermResponse, error) { lc.Lock() defer lc.Unlock() diff --git a/server/public_rpc_server.go b/server/public_rpc_server.go index 8b9d8da7..94d70ddc 100644 --- a/server/public_rpc_server.go +++ b/server/public_rpc_server.go @@ -17,9 +17,10 @@ package server import ( "context" "crypto/tls" - "github.com/streamnative/oxia/server/auth" "log/slog" + "github.com/streamnative/oxia/server/auth" + "github.com/pkg/errors" "google.golang.org/grpc" "google.golang.org/grpc/status" diff --git a/server/rpc_provider.go b/server/rpc_provider.go index 8a67f6c7..dddbb296 100644 --- a/server/rpc_provider.go +++ b/server/rpc_provider.go @@ -42,7 +42,7 @@ type replicationRpcProvider struct { func NewReplicationRpcProvider(tlsConf *tls.Config) ReplicationRpcProvider { return &replicationRpcProvider{ - pool: common.NewClientPool(tlsConf), + pool: common.NewClientPool(tlsConf, nil), } } diff --git a/server/server.go b/server/server.go index 4d705586..e51bc6d1 100644 --- a/server/server.go +++ b/server/server.go @@ -16,10 +16,11 @@ package server import ( "crypto/tls" - "github.com/streamnative/oxia/server/auth" "log/slog" "time" + "github.com/streamnative/oxia/server/auth" + "go.uber.org/multierr" "google.golang.org/grpc/health" diff --git a/server/standalone.go b/server/standalone.go index f6edc6c5..646a8a80 100644 --- a/server/standalone.go +++ b/server/standalone.go @@ -16,10 +16,11 @@ package server import ( "context" - "github.com/streamnative/oxia/server/auth" "log/slog" "path/filepath" + "github.com/streamnative/oxia/server/auth" + "go.uber.org/multierr" "github.com/streamnative/oxia/common" diff --git a/tests/security/tls/tls_encryption_test.go b/tests/security/tls/tls_encryption_test.go index e2a43d37..121b0667 100644 --- a/tests/security/tls/tls_encryption_test.go +++ b/tests/security/tls/tls_encryption_test.go @@ -132,7 +132,7 @@ func TestClusterHandshakeSuccess(t *testing.T) { tlsConf, err := option.MakeClientTLSConf() assert.NoError(t, err) - clientPool := common.NewClientPool(tlsConf) + clientPool := common.NewClientPool(tlsConf, nil) defer clientPool.Close() coordinator, err := impl.NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, impl.NewRpcProvider(clientPool)) @@ -162,7 +162,7 @@ func TestClientHandshakeFailByNoTlsConfig(t *testing.T) { tlsConf, err := option.MakeClientTLSConf() assert.NoError(t, err) - clientPool := common.NewClientPool(tlsConf) + clientPool := common.NewClientPool(tlsConf, nil) defer clientPool.Close() coordinator, err := impl.NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, impl.NewRpcProvider(clientPool)) @@ -196,7 +196,7 @@ func TestClientHandshakeByAuthFail(t *testing.T) { tlsConf, err := option.MakeClientTLSConf() assert.NoError(t, err) - clientPool := common.NewClientPool(tlsConf) + clientPool := common.NewClientPool(tlsConf, nil) defer clientPool.Close() coordinator, err := impl.NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, impl.NewRpcProvider(clientPool)) @@ -236,7 +236,7 @@ func TestClientHandshakeWithInsecure(t *testing.T) { tlsConf, err := option.MakeClientTLSConf() assert.NoError(t, err) - clientPool := common.NewClientPool(tlsConf) + clientPool := common.NewClientPool(tlsConf, nil) defer clientPool.Close() coordinator, err := impl.NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, impl.NewRpcProvider(clientPool)) @@ -277,7 +277,7 @@ func TestClientHandshakeSuccess(t *testing.T) { tlsConf, err := option.MakeClientTLSConf() assert.NoError(t, err) - clientPool := common.NewClientPool(tlsConf) + clientPool := common.NewClientPool(tlsConf, nil) defer clientPool.Close() coordinator, err := impl.NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, impl.NewRpcProvider(clientPool)) @@ -314,7 +314,7 @@ func TestOnlyEnablePublicTls(t *testing.T) { }}, Servers: []model.ServerAddress{sa1, sa2, sa3}, } - clientPool := common.NewClientPool(nil) + clientPool := common.NewClientPool(nil, nil) defer clientPool.Close() coordinator, err := impl.NewCoordinator(metadataProvider, func() (model.ClusterConfig, error) { return clusterConfig, nil }, nil, impl.NewRpcProvider(clientPool)) From 3a09e9f657cc445a042c991fae202581bad654a9 Mon Sep 17 00:00:00 2001 From: Qiang Zhao Date: Mon, 24 Jun 2024 23:16:15 +0800 Subject: [PATCH 3/8] add test --- go.mod | 3 + go.sum | 9 ++ oxia/internal/shard_manager.go | 3 +- server/auth/oidc.go | 1 + tests/security/auth/auth_oidc_test.go | 186 ++++++++++++++++++++++++++ 5 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 tests/security/auth/auth_oidc_test.go diff --git a/go.mod b/go.mod index 64eaa2df..fbacd13d 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/evanphx/json-patch v5.6.0+incompatible // indirect github.com/getsentry/sentry-go v0.21.0 // indirect + github.com/go-jose/go-jose/v3 v3.0.1 // indirect github.com/go-jose/go-jose/v4 v4.0.1 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -78,6 +79,7 @@ require ( github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.22.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v5 v5.2.0 // indirect github.com/golang/glog v1.2.1 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 // indirect @@ -100,6 +102,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.1 // indirect diff --git a/go.sum b/go.sum index 914eca3d..53d4612a 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,8 @@ github.com/getsentry/sentry-go v0.21.0 h1:c9l5F1nPF30JIppulk4veau90PK6Smu3abgVtV github.com/getsentry/sentry-go v0.21.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWqS6U= github.com/go-jose/go-jose/v4 v4.0.1/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -71,6 +73,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4 github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= +github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4= github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= @@ -80,6 +84,7 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -136,6 +141,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25 h1:9bCMuD3TcnjeqjPT2gSlha4asp8NvgcFRYExCaikCxk= +github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25/go.mod h1:eDjgYHYDJbPLBLsyZ6qRaugP0mX8vePOhZ5id1fdzJw= github.com/onsi/ginkgo/v2 v2.15.0 h1:79HwNRBAZHOEwrczrgSOPy+eFTTlIGELKy5as+ClttY= github.com/onsi/ginkgo/v2 v2.15.0/go.mod h1:HlxMHtYF57y6Dpf+mc5529KKmSq9h2FpCF+/ZkwUxKM= github.com/onsi/gomega v1.31.0 h1:54UJxxj6cPInHS3a35wm6BK/F9nHYueZ1NVujHDrnXE= @@ -198,6 +205,7 @@ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= @@ -229,6 +237,7 @@ go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnw go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= diff --git a/oxia/internal/shard_manager.go b/oxia/internal/shard_manager.go index eff2d695..b1ec64f1 100644 --- a/oxia/internal/shard_manager.go +++ b/oxia/internal/shard_manager.go @@ -241,7 +241,8 @@ func isErrorRetryable(err error) bool { switch status.Code(err) { case common.CodeNamespaceNotFound: return false - + case codes.Unauthenticated: + return false default: return true } diff --git a/server/auth/oidc.go b/server/auth/oidc.go index 17e8a3f4..ea26fa76 100644 --- a/server/auth/oidc.go +++ b/server/auth/oidc.go @@ -136,6 +136,7 @@ func NewOIDCProvider(ctx context.Context, jsonParam string) (AuthenticationProvi oidcProvider := &OIDCProvider{ userNameClaim: oidcParams.UserNameClaim, allowedAudiences: allowedAudienceMap, + providers: make(map[string]*ProviderWithVerifier), } ctx = oidc.ClientContext(ctx, &http.Client{Timeout: 30 * time.Second}) diff --git a/tests/security/auth/auth_oidc_test.go b/tests/security/auth/auth_oidc_test.go new file mode 100644 index 00000000..7bb198b2 --- /dev/null +++ b/tests/security/auth/auth_oidc_test.go @@ -0,0 +1,186 @@ +package auth + +import ( + "context" + "fmt" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/oauth2-proxy/mockoidc" + "github.com/pkg/errors" + "github.com/streamnative/oxia/common" + "github.com/streamnative/oxia/coordinator/impl" + "github.com/streamnative/oxia/coordinator/model" + "github.com/streamnative/oxia/oxia" + clientAuth "github.com/streamnative/oxia/oxia/auth" + "github.com/streamnative/oxia/server" + "github.com/streamnative/oxia/server/auth" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "k8s.io/apimachinery/pkg/util/json" + "testing" + "time" +) + +func newOxiaClusterWithAuth(t *testing.T, issueURL string, audiences string) (address string, closeFunc func()) { + t.Helper() + options := auth.OIDCOptions{ + AllowedIssueURLs: issueURL, + AllowedAudiences: audiences, + } + jsonParams, err := json.Marshal(options) + assert.NoError(t, err) + authParams := auth.Options{ + ProviderName: auth.ProviderOIDC, + ProviderParams: string(jsonParams), + } + s1, err := server.New(server.Config{ + PublicServiceAddr: "localhost:0", + InternalServiceAddr: "localhost:0", + MetricsServiceAddr: "", // Disable metrics to avoid conflict + DataDir: t.TempDir(), + WalDir: t.TempDir(), + NotificationsRetentionTime: 1 * time.Minute, + AuthOptions: authParams, + }) + assert.NoError(t, err) + s1Addr := model.ServerAddress{ + Public: fmt.Sprintf("localhost:%d", s1.PublicPort()), + Internal: fmt.Sprintf("localhost:%d", s1.InternalPort()), + } + s2, err := server.New(server.Config{ + PublicServiceAddr: "localhost:0", + InternalServiceAddr: "localhost:0", + MetricsServiceAddr: "", // Disable metrics to avoid conflict + DataDir: t.TempDir(), + WalDir: t.TempDir(), + NotificationsRetentionTime: 1 * time.Minute, + AuthOptions: authParams, + }) + assert.NoError(t, err) + s2Addr := model.ServerAddress{ + Public: fmt.Sprintf("localhost:%d", s2.PublicPort()), + Internal: fmt.Sprintf("localhost:%d", s2.InternalPort()), + } + s3, err := server.New(server.Config{ + PublicServiceAddr: "localhost:0", + InternalServiceAddr: "localhost:0", + MetricsServiceAddr: "", // Disable metrics to avoid conflict + DataDir: t.TempDir(), + WalDir: t.TempDir(), + NotificationsRetentionTime: 1 * time.Minute, + AuthOptions: authParams, + }) + assert.NoError(t, err) + s3Addr := model.ServerAddress{ + Public: fmt.Sprintf("localhost:%d", s3.PublicPort()), + Internal: fmt.Sprintf("localhost:%d", s3.InternalPort()), + } + + metadataProvider := impl.NewMetadataProviderMemory() + clusterConfig := model.ClusterConfig{ + Namespaces: []model.NamespaceConfig{{ + Name: common.DefaultNamespace, + ReplicationFactor: 3, + InitialShardCount: 1, + }}, + Servers: []model.ServerAddress{s1Addr, s2Addr, s3Addr}, + } + + clientPool := common.NewClientPool(nil, nil) + + coordinator, err := impl.NewCoordinator(metadataProvider, + func() (model.ClusterConfig, error) { return clusterConfig, nil }, + nil, impl.NewRpcProvider(clientPool)) + assert.NoError(t, err) + + return s1Addr.Public, func() { + clientPool.Close() + coordinator.Close() + } +} + +func TestOIDCWithStaticToken(t *testing.T) { + mockOIDC, err := mockoidc.Run() + assert.NoError(t, err) + defer mockOIDC.Shutdown() + + audience := generateRandomStr(t) + audience2 := generateRandomStr(t) + id := generateRandomStr(t) + subject := generateRandomStr(t) + registeredClaims := &jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{audience, audience2}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(1) * time.Hour)), + ID: id, + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: mockOIDC.Issuer(), + NotBefore: jwt.NewNumericDate(time.Time{}), + Subject: subject, + } + signedToken, err := mockOIDC.Keypair.SignJWT(registeredClaims) + assert.NoError(t, err) + + addr, clusterCloseFunc := newOxiaClusterWithAuth(t, mockOIDC.Issuer(), audience) + defer clusterCloseFunc() + + // assert connection failed with empty token + client, err := oxia.NewSyncClient(addr) + assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) + + // assert connection failed with malformed token + client, err = oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken("wrongToken", false))) + assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) + + // assert connection failed with unknown issue + client, err = oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken(illegalToken, false))) + assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) + + cutToken := signedToken[5:] + client, err = oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken(cutToken, false))) + assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) + + // assert connection failed with expired token + expiredToken, err := mockOIDC.Keypair.SignJWT(&jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{audience}, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(1) * time.Second)), + ID: id, + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: mockOIDC.Issuer(), + NotBefore: jwt.NewNumericDate(time.Time{}), + Subject: subject, + }) + assert.NoError(t, err) + time.Sleep(3 * time.Second) + client, err = oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken(expiredToken, false))) + assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) + + // assert connection success with correct token + client, err = oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken(signedToken, false))) + assert.NoError(t, err) + ctx := context.Background() + key := "hi" + payload := []byte("matt") + _, pVersion, err := client.Put(ctx, key, payload) + assert.NoError(t, err) + key, gValue, gVersion, err := client.Get(ctx, key) + assert.Equal(t, pVersion, gVersion) + assert.Equal(t, gValue, payload) + client.Close() +} + +func generateRandomStr(t *testing.T) string { + t.Helper() + random, err := uuid.NewRandom() + assert.NoError(t, err) + return random.String() +} + +const ( + illegalToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6ImRIWFRTQ3lvdXE2RGlXYVF3bFh0TlA1NC1DNzVtdzNJY29Za0VSZmwzZlEiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwOi8vMTI3LjAuMC4xOjYxMzM4L29pZGMiLCJzdWIiOiJjZGJiYWE4NC0xODg0LTQ5NTktOTgwNS02NGRiY2NiMzdlZTIiLCJhdWQiOlsiNDc5YzRmYTktNDZjMC00NjY5LTkyZTktY2QxYmM3Mjc2MWNlIl0sImV4cCI6MTcxOTI0MjM4OSwibmJmIjotNjIxMzU1OTY4MDAsImlhdCI6MTcxOTIzODc4OSwianRpIjoiODY1MTlkOGEtMmNiYy00NTY0LTlmZjMtNTUyZTAxYjQzNzc2In0.netDk-UFqBwlxJZlDc3Any2tSqBHXsLxdorM3MrL171Xql6Mms6KCiNabpWbx--xvPtvlzs3v1O1R5LO3bZbI1VO-efumOpvjDBxe6WRqeezGp1spcJ_s0M90MjF7d6uRDxlOfEmPaB1Oryb8kYlyErrdXM3P1jRN_i2HMdju0tKjEVcqIbuzBs5et3RrLHmcP5yMFB9D9xN4zeTd_Rf7Qyl1JdiA2qD-1KDfeVtGAahyuNiR0-VOncY1VU3sqi-h8cviyB7cn2j4Iuo5D-DIuvrbC-jS51NUSLb_nSD8LjuGoc76n3-_zB2svTFVv-1tiLESASqna4HaI_AyRSDNQ" +) From 092082912e729a2ae12b41a04b9df0f8e9709ec5 Mon Sep 17 00:00:00 2001 From: Qiang Zhao Date: Mon, 24 Jun 2024 23:18:41 +0800 Subject: [PATCH 4/8] fix lint --- coordinator/impl/cluster_rebalance.go | 2 +- server/leader_controller.go | 2 +- tests/security/auth/auth_oidc_test.go | 45 +++++++++++++++------------ 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/coordinator/impl/cluster_rebalance.go b/coordinator/impl/cluster_rebalance.go index 38e0b8df..7cf4f92b 100644 --- a/coordinator/impl/cluster_rebalance.go +++ b/coordinator/impl/cluster_rebalance.go @@ -131,7 +131,7 @@ outer: return res } -func getShardsPerServer(servers []model.ServerAddress, currentStatus *model.ClusterStatus) ( +func getShardsPerServer(servers []model.ServerAddress, currentStatus *model.ClusterStatus) ( existingServers map[model.ServerAddress]common.Set[int64], deletedServers map[model.ServerAddress]common.Set[int64]) { existingServers = map[model.ServerAddress]common.Set[int64]{} diff --git a/server/leader_controller.go b/server/leader_controller.go index 25eae4b2..2a4465b5 100644 --- a/server/leader_controller.go +++ b/server/leader_controller.go @@ -198,7 +198,7 @@ func (lc *leaderController) Term() int64 { // // Any existing follow cursors are destroyed as is any state // regarding reconfigurations. -func (lc *leaderController) NewTerm(req *proto.NewTermRequest) (*proto.NewTermResponse, error) { +func (lc *leaderController) NewTerm(req *proto.NewTermRequest) (*proto.NewTermResponse, error) { lc.Lock() defer lc.Unlock() diff --git a/tests/security/auth/auth_oidc_test.go b/tests/security/auth/auth_oidc_test.go index 7bb198b2..fbc5d3a6 100644 --- a/tests/security/auth/auth_oidc_test.go +++ b/tests/security/auth/auth_oidc_test.go @@ -3,23 +3,25 @@ package auth import ( "context" "fmt" + "testing" + "time" + "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/oauth2-proxy/mockoidc" "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "k8s.io/apimachinery/pkg/util/json" + "github.com/streamnative/oxia/common" "github.com/streamnative/oxia/coordinator/impl" "github.com/streamnative/oxia/coordinator/model" "github.com/streamnative/oxia/oxia" - clientAuth "github.com/streamnative/oxia/oxia/auth" + clientauth "github.com/streamnative/oxia/oxia/auth" "github.com/streamnative/oxia/server" "github.com/streamnative/oxia/server/auth" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "k8s.io/apimachinery/pkg/util/json" - "testing" - "time" ) func newOxiaClusterWithAuth(t *testing.T, issueURL string, audiences string) (address string, closeFunc func()) { @@ -103,7 +105,9 @@ func newOxiaClusterWithAuth(t *testing.T, issueURL string, audiences string) (ad func TestOIDCWithStaticToken(t *testing.T) { mockOIDC, err := mockoidc.Run() assert.NoError(t, err) - defer mockOIDC.Shutdown() + defer func(mockOIDC *mockoidc.MockOIDC) { + _ = mockOIDC.Shutdown() + }(mockOIDC) audience := generateRandomStr(t) audience2 := generateRandomStr(t) @@ -125,22 +129,22 @@ func TestOIDCWithStaticToken(t *testing.T) { defer clusterCloseFunc() // assert connection failed with empty token - client, err := oxia.NewSyncClient(addr) + _, err = oxia.NewSyncClient(addr) assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) // assert connection failed with malformed token - client, err = oxia.NewSyncClient(addr, - oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken("wrongToken", false))) + _, err = oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientauth.NewTokenAuthenticationWithToken("wrongToken", false))) assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) // assert connection failed with unknown issue - client, err = oxia.NewSyncClient(addr, - oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken(illegalToken, false))) + _, err = oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientauth.NewTokenAuthenticationWithToken(illegalToken, false))) assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) cutToken := signedToken[5:] - client, err = oxia.NewSyncClient(addr, - oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken(cutToken, false))) + _, err = oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientauth.NewTokenAuthenticationWithToken(cutToken, false))) assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) // assert connection failed with expired token @@ -155,20 +159,21 @@ func TestOIDCWithStaticToken(t *testing.T) { }) assert.NoError(t, err) time.Sleep(3 * time.Second) - client, err = oxia.NewSyncClient(addr, - oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken(expiredToken, false))) + _, err = oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientauth.NewTokenAuthenticationWithToken(expiredToken, false))) assert.Equal(t, codes.Unauthenticated, status.Code(errors.Unwrap(err))) // assert connection success with correct token - client, err = oxia.NewSyncClient(addr, - oxia.WithAuthentication(clientAuth.NewTokenAuthenticationWithToken(signedToken, false))) + client, err := oxia.NewSyncClient(addr, + oxia.WithAuthentication(clientauth.NewTokenAuthenticationWithToken(signedToken, false))) assert.NoError(t, err) ctx := context.Background() key := "hi" payload := []byte("matt") _, pVersion, err := client.Put(ctx, key, payload) assert.NoError(t, err) - key, gValue, gVersion, err := client.Get(ctx, key) + _, gValue, gVersion, err := client.Get(ctx, key) + assert.NoError(t, err) assert.Equal(t, pVersion, gVersion) assert.Equal(t, gValue, payload) client.Close() From 202a67dbe71ff7ed09c9de5767a7a1c5e64643fc Mon Sep 17 00:00:00 2001 From: Qiang Zhao Date: Mon, 24 Jun 2024 23:33:25 +0800 Subject: [PATCH 5/8] fix license --- oxia/auth/authentication.go | 14 ++++++++++++++ oxia/auth/token.go | 14 ++++++++++++++ server/auth/authentication.go | 14 ++++++++++++++ server/auth/interceptor.go | 14 ++++++++++++++ server/auth/oidc.go | 14 ++++++++++++++ tests/security/auth/auth_oidc_test.go | 14 ++++++++++++++ 6 files changed, 84 insertions(+) diff --git a/oxia/auth/authentication.go b/oxia/auth/authentication.go index 657dac1d..f7568079 100644 --- a/oxia/auth/authentication.go +++ b/oxia/auth/authentication.go @@ -1,3 +1,17 @@ +// Copyright 2024 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package auth import "google.golang.org/grpc/credentials" diff --git a/oxia/auth/token.go b/oxia/auth/token.go index 90e2bd45..bd992d96 100644 --- a/oxia/auth/token.go +++ b/oxia/auth/token.go @@ -1,3 +1,17 @@ +// Copyright 2024 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package auth import ( diff --git a/server/auth/authentication.go b/server/auth/authentication.go index 927f0654..74cf8456 100644 --- a/server/auth/authentication.go +++ b/server/auth/authentication.go @@ -1,3 +1,17 @@ +// Copyright 2024 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package auth import ( diff --git a/server/auth/interceptor.go b/server/auth/interceptor.go index d5cb73f4..5cdb5604 100644 --- a/server/auth/interceptor.go +++ b/server/auth/interceptor.go @@ -1,3 +1,17 @@ +// Copyright 2024 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package auth import ( diff --git a/server/auth/oidc.go b/server/auth/oidc.go index ea26fa76..bb0c1452 100644 --- a/server/auth/oidc.go +++ b/server/auth/oidc.go @@ -1,3 +1,17 @@ +// Copyright 2024 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package auth import ( diff --git a/tests/security/auth/auth_oidc_test.go b/tests/security/auth/auth_oidc_test.go index fbc5d3a6..c8664d54 100644 --- a/tests/security/auth/auth_oidc_test.go +++ b/tests/security/auth/auth_oidc_test.go @@ -1,3 +1,17 @@ +// Copyright 2024 StreamNative, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package auth import ( From e663b2f3b146217155ea118f5aeb6f8facaa24dd Mon Sep 17 00:00:00 2001 From: Qiang Zhao Date: Tue, 25 Jun 2024 06:48:22 +0800 Subject: [PATCH 6/8] fix npe --- cmd/health/cmd_test.go | 3 ++- server/auth/authentication.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cmd/health/cmd_test.go b/cmd/health/cmd_test.go index 8040a3d3..48424d64 100644 --- a/cmd/health/cmd_test.go +++ b/cmd/health/cmd_test.go @@ -16,6 +16,7 @@ package health import ( "fmt" + "github.com/streamnative/oxia/server/auth" "testing" "github.com/stretchr/testify/assert" @@ -32,7 +33,7 @@ func TestHealthCmd(t *testing.T) { _health := health.NewServer() server, err := container.Default.StartGrpcServer("health", "localhost:0", func(registrar grpc.ServiceRegistrar) { grpc_health_v1.RegisterHealthServer(registrar, _health) - }, nil, nil) + }, nil, &auth.Options{}) assert.NoError(t, err) defer func() { _ = server.Close() diff --git a/server/auth/authentication.go b/server/auth/authentication.go index 74cf8456..06d7b8db 100644 --- a/server/auth/authentication.go +++ b/server/auth/authentication.go @@ -41,7 +41,7 @@ type Options struct { } func (op *Options) IsEnabled() bool { - return op.ProviderName != "" + return op != nil && op.ProviderName != "" } // todo: add metrics From 6c691343c061618e553914247b59f16ce6fd620f Mon Sep 17 00:00:00 2001 From: Qiang Zhao Date: Tue, 25 Jun 2024 07:38:39 +0800 Subject: [PATCH 7/8] cleanup resources --- tests/security/auth/auth_oidc_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/security/auth/auth_oidc_test.go b/tests/security/auth/auth_oidc_test.go index c8664d54..18fe210e 100644 --- a/tests/security/auth/auth_oidc_test.go +++ b/tests/security/auth/auth_oidc_test.go @@ -111,6 +111,10 @@ func newOxiaClusterWithAuth(t *testing.T, issueURL string, audiences string) (ad assert.NoError(t, err) return s1Addr.Public, func() { + s1.Close() + s1.Close() + s3.Close() + clientPool.Close() coordinator.Close() } From 6772bf7924ff9f5614e7be0c257c5a9d89e3dd6b Mon Sep 17 00:00:00 2001 From: Qiang Zhao Date: Tue, 25 Jun 2024 07:41:51 +0800 Subject: [PATCH 8/8] fix wrong close --- tests/security/auth/auth_oidc_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/security/auth/auth_oidc_test.go b/tests/security/auth/auth_oidc_test.go index 18fe210e..734ceac9 100644 --- a/tests/security/auth/auth_oidc_test.go +++ b/tests/security/auth/auth_oidc_test.go @@ -112,7 +112,7 @@ func newOxiaClusterWithAuth(t *testing.T, issueURL string, audiences string) (ad return s1Addr.Public, func() { s1.Close() - s1.Close() + s2.Close() s3.Close() clientPool.Close()