From 206a7a348de5bccfe61cf2138e4e45a8666d2d26 Mon Sep 17 00:00:00 2001 From: Dmitry Verkhoturov Date: Thu, 4 Apr 2024 10:44:42 +0200 Subject: [PATCH 1/3] add v2 directory --- v2/auth.go | 421 ++++++++++++++++++ v2/auth_test.go | 596 +++++++++++++++++++++++++ v2/avatar/avatar.go | 238 ++++++++++ v2/avatar/avatar_test.go | 280 ++++++++++++ v2/avatar/bolt.go | 141 ++++++ v2/avatar/bolt_test.go | 103 +++++ v2/avatar/gridfs.go | 157 +++++++ v2/avatar/gridfs_test.go | 110 +++++ v2/avatar/localfs.go | 124 ++++++ v2/avatar/localfs_test.go | 177 ++++++++ v2/avatar/noop.go | 35 ++ v2/avatar/noop_test.go | 89 ++++ v2/avatar/store.go | 133 ++++++ v2/avatar/store_test.go | 122 ++++++ v2/avatar/testdata/circles.jpg | Bin 0 -> 23983 bytes v2/avatar/testdata/circles.png | Bin 0 -> 11392 bytes v2/go.mod | 49 +++ v2/go.sum | 187 ++++++++ v2/logger/interface.go | 22 + v2/logger/interface_test.go | 43 ++ v2/middleware/auth.go | 262 +++++++++++ v2/middleware/auth_test.go | 571 ++++++++++++++++++++++++ v2/middleware/user_updater.go | 38 ++ v2/middleware/user_updater_test.go | 71 +++ v2/provider/apple.go | 540 +++++++++++++++++++++++ v2/provider/apple_pubkeys.go | 178 ++++++++ v2/provider/apple_pubkeys_test.go | 226 ++++++++++ v2/provider/apple_test.go | 681 +++++++++++++++++++++++++++++ v2/provider/custom_server.go | 359 +++++++++++++++ v2/provider/custom_server_test.go | 234 ++++++++++ v2/provider/dev_provider.go | 318 ++++++++++++++ v2/provider/dev_provider_test.go | 111 +++++ v2/provider/direct.go | 193 ++++++++ v2/provider/direct_test.go | 277 ++++++++++++ v2/provider/oauth1.go | 195 +++++++++ v2/provider/oauth1_test.go | 293 +++++++++++++ v2/provider/oauth2.go | 254 +++++++++++ v2/provider/oauth2_test.go | 371 ++++++++++++++++ v2/provider/providers.go | 267 +++++++++++ v2/provider/providers_test.go | 210 +++++++++ v2/provider/sender/email.go | 90 ++++ v2/provider/sender/email_test.go | 70 +++ v2/provider/service.go | 95 ++++ v2/provider/service_test.go | 98 +++++ v2/provider/telegram.go | 484 ++++++++++++++++++++ v2/provider/telegram_moq_test.go | 225 ++++++++++ v2/provider/telegram_test.go | 585 +++++++++++++++++++++++++ v2/provider/verify.go | 219 ++++++++++ v2/provider/verify_test.go | 329 ++++++++++++++ v2/token/jwt.go | 405 +++++++++++++++++ v2/token/jwt_test.go | 639 +++++++++++++++++++++++++++ v2/token/user.go | 169 +++++++ v2/token/user_test.go | 129 ++++++ 53 files changed, 12213 insertions(+) create mode 100644 v2/auth.go create mode 100644 v2/auth_test.go create mode 100644 v2/avatar/avatar.go create mode 100644 v2/avatar/avatar_test.go create mode 100644 v2/avatar/bolt.go create mode 100644 v2/avatar/bolt_test.go create mode 100644 v2/avatar/gridfs.go create mode 100644 v2/avatar/gridfs_test.go create mode 100644 v2/avatar/localfs.go create mode 100644 v2/avatar/localfs_test.go create mode 100644 v2/avatar/noop.go create mode 100644 v2/avatar/noop_test.go create mode 100644 v2/avatar/store.go create mode 100644 v2/avatar/store_test.go create mode 100644 v2/avatar/testdata/circles.jpg create mode 100644 v2/avatar/testdata/circles.png create mode 100644 v2/go.mod create mode 100644 v2/go.sum create mode 100644 v2/logger/interface.go create mode 100644 v2/logger/interface_test.go create mode 100644 v2/middleware/auth.go create mode 100644 v2/middleware/auth_test.go create mode 100644 v2/middleware/user_updater.go create mode 100644 v2/middleware/user_updater_test.go create mode 100644 v2/provider/apple.go create mode 100644 v2/provider/apple_pubkeys.go create mode 100644 v2/provider/apple_pubkeys_test.go create mode 100644 v2/provider/apple_test.go create mode 100644 v2/provider/custom_server.go create mode 100644 v2/provider/custom_server_test.go create mode 100644 v2/provider/dev_provider.go create mode 100644 v2/provider/dev_provider_test.go create mode 100644 v2/provider/direct.go create mode 100644 v2/provider/direct_test.go create mode 100644 v2/provider/oauth1.go create mode 100644 v2/provider/oauth1_test.go create mode 100644 v2/provider/oauth2.go create mode 100644 v2/provider/oauth2_test.go create mode 100644 v2/provider/providers.go create mode 100644 v2/provider/providers_test.go create mode 100644 v2/provider/sender/email.go create mode 100644 v2/provider/sender/email_test.go create mode 100644 v2/provider/service.go create mode 100644 v2/provider/service_test.go create mode 100644 v2/provider/telegram.go create mode 100644 v2/provider/telegram_moq_test.go create mode 100644 v2/provider/telegram_test.go create mode 100644 v2/provider/verify.go create mode 100644 v2/provider/verify_test.go create mode 100644 v2/token/jwt.go create mode 100644 v2/token/jwt_test.go create mode 100644 v2/token/user.go create mode 100644 v2/token/user_test.go diff --git a/v2/auth.go b/v2/auth.go new file mode 100644 index 00000000..91faa7e4 --- /dev/null +++ b/v2/auth.go @@ -0,0 +1,421 @@ +// Package auth provides "social login" with Github, Google, Facebook, Microsoft, Yandex and Battle.net as well as custom auth providers. +package auth + +import ( + "fmt" + "net/http" + "strings" + "time" + + "github.com/go-pkgz/rest" + + "github.com/go-pkgz/auth/avatar" + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/middleware" + "github.com/go-pkgz/auth/provider" + "github.com/go-pkgz/auth/token" +) + +// Client is a type of auth client +type Client struct { + Cid string + Csecret string +} + +// Service provides higher level wrapper allowing to construct everything and get back token middleware +type Service struct { + logger logger.L + opts Opts + jwtService *token.Service + providers []provider.Service + authMiddleware middleware.Authenticator + avatarProxy *avatar.Proxy + issuer string + useGravatar bool +} + +// Opts is a full set of all parameters to initialize Service +type Opts struct { + SecretReader token.Secret // reader returns secret for given site id (aud), required + ClaimsUpd token.ClaimsUpdater // updater for jwt to add/modify values stored in the token + SecureCookies bool // makes jwt cookie secure + TokenDuration time.Duration // token's TTL, refreshed automatically + CookieDuration time.Duration // cookie's TTL. This cookie stores JWT token + + DisableXSRF bool // disable XSRF protection, useful for testing/debugging + DisableIAT bool // disable IssuedAt claim + + // optional (custom) names for cookies and headers + JWTCookieName string // default "JWT" + JWTCookieDomain string // default empty + JWTHeaderKey string // default "X-JWT" + XSRFCookieName string // default "XSRF-TOKEN" + XSRFHeaderKey string // default "X-XSRF-TOKEN" + JWTQuery string // default "token" + SendJWTHeader bool // if enabled send JWT as a header instead of cookie + SameSiteCookie http.SameSite // limit cross-origin requests with SameSite cookie attribute + + Issuer string // optional value for iss claim, usually the application name, default "go-pkgz/auth" + + URL string // root url for the rest service, i.e. http://blah.example.com, required + Validator token.Validator // validator allows to reject some valid tokens with user-defined logic + + AvatarStore avatar.Store // store to save/load avatars, required (use avatar.NoOp to disable avatars support) + AvatarResizeLimit int // resize avatar's limit in pixels + AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar` + UseGravatar bool // for email based auth (verified provider) use gravatar service + + AdminPasswd string // if presented, allows basic auth with user admin and given password + BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored + AudienceReader token.Audience // list of allowed aud values, default (empty) allows any + AudSecrets bool // allow multiple secrets (secret per aud) + Logger logger.L // logger interface, default is no logging at all + RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens +} + +// NewService initializes everything +func NewService(opts Opts) (res *Service) { + + res = &Service{ + opts: opts, + logger: opts.Logger, + authMiddleware: middleware.Authenticator{ + Validator: opts.Validator, + AdminPasswd: opts.AdminPasswd, + BasicAuthChecker: opts.BasicAuthChecker, + RefreshCache: opts.RefreshCache, + }, + issuer: opts.Issuer, + useGravatar: opts.UseGravatar, + } + + if opts.Issuer == "" { + res.issuer = "go-pkgz/auth" + } + + if opts.Logger == nil { + res.logger = logger.NoOp + } + + jwtService := token.NewService(token.Opts{ + SecretReader: opts.SecretReader, + ClaimsUpd: opts.ClaimsUpd, + SecureCookies: opts.SecureCookies, + TokenDuration: opts.TokenDuration, + CookieDuration: opts.CookieDuration, + DisableXSRF: opts.DisableXSRF, + DisableIAT: opts.DisableIAT, + JWTCookieName: opts.JWTCookieName, + JWTCookieDomain: opts.JWTCookieDomain, + JWTHeaderKey: opts.JWTHeaderKey, + XSRFCookieName: opts.XSRFCookieName, + XSRFHeaderKey: opts.XSRFHeaderKey, + SendJWTHeader: opts.SendJWTHeader, + JWTQuery: opts.JWTQuery, + Issuer: res.issuer, + AudienceReader: opts.AudienceReader, + AudSecrets: opts.AudSecrets, + SameSite: opts.SameSiteCookie, + }) + + if opts.SecretReader == nil { + jwtService.SecretReader = token.SecretFunc(func(string) (string, error) { + return "", fmt.Errorf("secrets reader not available") + }) + res.logger.Logf("[WARN] no secret reader defined") + } + + res.jwtService = jwtService + res.authMiddleware.JWTService = jwtService + res.authMiddleware.L = res.logger + + if opts.AvatarStore != nil { + res.avatarProxy = &avatar.Proxy{ + Store: opts.AvatarStore, + URL: opts.URL, + RoutePath: opts.AvatarRoutePath, + ResizeLimit: opts.AvatarResizeLimit, + L: res.logger, + } + if res.avatarProxy.RoutePath == "" { + res.avatarProxy.RoutePath = "/avatar" + } + } + + return res +} + +// Handlers gets http.Handler for all providers and avatars +func (s *Service) Handlers() (authHandler, avatarHandler http.Handler) { + + ah := func(w http.ResponseWriter, r *http.Request) { + elems := strings.Split(r.URL.Path, "/") + if len(elems) < 2 { + w.WriteHeader(http.StatusBadRequest) + return + } + + // list all providers + if elems[len(elems)-1] == "list" { + list := []string{} + for _, p := range s.providers { + list = append(list, p.Name()) + } + rest.RenderJSON(w, list) + return + } + + // allow logout without specifying provider + if elems[len(elems)-1] == "logout" { + if len(s.providers) == 0 { + w.WriteHeader(http.StatusBadRequest) + rest.RenderJSON(w, rest.JSON{"error": "providers not defined"}) + return + } + s.providers[0].Handler(w, r) + return + } + + // show user info + if elems[len(elems)-1] == "user" { + claims, _, err := s.jwtService.Get(r) + if err != nil || claims.User == nil { + w.WriteHeader(http.StatusUnauthorized) + msg := "user is nil" + if err != nil { + msg = err.Error() + } + rest.RenderJSON(w, rest.JSON{"error": msg}) + return + } + rest.RenderJSON(w, claims.User) + return + } + + // status of logged-in user + if elems[len(elems)-1] == "status" { + claims, _, err := s.jwtService.Get(r) + if err != nil || claims.User == nil { + rest.RenderJSON(w, rest.JSON{"status": "not logged in"}) + return + } + rest.RenderJSON(w, rest.JSON{"status": "logged in", "user": claims.User.Name}) + return + } + + // regular auth handlers + provName := elems[len(elems)-2] + p, err := s.Provider(provName) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + rest.RenderJSON(w, rest.JSON{"error": fmt.Sprintf("provider %s not supported", provName)}) + return + } + p.Handler(w, r) + } + + return http.HandlerFunc(ah), http.HandlerFunc(s.avatarProxy.Handler) +} + +// Middleware returns auth middleware +func (s *Service) Middleware() middleware.Authenticator { + return s.authMiddleware +} + +// AddProviderWithUserAttributes adds provider with user attributes mapping +func (s *Service) AddProviderWithUserAttributes(name, cid, csecret string, userAttributes provider.UserAttributes) { + p := provider.Params{ + URL: s.opts.URL, + JwtService: s.jwtService, + Issuer: s.issuer, + AvatarSaver: s.avatarProxy, + Cid: cid, + Csecret: csecret, + L: s.logger, + UserAttributes: userAttributes, + } + s.addProvider(name, p) +} + +func (s *Service) addProvider(name string, p provider.Params) { + switch strings.ToLower(name) { + case "github": + s.providers = append(s.providers, provider.NewService(provider.NewGithub(p))) + case "google": + s.providers = append(s.providers, provider.NewService(provider.NewGoogle(p))) + case "facebook": + s.providers = append(s.providers, provider.NewService(provider.NewFacebook(p))) + case "yandex": + s.providers = append(s.providers, provider.NewService(provider.NewYandex(p))) + case "battlenet": + s.providers = append(s.providers, provider.NewService(provider.NewBattlenet(p))) + case "microsoft": + s.providers = append(s.providers, provider.NewService(provider.NewMicrosoft(p))) + case "twitter": + s.providers = append(s.providers, provider.NewService(provider.NewTwitter(p))) + case "patreon": + s.providers = append(s.providers, provider.NewService(provider.NewPatreon(p))) + case "dev": + s.providers = append(s.providers, provider.NewService(provider.NewDev(p))) + default: + return + } + + s.authMiddleware.Providers = s.providers +} + +// AddProvider adds provider for given name +func (s *Service) AddProvider(name, cid, csecret string) { + + p := provider.Params{ + URL: s.opts.URL, + JwtService: s.jwtService, + Issuer: s.issuer, + AvatarSaver: s.avatarProxy, + Cid: cid, + Csecret: csecret, + L: s.logger, + UserAttributes: map[string]string{}, + } + + s.addProvider(name, p) +} + +// AddDevProvider with a custom host and port +func (s *Service) AddDevProvider(host string, port int) { + p := provider.Params{ + URL: s.opts.URL, + JwtService: s.jwtService, + Issuer: s.issuer, + AvatarSaver: s.avatarProxy, + L: s.logger, + Port: port, + Host: host, + } + s.providers = append(s.providers, provider.NewService(provider.NewDev(p))) +} + +// AddAppleProvider allow SignIn with Apple ID +func (s *Service) AddAppleProvider(appleConfig provider.AppleConfig, privKeyLoader provider.PrivateKeyLoaderInterface) error { + p := provider.Params{ + URL: s.opts.URL, + JwtService: s.jwtService, + Issuer: s.issuer, + AvatarSaver: s.avatarProxy, + L: s.logger, + } + + // Error checking at create need for catch one when apple private key init + appleProvider, err := provider.NewApple(p, appleConfig, privKeyLoader) + if err != nil { + return fmt.Errorf("an AppleProvider creating failed: %w", err) + } + + s.providers = append(s.providers, provider.NewService(appleProvider)) + return nil +} + +// AddCustomProvider adds custom provider (e.g. https://gopkg.in/oauth2.v3) +func (s *Service) AddCustomProvider(name string, client Client, copts provider.CustomHandlerOpt) { + p := provider.Params{ + URL: s.opts.URL, + JwtService: s.jwtService, + Issuer: s.issuer, + AvatarSaver: s.avatarProxy, + Cid: client.Cid, + Csecret: client.Csecret, + L: s.logger, + } + + s.providers = append(s.providers, provider.NewService(provider.NewCustom(name, p, copts))) + s.authMiddleware.Providers = s.providers +} + +// AddDirectProvider adds provider with direct check against data store +// it doesn't do any handshake and uses provided credChecker to verify user and password from the request +func (s *Service) AddDirectProvider(name string, credChecker provider.CredChecker) { + dh := provider.DirectHandler{ + L: s.logger, + ProviderName: name, + Issuer: s.issuer, + TokenService: s.jwtService, + CredChecker: credChecker, + AvatarSaver: s.avatarProxy, + } + s.providers = append(s.providers, provider.NewService(dh)) + s.authMiddleware.Providers = s.providers +} + +// AddDirectProviderWithUserIDFunc adds provider with direct check against data store and sets custom UserIDFunc allows +// to modify user's ID on the client side. +// it doesn't do any handshake and uses provided credChecker to verify user and password from the request +func (s *Service) AddDirectProviderWithUserIDFunc(name string, credChecker provider.CredChecker, ufn provider.UserIDFunc) { + dh := provider.DirectHandler{ + L: s.logger, + ProviderName: name, + Issuer: s.issuer, + TokenService: s.jwtService, + CredChecker: credChecker, + AvatarSaver: s.avatarProxy, + UserIDFunc: ufn, + } + s.providers = append(s.providers, provider.NewService(dh)) + s.authMiddleware.Providers = s.providers +} + +// AddVerifProvider adds provider user's verification sent by sender +func (s *Service) AddVerifProvider(name, msgTmpl string, sender provider.Sender) { + dh := provider.VerifyHandler{ + L: s.logger, + ProviderName: name, + Issuer: s.issuer, + TokenService: s.jwtService, + AvatarSaver: s.avatarProxy, + Sender: sender, + Template: msgTmpl, + UseGravatar: s.useGravatar, + } + s.providers = append(s.providers, provider.NewService(dh)) + s.authMiddleware.Providers = s.providers +} + +// AddCustomHandler adds user-defined self-implemented handler of auth provider +func (s *Service) AddCustomHandler(handler provider.Provider) { + s.providers = append(s.providers, provider.NewService(handler)) + s.authMiddleware.Providers = s.providers +} + +// DevAuth makes dev oauth2 server, for testing and development only! +func (s *Service) DevAuth() (*provider.DevAuthServer, error) { + p, err := s.Provider("dev") // peak dev provider + if err != nil { + return nil, fmt.Errorf("dev provider not registered: %w", err) + } + // make and start dev auth server + return &provider.DevAuthServer{Provider: p.Provider.(provider.Oauth2Handler), L: s.logger}, nil +} + +// Provider gets provider by name +func (s *Service) Provider(name string) (provider.Service, error) { + for _, p := range s.providers { + if p.Name() == name { + return p, nil + } + } + return provider.Service{}, fmt.Errorf("provider %s not found", name) +} + +// Providers gets all registered providers +func (s *Service) Providers() []provider.Service { + return s.providers +} + +// TokenService returns token.Service +func (s *Service) TokenService() *token.Service { + return s.jwtService +} + +// AvatarProxy returns stored in service +func (s *Service) AvatarProxy() *avatar.Proxy { + return s.avatarProxy +} diff --git a/v2/auth_test.go b/v2/auth_test.go new file mode 100644 index 00000000..15b9efbe --- /dev/null +++ b/v2/auth_test.go @@ -0,0 +1,596 @@ +package auth + +import ( + "context" + "encoding/json" + "io" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/avatar" + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/provider" + "github.com/go-pkgz/auth/token" +) + +func TestNewService(t *testing.T) { + + options := Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24, + Issuer: "my-test-app", + URL: "http://127.0.0.1:8089", + AvatarStore: avatar.NewLocalFS("/tmp"), + Logger: logger.Std, + } + + svc := NewService(options) + assert.NotNil(t, svc) + assert.NotNil(t, svc.TokenService()) + assert.NotNil(t, svc.AvatarProxy()) +} + +func TestProvider(t *testing.T) { + options := Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + URL: "http://127.0.0.1:8089", + Logger: logger.Std, + } + svc := NewService(options) + + _, err := svc.Provider("some provider") + assert.EqualError(t, err, "provider some provider not found") + + svc.AddProvider("dev", "cid", "csecret") + svc.AddProvider("github", "cid", "csecret") + svc.AddProvider("google", "cid", "csecret") + svc.AddProvider("facebook", "cid", "csecret") + svc.AddProvider("yandex", "cid", "csecret") + svc.AddProvider("microsoft", "cid", "csecret") + svc.AddProvider("battlenet", "cid", "csecret") + svc.AddProvider("patreon", "cid", "csecret") + svc.AddProvider("bad", "cid", "csecret") + + c := customHandler{} + svc.AddCustomHandler(c) + + p, err := svc.Provider("dev") + assert.NoError(t, err) + op := p.Provider.(provider.Oauth2Handler) + assert.Equal(t, "dev", op.Name()) + assert.Equal(t, "cid", op.Cid) + assert.Equal(t, "csecret", op.Csecret) + assert.Equal(t, "go-pkgz/auth", op.Issuer) + + p, err = svc.Provider("github") + assert.NoError(t, err) + op = p.Provider.(provider.Oauth2Handler) + assert.Equal(t, "github", op.Name()) + + pp := svc.Providers() + assert.Equal(t, 9, len(pp)) + + ch, err := svc.Provider("telegramBotMySiteCom") + assert.NoError(t, err) + chp := ch.Provider + assert.Equal(t, "telegramBotMySiteCom", chp.Name()) +} + +func TestService_AddAppleProvider(t *testing.T) { + + options := Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + URL: "http://127.0.0.1:8089", + Logger: logger.Std, + } + svc := NewService(options) + + testValidKey := `-----BEGIN PRIVATE KEY----- +MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgTxaHXzyuM85Znw7y +SJ9XeeC8gqcpE/VLhZHGsnPPiPagCgYIKoZIzj0DAQehRANCAATnwlOv7I6eC3Ec +/+GeYXT+hbcmhEVveDqLmNcHiXCR9XxJZXtpMRlcRfY8eaJpUdig27dfsbvpnfX5 +Ivx5tHkv +-----END PRIVATE KEY-----` + testPrivKeyFileName := "privKeyTest.tmp" + + dir, err := os.MkdirTemp(os.TempDir(), testPrivKeyFileName) + assert.NoError(t, err) + assert.NotNil(t, dir) + if err != nil { + require.NoError(t, err) + return + } + + defer func() { + err = os.RemoveAll(dir) + require.NoError(t, err) + }() + + filePath := filepath.Join(dir, testPrivKeyFileName) + + if err = os.WriteFile(filePath, []byte(testValidKey), 0o600); err != nil { + require.NoError(t, err) + return + } + assert.NoError(t, err) + + appleCfg := provider.AppleConfig{ + ClientID: "111222", + TeamID: "3334445556", + KeyID: "0011002200", + } + + err = svc.AddAppleProvider(appleCfg, provider.LoadApplePrivateKeyFromFile(filePath)) + require.NoError(t, err) + p, err := svc.Provider("apple") + assert.NoError(t, err) + assert.Equal(t, p.Name(), "apple") + + err = svc.AddAppleProvider(appleCfg, nil) + require.Error(t, err) + +} + +func TestIntegrationProtected(t *testing.T) { + + _, teardown := prepService(t) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + resp, err := client.Get("http://127.0.0.1:8089/private") + require.Nil(t, err) + assert.Equal(t, 401, resp.StatusCode) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "Unauthorized\n", string(body)) + + // check non-admin, permanent + resp, err = client.Get("http://127.0.0.1:8089/auth/dev/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + defer resp.Body.Close() + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + require.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 86400, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + resp, err = client.Get("http://127.0.0.1:8089/private") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.NoError(t, resp.Body.Close()) +} + +func TestIntegrationBasicAuth(t *testing.T) { + + _, teardown := prepService(t) + defer teardown() + + client := &http.Client{Timeout: 5 * time.Second} + req, err := http.NewRequest("GET", "http://127.0.0.1:8089/private", http.NoBody) + require.Nil(t, err) + resp, err := client.Do(req) + require.Nil(t, err) + assert.Equal(t, 401, resp.StatusCode) + defer resp.Body.Close() + + req, err = http.NewRequest("GET", "http://127.0.0.1:8089/private", http.NoBody) + require.Nil(t, err) + req.SetBasicAuth("admin", "password") + resp, err = client.Do(req) + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.NoError(t, resp.Body.Close()) +} + +func TestIntegrationAvatar(t *testing.T) { + + _, teardown := prepService(t) + defer teardown() + + // login + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + resp, err := client.Get("http://127.0.0.1:8089/auth/dev/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + + resp, err = http.Get("http://127.0.0.1:8089/api/v1/avatar/ccfa2abd01667605b4e1fc4fcb91b1e1af323240.image") + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, 200, resp.StatusCode) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, 569, len(b)) +} + +func TestIntegrationList(t *testing.T) { + _, teardown := prepService(t) + defer teardown() + + resp, err := http.Get("http://127.0.0.1:8089/auth/list") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, `["dev","github","custom123","direct","direct_custom","email"]`+"\n", string(b)) +} + +func TestIntegrationUserInfo(t *testing.T) { + _, teardown := prepService(t) + defer teardown() + + resp, err := http.Get("http://127.0.0.1:8089/auth/user") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 401, resp.StatusCode) + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + // login + resp, err = client.Get("http://127.0.0.1:8089/auth/dev/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + defer resp.Body.Close() + + // get user info + req, err := http.NewRequest("GET", "http://127.0.0.1:8089/auth/user", http.NoBody) + require.NoError(t, err) + t.Log(resp.Cookies()) + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, 200, resp.StatusCode) + + u := token.User{} + err = json.NewDecoder(resp.Body).Decode(&u) + require.NoError(t, err) + + assert.Equal(t, token.User{Name: "dev_user", ID: "dev_user", Audience: "my-test-site", + Picture: "http://127.0.0.1:8089/api/v1/avatar/ccfa2abd01667605b4e1fc4fcb91b1e1af323240.image"}, u) +} + +func TestLogout(t *testing.T) { + _, teardown := prepService(t) + defer teardown() + + // login + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + resp, err := client.Get("http://127.0.0.1:8089/auth/dev/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + + // logout + resp, err = client.Get("http://127.0.0.1:8089/auth/logout") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + defer resp.Body.Close() + + resp, err = client.Get("http://127.0.0.1:8089/private") + require.Nil(t, err) + assert.Equal(t, 401, resp.StatusCode) + assert.NoError(t, resp.Body.Close()) +} + +func TestLogoutNoProviders(t *testing.T) { + svc := NewService(Opts{Logger: logger.Std}) + authRoute, _ := svc.Handlers() + + mux := http.NewServeMux() + mux.Handle("/auth/", authRoute) + ts := httptest.NewServer(mux) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/auth/logout") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 400, resp.StatusCode) + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "{\"error\":\"providers not defined\"}\n", string(b)) +} + +func TestBadRequests(t *testing.T) { + _, teardown := prepService(t) + defer teardown() + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get("http://127.0.0.1:8089/auth/bad/login") + require.Nil(t, err) + assert.Equal(t, 400, resp.StatusCode) + defer resp.Body.Close() + + resp, err = client.Get("http://127.0.0.1:8089/auth") + require.Nil(t, err) + assert.Equal(t, 400, resp.StatusCode) + assert.NoError(t, resp.Body.Close()) + +} + +func TestDirectProvider(t *testing.T) { + _, teardown := prepService(t) + defer teardown() + + // login + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + resp, err := client.Get("http://127.0.0.1:8089/auth/direct/login?user=dev_direct&passwd=bad") + require.Nil(t, err) + defer resp.Body.Close() + assert.Equal(t, 403, resp.StatusCode) + + resp, err = client.Get("http://127.0.0.1:8089/auth/direct/login?user=dev_direct&passwd=password") + require.Nil(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + + assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_38773b45e3a477434abb6d08668358a2b0ddd2f5"`) + + require.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 86400, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + resp, err = client.Get("http://127.0.0.1:8089/private") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.NoError(t, resp.Body.Close()) +} + +func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { + _, teardown := prepService(t) + defer teardown() + + // login + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad") + require.Nil(t, err) + defer resp.Body.Close() + assert.Equal(t, 403, resp.StatusCode) + + resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password") + require.Nil(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + + assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) + + require.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 86400, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + resp, err = client.Get("http://127.0.0.1:8089/private") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.NoError(t, resp.Body.Close()) +} + +func TestVerifProvider(t *testing.T) { + _, teardown := prepService(t) + defer teardown() + + // login + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get("http://127.0.0.1:8089/auth/email/login?user=dev&address=xyz@gmail.com") + require.Nil(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + + tkn := sender.text + jar, err := cookiejar.New(nil) + require.NoError(t, err) + client = &http.Client{Jar: jar, Timeout: 5 * time.Second} + resp, err = client.Get("http://127.0.0.1:8089/auth/email/login?token=" + tkn) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 200, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + + u := token.User{} + err = json.Unmarshal(body, &u) + require.NoError(t, err) + assert.Equal(t, token.User{Name: "dev", ID: "email_84714ea398a960df03e2619d1b850dfac25f585e", + Picture: "http://127.0.0.1:8089/api/v1/avatar/e8eb81cc51b1123059ab29575296cbfd8a6a1b6e.image"}, u) + + require.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 86400, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") +} + +func TestStatus(t *testing.T) { + + svc, teardown := prepService(t) + defer teardown() + + authRoute, _ := svc.Handlers() + + mux := http.NewServeMux() + mux.Handle("/auth/", authRoute) + ts := httptest.NewServer(mux) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/auth/status") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "{\"status\":\"not logged in\"}\n", string(b)) + + // login + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + resp, err = client.Get("http://127.0.0.1:8089/auth/dev/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + + resp, err = client.Get(ts.URL + "/auth/status") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + b, err = io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "{\"status\":\"logged in\",\"user\":\"dev_user\"}\n", string(b)) + +} + +func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unparam + + options := Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24, + Issuer: "my-test-app", + URL: "http://127.0.0.1:8089", + DisableXSRF: true, + DisableIAT: true, + Validator: token.ValidatorFunc(func(_ string, claims token.Claims) bool { + return claims.User != nil && strings.HasPrefix(claims.User.Name, "dev_") // allow only dev_ names + }), + AvatarStore: avatar.NewLocalFS("/tmp/auth-pkgz"), + AvatarResizeLimit: 120, + AvatarRoutePath: "/api/v1/avatar", + AdminPasswd: "password", + Logger: logger.Std, + } + + svc = NewService(options) + svc.AddDevProvider("localhost", 18084) // add dev provider on 18084 + svc.AddProvider("github", "cid", "csec") // add github provider + + // add go-oauth2/oauth2 provider + svc.AddCustomProvider("custom123", Client{"cid", "csecret"}, provider.CustomHandlerOpt{}) + + // add direct provider + svc.AddDirectProvider("direct", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { + return user == "dev_direct" && password == "password", nil + })) + + // add direct provider with custom user id func + svc.AddDirectProviderWithUserIDFunc("direct_custom", + provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { + return user == "dev_direct" && password == "password", nil + }), + func(user string, r *http.Request) string { + return "blah" + }, + ) + + svc.AddVerifProvider("email", "{{.Token}}", &sender) + + // run dev/test oauth2 server on :18084 + devAuth, err := svc.DevAuth() + require.NoError(t, err) + devAuth.Automatic = true // eliminate form + go devAuth.Run(context.TODO()) + time.Sleep(time.Millisecond * 50) + + // setup http server + m := svc.Middleware() + mux := http.NewServeMux() + mux.Handle("/open", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // no token required + _, _ = w.Write([]byte("open route, no token needed\n")) + })) + mux.Handle("/private", m.Auth(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // token required + _, _ = w.Write([]byte("open route, no token needed\n")) + }))) + + // setup auth routes + authRoute, avaRoutes := svc.Handlers() + mux.Handle("/auth/", authRoute) // add token handlers + mux.Handle("/api/v1/avatar/", http.StripPrefix("/api/v1/avatar", avaRoutes)) // add avatar handler + + l, err := net.Listen("tcp", "127.0.0.1:8089") + require.Nil(t, err) + ts := httptest.NewUnstartedServer(mux) + assert.NoError(t, ts.Listener.Close()) + ts.Listener = l + ts.Start() + + return svc, func() { + ts.Close() + devAuth.Shutdown() + _ = os.RemoveAll("/tmp/auth-pkgz") + } +} + +var sender = mockSender{} + +type mockSender struct { + err error + + to string + text string +} + +func (m *mockSender) Send(to, text string) error { + if m.err != nil { + return m.err + } + m.to = to + m.text = text + return nil +} + +type customHandler struct{} + +func (c customHandler) Name() string { + return "telegramBotMySiteCom" +} +func (c customHandler) LoginHandler(http.ResponseWriter, *http.Request) {} +func (c customHandler) AuthHandler(http.ResponseWriter, *http.Request) {} +func (c customHandler) LogoutHandler(http.ResponseWriter, *http.Request) {} diff --git a/v2/avatar/avatar.go b/v2/avatar/avatar.go new file mode 100644 index 00000000..ed691fa8 --- /dev/null +++ b/v2/avatar/avatar.go @@ -0,0 +1,238 @@ +// Package avatar implements avatart proxy for oauth and +// defines store interface and implements local (fs), gridfs (mongo) and boltdb stores. +package avatar + +import ( + "bytes" + "crypto/md5" //nolint gosec + "encoding/hex" + "fmt" + "image" + "image/png" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/go-pkgz/rest" + "github.com/rrivera/identicon" + "golang.org/x/image/draw" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +// Proxy provides http handler for avatars from avatar.Store +// On user login token will call Put and it will retrieve and save picture locally. +type Proxy struct { + logger.L + Store Store + RoutePath string + URL string + ResizeLimit int +} + +// Put stores retrieved avatar to avatar.Store. Gets image from user info. Returns proxied url +func (p *Proxy) Put(u token.User, client *http.Client) (avatarURL string, err error) { + + genIdenticon := func(userID string) (avatarURL string, err error) { + b, e := GenerateAvatar(userID) + if e != nil { + return "", fmt.Errorf("no picture for %s: %w", userID, e) + } + // put returns avatar base name, like 123456.image + avatarID, e := p.Store.Put(userID, p.resize(bytes.NewBuffer(b), p.ResizeLimit)) + if e != nil { + return "", e + } + + p.Logf("[DEBUG] saved identicon avatar to %s, user %q", avatarID, u.Name) + return p.URL + p.RoutePath + "/" + avatarID, nil + } + + // no picture for user, try to generate identicon avatar + if u.Picture == "" { + return genIdenticon(u.ID) + } + + body, err := p.load(u.Picture, client) + if err != nil { + p.Logf("[DEBUG] failed to fetch avatar from the orig %s, %v", u.Picture, err) + return genIdenticon(u.ID) + } + + defer func() { + if e := body.Close(); e != nil { + p.Logf("[WARN] can't close response body, %s", e) + } + }() + + avatarID, err := p.Store.Put(u.ID, p.resize(body, p.ResizeLimit)) // put returns avatar base name, like 123456.image + if err != nil { + return "", err + } + + p.Logf("[DEBUG] saved avatar from %s to %s, user %q", u.Picture, avatarID, u.Name) + return p.URL + p.RoutePath + "/" + avatarID, nil +} + +// load avatar from remote url and return body. Caller has to close the reader +func (p *Proxy) load(url string, client *http.Client) (rc io.ReadCloser, err error) { + // load avatar from remote location + var resp *http.Response + err = retry(5, time.Second, func() error { + var e error + resp, e = client.Get(url) + return e + }) + if err != nil { + return nil, fmt.Errorf("failed to fetch avatar from the orig: %w", err) + } + + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() // caller won't close on error + return nil, fmt.Errorf("failed to get avatar from the orig, status %s", resp.Status) + } + + return resp.Body, nil +} + +// Handler returns token routes for given provider +func (p *Proxy) Handler(w http.ResponseWriter, r *http.Request) { + + if r.Method != "GET" { + w.WriteHeader(http.StatusMethodNotAllowed) + } + elems := strings.Split(r.URL.Path, "/") + avatarID := elems[len(elems)-1] + if !reValidAvatarID.MatchString(avatarID) { + rest.SendErrorJSON(w, r, p.L, http.StatusForbidden, fmt.Errorf("invalid avatar id from %s", r.URL.Path), "can't load avatar") + return + } + + // enforce client-side caching + etag := `"` + p.Store.ID(avatarID) + `"` + w.Header().Set("Etag", etag) + w.Header().Set("Cache-Control", "max-age=604800") // 7 days + if match := r.Header.Get("If-None-Match"); match != "" { + etag = strings.TrimPrefix(etag, `"`) + etag = strings.TrimSuffix(etag, `"`) + if match == etag { + w.WriteHeader(http.StatusNotModified) + return + } + } + + avReader, size, err := p.Store.Get(avatarID) + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusBadRequest, err, "can't load avatar") + return + } + + defer func() { + if e := avReader.Close(); e != nil { + p.Logf("[WARN] can't close avatar reader for %s, %s", avatarID, e) + } + }() + + w.Header().Set("Content-Type", "image/*") + w.Header().Set("Content-Length", strconv.Itoa(size)) + w.WriteHeader(http.StatusOK) + if _, err = io.Copy(w, avReader); err != nil { + p.Logf("[WARN] can't send response to %s, %s", r.RemoteAddr, err) + } +} + +// resize an image of supported format (PNG, JPG, GIF) to the size of "limit" px of the biggest side +// (width or height) preserving aspect ratio. +// Returns original reader if resizing is not needed or failed. +func (p *Proxy) resize(reader io.Reader, limit int) io.Reader { + if reader == nil { + p.Logf("[WARN] avatar resize(): reader is nil") + return nil + } + if limit <= 0 { + p.Logf("[DEBUG] avatar resize(): limit should be greater than 0") + return reader + } + + var teeBuf bytes.Buffer + tee := io.TeeReader(reader, &teeBuf) + src, _, err := image.Decode(tee) + if err != nil { + p.Logf("[WARN] avatar resize(): can't decode avatar image, %s", err) + return &teeBuf + } + + bounds := src.Bounds() + w, h := bounds.Dx(), bounds.Dy() + if w <= limit && h <= limit || w <= 0 || h <= 0 { + p.Logf("[DEBUG] resizing image is smaller that the limit or has 0 size") + return &teeBuf + } + newW, newH := w*limit/h, limit + if w > h { + newW, newH = limit, h*limit/w + } + m := image.NewRGBA(image.Rect(0, 0, newW, newH)) + // Slower than `draw.ApproxBiLinear.Scale()` but better quality. + draw.BiLinear.Scale(m, m.Bounds(), src, src.Bounds(), draw.Src, nil) + + var out bytes.Buffer + if err = png.Encode(&out, m); err != nil { + p.Logf("[WARN] avatar resize(): can't encode resized avatar to PNG, %s", err) + return &teeBuf + } + return &out +} + +// GenerateAvatar for give user with identicon +func GenerateAvatar(user string) ([]byte, error) { + + iconGen, err := identicon.New("pkgz/auth", 5, 5) + if err != nil { + return nil, fmt.Errorf("can't create identicon service: %w", err) + } + + ii, err := iconGen.Draw(user) // generate an IdentIcon + if err != nil { + return nil, fmt.Errorf("failed to draw avatar for %s: %w", user, err) + } + + buf := &bytes.Buffer{} + err = ii.Png(300, buf) + return buf.Bytes(), err +} + +// GetGravatarURL returns url to gravatar picture for given email +func GetGravatarURL(email string) (res string, err error) { + + hash := md5.Sum([]byte(strings.ToLower(strings.TrimSpace(email)))) + hexHash := hex.EncodeToString(hash[:]) + + client := http.Client{Timeout: 5 * time.Second} + res = "https://www.gravatar.com/avatar/" + hexHash + resp, err := client.Get(res + "?d=404&s=80") + if err != nil { + return "", err + } + defer resp.Body.Close() //nolint gosec // we don't care about response body + if resp.StatusCode != 200 { + return "", fmt.Errorf("%s", resp.Status) + } + return res, nil +} + +func retry(retries int, delay time.Duration, fn func() error) (err error) { + for i := 0; i < retries; i++ { + if err = fn(); err == nil { + return nil + } + time.Sleep(delay) + } + if err != nil { + return fmt.Errorf("retry failed: %w", err) + } + return nil +} diff --git a/v2/avatar/avatar_test.go b/v2/avatar/avatar_test.go new file mode 100644 index 00000000..a5e0bc2b --- /dev/null +++ b/v2/avatar/avatar_test.go @@ -0,0 +1,280 @@ +package avatar + +import ( + "bytes" + "fmt" + "image" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +func TestAvatar_Put(t *testing.T) { + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/pic.png" { + w.Header().Set("Content-Type", "image/*") + fmt.Fprint(w, "some picture bin data") + return + } + http.Error(w, "not found", http.StatusNotFound) + })) + defer func() { + _ = os.RemoveAll("/tmp/avatars.test/") + ts.Close() + }() + + p := Proxy{RoutePath: "/avatar", URL: "http://localhost:8080", Store: NewLocalFS("/tmp/avatars.test"), L: logger.NoOp} + assert.NoError(t, os.MkdirAll("/tmp/avatars.test", 0o700)) + defer os.RemoveAll("/tmp/avatars.test") + + client := &http.Client{Timeout: time.Second} + u := token.User{ID: "user1", Name: "user1 name", Picture: ts.URL + "/pic.png"} + res, err := p.Put(u, client) + assert.NoError(t, err) + assert.Equal(t, "http://localhost:8080/avatar/b3daa77b4c04a9551b8781d03191fe098f325e67.image", res) + fi, err := os.Stat("/tmp/avatars.test/30/b3daa77b4c04a9551b8781d03191fe098f325e67.image") + assert.NoError(t, err) + assert.Equal(t, int64(21), fi.Size()) + + u.ID = "user2" + res, err = p.Put(u, client) + assert.NoError(t, err) + assert.Equal(t, "http://localhost:8080/avatar/a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image", res) + fi, err = os.Stat("/tmp/avatars.test/84/a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image") + assert.NoError(t, err) + assert.Equal(t, int64(21), fi.Size()) +} + +func TestAvatar_PutIdenticon(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Print("request: ", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + })) + defer func() { + _ = os.RemoveAll("/tmp/avatars.test/") + ts.Close() + }() + p := Proxy{RoutePath: "/avatar", URL: "http://localhost:8080", Store: NewLocalFS("/tmp/avatars.test"), L: logger.Std} + client := &http.Client{Timeout: time.Second} + + u := token.User{ID: "user1", Name: "user1 name"} + res, err := p.Put(u, client) + assert.NoError(t, err) + assert.Equal(t, "http://localhost:8080/avatar/b3daa77b4c04a9551b8781d03191fe098f325e67.image", res) + fi, err := os.Stat("/tmp/avatars.test/30/b3daa77b4c04a9551b8781d03191fe098f325e67.image") + assert.NoError(t, err) + assert.Equal(t, int64(999), fi.Size()) + +} + +func TestAvatar_PutFailed(t *testing.T) { + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Print("request: ", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + })) + defer func() { + _ = os.RemoveAll("/tmp/avatars.test/") + ts.Close() + }() + + p := Proxy{RoutePath: "/avatar", URL: "http://localhost:8080", Store: NewLocalFS("/tmp/avatars.test"), L: logger.Std} + client := &http.Client{Timeout: time.Second} + + u := token.User{ID: "user2", Name: "user2 name", Picture: "http://127.0.0.1:22345/avater/pic"} + res, err := p.Put(u, client) + require.NoError(t, err) + assert.Equal(t, "http://localhost:8080/avatar/a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image", res) + fi, err := os.Stat("/tmp/avatars.test/84/a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image") + require.NoError(t, err) + assert.Equal(t, int64(992), fi.Size()) +} + +func TestAvatar_Routes(t *testing.T) { + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/pic.png" { + w.Header().Set("Content-Type", "image/*") + w.Header().Set("Custom-Header", "xyz") + _, err := fmt.Fprint(w, "some picture bin data") + require.NoError(t, err) + return + } + http.Error(w, "not found", http.StatusNotFound) + })) + defer ts.Close() + + p := Proxy{RoutePath: "/avatar", Store: NewLocalFS("/tmp/avatars.test"), L: logger.Std} + assert.NoError(t, os.MkdirAll("/tmp/avatars.test", 0o700)) + defer os.RemoveAll("/tmp/avatars.test") + client := &http.Client{Timeout: time.Second} + + u := token.User{ID: "user1", Name: "user1 name", Picture: ts.URL + "/pic.png"} + _, err := p.Put(u, client) + assert.NoError(t, err) + + { + // status 400 + req, err := http.NewRequest("GET", "/123aa77b4c04a9551b8781d03191fe098f325e67.image", http.NoBody) + require.NoError(t, err) + rr := httptest.NewRecorder() + handler := http.Handler(http.HandlerFunc(p.Handler)) + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + } + + { + // status 403 + req, err := http.NewRequest("GET", "../not-allowed.txt", http.NoBody) + require.NoError(t, err) + rr := httptest.NewRecorder() + handler := http.Handler(http.HandlerFunc(p.Handler)) + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + } + + { // status 200 + req, err := http.NewRequest("GET", "/b3daa77b4c04a9551b8781d03191fe098f325e67.image", http.NoBody) + require.NoError(t, err) + rr := httptest.NewRecorder() + handler := http.Handler(http.HandlerFunc(p.Handler)) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, []string{"image/*"}, rr.Header()["Content-Type"]) + assert.Equal(t, []string{"21"}, rr.Header()["Content-Length"]) + assert.Equal(t, []string(nil), rr.Header()["Custom-Header"], "strip all custom headers") + assert.NotNil(t, rr.Header()["Etag"]) + + bb := bytes.Buffer{} + sz, err := io.Copy(&bb, rr.Body) + assert.NoError(t, err) + assert.Equal(t, int64(21), sz) + assert.Equal(t, "some picture bin data", bb.String()) + } + + { + // status 304 + req, err := http.NewRequest("GET", "/b3daa77b4c04a9551b8781d03191fe098f325e67.image", http.NoBody) + require.NoError(t, err) + id := p.Store.ID("b3daa77b4c04a9551b8781d03191fe098f325e67.image") + req.Header.Add("If-None-Match", p.Store.ID(id)) // hash of `some_random_name.image` since the file doesn't exist + + rr := httptest.NewRecorder() + handler := http.Handler(http.HandlerFunc(p.Handler)) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotModified, rr.Code) + assert.Equal(t, []string{`"` + id + `"`}, rr.Header()["Etag"]) + } + +} + +func TestAvatar_resize(t *testing.T) { + checkC := func(t *testing.T, r io.Reader, cExp []byte) { + content, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, cExp, content) + } + + p := Proxy{L: logger.Std} + // Reader is nil. + resizedR := p.resize(nil, 100) + assert.Nil(t, resizedR) + + // Negative limit error. + resizedR = p.resize(strings.NewReader("some picture bin data"), -1) + require.NotNil(t, resizedR) + checkC(t, resizedR, []byte("some picture bin data")) + + // Decode error. + resizedR = p.resize(strings.NewReader("invalid image content"), 100) + assert.NotNil(t, resizedR) + checkC(t, resizedR, []byte("invalid image content")) + + cases := []struct { + file string + wr, hr int + }{ + {"testdata/circles.png", 400, 300}, // full size: 800x600 px + {"testdata/circles.jpg", 300, 400}, // full size: 600x800 px + } + + for _, c := range cases { + img, err := os.ReadFile(c.file) + require.Nil(t, err, "can't open test file %s", c.file) + + // No need for resize, avatar dimensions are smaller than resize limit. + resizedR = p.resize(bytes.NewReader(img), 800) + assert.NotNil(t, resizedR, "file %s", c.file) + checkC(t, resizedR, img) + + // Resizing to half of width. Check resizedR avatar format PNG. + resizedR = p.resize(bytes.NewReader(img), 400) + assert.NotNil(t, resizedR, "file %s", c.file) + + imgRz, format, err := image.Decode(resizedR) + assert.NoError(t, err, "file %s", c.file) + assert.Equal(t, "png", format, "file %s", c.file) + bounds := imgRz.Bounds() + assert.Equal(t, c.wr, bounds.Dx(), "file %s", c.file) + assert.Equal(t, c.hr, bounds.Dy(), "file %s", c.file) + } +} + +func TestAvatar_GetGravatarURL(t *testing.T) { + tbl := []struct { + email string + err error + url string + }{ + {"eefretsoul@gmail.com", nil, "https://www.gravatar.com/avatar/c82739de14cf64affaf30856ca95b851"}, + {"umputun-xyz@example.com", fmt.Errorf("404 Not Found"), ""}, + } + + for i, tt := range tbl { + tt := tt + t.Run("test-"+strconv.Itoa(i), func(t *testing.T) { + url, err := GetGravatarURL(tt.email) + if tt.err != nil { + assert.EqualError(t, err, tt.err.Error()) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.url, url) + }) + } +} + +func TestAvatar_Retry(t *testing.T) { + i := 0 + err := retry(5, time.Millisecond, func() error { + if i == 3 { + return nil + } + i++ + return fmt.Errorf("err") + }) + assert.NoError(t, err) + assert.Equal(t, 3, i) + + st := time.Now() + err = retry(5, time.Millisecond, func() error { + return fmt.Errorf("err") + }) + assert.Error(t, err) + assert.True(t, time.Since(st) >= time.Microsecond*5) +} diff --git a/v2/avatar/bolt.go b/v2/avatar/bolt.go new file mode 100644 index 00000000..001f214a --- /dev/null +++ b/v2/avatar/bolt.go @@ -0,0 +1,141 @@ +package avatar + +import ( + "bytes" + "fmt" + "io" + "log" + + bolt "go.etcd.io/bbolt" +) + +// BoltDB implements avatar store with bolt +// using separate db (file) with "avatars" bucket to keep image bin and "metas" bucket +// to keep sha1 of picture. avatarID (base file name) used as a key for both. +type BoltDB struct { + fileName string // full path to boltdb + db *bolt.DB +} + +const avatarsBktName = "avatars" +const metasBktName = "metas" + +// NewBoltDB makes bolt avatar store +func NewBoltDB(fileName string, options bolt.Options) (*BoltDB, error) { + db, err := bolt.Open(fileName, 0600, &options) //nolint + if err != nil { + return nil, fmt.Errorf("failed to make boltdb for %s: %w", fileName, err) + } + err = db.Update(func(tx *bolt.Tx) error { + if _, e := tx.CreateBucketIfNotExists([]byte(avatarsBktName)); e != nil { + return fmt.Errorf("failed to create top level bucket %s: %w", avatarsBktName, e) + } + if _, e := tx.CreateBucketIfNotExists([]byte(metasBktName)); e != nil { + return fmt.Errorf("failed to create top metas bucket %s: %w", metasBktName, e) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize boltdb db %q buckets: %w", fileName, err) + } + return &BoltDB{db: db, fileName: fileName}, nil +} + +// Put avatar to bolt, key by avatarID. Trying to resize image and lso calculates sha1 of the file for ID func +func (b *BoltDB) Put(userID string, reader io.Reader) (avatar string, err error) { + id := encodeID(userID) + + avatarID := id + imgSfx + err = b.db.Update(func(tx *bolt.Tx) error { + buf := &bytes.Buffer{} + if _, err = io.Copy(buf, reader); err != nil { + return fmt.Errorf("can't read avatar %s: %w", avatarID, err) + } + + if err = tx.Bucket([]byte(avatarsBktName)).Put([]byte(avatarID), buf.Bytes()); err != nil { + return fmt.Errorf("can't put to bucket with %s: %w", avatarID, err) + } + // store sha1 of the image + return tx.Bucket([]byte(metasBktName)).Put([]byte(avatarID), []byte(hash(buf.Bytes(), avatarID))) + }) + return avatarID, err +} + +// Get avatar reader for avatar id.image, avatarID used as the direct key +func (b *BoltDB) Get(avatarID string) (reader io.ReadCloser, size int, err error) { + buf := &bytes.Buffer{} + err = b.db.View(func(tx *bolt.Tx) error { + data := tx.Bucket([]byte(avatarsBktName)).Get([]byte(avatarID)) + if data == nil { + return fmt.Errorf("can't load avatar %s", avatarID) + } + size, err = buf.Write(data) + if err != nil { + return fmt.Errorf("failed to write for %s: %w", avatarID, err) + } + return nil + }) + return io.NopCloser(buf), size, err +} + +// ID returns a fingerprint of the avatar content. +func (b *BoltDB) ID(avatarID string) (id string) { + data := []byte{} + err := b.db.View(func(tx *bolt.Tx) error { + if data = tx.Bucket([]byte(metasBktName)).Get([]byte(avatarID)); data == nil { + return fmt.Errorf("can't load avatar's id for %s", avatarID) + } + return nil + }) + + if err != nil { // failed to get ID, use encoded avatarID + log.Printf("[DEBUG] can't get avatar info '%s', %s", avatarID, err) + return encodeID(avatarID) + } + + return string(data) +} + +// Remove avatar from bolt +func (b *BoltDB) Remove(avatarID string) (err error) { + return b.db.Update(func(tx *bolt.Tx) error { + bkt := tx.Bucket([]byte(avatarsBktName)) + if bkt.Get([]byte(avatarID)) == nil { + return fmt.Errorf("avatar key not found, %s", avatarID) + } + if err = tx.Bucket([]byte(avatarsBktName)).Delete([]byte(avatarID)); err != nil { + return fmt.Errorf("can't delete avatar object %s: %w", avatarID, err) + } + if err = tx.Bucket([]byte(metasBktName)).Delete([]byte(avatarID)); err != nil { + return fmt.Errorf("can't delete meta object %s: %w", avatarID, err) + } + return nil + }) +} + +// List all avatars (ids) from metas bucket +// note: id includes .image suffix +func (b *BoltDB) List() (ids []string, err error) { + err = b.db.View(func(tx *bolt.Tx) error { + return tx.Bucket([]byte(metasBktName)).ForEach(func(k, _ []byte) error { + ids = append(ids, string(k)) + return nil + }) + }) + if err != nil { + return nil, fmt.Errorf("failed to list: %w", err) + } + return ids, nil +} + +// Close bolt store +func (b *BoltDB) Close() error { + if err := b.db.Close(); err != nil { + return fmt.Errorf("failed to close %s: %w", b.fileName, err) + } + return nil +} + +func (b *BoltDB) String() string { + return fmt.Sprintf("boltdb, path=%s", b.fileName) +} diff --git a/v2/avatar/bolt_test.go b/v2/avatar/bolt_test.go new file mode 100644 index 00000000..a4bddbcc --- /dev/null +++ b/v2/avatar/bolt_test.go @@ -0,0 +1,103 @@ +package avatar + +import ( + "io" + "os" + "sort" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + bolt "go.etcd.io/bbolt" +) + +var testDB = "/tmp/test-remark-avatars.db" + +func TestBoltDB_PutAndGet(t *testing.T) { + var b Store + b, teardown := prepBoltStore(t) + defer teardown() + + avatar, err := b.Put("user1", strings.NewReader("some picture bin data")) + require.Nil(t, err) + assert.Equal(t, "b3daa77b4c04a9551b8781d03191fe098f325e67.image", avatar) + + rd, size, err := b.Get(avatar) + require.Nil(t, err) + assert.Equal(t, 21, size) + data, err := io.ReadAll(rd) + require.Nil(t, err) + assert.Equal(t, "some picture bin data", string(data)) + + _, _, err = b.Get("bad avatar") + assert.Error(t, err) + + // check IDs + assert.Equal(t, "fddae9ce556712a6ece0e8951a6e7a05c51ed6bf", b.ID(avatar)) + assert.Equal(t, "70c881d4a26984ddce795f6f71817c9cf4480e79", b.ID("aaaa"), "no data, encoded avatar id") + + l, err := b.List() + require.Nil(t, err) + assert.Equal(t, 1, len(l)) + assert.Equal(t, "b3daa77b4c04a9551b8781d03191fe098f325e67.image", l[0]) +} + +func TestBoltDB_Remove(t *testing.T) { + b, teardown := prepBoltStore(t) + defer teardown() + + assert.Error(t, b.Remove("no-such-thing.image")) + + avatar, err := b.Put("user1", strings.NewReader("some picture bin data")) + require.Nil(t, err) + assert.Equal(t, "b3daa77b4c04a9551b8781d03191fe098f325e67.image", avatar) + assert.NoError(t, b.Remove("b3daa77b4c04a9551b8781d03191fe098f325e67.image"), "remove real one") + assert.Error(t, b.Remove("b3daa77b4c04a9551b8781d03191fe098f325e67.image"), "already removed") +} + +func TestBoltDB_List(t *testing.T) { + b, teardown := prepBoltStore(t) + defer teardown() + + // write some avatars + _, err := b.Put("user1", strings.NewReader("some picture bin data 1")) + require.Nil(t, err) + _, err = b.Put("user2", strings.NewReader("some picture bin data 2")) + require.Nil(t, err) + _, err = b.Put("user3", strings.NewReader("some picture bin data 3")) + require.Nil(t, err) + + l, err := b.List() + assert.NoError(t, err) + assert.Equal(t, 3, len(l), "3 avatars listed") + sort.Strings(l) + assert.Equal(t, []string{"0b7f849446d3383546d15a480966084442cd2193.image", "a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image", "b3daa77b4c04a9551b8781d03191fe098f325e67.image"}, l) + + r, size, err := b.Get("0b7f849446d3383546d15a480966084442cd2193.image") + assert.NoError(t, err) + assert.Equal(t, 23, size) + data, err := io.ReadAll(r) + assert.NoError(t, err) + assert.Equal(t, "some picture bin data 3", string(data)) +} + +func TestBoltDB_DoubleClose(t *testing.T) { + _ = os.Remove(testDB) + boltStore, err := NewBoltDB(testDB, bolt.Options{}) + require.Nil(t, err) + assert.NoError(t, boltStore.Close()) + assert.NoError(t, boltStore.Close(), "second call should not result in panic or errors") + _ = os.Remove(testDB) +} + +// makes new boltdb, put two records +func prepBoltStore(t *testing.T) (blt *BoltDB, teardown func()) { + _ = os.Remove(testDB) + boltStore, err := NewBoltDB(testDB, bolt.Options{}) + require.Nil(t, err) + return boltStore, func() { + assert.NoError(t, boltStore.Close()) + _ = os.Remove(testDB) + } +} diff --git a/v2/avatar/gridfs.go b/v2/avatar/gridfs.go new file mode 100644 index 00000000..fc6e2fb1 --- /dev/null +++ b/v2/avatar/gridfs.go @@ -0,0 +1,157 @@ +package avatar + +import ( + "bytes" + "context" + "fmt" + "io" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/gridfs" + "go.mongodb.org/mongo-driver/mongo/options" +) + +// NewGridFS makes gridfs (mongo) avatar store +func NewGridFS(client *mongo.Client, dbName, bucketName string, timeout time.Duration) *GridFS { + return &GridFS{client: client, db: client.Database(dbName), bucketName: bucketName, timeout: timeout} +} + +// GridFS implements Store for GridFS +type GridFS struct { + client *mongo.Client + db *mongo.Database + bucketName string + timeout time.Duration +} + +// Put avatar to gridfs object, try to resize +func (gf *GridFS) Put(userID string, reader io.Reader) (avatar string, err error) { + id := encodeID(userID) + bucket, err := gridfs.NewBucket(gf.db, &options.BucketOptions{Name: &gf.bucketName}) + if err != nil { + return "", err + } + + buf := &bytes.Buffer{} + if _, err = io.Copy(buf, reader); err != nil { + return "", fmt.Errorf("can't read avatar for %s: %w", userID, err) + } + + avaHash := hash(buf.Bytes(), id) + _, err = bucket.UploadFromStream(id+imgSfx, buf, &options.UploadOptions{Metadata: bson.M{"hash": avaHash}}) + return id + imgSfx, err +} + +// Get avatar reader for avatar id.image +func (gf *GridFS) Get(avatar string) (reader io.ReadCloser, size int, err error) { + bucket, err := gridfs.NewBucket(gf.db, &options.BucketOptions{Name: &gf.bucketName}) + if err != nil { + return nil, 0, err + } + buf := &bytes.Buffer{} + sz, e := bucket.DownloadToStreamByName(avatar, buf) + if e != nil { + return nil, 0, fmt.Errorf("can't read avatar %s: %w", avatar, e) + } + return io.NopCloser(buf), int(sz), nil +} + +// ID returns a fingerprint of the avatar content. Uses MD5 because gridfs provides it directly +func (gf *GridFS) ID(avatar string) (id string) { + + finfo := struct { + ID primitive.ObjectID `bson:"_id"` + Len int `bson:"length"` + FileName string `bson:"filename"` + MetaData struct { + Hash string `bson:"hash"` + } `bson:"metadata"` + }{} + + bucket, err := gridfs.NewBucket(gf.db, &options.BucketOptions{Name: &gf.bucketName}) + if err != nil { + return encodeID(avatar) + } + cursor, err := bucket.Find(bson.M{"filename": avatar}) + if err != nil { + return encodeID(avatar) + } + + ctx, cancel := context.WithTimeout(context.Background(), gf.timeout) + defer cancel() + if found := cursor.Next(ctx); found { + if err = cursor.Decode(&finfo); err != nil { + return encodeID(avatar) + } + return finfo.MetaData.Hash + } + return encodeID(avatar) +} + +// Remove avatar from gridfs +func (gf *GridFS) Remove(avatar string) error { + bucket, err := gridfs.NewBucket(gf.db, &options.BucketOptions{Name: &gf.bucketName}) + if err != nil { + return err + } + cursor, err := bucket.Find(bson.M{"filename": avatar}) + if err != nil { + return err + } + + r := struct { + ID primitive.ObjectID `bson:"_id"` + }{} + ctx, cancel := context.WithTimeout(context.Background(), gf.timeout) + defer cancel() + if found := cursor.Next(ctx); found { + if err := cursor.Decode(&r); err != nil { + return err + } + return bucket.Delete(r.ID) + } + return fmt.Errorf("avatar %s not found", avatar) +} + +// List all avatars (ids) on gfs +// note: id includes .image suffix +func (gf *GridFS) List() (ids []string, err error) { + bucket, err := gridfs.NewBucket(gf.db, &options.BucketOptions{Name: &gf.bucketName}) + if err != nil { + return nil, err + } + + gfsFile := struct { + Filename string `bson:"filename,omitempty"` + }{} + cursor, err := bucket.Find(bson.M{}) + if err != nil { + return nil, err + } + ctx, cancel := context.WithTimeout(context.Background(), gf.timeout) + defer cancel() + for cursor.Next(ctx) { + if err := cursor.Decode(&gfsFile); err != nil { + return nil, err + } + ids = append(ids, gfsFile.Filename) + } + return ids, nil +} + +// Close gridfs store +func (gf *GridFS) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), gf.timeout) + defer cancel() + if err := gf.client.Disconnect(ctx); err != nil && err != mongo.ErrClientDisconnected { + return err + } + return nil +} + +func (gf *GridFS) String() string { + return fmt.Sprintf("mongo (grid fs), db=%s, bucket=%s", gf.db.Name(), gf.bucketName) +} diff --git a/v2/avatar/gridfs_test.go b/v2/avatar/gridfs_test.go new file mode 100644 index 00000000..68c703d5 --- /dev/null +++ b/v2/avatar/gridfs_test.go @@ -0,0 +1,110 @@ +package avatar + +import ( + "context" + "io" + "os" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +func TestGridFS_PutAndGet(t *testing.T) { + if _, ok := os.LookupEnv("ENABLE_MONGO_TESTS"); !ok { + t.Skip("ENABLE_MONGO_TESTS env variable is not set") + } + p := prepGFStore(t) + defer p.Close() + avatar, err := p.Put("user1", strings.NewReader("some picture bin data")) + require.Nil(t, err) + assert.Equal(t, "b3daa77b4c04a9551b8781d03191fe098f325e67.image", avatar) + + rd, size, err := p.Get(avatar) + require.Nil(t, err) + assert.Equal(t, 21, size) + data, err := io.ReadAll(rd) + require.Nil(t, err) + assert.Equal(t, "some picture bin data", string(data)) + + _, _, err = p.Get("bad avatar") + assert.Error(t, err) + assert.Equal(t, "fddae9ce556712a6ece0e8951a6e7a05c51ed6bf", p.ID(avatar)) + assert.Equal(t, "70c881d4a26984ddce795f6f71817c9cf4480e79", p.ID("aaaa"), "no data, encode avatar id") + + l, err := p.List() + require.Nil(t, err) + assert.Equal(t, 1, len(l)) + assert.Equal(t, "b3daa77b4c04a9551b8781d03191fe098f325e67.image", l[0]) +} + +func TestGridFS_Remove(t *testing.T) { + if _, ok := os.LookupEnv("ENABLE_MONGO_TESTS"); !ok { + t.Skip("ENABLE_MONGO_TESTS env variable is not set") + } + p := prepGFStore(t) + defer p.Close() + assert.Error(t, p.Remove("no-such-thing.image")) + avatar, err := p.Put("user1", strings.NewReader("some picture bin data")) + require.Nil(t, err) + assert.Equal(t, "b3daa77b4c04a9551b8781d03191fe098f325e67.image", avatar) + assert.NoError(t, p.Remove("b3daa77b4c04a9551b8781d03191fe098f325e67.image"), "remove real one") + assert.Error(t, p.Remove("b3daa77b4c04a9551b8781d03191fe098f325e67.image"), "already removed") +} + +func TestGridFS_List(t *testing.T) { + if _, ok := os.LookupEnv("ENABLE_MONGO_TESTS"); !ok { + t.Skip("ENABLE_MONGO_TESTS env variable is not set") + } + p := prepGFStore(t) + defer p.Close() + + // write some avatars + _, err := p.Put("user1", strings.NewReader("some picture bin data 1")) + require.Nil(t, err) + _, err = p.Put("user2", strings.NewReader("some picture bin data 2")) + require.Nil(t, err) + _, err = p.Put("user3", strings.NewReader("some picture bin data 3")) + require.Nil(t, err) + + l, err := p.List() + assert.NoError(t, err) + assert.Equal(t, 3, len(l), "3 avatars listed") + sort.Strings(l) + assert.Equal(t, []string{"0b7f849446d3383546d15a480966084442cd2193.image", "a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image", "b3daa77b4c04a9551b8781d03191fe098f325e67.image"}, l) + + r, size, err := p.Get("0b7f849446d3383546d15a480966084442cd2193.image") + assert.NoError(t, err) + assert.Equal(t, 23, size) + data, err := io.ReadAll(r) + assert.NoError(t, err) + assert.Equal(t, "some picture bin data 3", string(data)) +} + +func TestGridFS_DoubleClose(t *testing.T) { + if _, ok := os.LookupEnv("ENABLE_MONGO_TESTS"); !ok { + t.Skip("ENABLE_MONGO_TESTS env variable is not set") + } + p := prepGFStore(t) + assert.NoError(t, p.Close()) + assert.NoError(t, p.Close(), "second call should not result in panic or errors") +} + +func prepGFStore(t *testing.T) *GridFS { + const timeout = time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017").SetConnectTimeout(timeout)) + require.NoError(t, err) + + _ = client.Database("test").Collection("ava_fs.chunks").Drop(ctx) + _ = client.Database("test").Collection("ava_fs.files").Drop(ctx) + + return NewGridFS(client, "test", "ava_fs", time.Second) +} diff --git a/v2/avatar/localfs.go b/v2/avatar/localfs.go new file mode 100644 index 00000000..ee4828c0 --- /dev/null +++ b/v2/avatar/localfs.go @@ -0,0 +1,124 @@ +package avatar + +import ( + "fmt" + "hash/crc64" + "io" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" +) + +// LocalFS implements Store for local file system +type LocalFS struct { + storePath string + ctcTable *crc64.Table + once sync.Once +} + +// NewLocalFS makes file-system avatar store +func NewLocalFS(storePath string) *LocalFS { + return &LocalFS{storePath: storePath} +} + +// Put avatar for userID to file and return avatar's file name (base), like 12345678.image +// userID can be avatarID as well, in this case encoding just strip .image prefix +func (fs *LocalFS) Put(userID string, reader io.Reader) (avatar string, err error) { + if reader == nil { + return "", fmt.Errorf("empty reader") + } + id := encodeID(userID) + location := fs.location(id) // location adds partition to path + + if e := os.MkdirAll(location, 0o750); e != nil { + return "", fmt.Errorf("failed to mkdir avatar location %s: %w", location, e) + } + + avFile := path.Join(location, id+imgSfx) + fh, err := os.Create(avFile) //nolint + if err != nil { + return "", fmt.Errorf("can't create file %s: %w", avFile, err) + } + defer func() { //nolint + if e := fh.Close(); e != nil { + err = fmt.Errorf("can't close avatar file %s: %w", avFile, e) + } + }() + + if _, err = io.Copy(fh, reader); err != nil { + return "", fmt.Errorf("can't save file %s: %w", avFile, err) + } + return id + imgSfx, nil +} + +// Get avatar reader for avatar id.image +func (fs *LocalFS) Get(avatar string) (reader io.ReadCloser, size int, err error) { + location := fs.location(strings.TrimSuffix(avatar, imgSfx)) + fh, err := os.Open(path.Join(location, avatar)) //nolint + if err != nil { + return nil, 0, fmt.Errorf("can't load avatar %s, id: %w", avatar, err) + } + if fi, e := fh.Stat(); e == nil { + size = int(fi.Size()) + } + return fh, size, nil +} + +// ID returns a fingerprint of the avatar content. +func (fs *LocalFS) ID(avatar string) (id string) { + location := fs.location(strings.TrimSuffix(avatar, imgSfx)) + avFile := path.Join(location, avatar) + fi, err := os.Stat(avFile) + if err != nil { + return encodeID(avatar) + } + return encodeID(avatar + strconv.FormatInt(fi.ModTime().Unix(), 10)) +} + +// Remove avatar file +func (fs *LocalFS) Remove(avatar string) error { + location := fs.location(strings.TrimSuffix(avatar, imgSfx)) + avFile := path.Join(location, avatar) + return os.Remove(avFile) +} + +// List all avatars (ids) on local file system +// note: id includes .image suffix +func (fs *LocalFS) List() (ids []string, err error) { + err = filepath.Walk(fs.storePath, + func(_ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() && strings.HasSuffix(info.Name(), imgSfx) { + ids = append(ids, info.Name()) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("can't list avatars: %w", err) + } + return ids, nil +} + +// Close LocalFS does nothing but satisfies interface +func (fs *LocalFS) Close() error { + return nil +} + +func (fs *LocalFS) String() string { + return fmt.Sprintf("localfs, path=%s", fs.storePath) +} + +// get location (directory) for user id by adding partition to final path in order to keep files +// in different subdirectories and avoid too many files in a single place. +// the end result is a full path like this - /tmp/avatars.test/92 +func (fs *LocalFS) location(id string) string { + fs.once.Do(func() { fs.ctcTable = crc64.MakeTable(crc64.ECMA) }) + checksum64 := crc64.Checksum([]byte(id), fs.ctcTable) + partition := checksum64 % 100 + return path.Join(fs.storePath, fmt.Sprintf("%02d", partition)) +} diff --git a/v2/avatar/localfs_test.go b/v2/avatar/localfs_test.go new file mode 100644 index 00000000..1616809b --- /dev/null +++ b/v2/avatar/localfs_test.go @@ -0,0 +1,177 @@ +package avatar + +import ( + "io" + "os" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAvatarStoreFS_Put(t *testing.T) { + p := NewLocalFS("/tmp/avatars.test") + err := os.MkdirAll("/tmp/avatars.test", 0o700) + require.NoError(t, err) + defer os.RemoveAll("/tmp/avatars.test") + + avatar, err := p.Put("user1", nil) + assert.Equal(t, "", avatar) + assert.EqualError(t, err, "empty reader") + + avatar, err = p.Put("user1", strings.NewReader("some picture bin data")) + require.Nil(t, err) + assert.Equal(t, "b3daa77b4c04a9551b8781d03191fe098f325e67.image", avatar) + fi, err := os.Stat("/tmp/avatars.test/30/b3daa77b4c04a9551b8781d03191fe098f325e67.image") + assert.NoError(t, err) + assert.Equal(t, int64(21), fi.Size()) + + avatar, err = p.Put("user2", strings.NewReader("some picture bin data 123")) + require.Nil(t, err) + assert.Equal(t, "a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image", avatar) + fi, err = os.Stat("/tmp/avatars.test/84/a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image") + assert.NoError(t, err) + assert.Equal(t, int64(25), fi.Size()) + + // with encoded id + avatar, err = p.Put("f1881c06eec96db9901c7bbfe41c42a3f08e9cb8.image", strings.NewReader("some picture bin data 123")) + require.Nil(t, err) + assert.Equal(t, "f1881c06eec96db9901c7bbfe41c42a3f08e9cb8.image", avatar) + fi, err = os.Stat("/tmp/avatars.test/56/f1881c06eec96db9901c7bbfe41c42a3f08e9cb8.image") + assert.NoError(t, err) + assert.Equal(t, int64(25), fi.Size()) + + // with resize + file, e := os.Open("testdata/circles.png") + require.Nil(t, e) + avatar, err = p.Put("user3", file) + require.Nil(t, err) + assert.Equal(t, "0b7f849446d3383546d15a480966084442cd2193.image", avatar) + fi, err = os.Stat("/tmp/avatars.test/60/0b7f849446d3383546d15a480966084442cd2193.image") + assert.NoError(t, err) + assert.Equal(t, int64(11392), fi.Size()) + + p = NewLocalFS("/dev/null") + _, err = p.Put("user1", strings.NewReader("some picture bin data")) + assert.EqualError(t, err, "failed to mkdir avatar location /dev/null/30: mkdir /dev/null: not a directory") +} + +func TestAvatarStoreFS_Get(t *testing.T) { + p := NewLocalFS("/tmp/avatars.test") + err := os.MkdirAll("/tmp/avatars.test/30", 0o700) + require.NoError(t, err) + defer os.RemoveAll("/tmp/avatars.test") + + // file not exists + r, size, err := p.Get("some_random_name.image") + assert.Nil(t, r) + assert.Equal(t, 0, size) + assert.EqualError(t, err, "can't load avatar some_random_name.image, id: open /tmp/avatars.test/91/some_random_name.image: no such file or directory") + // file exists + err = os.WriteFile("/tmp/avatars.test/30/b3daa77b4c04a9551b8781d03191fe098f325e67.image", []byte("something"), 0666) //nolint + assert.NoError(t, err) + r, size, err = p.Get("b3daa77b4c04a9551b8781d03191fe098f325e67.image") + + assert.NoError(t, err) + assert.Equal(t, 9, size) + data, err := io.ReadAll(r) + assert.NoError(t, err) + assert.Equal(t, "something", string(data)) +} + +func TestAvatarStoreFS_Location(t *testing.T) { + p := NewLocalFS("/tmp/avatars.test") + + tbl := []struct { + id string + res string + }{ + {"abc", "/tmp/avatars.test/35"}, + {"xyz", "/tmp/avatars.test/69"}, + {"blah blah", "/tmp/avatars.test/29"}, + {"f1881c06eec96db9901c7bbfe41c42a3f08e9cb8", "/tmp/avatars.test/56"}, + } + + for i, tt := range tbl { + assert.Equal(t, tt.res, p.location(tt.id), "test #%d", i) + } +} + +func TestAvatarStoreFS_ID(t *testing.T) { + p := NewLocalFS("/tmp/avatars.test") + err := os.MkdirAll("/tmp/avatars.test/30", 0o700) + require.NoError(t, err) + defer os.RemoveAll("/tmp/avatars.test") + + // file not exists + id := p.ID("some_random_name.image") + assert.Equal(t, "a008de0a2ccb3308b5d99ffff66436e15538f701", id) // "some_random_name.image" + // file exists + err = os.WriteFile("/tmp/avatars.test/30/b3daa77b4c04a9551b8781d03191fe098f325e67.image", []byte("something"), 0666) //nolint + require.NoError(t, err) + touch := time.Date(2017, 7, 14, 2, 40, 0, 0, time.UTC) // 1500000000 + err = os.Chtimes("/tmp/avatars.test/30/b3daa77b4c04a9551b8781d03191fe098f325e67.image", touch, touch) + require.NoError(t, err) + id = p.ID("b3daa77b4c04a9551b8781d03191fe098f325e67.image") + assert.Equal(t, "325d5b451f32c2f8e7f30a9fd65bff6a42954d9a", id) +} + +func TestAvatarStoreFS_Remove(t *testing.T) { + p := NewLocalFS("/tmp/avatars.test") + err := os.MkdirAll("/tmp/avatars.test/30", 0o700) + require.NoError(t, err) + defer os.RemoveAll("/tmp/avatars.test") + + assert.Error(t, p.Remove("no-such-avatar"), "remove non-existing avatar") + err = os.WriteFile("/tmp/avatars.test/30/b3daa77b4c04a9551b8781d03191fe098f325e67.image", []byte("something"), 0666) //nolint + require.NoError(t, err) + + assert.NoError(t, p.Remove("b3daa77b4c04a9551b8781d03191fe098f325e67.image")) + _, err = os.Stat("/tmp/avatars.test/30/b3daa77b4c04a9551b8781d03191fe098f325e67.image") + assert.Error(t, err, "removed for real") + t.Log(err) +} + +func TestAvatarStoreFS_List(t *testing.T) { + p := NewLocalFS("/tmp/avatars.test") + err := os.MkdirAll("/tmp/avatars.test", 0o700) + require.NoError(t, err) + defer os.RemoveAll("/tmp/avatars.test") + + // write some avatars + _, err = p.Put("user1", strings.NewReader("some picture bin data 1")) + require.Nil(t, err) + _, err = p.Put("user2", strings.NewReader("some picture bin data 2")) + require.Nil(t, err) + _, err = p.Put("user3", strings.NewReader("some picture bin data 3")) + require.Nil(t, err) + + l, err := p.List() + assert.NoError(t, err) + assert.Equal(t, 3, len(l), "3 avatars listed") + sort.Strings(l) + assert.Equal(t, []string{"0b7f849446d3383546d15a480966084442cd2193.image", "a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image", "b3daa77b4c04a9551b8781d03191fe098f325e67.image"}, l) + + r, size, err := p.Get("0b7f849446d3383546d15a480966084442cd2193.image") + assert.NoError(t, err) + assert.Equal(t, 23, size) + data, err := io.ReadAll(r) + assert.NoError(t, err) + assert.Equal(t, "some picture bin data 3", string(data)) +} + +func BenchmarkAvatarStoreFS_ID(b *testing.B) { + p := NewLocalFS("/tmp/avatars.test") + _ = os.MkdirAll("/tmp/avatars.test/30", 0o700) + defer os.RemoveAll("/tmp/avatars.test") + err := os.WriteFile("/tmp/avatars.test/30/b3daa77b4c04a9551b8781d03191fe098f325e67.image", []byte("something"), 0666) //nolint + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + p.ID("b3daa77b4c04a9551b8781d03191fe098f325e67.image") + } +} diff --git a/v2/avatar/noop.go b/v2/avatar/noop.go new file mode 100644 index 00000000..560bcfc1 --- /dev/null +++ b/v2/avatar/noop.go @@ -0,0 +1,35 @@ +package avatar + +import ( + "bytes" + "io" +) + +// NoOp is an empty (no-op) implementation of Store interface +type NoOp struct{} + +// NewNoOp provides an empty (no-op) implementation of Store interface +func NewNoOp() *NoOp { return &NoOp{} } + +// String is a NoOp implementation +func (s *NoOp) String() string { return "" } + +// Put is a NoOp implementation +func (s *NoOp) Put(string, io.Reader) (avatarID string, err error) { return "", nil } + +// Get is a NoOp implementation +func (s *NoOp) Get(string) (reader io.ReadCloser, size int, err error) { + return io.NopCloser(bytes.NewBuffer([]byte(""))), 0, nil +} + +// ID is a NoOp implementation +func (s *NoOp) ID(string) (id string) { return "" } + +// Remove is a NoOp implementation +func (s *NoOp) Remove(string) error { return nil } + +// List is a NoOp implementation +func (s *NoOp) List() (ids []string, err error) { return nil, nil } + +// Close is a NoOp implementation +func (s *NoOp) Close() error { return nil } diff --git a/v2/avatar/noop_test.go b/v2/avatar/noop_test.go new file mode 100644 index 00000000..378b4cb4 --- /dev/null +++ b/v2/avatar/noop_test.go @@ -0,0 +1,89 @@ +package avatar + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNoOp_Close(t *testing.T) { + p := NewNoOp() + require.NoError(t, p.Close()) + require.NoError(t, p.Close(), "second call should not result in panic or errors") +} + +func TestNoOp_Get(t *testing.T) { + p := NewNoOp() + reader, size, err := p.Get("blah") + require.NoError(t, err) + require.Zero(t, size) + err = reader.Close() + require.NoError(t, err) + + proxy := Proxy{ + L: nil, + Store: p, + RoutePath: "/avatar", + URL: "http://127.0.0.1:8080", + ResizeLimit: 0, + } + + ts := httptest.NewServer(http.HandlerFunc(proxy.Handler)) + defer ts.Close() + + { + resp, err := http.Get(ts.URL + "/avatar/b3daa77b4c04a9551b8781d03191fe098f325e67.image") + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Zero(t, resp.ContentLength) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Empty(t, body) + err = resp.Body.Close() + require.NoError(t, err) + } + + { + resp, err := http.Get(ts.URL + "/avatar/invalid.image") + require.NoError(t, err) + require.Equal(t, http.StatusForbidden, resp.StatusCode) + err = resp.Body.Close() + require.NoError(t, err) + } + +} + +func TestNoOp_ID(t *testing.T) { + p := NewNoOp() + id := p.ID("blah") + require.Empty(t, id) +} + +func TestNoOp_List(t *testing.T) { + p := NewNoOp() + ids, err := p.List() + require.NoError(t, err) + require.Empty(t, ids) +} + +func TestNoOp_Put(t *testing.T) { + p := NewNoOp() + avatarID, err := p.Put("blah", nil) + require.NoError(t, err) + require.Empty(t, avatarID) +} + +func TestNoOp_Remove(t *testing.T) { + p := NewNoOp() + err := p.Remove("blah") + require.NoError(t, err) +} + +func TestNoOp_String(t *testing.T) { + p := NewNoOp() + s := p.String() + require.Empty(t, s) +} diff --git a/v2/avatar/store.go b/v2/avatar/store.go new file mode 100644 index 00000000..ab51a1b8 --- /dev/null +++ b/v2/avatar/store.go @@ -0,0 +1,133 @@ +package avatar + +import ( + "context" + "crypto/sha1" //nolint gosec + "encoding/hex" + "fmt" + _ "image/gif" // initializing packages for supporting GIF + _ "image/jpeg" // initializing packages for supporting JPEG. + _ "image/png" // initializing packages for supporting PNG. + "io" + "log" + "net/url" + "regexp" + "strings" + "time" + + bolt "go.etcd.io/bbolt" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + + "github.com/go-pkgz/auth/token" +) + +// imgSfx for avatars +const imgSfx = ".image" + +var reValidAvatarID = regexp.MustCompile(`^[a-fA-F0-9]{40}\.image$`) + +// Store defines interface to store and load avatars +type Store interface { + fmt.Stringer + Put(userID string, reader io.Reader) (avatarID string, err error) // save avatar data from the reader and return base name + Get(avatarID string) (reader io.ReadCloser, size int, err error) // load avatar via reader + ID(avatarID string) (id string) // unique id of stored avatar's data + Remove(avatarID string) error // remove avatar data + List() (ids []string, err error) // list all avatar ids + Close() error // close store +} + +// NewStore provides factory for all supported stores making the one +// based on uri protocol. Default (no protocol) is file-system +func NewStore(uri string) (Store, error) { + switch { + case strings.HasPrefix(uri, "file://"): + return NewLocalFS(strings.TrimPrefix(uri, "file://")), nil + case !strings.Contains(uri, "://"): + return NewLocalFS(uri), nil + case strings.HasPrefix(uri, "mongodb://"), strings.HasPrefix(uri, "mongodb+srv://"): + db, bucketName, u, err := parseExtMongoURI(uri) + if err != nil { + return nil, fmt.Errorf("can't parse mongo store uri %s: %w", uri, err) + } + + const timeout = time.Second * 30 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + client, err := mongo.Connect(ctx, options.Client().ApplyURI(u).SetConnectTimeout(timeout)) + if err != nil { + return nil, fmt.Errorf("failed to connect to mongo server: %w", err) + } + if err = client.Ping(ctx, nil); err != nil { + return nil, fmt.Errorf("failed to connect to mongo server: %w", err) + } + return NewGridFS(client, db, bucketName, time.Second*5), nil + case strings.HasPrefix(uri, "bolt://"): + return NewBoltDB(strings.TrimPrefix(uri, "bolt://"), bolt.Options{}) + } + return nil, fmt.Errorf("can't parse store url %s", uri) +} + +// Migrate avatars between stores +func Migrate(dst, src Store) (int, error) { + ids, err := src.List() + if err != nil { + return 0, err + } + for _, id := range ids { + srcReader, _, err := src.Get(id) + if err != nil { + log.Printf("[WARN] can't get reader for avatar %s", id) + continue + } + if _, err = dst.Put(id, srcReader); err != nil { + log.Printf("[WARN] can't put avatar %s", id) + } + if err = srcReader.Close(); err != nil { + log.Printf("[WARN] failed to close avatar %s", id) + } + } + return len(ids), nil +} + +// encodeID hashes id to sha1. Skip encoding for already processed +func encodeID(id string) string { + if reValidAvatarID.MatchString(id) { + return strings.TrimSuffix(id, imgSfx) // already encoded, strip .image + } + return token.HashID(sha1.New(), id) +} + +// parseExtMongoURI extracts extra params ava_db and ava_coll and remove +// from the url. Input example: mongodb://user:password@127.0.0.1:27017/test?ssl=true&ava_db=db1&ava_coll=coll1 +func parseExtMongoURI(uri string) (db, collection, cleanURI string, err error) { + + db, collection = "test", "avatars_fs" + u, err := url.Parse(uri) + if err != nil { + return "", "", "", err + } + if val := u.Query().Get("ava_db"); val != "" { + db = val + } + if val := u.Query().Get("ava_coll"); val != "" { + collection = val + } + + q := u.Query() + q.Del("ava_db") + q.Del("ava_coll") + u.RawQuery = q.Encode() + return db, collection, u.String(), nil +} + +func hash(data []byte, avatarID string) (id string) { + h := sha1.New() + if _, err := h.Write(data); err != nil { + log.Printf("[DEBUG] can't apply sha1 for content of '%s', %s", avatarID, err) + return encodeID(avatarID) + } + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/v2/avatar/store_test.go b/v2/avatar/store_test.go new file mode 100644 index 00000000..34b1a3df --- /dev/null +++ b/v2/avatar/store_test.go @@ -0,0 +1,122 @@ +package avatar + +import ( + "fmt" + "io" + "os" + "sort" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAvatarStore_Migrate(t *testing.T) { + if _, ok := os.LookupEnv("ENABLE_MONGO_TESTS"); !ok { + t.Skip("ENABLE_MONGO_TESTS env variable is not set") + } + // prep localfs + plocal := NewLocalFS("/tmp/avatars.test") + err := os.MkdirAll("/tmp/avatars.test", 0o700) + require.NoError(t, err) + defer os.RemoveAll("/tmp/avatars.test") + + // prep gridfs + pgfs := prepGFStore(t) + + // write to localfs + _, err = plocal.Put("user1", strings.NewReader("some picture bin data 1")) + require.Nil(t, err) + _, err = plocal.Put("user2", strings.NewReader("some picture bin data 2")) + require.Nil(t, err) + _, err = plocal.Put("user3", strings.NewReader("some picture bin data 3")) + require.Nil(t, err) + + // migrate and check reported count + count, err := Migrate(pgfs, plocal) + require.NoError(t, err) + assert.Equal(t, 3, count, "all 3 recs migrated") + + // list avatars + l, err := pgfs.List() + assert.NoError(t, err) + assert.Equal(t, 3, len(l), "3 avatars listed in destination store") + sort.Strings(l) + assert.Equal(t, []string{"0b7f849446d3383546d15a480966084442cd2193.image", "a1881c06eec96db9901c7bbfe41c42a3f08e9cb4.image", "b3daa77b4c04a9551b8781d03191fe098f325e67.image"}, l) + + // try to read one of migrated avatars + r, size, err := pgfs.Get("0b7f849446d3383546d15a480966084442cd2193.image") + assert.NoError(t, err) + assert.Equal(t, 23, size) + data, err := io.ReadAll(r) + assert.NoError(t, err) + assert.Equal(t, "some picture bin data 3", string(data)) +} + +func TestStore_NewStore(t *testing.T) { + tbl := []struct { + name string + uri string + res string + err error + }{ + {"local fs, default", "/tmp/ava_tmp", "localfs, path=/tmp/ava_tmp", nil}, + {"local fs, file: prefix", "file:///tmp/ava_tmp", "localfs, path=/tmp/ava_tmp", nil}, + {"bolt", "bolt:///tmp/ava_tmp", "boltdb, path=/tmp/ava_tmp", nil}, + {"valid mongo", "mongodb://127.0.0.1:27017/test?ava_db=db1&ava_coll=coll1", "mongo (grid fs), db=db1, bucket=coll1", nil}, + {"invalid mongo", "mongodb://127.0.0.1:27018/test?ava_db=db1&ava_coll=coll1", "", fmt.Errorf("failed to connect to mongo server")}, + {"unknown store", "blah:///tmp/ava_tmp", "", fmt.Errorf("can't parse store url blah:///tmp/ava_tmp")}, + } + + for _, tt := range tbl { + tt := tt + t.Run(tt.name, func(t *testing.T) { + if strings.Contains(tt.uri, "mongodb://") && tt.err == nil { + if _, ok := os.LookupEnv("ENABLE_MONGO_TESTS"); !ok { + t.Skip("ENABLE_MONGO_TESTS env variable is not set") + } + } + res, err := NewStore(tt.uri) + if tt.err != nil { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.err.Error()) + return + } + require.NoError(t, err) + assert.Equal(t, tt.res, res.String()) + }) + } +} + +func TestStore_parseExtMongoURI(t *testing.T) { + tbl := []struct { + name string + inp string + db, coll, u string + err error + }{ + {"simple", "blah", "test", "avatars_fs", "blah", nil}, + {"both", "mongodb://user:password@127.0.0.1:27017/test?ssl=true&ava_db=db1&ava_coll=coll1", "db1", "coll1", + "mongodb://user:password@127.0.0.1:27017/test?ssl=true", nil}, + {"default_both", "mongodb://user:password@127.0.0.1:27017/test?ssl=true&xyz=123", "test", "avatars_fs", + "mongodb://user:password@127.0.0.1:27017/test?ssl=true&xyz=123", nil}, + {"default_db", "mongodb://user:password@127.0.0.1:27017/test?ssl=true&xyz=123&ava_coll=coll1", "test", "coll1", + "mongodb://user:password@127.0.0.1:27017/test?ssl=true&xyz=123", nil}, + } + + for _, tt := range tbl { + tt := tt + t.Run(tt.name, func(t *testing.T) { + db, coll, u, err := parseExtMongoURI(tt.inp) + if tt.err != nil { + assert.EqualError(t, err, tt.err.Error()) + return + } + require.NoError(t, err) + assert.Equal(t, tt.db, db) + assert.Equal(t, tt.coll, coll) + assert.Equal(t, tt.u, u) + }) + } +} diff --git a/v2/avatar/testdata/circles.jpg b/v2/avatar/testdata/circles.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c7048c3a4b524ba2009a6820bd7fced0af45378 GIT binary patch literal 23983 zcmd5^2|Uzm_a9VBB8sFkqR3v^g}LQM2-&xo5JHh;H>MQI8bbE$S;xLJWzCYEv4pH+ zUuF!3G4r49d+&XPx^>^~=YPlNWA^8F&Uw!Hp6@x&^9-Sf@Ck5OPFhwPKtcilTqFJg z2*Uuas}5!+0Dyu5fD-@!><5sN&;rPak4T7r08cZ(p5Gn=0Bpp20D$Z@$$#H{P5Q^X zq*|}Z{&-Br{P}}~Ie^HGJ5~>^Z0=ayXXfWU4-mN`tFY(uXvEuZkI(-0@NnlTpI(1J z&Yr192VTvtU2!G61<>p#Vm1~d-sv= zKR|Ji_<{VxMA6B}NcWKK-TV16BoN~JfIT#OX-{9cyzhvz0Xd8H(Tk4*V)wIN$$m$t z(zbYp-|&Ik0SfwK496MIvaz4z;1m!P5*85^llt|lw9GYGIaM`v4NWa=9i!XECZ>1H z%xxap+SxleIz4&n?&10DxmRG&%ixgES7G6C@d=4Z$!}6pb8_?Y3kr*hODe0XYijG> z*Eh6xbar(kdV2eYM@GlSCyB^s7tshLq^*&9AMIsjas%rlEEgZ|KYAq~Hv8QHR(_R5I>QHT6!d2VhR-g2 zR_!;6yT85oWo zrKe{+br#5Yg7qXlJu^2mD;qlpC&zIh&v~A6=g*$wI7c83LPAbXPO+cj2nEHFbBy$i z=YIG`7$CZ(JfROjNk&3+CNdfT2(X^?hyrj9@S61h@s5np`tJ~+3 zo&OFoy8azvbpJcVK>Rzz_(ySL|CsI7kF!Uq5|)#KoLEdn-r30Z9xr5^53_o)8W4nk z$A`<+ts>WVmRT0RDMK^<%H9kC=5t0B%PG-My^#bjt+F7^!3NI3K8VqIZUL7gsowj5lB-lR~+3KIf=8i|2vglO-d zQ654`l3f^eQ7{>O;?FXVfsVm}s?o(IMa2-r5!urOfZ6jJr}N1a=B*l%m*{+8UZH}H z#USAaoUNn3r#IjpRDZ?VS%;x6ttZN2*5P@<&FPXXnM^mAZ;xyw8$p$pjOQxoCl5Z` zRB9M{g^w#+?==>kIdmKo24yt5zv6eAGhicT5U!3E`@p4>+xc?8aX+u9@0t}JKCG|c z!XpQy?I>Q~ry27JEoiQi2M?Yk06yV-osbe)yXRJF!aoD^`}p^&Xni7yjP5*ZEhoph zVC0K%QTJ^%^JI^|4qX8+!N3PY!GjsA&@k+vW362|B^=3zm%$A+7Qb&~m1h;t)b1I} zWm0gaar4RnUCPj~H9;yD49?99y+M9S6aY2XjL*=#XU%7W#S8dEf(7)#Iw$ox9lYM9 z&@L?>fzob5pfTNLk%~7H7gX9-BsW+`Ioi!VK!g5hTmAA)d(0)RL|G%5=fMFUQdCbh zea{t@O0=}6sJq|c=&<+h&3Hb==)Tia$qD~NS|UBJ$0BWi>rMQb=truWK{@;>Dgtpx z;oJFl(F?Q|D08Y--3#)+STC>_udKkJ2i(*`#qoUaaiPzO!jp{HrKsj)At(%Lw!*4$ zU-^;IUAfdDe`<@sDVsem zYs6;2a@_?P-fAL-pz&B@Xv(cFid>a-mc@Ysu%M&XC$l>oZt`#T4G;kQ1OV^A>YSxh zIqPuc*LY?3;zZWEUiZn60vj|ec|&Tj$oiG~Grt&G$#~D)sd2#!3A0Aq{tu3Li!rC` zD%{y=@%yBY?|H~lVVe{cw!tG7khWLkQgW*6ubsZz?F}OQ1IcI$42kqlTFY7Xx}a8#v{f27|TR z!&AsjqB_tLGFL0~WR$$dZt+u|KNV5#Qtrg-^O1TW$EQ-%kCq|0bimF6I+LBK#q|1% zG)6-B>8wNq%rs$mi%N&Ij=AgYkeO&9v{L~RX7FKyb_a}6x)L%yI3P` z!kcKn5&`9nd8}U`{%M;VZjV21vqH9x8*ZPVjlz_!Q3QZoZ8^wHp{jUtyeRoejcR5^ zqtZBHQR3b;W}5ecD9?qdq)g8n=eG{}vzZ7boJL{nPH3P#>OD}k4wJE%J&{2-;ef@n zPGfByVdEAH9T@SFiZ?ihznH+!s$zKdS%GA7Uc8X?A$5%ej)S?pJkM(L@voNo!MiW6 z(=p*vT)O>H0yAXYTQlZQPxw?eb+mWeCKfR z)Do~b(8;Y*Y@e&&-;+4@?xIzzA$XpjZzg62Ec~jrtAD|0#g#IxCAbpQFa|$?PriIp zrfX|V|MXN$WmHE1p!Rv@Mot^DsbJ~C?;;ZC|Lra zLZ|#>?pjL+MyAHP&x+3-9_6Yc!nsa{`oiOP_dBvy7m5pUv`}OkuvN3g1X_y)W-VT^ z;3=J9Uq5x$5!v(QeEf~Hx9{4tcj&G7!JHJ8qg@m2`iA#-8gk6(rw9NZVc?2|Tc(wJ z@(OHfQ!7;CT;_TXHO4v@T4t>eQ9PXl_uRLs2XFd-(4#px-5!9deKHe19&=a&N~y9j zgr%Li$eMC9+K;@IVj{%WJEUEH46f{~O(|RmLU(pk3)FzH4mERUKN^%2fl$2!fY^r# zhGYVO4>(xU!aKa;?T?*f-&`?+*dHMPKxbF8kfHZCuGZmJg76yx5a<8Am~GQo6gN-A z8B?N}Q$virz^J9oQDms~2!8D0OFXki+lTCcKByZ3aC$@5{*0c=C)+;PG+i0-nmT7t zD3-eaTClA>8S=SOTQ**>i|X{73K90nIlM9FLMpqz`ZNEiQ&64RrXC3uN+yPRg;)oMcw!F-<44w74eVap^OyCK9b_u0@%%6*uz{ zVJMt|eBZb_5S2#&oN!H!M6?6Fx6a!xrPyb;W(pc+D1%e?jZqdLaXmy7#YR4a>C4QO zvCH%ttHHV#7Oa+qxn8XQ@&$1I7$IN@xifq`xN_4Xx_~WsuT#!7MOn)m5$ab*j`GYq z(2XEkrt0J_z%yeIpG;RJ_YY1t(SR-4&ad(2Wqy81JiRIeNxV+8Iwb zYarrzB$qy(?N(j&?HPBYJZkh(>E z71qUy!ddM%y1yEHLQP)>RE+lxYiF&kxvimb>KQCCUMI={-5wY?AJq~DyOMa|fx5T% zY}eBazx!JaxFZe$q>ALKmt}=!{dGEqHw1mlBu0Hzmq;t5)3Hbnd(&kmQqQjYdBkdu zTV*4_^7ZUiu%-S1MnalcGkc!&prrwhOW%@9plUeJav_IBo%P7YQI+uPPdEEEY;iM4 zYtccLf|r|J24u9t8IxZ6koDI3HM2@Y^{|Ev)2KYnyvv) zy}GZ*F{mZ_pn_=hdH4~=>2?dF=kdm;{EgNS{kL+(dBwzg9FxbpTnZ&wPlOm|T^z5z z43qRnW*>55f3+2Nzl;J#ea1Xtr-2-!bG{RxR-5%`F8>p%y=_4YvCHD3j~zbNzGQ5Z zG@Z=H>4i37#IW(JUv{m(yNk@Fr&nVsX%jLVpK;$1O(CMbQ0~lA|FYBm!o1aZ1oNrq zlb{q>aLoi}N!V-EsMj@pB3P3p-%;9NXUk;Q^aHp-c3dh&1CTO(=q5Yo1Jx(FdN-04 z1zCm#AFlJWZ2anXy*#C7Xh7l+Bt*9clcE!S!KSWc2p99I#kXlq(pR@-KlW_z9NuC{ zUz)C_tY9JM_H9th6xi`;oxxJp$}2|>zQ(SHl4QpoRo_Z=7XS4m7&*BffUrio@2r+T zLyNuS-fYypGiBE3>-XXF{!5^SvXjk)0$qZwo4SutNEvGG+ugZsPJC!+pjZQ%x^$}p zlRg4{4;3Tvn|`MKj6y8DIq_PDcjQ3ezPaK6$CQwU{zXvv$4##BrQVOp4hDI8>$p;= zE`wtJ^7C-nGl373Y``E6(M#n`L;h+~oyG#YAWSC0HxUk2~>MJ~z0OcN2(p)%Qd?)*a}o(l4Mw z`CFzDTXNIJn!U+3BRW~Chj)++`sIjYomTOM{*UUohX#~8yBWM1e1(OaiZ%4a8|f&; z#pXA7Ve8XQQDw^gl>ySF%77c>x46c~K1Ku$#FQ$>Qyd+Y3(qr5FQ<`^&H7?=lJ95) zW!T%E$}w)%T2}5v?BfT@A2{)0rp=a3RrzC8adF^q4sv`p{x!_E)+lvk+_7AY*ONPG zyqM1V@YjRw%h3Ds*>R_k&{N&|pg=4A4;WZdy+dFH*jybl4Jm;sq54~~kux4|OoY?q z2U3ROF$93brlO_8?ZTG{4eIW$5l?H{YE{(VG%P{8%fqmK{N%cO)!SevWT86tQe0WX zU(O=n;Q1RDW#5vOv(4%#DmH4gTwSaU*o-i%P{s%G9*mjyYs3qpUo~kKMuk@T_pX_q z(bfls$8FB8qGsVOEp4avS_2ig=}7XI4I=&pSO{f;gBkR)KATa5)WMAsr{anqQ3z@QZ|jn%M{;!kc-lI1}VI|O7+ zxCp`enekvpPK#r?pF4|Kx-^9Nt*tW6xo`}x6d0FR+?4MYX<2lRfsX6bgHE(;q&nN0 zPDk=o@lkg3)0Z#gG7Va_L8QET`J-S!H48q6O) zQBY1n-GzTvo2dA#2^K2)71p%XoxbI2pWpNQRh#v;EoHLXp?lJKlM3nyb7yn;Go+)O zEQ8e@%G(;vOPsYU#_HYu$vb)w)4lW%%7&9c?QsM4XfVtOa0r-jTa47lkQ6jL3}V$(Qx-t~_W8Jr<@EN_|MIFb)iAGq-nQ zdwehHbr6$nvuJ>)(yEO4Vm3oPqQkgnQ4>*~Qz6u-R4-!$6%1#(aS^0JI&VpwRu$Mo>ibNK*Kk;gakj!xn2A2F- zog}8XG}%6RwWgWYgprQcP3y05zpp@`fJX-cb#u5!i&JVAKYCDF_1=^^XtO8 z99?C`wF{{X@4xbsJ8bi$uVJCBzi>o1cd;VT+HEPFYlaniIS7~5mh!62qZ#k+Q@ZNdf(~=qwuT4)LS5)&m1Q`$uNTCzClw_L>k?DT4 zL$A+dj`$+^?u?8i{oQ5uN1|*=D$0r~y(BH*n-O8TI5Ky-#V|j?QBE`@xDYe(K+&bi zt!(r=IuAbr*az1Ynrl%vaS6DJ+6wP8btw5Bg~=7n92)mwN^x*TXl91$(biP;heMNhxH*DI`ggu zi-@^O8t>DFo2f`)*03|CMj;Xw==J=w8kRN#k-h^Txmd33UNQd<%8^1OM`u^b*vCK| zFnOXf8YNc4P9QXvQ8koHOl+ened|R0_c~;|Qh@4i_95oyN(bxL5$};j6;570c7AyH zB9&8nY4(*0m$$o*^(~`+$9rTpaxu*XfS525kYd*gR+6iKxZSx<2_-fqH(9P<{Q3(b z{d+M+vP%MHiN1VP5gF%0WbLW$9O|LJa$MqyY`kE{U3b<`!PQ>0B#>ARj#!R#?*B*B zAZ-p44thue;=(1BUyrB5xLi>(yF&m-$vqekXhjU1`tBrn(KaP3?XZNa(K0N&Wl6~A z4sQv(V;yxNgV#gPCA;ZsknuZm{Ovn6@~ftNF_!)x7hUsL*C=JW??}fF!BYAU*l>k# z<-Oh|ws%P{qNP~z{uFA7x-JPX6HQr+?7M#Gy$id~1+h3eMU`3OENk1ml-Z zbgg*a*pY^S+X@)c6Y<_=N*nB0P2LOr)XK!xD*o3-6B}D>Z}l!1^bM)ZG|xeAvSdOVXcc?~>v9Zzo){BJVy zmpkl8MLi}Xot|hR`~^8=N(`M(esI0td^lmJ-~U7>IVSKp;@#{aXVH^Z-by(pd9JKZ z6`j*NqK0e-FQ-zu+BaL_Nr%n2_CTR7yRz$5{?DwBwt(>O!tmXR7p8(7TSt^vEtNY; zHe#fg?S;8!cc~>nJLor+`KFz|oCF$+cx$4z8O<4oh7YCG37TEx{KWi^;+c{Z81rG7 zv?9hc7AL#T*EIRM(4-t9TC@v9Y~QX$$zyim;fXFA@QIj?@#+-bG(1V8wsQ^VVw2mB zX#I&g2;~{~n+4&nD?mc7D%2fnT8FQZ>(R}J)6MT}9_-}Slo{6iN3ABUv502tLtK?~ zmo7mj=t7={siQnacIf;QfsI_oJX$Z$Y%*SOosA+iyQ=b3;3dQKK=#bIoyOV*#(>up z;Xba$kxzhFEz;5xTP@Hy=D`*Swx!J2KZ+Ii&74we#^!F^i zUHCJ(o-IYc&UIA;_dco;S8(WE8lvCzyYhKIP5j74HYBk_HWkh%f`m{L3{i^6Hz80V?9}-@=~uj8govixd8j zctf54^rK?7pJ&Q4U6)&xJe{@*7<46`$M6{ZVKcWXwi`3#X((IQR4n-1oQ%s~Cy!0}y>cXzW{ZB6*%HSNT_jaZx zZpga~HVwK}0h8}jaBj!>rQXOXj@9pdvr7~ZZA!k4a{nHqq#PsKakni!vZ{wipCwHx znQ-tLo7mV$x4wTraMcz1ELL1XM$tLmOFp>wxzH2X#(^6r248cn>%)49ZZ&J1_>5gk}o&>FHtT`R2R#O`;J*gU$BX-F_V}J0G z*D!dML99noiqvhpT29XQ;U)XN=Hlc#^2F~tkv_U$t1IDS>T}<9aal4S<=Cw^+h47s zBY`VxcFHO&^I6As{5;TQ+;@#E&@7=phs;?q|M1n~M`#9{xjFE?;9KpDcxV7Z z@nl-l*E)CqY2ZP(vHC6YF}^*|jG|UnJ(t2I@#t;my+gV4`i54-WgvEeMtxeB)-~Yp z*aJ0Bt`5U|M>d16dBE4K_?09(jf3;jjbxr+ z)Ewl!QvOHce$wip)!t-`4)A@&7${|!06>0p9TNfHgvkyq3G(f|y!rIkECw^HJ=eoC z`ZlM?{B7xca+<_TbSG?D*efOCts=Nvb&~TYNmJV4|RLbA<@}^!wdR=GW z56;=2_P;Y#a4B1$Y8+1|u{kq!F)}T@YbJx(1dKX5g)#X6QU2D{ea61s&gfKi*>328 z&Yk!W2_;=w4Bm}B;$agpKOP-)U6q7wC$nY$*d_Bh{9e~_Fx+Po>6qBB>UXsvxhA43jh$D34J1~T8i^>&%1%&hW;3-QRvW0lC1 z#t`$Sz}&x+-rG9T?G3d(8@tq2*AZ>wj!F_~nmP~la6_TU$%m@R5ba#B#%?XtKaKxe z_SoJbqgd6X(Wy3$s+1NB((X!QltY~~9Q(WiUu~g(HheKp$kB@`49^CWOc>W4ZL=c) zB=HeW(fA`=KKI1dD+WNH@X~wv`@WJ1L%Jt)O$Uw`bty?{8J_Q4qUexeV<(qC!ldJk z>>SFf@S3t-HLl~dn8BWVURpm9poSJ2CS3;QPd53qmyq7yNWke5YyOrJD zN&Ux7EkW)w8j@WtIyniYg*o^V#Ifoyt)K~oCvAl)0WDrH!X zv}|H_5SmR_p$@AslYmGDuN-&Pq~x3k$;Fj}nx@?aPdSCA^$2{&@Y~M#B&L9?l6CUX z(c}QCbjV(V{}@XkQBK_?mM%?vNEIG&qr@y(vZ4H?6UB(^K!Z-@9tw75Y@}^6nZzVMa8zI3n%K8iH{dn@R zy9h6bzjmJAn1oJlR-g;NOThXnNDi|odiADyBA*?oyEZFn?4Iu17i>MDWQbx|W|e`N zL2SLO)})V`dVt`=m*0bjBwLvdb5R)2u%q$kBW~TcoYF6ECJ01^%iB3t6sO9j>V=BZ z@U7=B14s4ka$L{@cqW<7Pg&m7hg{j=*rS~Ozw9!xUmAz;0OfuW)z67aig3t&fDi^fnbKf+BKxow@&uBSRue z4AV&?HpXPG)1V5OC1UE*DXZ^X>45j@LSrW3pQ7j0pXRKfvCtUIdgM5H$@k0wV8$%` zK*+Dt@C+jzAB)SCde<);WhL_E78BbLH1`ts^VMrG{|q@blVP5=#o);XcOq$yGk^ch z)l`@B@}ir4M;7`@?yrgUgqkbE$5%rcitn2Bo5>aNpQ_u4PlV7PIsH^(R;qNiK*;xK zp@1^oG4n@&S*{^>kq&=ux`0%T9>|^ZyVZ_cLifKjyk%=GC6v~jOsPZfrN1CvjR_BT zw;oBo#A!VU4h5sIQ6l1(VCquIV5#Nj5L5Rt6vCh)UY9G9A@O?94HfgVaGie`) z6;(kz3j))ZzUP7T%=eo?iJzPypX1nYm@g!rn|kz65N+D91&6nia+Y|}J}<3xiXzo0 znZ~{LV*;@32d6GNin=hb&gd5a&s2AHyH2BtC$81ylx(gX5@J}?va)iG33h&UHpkGi%UDOL0>{|Z8 zk<+zZF&R>R47-4q4CSiY5R+_M;A>3AhS8a(ewSkrRo?HT=0Q$LbvflTepfee%GS&_ zH!5oL-EDUWy)S;!PGV+))A*I2nX||Mn+p@yK~6zqs%l@#l+{sdjLujzV?N05kc*!t zeIadjmfMRz&B?rHkt1R7*#X)<`;F`iIIbiH?yE{3)bkuylwu7fgoaB}VVwkkhebiO ztatPjpC(?DLnVk^fq8T_2T#FO5koDAlih4w$9bz=%IVaxEwNRp+lR8NVVzWbE4|EHuAe&NZ>*2~+3O(n$hPBNoKOEK`1Tyk6TbBt`-;?I2nM^hjNX#(qR3Io!U-C_JbRQr2n+7R4zmjd8I_-@7GY zNL)U7YjOP6`S!u^4flGcszbrq1V9KYCOuly-DyK`&AYcQE2eh^j&39Xm;`}6YoIY4 z82kCvc%^(G>hr6hQ3YcL!4=RkyuoaIkFxPy(72aIK*M4{UpO$=-(X~rZi>qA9eB%Yys3%yB4?wb)k}|9V_|v%U2Q0qpIEGd{@CS0)-0|Iipr4Vf6Q)HA~P2$FPgC(&^JN1jDlSjb$7=?l$(T|T1iU$UtYTlEa14#`} zO9d*PimDjdw59#tDE~7qzxIVcIy6_TWlOoT#FLeSn&Cw@gHpJex~hQ=Z;0&%MiNd= zbzS^eAajP_U(uW|U0;H{lGZVto3%95VWx?_4^Ui zeA^E}zcUt(PjhanlL`3X7OzmA;0%2nJjhsubE_5 z@6i+*N*SLFGqdHMgQ1d>9&BEU?Tvm z;lU4#Km(~zcl@s{!=^7aMOPpH*tJOd*N)*YqoD|;DfOp#?#*M+wf8#y@z&T#aHT;|8jl-rv%b7+Cdw(DlKF{|Y?f`LDF%~|P zyqx?L&!cp8J%%p=T80aqeu$Z3y0bNRe<{olJ_G*VviG|>aDrOdt0CQS_}uChr}+AL zP2=jc&2Clr@jWx&#$z3C%AnXOyB?P&s~vZjhIQ7D8s1>wIedEhyaBLnff1K#W9*8Eiu$U z9?22kJNE9oZ%N$UcqEf}#%EXK`l-e&GqSi=E%o|wbL+t4-k}nXjUVugjF`S#yMOT; zKJ&FN{CPf@?8#CSlYackPu48AHm^G7ALB9OU4bS~twrM%RF+;|viO`rPR_CRf4_k! z04!i3waZ-*q;U_%yn`l)}KdW zsf26Sp%(^g<2M4KQu_BCoW;rp&SD9Gci@0oV?pfg$@1v5&ao-ruV3@%tzhu^nn$L( zeId?tCALQO%cnkE@0V*EnDo1;ZtIzVTy=_0+#^rP!AQwMNuuw|X&!m>nAB+hy@9~5 z@9f?_{CAjk%L*~Y>W#BqWiQ@e^b-niAZ{$CqO=b3R`HuepOR~kKUmNF64LPwbIR9o z@!k|qHxe;?;UJByf)%0@re^yc7b1Ryq~GDv^$o%7k}zv-_(LWlF-4?R$wcdS&E|pA zHjm+sjHAUh3sUNV9-y#J&DTFMD-q}8i%tFK%j6S$rCalELEC>Qs zc#GtZcDJ&v4yi`O53QrbG4E8=YZ>lep3_PSFOkf-s3ND_>wW$8Unej!qsn#JPTr@1 zeCq4>uPUGa#k#mQ+O-}2Ya0Bp(YOj;EmzV5a|L{|4 ze`6lMJ_B2-$j=c!=As^pjfL?XxkwUr*nX7`shKzn7I!@c4~xrqU8g}Q9yVz`vSeL#Wx+m2f~}#Oi|e)< zUFEsap~2G0cD48h_;5>5)Yx` zR%zitp*ZTs4l5(RCW)vNb*G0A_$T6dPIeW@&V!p^hoq0=k|L; zA)Ipr0NA?Tt48Y3!A93-%l9r;o)YUac|FAuQmvLed1SsQo5W-2bEh;O`7%6kgm7gRaGDV5@6-fLEC zd|(s-;3+<9BFSK7Yr#7-VZUzcBL*q*7&uR^`Z307K3-ZNB)qUx(mX@ED=2cH*RJsI ztSbCP!Abo`S+>Q37X=~o{VLbr3(1kw^68AHd8oLwwc%Zp6;B*frCcCSw)CV)y;JDq z9CJaqrcZrpZC#Vqk~+k^Wzp4t-BqS_IF~}pS18nU-JXAX*H$L|xrFdv0!|0C{1P7? zCVQ3mJ7y^q$`cVeVi_kYtKCCY#BC4zPn!>4v9%2?fgi(zV`4vzpfy2u9^PU87)FM-!Ju7qzC>_HTT!7wEO8oqjR^^A2DP0B`K?D1eTDIt1fmW z4(ij4H9$*S@pN_KtUB#b_ChSIu64d@LU7W|IkN@VY)B+sh?S==q90I?C@p3>%67CT z=%5FP2d9xni1_ff=%7{bFp-%>`Goqm6}(SWiak1d_N+r`^h@^RX*C0Tl@0z|btsu7 z4-A~teY=2fWdKIW*0piiz1}u6JZ(T0qTL<%QBeu0!zq{yFR56? zgZHzDk z;6aNASQqaAT_rs69u(VAf9F^EBuY6Zg;Y~Y%mwOp*0G8=G z7-}inG^C|czuA@kn#X=8P5t@~h96?GTTtWgd~Zvm`#a_SAJyLd>twrZKKncB{Lj|d TZTWk5k^jd_ee3IFgx>!HlUQuC literal 0 HcmV?d00001 diff --git a/v2/avatar/testdata/circles.png b/v2/avatar/testdata/circles.png new file mode 100644 index 0000000000000000000000000000000000000000..fb30946d44016008b5c28eed4719f6ef6e850367 GIT binary patch literal 11392 zcmcI~c{G&m|L|BEiisAxYz^wM7G+8HVklc=k5HCmi!5QrUdeXbiWrKrB#~|Gk{Ggt zME0HRBU^@fuiNu{pZEQq-}~2l&U=pI%>B8p>$9)deLvCD(Lf#HJc2+VP#8^h0|bJG z4}qXIp+^Ep%EMqc1VU5;qpo7;^K)rf<({Q+Wybg)DsR1>uJ~t8&N;k#`&8e@9NQIg z1?F{ZZ8bTARIZLuhwy!24#`x`tVg|3^Zt;D32J!Je>%j&OaJt-hI0mOpEWc1PPnU6 zTP<8#b5Bq=@XUI!Y5J&eHmKZn^iqx%!FgrFaiz3xg&Kj#;ICWP2Ar@fgMT#i2n3Fb z7J-n&{huIf1l2zvB=bMJ4nT~2;0XrRBLJpnqsF zcBM`R%6C}l4&|;)D6Sq(>5lWy86%Fn0#h*Zom|cjaVd8glWkY_*lov(AHevwZdxeL z4M#JcYTj5*rM$HIQ^*YzulmuW^YizRz|pnL$%;VAj8a|AY1k-2b3Jn%vqqETYYR8( zF2ER`X)L*3s%D`A#_hrl`2-NH-QS3ZMJ>R)b*8mLnoa1ArS?&Qo!sML4YsdB9_xOD z&wY+zTeiH2mhPT*Yv2c#%(C%aC~u!`cfDCRUd{spga{Jj!MCmREa`XE;1etrPHR?% z#(PP$fR?BJe#xFC!@j^HnQ9bF+_U9=B>AT2psNvZIL*}QC)#Xci{{^(+zeQO4`qgt z;(3HEO1H8^@pe>zHM7a^EB?YB783BHk2uTX$aEOY^3(WTq`2t^?_%xSR`mAkYKu(^ zCIN#wqF>s4ewS_dxUFK-@DBIfx9lnDgKLZSlxw?m6gEJ<C@m9%KEmsh;jNNspkBzHytE zYFhhdQ*IMtM}QMIMlkZfbnj=cZEul5M7n4HQL*hM{Sd`2k$y;G#ZZEP)i0u3v=sx@ zBdI-0haJnuN1l!IJj;H5PW>5ye4*WHkbI%TYJq%#t{(A;qegZvILv25?Xe91QrshU z4+*v9`;yMryw4CRM;Aq?Byp0#t|;E8c*SPF@tDvcHRN#+wBbw~s=uIJl*h5u)YmBP zB_2&uT_bs=n+GgD9&cRbSJQM?(@58uw##$DS0qwSaa{#gGCIo3CRXC5%J%8YZ=orB zc|z`hd9?g!^b~eW@jJC7&Tzc=P#Gitjd+bpPI=E>NvKW+wqA!3!@a1kD_FZP0@&8H zEzEYc*4v$HBG^J7DQauPe5kHij#s$&?&}9ao1QJ{$AF2y*A~b2$yF}JbmdM9dij## z2AeN*m=COP8Salz*CV=*xXAsG{WHXyZHpRx*#9F_Xfx7B1mHTGexh*7XN1_Y^1EiK zOM3Ka{{gfE#b6&6U-U&xSEQAzP%P9)b?4EySyKcuR-k{2KefAA#I@vBlyBrM7SsKa zucNO)l)d?MeKy6;>gI3z`G|fl{|75lBBhRa&7%T4kJtc%w*G$Em=(JDKW|M4K5cE| zj66>3igDcOC;=Q*=#$-UYEzVFZevJ9)by{Pd{8nMr#evjiOC)jxfE|t-*>qUY8My< zU^}k)W%`nwyLg|NhW%cio_9Sp!yJJMdK(9Qti`boeIco z^&&dv6B0PtS5&0=mCLE$)`jzyH)4Mp?bm};*Uhh3XJB6G-!Zzd6K}}Xp8YUMN!Be= z9(WzDP}J%Bt3YftSUO=CDN96pA!u3>hRbq{Ja?r;@fMh25#9G6RZfBUPv|V{EtBCh z*4W-l!&?kEh5+Rl&VGG3&uhC@}W< zLpt@)xSBzG&C_O98o)88T+HFlcI;HBC<7d(CpA{{#P$YG@j^oVW<#ZCfK9Kv(H8#f z!rOeK#&mx+kMi{n15t>na&Q-MRYcF*dxx>Bf(%5Z9?pb09}1Sl-416m@tNfVogQNPpkA%xmypY-DP>>E7 z`Of=mLi|9L?8;gAKjrdxAXXFE#UC0TC*|NnwJM z`@6$ivp#qA1#3mk z@w^2jGkZZ+Au`(1a`?nsloobJ`mv9tBY`PBd|k4@WgM zG(`8R$KUT@3oo$4vnFHK{v%rfn%&YmO+V$fv4G>mk-aQ z#AeNxv3iSX0?CsjStq`4O5Rl<3p%!nog!OSpXwBa7h)vDK3mG3!KD0QCaEo&CV;Tv ziJmIR_qMr|{+Wd-0P~lFnx6#g-+Oh%TJwM1`ub>0+aO4-Wx=HM43OvY;-TEvN~aVi z6y^kp!$3K9zW`*igNdtCFn96Rm(!8QC2>9s0^kwTKJ5bZY1ml9H$39AoGFf3a=0)f zrUmnn&$^K2e34b5S1|X=K&6z$6Z!v5zU+{0M^v*n!Q+fJCd1m5>j6;bW+)n$pb{zen%FK3C1qN?!5 zbIj7CjnCKklCq90GzA}`XIK8UDZu@0&=VVN&+`c~bl)wyPX$D-47f+Xwy6g7nk=l< zPk=~eS529ZDf%@of=LAFW+I~h<2DHE@K~Y@ulHqtLCo4;Jmzr7j5B)>GbH0ip4?}$ z#B^t5KQEdC`ywnaN(?zh|z`BF3->C1Ez z)9ngsj;SG#K!0k4Sth3nY@50ZyrCXpZrf!wQ52Dfp6^|13J;?N$ntd?_nE!1 zn$XZ;*gg-dx>Y_8^FnBE*g|LjDEz9tzGtSeDw&9m8tQ0Z2PS8K&9*jqO2D#CfA98J zpI9yE>U)rC89=H%PTC$B67Ft7EhyGO=p9A|F#m|gtkxVG9uS#W7r{WCAHbBOJtb^L zaIb0Kz2YAzi1FZL1w!NwRi{6?`jeE~huz#sL=){nsApaxy=o*|EErI!5yNCC+%hBN z2iabDD-86dYr$lri;jjq`_Hc~Ysk{S$CoaFA!5JOg_~2lR(>3Rhd65tq4BqasAy9@ z@DB2%Kv_izN?jn$yS)O5MePP8T}Gla2Q7beDd#%RbFiGK)DJIb@!E7Y8bH#0K-qMn z?8;*fgKJ-5z$P#?cP<+PPXdH4e-#1{k8fLIph=$w1a*`f83dz&So2FdF>qv-yG{g? zV&~g+*cuU3PD(h=q=|t7Kq?KoD*A;5Jam6<4bcJk7{P23@fz*;YFb4Y@PY&xf=T0_ zt#{vjsiu34CMtp~p)9pXiA7(~uXU$mP+fj zp(fgh$x0HZ_fdiM8;w5HYq&E5ng_jHg`?jZ^`I$LRe zLCi&1EY3-4RSXn02e|!GfzyMeqsk^o|Ju%%St-3f+<`Ug$D$$N9)0VI=K9!}?hNjS>fU>Ou8I9=?RXT(O2L!Ah zJ=(@VZ%hM^FaiRhCqAGtI-jU%=%s}uaHQ3E^z}BI}H5F z8WD#!G73BnpTk9@qlx;};L?KTF*-L7ST5o`yYig8#wv)xCJ=*?IFQepK##O#J}3+N8Hr-D2LW=Y#$>fV9!+2XRi}&58ywO4O9=(oWLHLGJlH^v86zt! zqA`M{K>C*rh+ryp5G6stZj5{ong*Z8Nf<$KrRCu(tEnQ&_je&*1 z0m?vvzd1|s1boLLX#iFXlyVjiv`-7vaImp{Kn(bsJO+Y!>|$h+`}-HHI}T;we+Ng- zC>{GpNGH%a9!LcX^S1>5uy-HcHW$SYg&nfVW?3q2rp{6}3vRwJP)-qk|F=oo*TAnt!&?woIFPAjkm-Pf3NV{> z#iLo`&AUL9G*9cMD6FC}*+WzZLY^vT^}hTYjcF=5)d@S2YBB+32^n^CF_&k60EAT* zm8kKV2FWb%EX;ZtoNNO604%*6>=ML?^V0&?or&l_ps=$q7A9ZxHG zLW-YGKoe+Sqq8R;(}J`d@-$YLSD5zqJ`NAhhTJ|gIJCzbGCNhlxxa9?5nBhkTXw*Q zB`L9-4v%*H>8dPL%Z!z>b6(jO>D!{nvD4rDgQcvE)=>iAIp5;??X}X_@2|L9Rv;{6 zyERTGv|-a4DU~}bQsga#YNorgIa8R#@DqUJ?tGucabe7_-ec9fyAq?dM3V(9cUu*2 zQ};4mNYlO>qJL-N=I@o!Cl%g>&ue!VCmWQE_ty&~EP2W6Gj0iL7ztKn=OXtj5;F&o z#VmRPv@=yhF+Mky7FRwm?_)pb?@>J}0-EX9MB7;UUpGl_!87?JlhV~sf>#tEa(XB$J5d-+Hy;cFIpv!CR#6R!QTmODacB1*VHaap!zU&~FR zqJ0{NJ?D!6CX%JP=ekAN-+nV`J^pf6*6GhXr|#6j^{E1K=IhQ9Mm|Zeog2rPR50Dr zy7jH``5w9vGN9kMTT(vM^=8 z|AcJE3^O5OorKAjSFBgjav9BxEVpGa)p@RikqguahAG}_dWJ3o3b6H9|s z0EQE`r$O>Jz?*X@{8W0pE!~EowzY>ri+BqT1$n8*Q~Xo}jn8+zXy~6sMa@B3Yt?H$ z60Ex`sEf(w)!v{fymRN?-PP}*kJN$StlH-$$IoKqHL$11Swl0VWqeBMP@UqoA1D+r zBe}i_V(i6S#XD~#h;}WUiTLW@mUxd2e}_YP?CY&2TPA%>w*%f{W0ZI3-F^LAE8l7w zJBuZ8^(~~F%S<;g-HvT-3l;)7qNCxW%`bb30XcalL(oPot*`c}34=aM2FP>Toj()w zU*opzLJ!g=?rSci+zo*3beh6Kp{&VpCXEHBgwv0by_G!TE2(G$Qv*OQ(Ez&N(-P~uv#FQw+T%Z z6Dp#Z#6$))2D=MhhOS^zsbJ(=jUN)7xh;x`?oSl@-WQ)5vVlZ0v$q((=4pWpn^%k0 zo5UDwuHS#~M~kwe({diC%&3#Qs+W}IQr?fnen6Bk_ercMfi*I`a{fN+ncYgwtrtM8`K$D# z6i+|Yn>pT?%9MC(MD1hQ<=F?KH1wr#-uemk#&y20-qFh+33YDY=59l%M%`-8qbo;q zC}!m@9ya}DwJ!7-0EeE6W5Ssn&G5DZ&s`9is-XNaCub!c+nc?Axwtl;*O&<6Ql;8OfXzq1C7EmiZ+xcLM#UD=feJXta1x{4NOFawzBYIemT9 zrEPE<QEg>9YjyHw-=NELHfLjHb&EWXk!%Nl9qlR(`YB&JOpiz4ROtZ=28Kb2dV zVQ&n_zf~id=~KT6=I-1m99f@F^nB_OWB5a{33y5cK#}6IK)fEQo!fl_n0+-I-uQx) zPQ#$>SfX15+F3w9s(lXX{W=n}zjui>(V&U#B&YHivj$14+is!YtI0Jj`!jS~WHXRZ zMuIgVotiar09R-BgAo*ze1=Oo!nTl{y~?>_b%eK+mmQ09$z_>Ga|v| z+Jq_;IJus`qPW}tN-yQp-QNl|JwV|2LK!h}G_ifwz;Vgn{hpi0Ti}IfItR;IVzXY4 z2EGHn#hx(q%moqT-ukgj>^vGCi7u&tGWuM>bx>6e3LGKzu|z6^Q47iZoVFPlL?U`* z7~45-48}B&2QkspiCK*_ExRH3S~G-J&(gw3L=!gr4kGA5i&VLo0(!oAkCzo|z>$0$ zT=^oI=kk6OIep_N4#w-p`oAHkWkyUJzn88`D` zUnceGb9D=T-$#C4t1^(A4p6v6mwy8I&2s=Xx{UihT0R$l#A$2X|6B7Kr?U06p;&Pu zn1D8e;CaABGue(3kp&$*Q7F>*%-{#^>0leJDSR@t7cC*Xy|)C=bAuH+vLCMq{_SETb);F zK)LDNIGe<9M9?dM=8uf-tmyA7Rp7%X3CD&(&6^5V#aL2kMI;nHg8iZ~ErBdCfLQ7f z;%=ruZa&_6I?SHc&wwEFD%s~aNTx!I;QfU&^+&z9UV;qY-j*mS>yz%Mzt*mLf=bR; zS65;Z*ISB|5i_3xGZ6^47u@Y4#ADSEC$BQ;JrCfe?OQceSwcCj+h0^4->URfMt4RU zhuh{m*M=UbHmCr>RjE0Il$Fy*UgLB+VCFryIQ{1L^z)B7adwC}={_X(7K*Q4f7A97>*Uk0C^YDWY3REcg>Av4 z1c)J`JkBTm1I($GZlTtIz$ipT>+meT2(;wj4|lbo$&%p|m>a-f21oR4BESn~B~D&d z?oJURLI4ad(n2w_7eZm)Idj7f(a>Ad$MuHoa{?Acm$N<%2FUV)>rez$$MQyl)oE5m zNQ+(B8rKAfg%UF55~NI5#MPx zh6zi?>Y(^={D56`gVP0tUz_rMoWTso5B;p8n*yQ5>SHio0GjKU8IA{FGbP;Ar!UmH zfB7Q@QvM;puAKBx=|g?|k-`#lU<-gg>j5ww%h29#5x6|(Q(9)H#Q`+x@8Z_L$oECEnTyNm+f_or-_3nP~wsaf2RyUFQ_5pCx zTyLD{u;}u^K&6UyoFc_X^BOWeD=p#4MaH(FyAV0|4COMa5w6rZ)8fPXNx8s~9) z^31K=p(LA`CA={Wy(G&{ax6gY8y;h2wpizyDa_KpwMEgNIt&s{jwe0wN@!C;mVRqz zK{TY)>vlZJaO|*&)zDcIEocc0zcLHVdkR}x%01j>-?uD=+Z0725`@O;An58CL@~lg zc?sA>wA@Q)MSSU4{E>V2AH8ou&4GcjU>=ITeAk;Y=K6RPqInuMXW*MyNdZNo+EQ8e zjd8a`%)sS>2G!3fNUdhzI_G)XabIAWa@*PUQ) z_2Sa{Yy$9HdKaDKO2Vk^&}hVLuzkRM<)IAe>7+kDzL0K$_RQ_-@U?thl9TQ7SY@9*K2!0qy9YsG|KOR)}w(B$6Px3B%O zH`6@6e~x=eSas(voLUzF<|+yczfqh;GLNB)+Q-@WwiV2Emnhc#wf7q=y1Dqe;=rok z#T5LelxX=Y?E_pVHo#@$Du%^bS}7_ms%-st{8goLX&j=d1z#~^g<}ow_H>c%>=e+B#WGbS=sCRy4UNq{nM)a$ zTZ3}9DC`|FkHPR3UdpTKe>~FqEDV&@y$FVaid|q)A1(pkg>(8loijSjxi8vh30tiQ zYPUJz%#m=-;h_ch}_Qah&?wWfDsy6~kEX2<9qUh{zZ8A@KDt6G$q1l3ck=i8#%@&xiw-gS+8Jc5*$a2MKeCTeG* zA#s1yd|mDh73~=<@%J8acZ$yLkv;Rv0$irW9)rdNB(p>d=K!4vS-I|wf|Q*xhfkS#)x`Hoyo`s*tggY|aQRSNhwn;r1M_Iu-fa*%d;%AK zzc4BYY1R{e3A&Sxo0R=DlgO&H?J97C#jc)(439~ni)kc%cw(Vxs;Bk3w9L0g3L61b za`ULb#r*v1sck48rGdz(ZO}d*%9d6I4UqXG{#+{BxtHAPNwys}#!V=mowE1g?bG&g z47nD1rY2f%{dbEV<=?>40W>--u5FeMImgAB7bh((*}INRZgDi*p$@}D=|fVXWpuIY4yeTfoZ%h z-UicLUVNd4SZ528B#Y}6l5p?lQwMG~x|z3ugt(M7^1xT!B2npVdBqwKBHwFSij39m zhX=>+H{0!BtK?Guv+>qAbrhTT1w`X?wpNLb~p|6iQHi>+-uBMOh(W2u!0~Y61TG zIrSKrMFx7})WA&Nlg#iwvHc!vMqVYjUZKPCdZkg(j)mT~tFWJ1eJkfQ`_`zDgzcn~ zqPwo9Vd^n{N9nQC$kw-o(fo-{Fno?Bkc!ptPNlm&dE}L@tE6JN6ypm?c-|TN$ED;w zPzhxx2jeWh&b;bVHt67r9;YbLMf*QUbp230?J$Zj0N%Mor*FBd==afR=`zS+&)oTZ zIbs%h9k=82#_QwiHj&XTEn)pX-IC+=*yFk(-w zFI^2hYCNb`B^0t}EF#64Q#?O}pstri#`}*=3CQc-KJP$Yvd8OA7ZO*oGlN!7*=%pu?jK}Rmyx%}#QXd`&l>5{N{S~=tOeFo`&-nz(CQX0)eT<&}uKugw zDRIv zTt&jqy_P^&4fQ+Ix+Uv5w_MdchmKV#5@6Av8Mz2`;N51g_puJQg6j1d%qkj-AORqR zPS-7T4))Zb^j)OATVZk;$1WAk-Pde#y?pax;eZ2d8;4tge(|nq9{Nu1BYP*_dlWW% zd_^49W8b6i-beM+j{Fj8sz3yJ{r?nS4+b=Fc~tO?Udf(Ei2g>uxV7UI@Kr zK1ic-<`&J?W9i;cs?MNM0|l{R@QloJE+Bqhf>KbY?X3sw#Ll2a7LLj3pe@R=!Oj)N z5!vakI5n%lqc9B|&CZP-i!&0gmTOP?d^aDG7XsgHxu7;5P8GE*ufo3Q7CbI&rhQJz zM+7Q=w|i}2s`e47v)`hUqh|P0h>X_eL)bxNL;OX~b6+LNj)JjGvB&S@TiaP;b4aFb z0i`GtlwU-B#dN`tuZ6RCS%j;G2Vd^QQcFYqIsdjeE@A}Uc;Fgriwypq6BLWkS@&Y? zy?d@oKU!_?4%T9)c02busz5h{d{;Vb9ug!HBl*zkfOS6tU~a12mJe=AlS?P-UTB* x*mVFpzyO25ebGPP+5U_40Ca#92H^zASe-`>amiT^u>b=Cb45q}y{dKa{{lfs0S*8F literal 0 HcmV?d00001 diff --git a/v2/go.mod b/v2/go.mod new file mode 100644 index 00000000..2b6075d4 --- /dev/null +++ b/v2/go.mod @@ -0,0 +1,49 @@ +module github.com/go-pkgz/auth + +go 1.21 + +require ( + github.com/dghubble/oauth1 v0.7.3 + github.com/go-oauth2/oauth2/v4 v4.5.2 + github.com/go-pkgz/email v0.5.0 + github.com/go-pkgz/repeater v1.1.3 + github.com/go-pkgz/rest v1.19.0 + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/rrivera/identicon v0.0.0-20240116195454-d5ba35832c0d + github.com/stretchr/testify v1.9.0 + go.etcd.io/bbolt v1.3.9 + go.mongodb.org/mongo-driver v1.14.0 + golang.org/x/image v0.15.0 + golang.org/x/oauth2 v0.18.0 +) + +require ( + cloud.google.com/go/compute v1.25.1 // indirect + cloud.google.com/go/compute/metadata v0.2.3 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/golang/snappy v0.0.4 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/klauspost/compress v1.17.7 // indirect + github.com/montanaflynn/stats v0.7.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/tidwall/btree v1.7.0 // indirect + github.com/tidwall/buntdb v1.3.0 // indirect + github.com/tidwall/gjson v1.17.1 // indirect + github.com/tidwall/grect v0.1.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/rtred v0.1.2 // indirect + github.com/tidwall/tinyqueue v0.1.1 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a // indirect + golang.org/x/crypto v0.21.0 // indirect + golang.org/x/sync v0.6.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + google.golang.org/appengine v1.6.8 // indirect + google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/v2/go.sum b/v2/go.sum new file mode 100644 index 00000000..97892557 --- /dev/null +++ b/v2/go.sum @@ -0,0 +1,187 @@ +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go/compute v1.25.1/go.mod h1:oopOIR53ly6viBYxaDhBfJwzUAxf1zE//uf3IB011ls= +cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= +github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dghubble/oauth1 v0.7.3/go.mod h1:oxTe+az9NSMIucDPDCCtzJGsPhciJV33xocHfcR2sVY= +github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8= +github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= +github.com/go-oauth2/oauth2/v4 v4.5.2/go.mod h1:wk/2uLImWIa9VVQDgxz99H2GDbhmfi/9/Xr+GvkSUSQ= +github.com/go-pkgz/email v0.5.0/go.mod h1:BdxglsQnymzhfdbnncEE72a6DrucZHy6I+42LK2jLEc= +github.com/go-pkgz/repeater v1.1.3/go.mod h1:hVTavuO5x3Gxnu8zW7d6sQBfAneKV8X2FjU48kGfpKw= +github.com/go-pkgz/rest v1.19.0/go.mod h1:Po+W6zQzpMPP6XDGLdAN2aW7UKk1IyrLSb48Lp1N3oQ= +github.com/go-session/session v3.1.2+incompatible/go.mod h1:8B3iivBQjrz/JtC68Np2T1yBBLxTan3mn/3OM0CyRt0= +github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= +github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rrivera/identicon v0.0.0-20240116195454-d5ba35832c0d/go.mod h1:TbpErkob6SY7cyozRVSGoB3OlO2qOAgVN8O3KAJ4fMI= +github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= +github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= +github.com/tidwall/buntdb v1.1.2/go.mod h1:xAzi36Hir4FarpSHyfuZ6JzPJdjRZ8QlLZSntE2mqlI= +github.com/tidwall/buntdb v1.3.0/go.mod h1:lZZrZUWzlyDJKlLQ6DKAy53LnG7m5kHyrEHvvcDmBpU= +github.com/tidwall/gjson v1.3.4/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= +github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb/go.mod h1:lKYYLFIr9OIgdgrtgkZ9zgRxRdvPYsExnYBsEAd8W5M= +github.com/tidwall/grect v0.1.4/go.mod h1:9FBsaYRaR0Tcy4UwefBX/UDcDcDy9V5jUcxHzv2jd5Q= +github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/rtred v0.1.2/go.mod h1:hd69WNXQ5RP9vHd7dqekAz+RIdtfBogmglkZSRxCHFQ= +github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e/go.mod h1:/h+UnNGt0IhNNJLkGikcdcJqm66zGD/uJGMRxK/9+Ao= +github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ= +github.com/tidwall/tinyqueue v0.1.1/go.mod h1:O/QNHwrnjqr6IHItYrzoHAKYhBkLI67Q096fQP5zMYw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= +github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4= +github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= +github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= +github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= +go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/v2/logger/interface.go b/v2/logger/interface.go new file mode 100644 index 00000000..85724318 --- /dev/null +++ b/v2/logger/interface.go @@ -0,0 +1,22 @@ +// Package logger defines interface for logging. Implementation should be passed by user. +// Also provides NoOp (do-nothing) and Std (redirect to std log) predefined loggers. +package logger + +import "log" + +// L defined logger interface used everywhere in the package +type L interface { + Logf(format string, args ...interface{}) +} + +// Func type is an adapter to allow the use of ordinary functions as Logger. +type Func func(format string, args ...interface{}) + +// Logf calls f(id) +func (f Func) Logf(format string, args ...interface{}) { f(format, args...) } + +// NoOp logger +var NoOp = Func(func(string, ...interface{}) {}) + +// Std logger sends to std default logger directly +var Std = Func(func(format string, args ...interface{}) { log.Printf(format, args...) }) diff --git a/v2/logger/interface_test.go b/v2/logger/interface_test.go new file mode 100644 index 00000000..46f956cf --- /dev/null +++ b/v2/logger/interface_test.go @@ -0,0 +1,43 @@ +package logger + +import ( + "bytes" + "fmt" + "log" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLogger(t *testing.T) { + buff := bytes.NewBufferString("") + lg := Func(func(format string, args ...interface{}) { + fmt.Fprintf(buff, format, args...) + }) + + lg.Logf("blah %s %d something", "str", 123) + assert.Equal(t, "blah str 123 something", buff.String()) + + Std.Logf("blah %s %d something", "str", 123) + Std.Logf("[DEBUG] auth failed, %s", fmt.Errorf("blah blah")) +} + +func TestStd(t *testing.T) { + buff := bytes.NewBufferString("") + log.SetOutput(buff) + defer log.SetOutput(os.Stdout) + + Std.Logf("blah %s %d something", "str", 123) + assert.True(t, strings.HasSuffix(buff.String(), "blah str 123 something\n"), buff.String()) +} + +func TestNoOp(t *testing.T) { + buff := bytes.NewBufferString("") + log.SetOutput(buff) + defer log.SetOutput(os.Stdout) + + NoOp.Logf("blah %s %d something", "str", 123) + assert.Equal(t, "", buff.String()) +} diff --git a/v2/middleware/auth.go b/v2/middleware/auth.go new file mode 100644 index 00000000..64d6c7ce --- /dev/null +++ b/v2/middleware/auth.go @@ -0,0 +1,262 @@ +// Package middleware provides login middlewares: +// - Auth: adds auth from session and populates user info +// - Trace: populates user info if token presented +// - AdminOnly: restrict access to admin users only +package middleware + +import ( + "crypto/subtle" + "fmt" + "net/http" + "strings" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/provider" + "github.com/go-pkgz/auth/token" +) + +// Authenticator is top level auth object providing middlewares +type Authenticator struct { + logger.L + JWTService TokenService + Providers []provider.Service + Validator token.Validator + AdminPasswd string + BasicAuthChecker BasicAuthFunc + RefreshCache RefreshCache +} + +// RefreshCache defines interface storing and retrieving refreshed tokens +type RefreshCache interface { + Get(key interface{}) (value interface{}, ok bool) + Set(key, value interface{}) +} + +// TokenService defines interface accessing tokens +type TokenService interface { + Parse(tokenString string) (claims token.Claims, err error) + Set(w http.ResponseWriter, claims token.Claims) (token.Claims, error) + Get(r *http.Request) (claims token.Claims, token string, err error) + IsExpired(claims token.Claims) bool + Reset(w http.ResponseWriter) +} + +// BasicAuthFunc type is an adapter to allow the use of ordinary functions as BasicAuth. +// The second return parameter `User` need for add user claims into context of request. +type BasicAuthFunc func(user, passwd string) (ok bool, userInfo token.User, err error) + +// adminUser sets claims for an optional basic auth +var adminUser = token.User{ + ID: "admin", + Name: "admin", + Attributes: map[string]interface{}{ + "admin": true, + }, +} + +// Auth middleware adds auth from session and populates user info +func (a *Authenticator) Auth(next http.Handler) http.Handler { + return a.auth(true)(next) +} + +// Trace middleware doesn't require valid user but if user info presented populates info +func (a *Authenticator) Trace(next http.Handler) http.Handler { + return a.auth(false)(next) +} + +// auth implements all logic for authentication (reqAuth=true) and tracing (reqAuth=false) +func (a *Authenticator) auth(reqAuth bool) func(http.Handler) http.Handler { + + onError := func(h http.Handler, w http.ResponseWriter, r *http.Request, err error) { + if !reqAuth { // if no auth required allow to proceeded on error + h.ServeHTTP(w, r) + return + } + a.Logf("[DEBUG] auth failed, %v", err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } + + f := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + + // use admin user basic auth if enabled but ignore when BasicAuthChecker defined + if a.BasicAuthChecker == nil && a.basicAdminUser(r) { + r = token.SetUserInfo(r, adminUser) + h.ServeHTTP(w, r) + return + } + + // use custom basic auth if BasicAuthChecker defined + if a.BasicAuthChecker != nil { + if user, passwd, isBasicAuth := r.BasicAuth(); isBasicAuth { + ok, userInfo, err := a.BasicAuthChecker(user, passwd) + if err != nil { + onError(h, w, r, fmt.Errorf("basic auth check failed: %w", err)) + return + } + if !ok { + onError(h, w, r, fmt.Errorf("credentials are wrong for basic auth: %w", err)) + return + } + r = token.SetUserInfo(r, userInfo) // pass user claims into context of incoming request + h.ServeHTTP(w, r) + return + } + } + + claims, tkn, err := a.JWTService.Get(r) + if err != nil { + onError(h, w, r, fmt.Errorf("can't get token: %w", err)) + return + } + + if claims.Handshake != nil { // handshake in token indicates special use cases, not for login + onError(h, w, r, fmt.Errorf("invalid kind of token")) + return + } + + if claims.User == nil { + onError(h, w, r, fmt.Errorf("no user info presented in the claim")) + return + } + + if claims.User != nil { // if uinfo in token populate it to context + // validator passed by client and performs check on token or/and claims + if a.Validator != nil && !a.Validator.Validate(tkn, claims) { + onError(h, w, r, fmt.Errorf("user %s/%s blocked", claims.User.Name, claims.User.ID)) + a.JWTService.Reset(w) + return + } + + // check if user provider is allowed + if !a.isProviderAllowed(claims.User.ID) { + onError(h, w, r, fmt.Errorf("user %s/%s provider is not allowed", claims.User.Name, claims.User.ID)) + a.JWTService.Reset(w) + return + } + + if a.JWTService.IsExpired(claims) { + if claims, err = a.refreshExpiredToken(w, claims, tkn); err != nil { + a.JWTService.Reset(w) + onError(h, w, r, fmt.Errorf("can't refresh token: %w", err)) + return + } + } + + r = token.SetUserInfo(r, *claims.User) // populate user info to request context + } + + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } + return f +} + +// isProviderAllowed checks if user provider is allowed, user id looks like "provider_1234567890" +// this check is needed to reject users from providers what are used to be allowed but not anymore. +// Such users made token before the provider was disabled and should not be allowed to login anymore. +func (a *Authenticator) isProviderAllowed(userID string) bool { + userProvider := strings.Split(userID, "_")[0] + for _, p := range a.Providers { + if p.Name() == userProvider { + return true + } + } + return false +} + +// refreshExpiredToken makes a new token with passed claims +func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token.Claims, tkn string) (token.Claims, error) { + + // cache refreshed claims for given token in order to eliminate multiple refreshes for concurrent requests + if a.RefreshCache != nil { + if c, ok := a.RefreshCache.Get(tkn); ok { + // already in cache + return c.(token.Claims), nil + } + } + + claims.ExpiresAt = 0 // this will cause now+duration for refreshed token + c, err := a.JWTService.Set(w, claims) // Set changes token + if err != nil { + return token.Claims{}, err + } + + if a.RefreshCache != nil { + a.RefreshCache.Set(tkn, c) + } + + a.Logf("[DEBUG] token refreshed for %+v", claims.User) + return c, nil +} + +// AdminOnly middleware allows access for admins only +// this handler internally wrapped with auth(true) to avoid situation if AdminOnly defined without prior Auth +func (a *Authenticator) AdminOnly(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + user, err := token.GetUserInfo(r) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + if !user.IsAdmin() { + http.Error(w, "Access denied", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + } + return a.auth(true)(http.HandlerFunc(fn)) // enforce auth +} + +// basic auth for admin user +func (a *Authenticator) basicAdminUser(r *http.Request) bool { + + if a.AdminPasswd == "" { + return false + } + + user, passwd, ok := r.BasicAuth() + if !ok { + return false + } + + // using ConstantTimeCompare to avoid timing attack + if user != "admin" || subtle.ConstantTimeCompare([]byte(passwd), []byte(a.AdminPasswd)) != 1 { + a.Logf("[WARN] admin basic auth failed, user/passwd mismatch, %s:%s", user, passwd) + return false + } + + return true +} + +// RBAC middleware allows role based control for routes +// this handler internally wrapped with auth(true) to avoid situation if RBAC defined without prior Auth +func (a *Authenticator) RBAC(roles ...string) func(http.Handler) http.Handler { + + f := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + user, err := token.GetUserInfo(r) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + var matched bool + for _, role := range roles { + if strings.EqualFold(role, user.Role) { + matched = true + break + } + } + if !matched { + http.Error(w, "Access denied", http.StatusForbidden) + return + } + h.ServeHTTP(w, r) + } + return a.auth(true)(http.HandlerFunc(fn)) // enforce auth + } + return f +} diff --git a/v2/middleware/auth_test.go b/v2/middleware/auth_test.go new file mode 100644 index 00000000..18208aee --- /dev/null +++ b/v2/middleware/auth_test.go @@ -0,0 +1,571 @@ +package middleware + +import ( + "fmt" + "io" + "log" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/provider" + "github.com/go-pkgz/auth/token" +) + +var testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.orBYt_pVA4uvCCw0JMQLla3DA0mpjRTl_U9vT_wtI30" + +var testJwtValidWrongProvider = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjNfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.p0w7GmXKwujm0ROn0RIACnBwN4KmPcqXDMS9YoFq4jQ" + +var testJwtExpired = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6MTE4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.PlRRc5YA6pvoVOT4NLLOoTwU2Kn3GaTfbjr6j-P6RhA" + +var testJwtWithHandshake = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJoYW5kc2hha2UiOnsic3RhdGUiOiIxMjM0NTYiLCJmcm9tIjoiZnJvbSIsImlkIjoibXlpZC0xMjM0NTYifX0._2X1cAEoxjLA7XuN8xW8V9r7rYfP_m9lSRz_9_UFzac" + +var testJwtNoUser = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4MjIsImp0aSI6InJhbmRvbSBpZCIsImlzcyI6InJlbWFyazQyIiwibmJmIjoxNTI2ODg0MjIyfQ.sBpblkbBRzZsBSPPNrTWqA5h7h54solrw5L4IypJT_o" + +var testJwtWithRole = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9LCJyb2xlIjoiZW1wbG95ZWUifX0.o95raB0aNl2TWUs43Tu6xyX5Y3Fa5wv6_6RFJuN-d6g" + +func TestAuthJWTCookie(t *testing.T) { + a := makeTestAuth(t) + + mux := http.NewServeMux() + handler := func(w http.ResponseWriter, r *http.Request) { + u, err := token.GetUserInfo(r) + assert.NoError(t, err) + assert.Equal(t, token.User{Name: "name1", ID: "provider1_id1", Picture: "http://example.com/pic.png", + IP: "127.0.0.1", Email: "me@example.com", Audience: "test_sys", + Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}}, u) + w.WriteHeader(201) + } + mux.Handle("/auth", a.Auth(http.HandlerFunc(handler))) + server := httptest.NewServer(mux) + defer server.Close() + + client := &http.Client{Timeout: 5 * time.Second} + expiration := int(365 * 24 * time.Hour.Seconds()) //nolint + + t.Run("valid token", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "valid token user") + }) + + t.Run("valid token, wrong provider", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValidWrongProvider, HttpOnly: true, Path: "/", + MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "user name1/provider3_id1 provider is not allowed") + }) + + t.Run("xsrf mismatch", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "wrong id") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "xsrf mismatch") + }) + + t.Run("token expired and refreshed", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed") + }) + + t.Run("no user info in the token", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtNoUser, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "no user info in the token") + }) +} + +func TestAuthJWTHeader(t *testing.T) { + a := makeTestAuth(t) + server := httptest.NewServer(makeTestMux(t, &a, true)) + defer server.Close() + + client := &http.Client{Timeout: 5 * time.Second} + t.Run("valid token", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.Header.Add("X-JWT", testJwtValid) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "valid token user") + }) + + t.Run("valid token, wrong provider", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.Header.Add("X-JWT", testJwtValidWrongProvider) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "wrong provider") + }) + + t.Run("token expired", func(t *testing.T) { + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.Header.Add("X-JWT", testJwtExpired) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "token expired") + }) +} + +func TestAuthJWTRefresh(t *testing.T) { + a := makeTestAuth(t) + server := httptest.NewServer(makeTestMux(t, &a, true)) + defer server.Close() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + + expiration := int(365 * 24 * time.Hour.Seconds()) //nolint + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "token expired and refreshed") + + cookies := resp.Cookies() + assert.Equal(t, 2, len(cookies)) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + t.Log(resp.Cookies()[0].Value) + assert.True(t, resp.Cookies()[0].Value != testJwtExpired, "jwt token changed") + + claims, err := a.JWTService.Parse(resp.Cookies()[0].Value) + assert.NoError(t, err) + ts := time.Unix(claims.ExpiresAt, 0) + assert.True(t, ts.After(time.Now()), "expiration in the future") + log.Print(time.Unix(claims.ExpiresAt, 0)) +} + +func TestAuthJWTRefreshConcurrentWithCache(t *testing.T) { + + a := makeTestAuth(t) + server := httptest.NewServer(makeTestMux(t, &a, true)) + defer server.Close() + + var refreshCount int32 + var wg sync.WaitGroup + a.RefreshCache = newTestRefreshCache() + wg.Add(100) + for i := 0; i < 100; i++ { + time.Sleep(1 * time.Millisecond) // TODO! not sure how testRefreshCache may have misses without this delay + go func() { + defer wg.Done() + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + + expiration := int(365 * 24 * time.Hour.Seconds()) //nolint + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", + MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode) + + cookies := resp.Cookies() + if len(cookies) == 2 && resp.Cookies()[0].Name == "JWT" && resp.Cookies()[0].Value != testJwtExpired { + atomic.AddInt32(&refreshCount, 1) + } + }() + } + wg.Wait() + assert.Equal(t, int32(1), a.RefreshCache.(*testRefreshCache).misses, "1 cache miss") + assert.Equal(t, int32(99), a.RefreshCache.(*testRefreshCache).hits, "99 cache hits") + assert.Equal(t, int32(1), atomic.LoadInt32(&refreshCount), "should make one refresh only") + + // make another expired token + c, err := a.JWTService.Parse(testJwtExpired) + require.NoError(t, err) + c.User.ID = "provider1_other ID" + tkSvc := a.JWTService.(*token.Service) + tkn, err := tkSvc.Token(c) + require.NoError(t, err) + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + expiration := int(365 * 24 * time.Hour.Seconds()) //nolint + req.AddCookie(&http.Cookie{Name: "JWT", Value: tkn, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode) + + cookies := resp.Cookies() + require.Equal(t, 2, len(cookies)) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, tkn, resp.Cookies()[0].Value) + t.Log(resp.Cookies()[0].Value) +} + +type badJwtService struct { + *token.Service +} + +func (b *badJwtService) Set(http.ResponseWriter, token.Claims) (token.Claims, error) { + return token.Claims{}, fmt.Errorf("jwt set fake error") +} + +func TestAuthJWTRefreshFailed(t *testing.T) { + + a := makeTestAuth(t) + server := httptest.NewServer(makeTestMux(t, &a, true)) + defer server.Close() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + a.JWTService = &badJwtService{Service: a.JWTService.(*token.Service)} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + + expiration := int(365 * 24 * time.Hour.Seconds()) //nolint + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtExpired, HttpOnly: true, Path: "/", + MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + + require.Nil(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 401, resp.StatusCode) + + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "Unauthorized\n", string(data)) +} + +func TestAuthJWtBlocked(t *testing.T) { + a := makeTestAuth(t) + a.Validator = token.ValidatorFunc(func(token string, claims token.Claims) bool { return false }) + server := httptest.NewServer(makeTestMux(t, &a, true)) + defer server.Close() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.Header.Add("X-JWT", testJwtValid) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "blocked user") +} + +func TestAuthJWtWithHandshake(t *testing.T) { + a := makeTestAuth(t) + server := httptest.NewServer(makeTestMux(t, &a, true)) + defer server.Close() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.Nil(t, err) + req.Header.Add("X-JWT", testJwtWithHandshake) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "blocked user") +} + +func TestAuthWithBasic(t *testing.T) { + a := makeTestAuth(t) + server := httptest.NewServer(makeTestMux(t, &a, true)) + defer server.Close() + + client := &http.Client{Timeout: 1 * time.Second} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("admin", "123456") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "valid token user") + + req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("dev", "xyz") + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "wrong token creds") + + a.AdminPasswd = "" // disable admin + req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("admin", "123456") + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "admin with basic not allowed") +} + +func TestAuthWithBasicChecker(t *testing.T) { + a := makeTestAuth(t) + a.AdminPasswd = "" // disable admin + a.BasicAuthChecker = func(user, passwd string) (bool, token.User, error) { + if user == "basic_user" && passwd == "123456" { + return true, token.User{Name: user, Role: "test_r"}, nil + } + return false, token.User{}, fmt.Errorf("basic auth credentials check failed") + } + + server := httptest.NewServer(makeTestMux(t, &a, true)) + defer server.Close() + + client := &http.Client{Timeout: 1 * time.Second} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("basic_user", "123456") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "valid basic user") + + req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("dev", "xyz") + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "wrong basic auth creds") + + a.BasicAuthChecker = nil // disable basicAuthChecker + req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("admin", "123456") + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "auth with basic not allowed") +} + +func TestAuthNotRequired(t *testing.T) { + a := makeTestAuth(t) + server := httptest.NewServer(makeTestMux(t, &a, false)) + defer server.Close() + + client := &http.Client{Timeout: 1 * time.Second} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "no token user") + + req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.Header.Add("X-JWT", testJwtValid) + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "valid token user") + + req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.Header.Add("X-JWT", testJwtWithHandshake) + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "wrong token") +} + +func TestAdminRequired(t *testing.T) { + a := makeTestAuth(t) + mux := http.NewServeMux() + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + } + mux.Handle("/auth", a.AdminOnly(http.HandlerFunc(handler))) + + server := httptest.NewServer(mux) + defer server.Close() + + client := &http.Client{Timeout: 10 * time.Second} + req, err := http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("admin", "123456") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "valid token user, admin") + + adminUser.SetAdmin(false) + req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.SetBasicAuth("admin", "123456") + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 403, resp.StatusCode, "valid token user, not admin") + + req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "not authorized") + + req, err = http.NewRequest("GET", server.URL+"/auth", http.NoBody) + require.NoError(t, err) + req.Header.Add("X-JWT", "bad bad token") + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 401, resp.StatusCode, "not authorized") +} + +func TestRBAC(t *testing.T) { + a := makeTestAuth(t) + + mux := http.NewServeMux() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u, err := token.GetUserInfo(r) + assert.NoError(t, err) + assert.Equal(t, token.User{Name: "name1", ID: "provider1_id1", Picture: "http://example.com/pic.png", + IP: "127.0.0.1", Email: "me@example.com", Audience: "test_sys", + Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}, + Role: "employee"}, u) + w.WriteHeader(201) + }) + + mux.Handle("/authForEmployees", a.RBAC("someone", "employee")(handler)) + mux.Handle("/authForExternals", a.RBAC("external")(handler)) + server := httptest.NewServer(mux) + defer server.Close() + + // employee route only, token with employee role + expiration := int(365 * 24 * time.Hour.Seconds()) //nolint + req, err := http.NewRequest("GET", server.URL+"/authForEmployees", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtWithRole, HttpOnly: true, Path: "/", + MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "valid token user") + + // employee route only, token without an employee role + expiration = int(365 * 24 * time.Hour.Seconds()) //nolint + req, err = http.NewRequest("GET", server.URL+"/authForEmployees", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", + MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + + client = &http.Client{Timeout: 5 * time.Second} + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, 403, resp.StatusCode, "valid token user, incorrect role") + + // external route only, token with employee role + req, err = http.NewRequest("GET", server.URL+"/authForExternals", http.NoBody) + require.Nil(t, err) + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtWithRole, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, 403, resp.StatusCode) + + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "Access denied\n", string(data)) +} + +func makeTestMux(_ *testing.T, a *Authenticator, required bool) http.Handler { + mux := http.NewServeMux() + authMiddleware := a.Auth + if !required { + authMiddleware = a.Trace + } + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + } + mux.Handle("/auth", authMiddleware(http.HandlerFunc(handler))) + return mux +} + +func makeTestAuth(_ *testing.T) Authenticator { + + j := token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "xyz 12345", nil }), + SecureCookies: false, + TokenDuration: time.Second, + CookieDuration: time.Hour * 24 * 31, + ClaimsUpd: token.ClaimsUpdFunc(func(claims token.Claims) token.Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + }) + + return Authenticator{ + AdminPasswd: "123456", + JWTService: j, + Validator: token.ValidatorFunc(func(token string, claims token.Claims) bool { return true }), + L: logger.Std, + Providers: []provider.Service{ + {Provider: provider.DirectHandler{ProviderName: "provider1"}}, + {Provider: provider.DirectHandler{ProviderName: "provider2"}}, + }, + } +} + +type testRefreshCache struct { + data map[interface{}]interface{} + sync.RWMutex + hits, misses int32 +} + +func newTestRefreshCache() *testRefreshCache { + return &testRefreshCache{data: make(map[interface{}]interface{})} +} + +func (c *testRefreshCache) Get(key interface{}) (value interface{}, ok bool) { + c.RLock() + defer c.RUnlock() + value, ok = c.data[key] + if ok { + atomic.AddInt32(&c.hits, 1) + } else { + atomic.AddInt32(&c.misses, 1) + } + return value, ok +} + +func (c *testRefreshCache) Set(key, value interface{}) { + c.Lock() + defer c.Unlock() + c.data[key] = value +} diff --git a/v2/middleware/user_updater.go b/v2/middleware/user_updater.go new file mode 100644 index 00000000..34dc5dcd --- /dev/null +++ b/v2/middleware/user_updater.go @@ -0,0 +1,38 @@ +package middleware + +import ( + "net/http" + + "github.com/go-pkgz/auth/token" +) + +// UserUpdater defines interface adding extras or modifying UserInfo in request context +type UserUpdater interface { + Update(claims token.User) token.User +} + +// UserUpdFunc type is an adapter to allow the use of ordinary functions as UserUpdater. If f is a function +// with the appropriate signature, UserUpdFunc(f) is a Handler that calls f. +type UserUpdFunc func(user token.User) token.User + +// Update calls f(user) +func (f UserUpdFunc) Update(user token.User) token.User { + return f(user) +} + +// UpdateUser update user info with UserUpdater if it exists in request's context. Otherwise do nothing. +// should be placed after either Auth, Trace. AdminOnly or RBAC middleware. +func (a *Authenticator) UpdateUser(upd UserUpdater) func(http.Handler) http.Handler { + f := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + // call update only if user info exists, otherwise do nothing + if user, err := token.GetUserInfo(r); err == nil { + r = token.SetUserInfo(r, upd.Update(user)) + } + + h.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } + return f +} diff --git a/v2/middleware/user_updater_test.go b/v2/middleware/user_updater_test.go new file mode 100644 index 00000000..3797ab81 --- /dev/null +++ b/v2/middleware/user_updater_test.go @@ -0,0 +1,71 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/token" +) + +func TestUserUpdate(t *testing.T) { + a := makeTestAuth(t) + mux := http.NewServeMux() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userInfo, err := token.GetUserInfo(r) + require.NoError(t, err) + assert.Equal(t, "testValue", userInfo.StrAttr("testAttr")) + + w.WriteHeader(201) + }) + upd := UserUpdFunc(func(user token.User) token.User { + user.SetStrAttr("testAttr", "testValue") + return user + }) + updateUserHandler := a.UpdateUser(upd)(handler) + mux.Handle("/trace", a.Trace(updateUserHandler)) + + server := httptest.NewServer(mux) + defer server.Close() + + client := &http.Client{Timeout: 10 * time.Second} + + // check everything works if there is no Trace/Auth/AdminOnly middleware + req, err := http.NewRequest("GET", server.URL+"/trace", http.NoBody) + require.NoError(t, err) + req.Header.Add("X-JWT", testJwtValid) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "trace with userUpdate") + +} + +func TestUserUpdate_WithoutAuth(t *testing.T) { + a := makeTestAuth(t) + mux := http.NewServeMux() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + }) + upd := UserUpdFunc(func(user token.User) token.User { + t.Fatal("should not be called without auth") + return user + }) + updateUserHandler := a.UpdateUser(upd)(handler) + mux.Handle("/no_auth", updateUserHandler) + + server := httptest.NewServer(mux) + defer server.Close() + + client := &http.Client{Timeout: 10 * time.Second} + + // check everything works if there is no Trace/Auth/AdminOnly middleware + req, err := http.NewRequest("GET", server.URL+"/no_auth", http.NoBody) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 201, resp.StatusCode, "no auth") +} diff --git a/v2/provider/apple.go b/v2/provider/apple.go new file mode 100644 index 00000000..2663cf11 --- /dev/null +++ b/v2/provider/apple.go @@ -0,0 +1,540 @@ +package provider + +// Implementation sign in with Apple for allow users to sign in to web services using their Apple ID. +// For correct work this provider user must has Apple developer account and correct configure "sign in with Apple" at in +// See more: https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_rest_api +// and https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_js/incorporating_sign_in_with_apple_into_other_platforms + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/sha1" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" + + "golang.org/x/oauth2" + + "github.com/go-pkgz/rest" + "github.com/golang-jwt/jwt" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +const ( + // appleAuthURL is the base authentication URL for sign in with Apple ID and fetch request code for user validation request. + appleAuthURL = "https://appleid.apple.com/auth/authorize" + + // appleTokenURL is the endpoint for verifying tokens and get user unique ID and E-mail + appleTokenURL = "https://appleid.apple.com/auth/token" // #nosec + + // appleRequestContentType is the valid type which apple REST API accept only + appleRequestContentType = "application/x-www-form-urlencoded" + + // UserAgent required to every request to Apple REST API + defaultUserAgent = "github.com/go-pkgz/auth" + + // AcceptJSONHeader is the content to accept from response + AcceptJSONHeader = "application/json" +) + +// appleVerificationResponse is based on https://developer.apple.com/documentation/signinwithapplerestapi/tokenresponse +type appleVerificationResponse struct { + // A token used to access allowed user data, but now not implemented public interface for it. + AccessToken string `json:"access_token"` + + // Access token type, always equal the "bearer". + TokenType string `json:"token_type"` + + // Access token expires time in seconds. Always equal 3600 seconds (1 hour) + ExpiresIn int `json:"expires_in"` + + // The refresh token used to regenerate new access tokens. + RefreshToken string `json:"refresh_token"` + + // Main JSON Web Token that contains the user’s identity information. + IDToken string `json:"id_token"` + + // Used to capture any error returned in response. Always check error for empty + Error string `json:"error"` +} + +// AppleConfig is the main oauth2 required parameters for "Sign in with Apple" +type AppleConfig struct { + ClientID string // the identifier Services ID for your app created in Apple developer account. + TeamID string // developer Team ID (10 characters), required for create JWT. It available, after signed in at developer account, by link: https://developer.apple.com/account/#/membership + KeyID string // private key ID assigned to private key obtain in Apple developer account + ResponseMode string // changes method of receiving data in callback. Default value "form_post" (https://developer.apple.com/documentation/sign_in_with_apple/request_an_authorization_to_the_sign_in_with_apple_server?changes=_1_2#4066168) + + scopes []string // for this package allow only username scope and UID in token claims. Apple service API provide only "email" and "name" scope values (https://developer.apple.com/documentation/sign_in_with_apple/clientconfigi/3230955-scope) + privateKey interface{} // private key from Apple obtained in developer account (the keys section). Required for create the Client Secret (https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048) + publicKey crypto.PublicKey // need for validate sign of token + clientSecret string // is the JWT client secret will create after first call and then used until expired + jwkURL string // URL for fetch JWK Apple keys, need redefine for tests +} + +// AppleHandler implements login via Apple ID +type AppleHandler struct { + Params + + // all of these fields specific to particular oauth2 provider + name string + // infoURL string not implemented at Apple side + endpoint oauth2.Endpoint + + mapUser func(jwt.MapClaims) token.User // map info from InfoURL to User + conf AppleConfig // main config for Apple auth provider + + PrivateKeyLoader PrivateKeyLoaderInterface // custom function interface for load private key + +} + +// PrivateKeyLoaderInterface interface for implement custom loader for Apple private key from user source +type PrivateKeyLoaderInterface interface { + LoadPrivateKey() ([]byte, error) +} + +// LoadFromFileFunc is the type for use pre-defined private key loader function +// Path field must be set with actual path to private key file +type LoadFromFileFunc struct { + Path string +} + +// LoadApplePrivateKeyFromFile return instance for pre-defined loader function from local file +func LoadApplePrivateKeyFromFile(path string) LoadFromFileFunc { + return LoadFromFileFunc{ + Path: path, + } +} + +// LoadPrivateKey implement pre-defined (built-in) PrivateKeyLoaderInterface interface method for load private key from local file +func (lf LoadFromFileFunc) LoadPrivateKey() ([]byte, error) { + if lf.Path == "" { + return nil, fmt.Errorf("empty private key path not allowed") + } + + keyFile, err := os.Open(lf.Path) + if err != nil { + return nil, err + } + keyValue, err := io.ReadAll(keyFile) + if err != nil { + return nil, err + } + err = keyFile.Close() + return keyValue, err +} + +// NewApple create new AppleProvider instance with a user parameters +// Private key must be set, when instance create call, for create `client_secret` +func NewApple(p Params, appleCfg AppleConfig, privateKeyLoader PrivateKeyLoaderInterface) (*AppleHandler, error) { + + if p.L == nil { + p.L = logger.NoOp + } + var emptyParams []string + + // check required parameters filled + if appleCfg.ClientID == "" { + emptyParams = append(emptyParams, "ClientID") + } + if appleCfg.TeamID == "" { + emptyParams = append(emptyParams, "TeamID") + } + if appleCfg.KeyID == "" { + emptyParams = append(emptyParams, "KeyID") + } + if len(emptyParams) > 0 { + return nil, fmt.Errorf("required params missed: %s", strings.Join(emptyParams, ", ")) + } + + responseMode := "form_post" + if appleCfg.ResponseMode != "" { + responseMode = appleCfg.ResponseMode + } + + ah := AppleHandler{ + Params: p, + name: "apple", // static name for an Apple provider + + conf: AppleConfig{ + ClientID: appleCfg.ClientID, + TeamID: appleCfg.TeamID, + KeyID: appleCfg.KeyID, + scopes: []string{"name"}, + jwkURL: appleKeysURL, + ResponseMode: responseMode, + }, + + endpoint: oauth2.Endpoint{ + AuthURL: appleAuthURL, + TokenURL: appleTokenURL, + }, + + mapUser: func(claims jwt.MapClaims) token.User { + var usr token.User + if uid, ok := claims["sub"]; ok { + usr.ID = "apple_" + token.HashID(sha1.New(), uid.(string)) + } + return usr + }, + } + + if privateKeyLoader == nil { + return nil, fmt.Errorf("private key loader undefined") + } + + ah.PrivateKeyLoader = privateKeyLoader + + err := ah.initPrivateKey() + return &ah, err +} + +// initPrivateKey parse Apple private key and assign to AppleHandler +func (ah *AppleHandler) initPrivateKey() error { + + sKey, err := ah.PrivateKeyLoader.LoadPrivateKey() + if err != nil { + return fmt.Errorf("problem with private key loading: %w", err) + } + + block, _ := pem.Decode(sKey) + if block == nil { + return fmt.Errorf("empty block after decoding") + } + ah.conf.privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return err + } + publicKey, ok := ah.conf.privateKey.(*ecdsa.PrivateKey) + if !ok { + return fmt.Errorf("provided private key is not ECDSA") + } + ah.conf.publicKey = publicKey.Public() + ah.conf.clientSecret, err = ah.createClientSecret() + if err != nil { + return err + } + return nil +} + +// tokenKeyFunc use for verify JWT sign, it receives the parsed token and should return the key for validating. +func (ah *AppleHandler) tokenKeyFunc(jwtToken *jwt.Token) (interface{}, error) { + if jwtToken == nil { + return nil, fmt.Errorf("failed to call token keyFunc, because token is nil") + } + return ah.conf.publicKey, nil // extract public key from private key +} + +// Name of the provider +func (ah *AppleHandler) Name() string { return ah.name } + +// LoginHandler - GET */{provider-name}/login +func (ah *AppleHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { + + ah.Logf("[DEBUG] login with %s", ah.Name()) + // make state (random) and store in session + state, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, "failed to make oauth2 state") + return + } + + cid, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, "failed to make claim's id") + return + } + + claims := token.Claims{ + Handshake: &token.Handshake{ + State: state, + From: r.URL.Query().Get("from"), + }, + SessionOnly: r.URL.Query().Get("session") != "" && r.URL.Query().Get("session") != "0", + StandardClaims: jwt.StandardClaims{ + Id: cid, + Audience: r.URL.Query().Get("site"), + ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), + NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + }, + } + + if _, err = ah.JwtService.Set(w, claims); err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, "failed to set token") + return + } + + // return login url + loginURL, err := ah.prepareLoginURL(state, r.URL.Path) + if err != nil { + errMsg := fmt.Sprintf("prepare login url for [%s] provider failed", ah.name) + ah.Logf("[ERROR] %s", errMsg) + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, errMsg) + return + } + ah.Logf("[DEBUG] login url %s, claims=%+v", loginURL, claims) + + http.Redirect(w, r, loginURL, http.StatusFound) +} + +// AuthHandler fills user info and redirects to "from" url. This is callback url redirected locally by browser +// GET /callback +func (ah AppleHandler) AuthHandler(w http.ResponseWriter, r *http.Request) { + + // read response form data + if err := r.ParseForm(); err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, "read callback response from data failed") + return + } + + state := r.FormValue("state") // state value which sent with auth request + code := r.FormValue("code") // client code for validation + + // response with user name filed return only one time at first login, next login field user doesn't exist + // until user delete sign with Apple ID in account profile (security section) + // example response: {"name":{"firstName":"Chan","lastName":"Lu"},"email":"user@email.com"} + jUser := r.FormValue("user") // json string with user name + + oauthClaims, _, err := ah.JwtService.Get(r) + if err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, "failed to get token") + return + } + + if oauthClaims.Handshake == nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusForbidden, nil, "invalid handshake token") + return + } + + retrievedState := oauthClaims.Handshake.State + if retrievedState == "" || retrievedState != state { + rest.SendErrorJSON(w, r, ah.L, http.StatusForbidden, nil, "unexpected state") + return + } + + var resp appleVerificationResponse + err = ah.exchange(context.Background(), code, ah.makeRedirURL(r.URL.Path), &resp) + if err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, "exchange failed") + return + } + ah.Logf("[DEBUG] response data %+v", resp) + if resp.Error != "" { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, nil, fmt.Sprintf("fetch IDtoken response error: %s", resp.Error)) + return + } + + // trying to fetch Apple public key (JWK) for verify token signature, it need for verify IDToken received from Apple + keySet, err := fetchAppleJWK(r.Context(), ah.conf.jwkURL) + if err != nil { + ah.L.Logf("[ERROR] failed to fetch JWK from Apple key service: " + err.Error()) + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, nil, fmt.Sprintf("failed to fetch JWK from Apple key service: %s", resp.Error)) + return + } + + // get token claims for extract uid (and email or name if they exist in scope) + tokenClaims := jwt.MapClaims{} + _, err = jwt.ParseWithClaims(resp.IDToken, tokenClaims, keySet.keyFunc) + if err != nil { + ah.L.Logf("[ERROR] failed to get claims: " + err.Error()) + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, nil, fmt.Sprintf("failed to token validation, key is invalid: %s", resp.Error)) + return + } + + u := ah.mapUser(tokenClaims) + + u, err = setAvatar(ah.AvatarSaver, u, &http.Client{Timeout: 5 * time.Second}) + if err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, "failed to save avatar to proxy") + return + } + + // try parse username if one exist at response or noname assign + ah.parseUserData(&u, jUser) + + cid, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, "failed to make claim's id") + return + } + + claims := token.Claims{ + User: &u, + StandardClaims: jwt.StandardClaims{ + Issuer: ah.Issuer, + Id: cid, + Audience: oauthClaims.Audience, + }, + SessionOnly: false, + } + + if _, err = ah.JwtService.Set(w, claims); err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusInternalServerError, err, "failed to set token") + return + } + + ah.Logf("[DEBUG] user info %+v", u) + + // redirect to back url if presented in login query params + if oauthClaims.Handshake != nil && oauthClaims.Handshake.From != "" { + http.Redirect(w, r, oauthClaims.Handshake.From, http.StatusTemporaryRedirect) + return + } + rest.RenderJSON(w, &u) + +} + +// LogoutHandler - GET /logout +func (ah AppleHandler) LogoutHandler(w http.ResponseWriter, r *http.Request) { + if _, _, err := ah.JwtService.Get(r); err != nil { + rest.SendErrorJSON(w, r, ah.L, http.StatusForbidden, err, "logout not allowed") + return + } + ah.JwtService.Reset(w) +} + +// exchange sends the validation token request and gets access token and user claims +// (e.g. https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens) +func (ah *AppleHandler) exchange(ctx context.Context, code, redirectURI string, result *appleVerificationResponse) error { + + // check client_secret for valid and recreate new (client_secret JWT) if required + if tkn, err := jwt.Parse(ah.conf.clientSecret, ah.tokenKeyFunc); err != nil || tkn == nil { + ah.conf.clientSecret, err = ah.createClientSecret() + if err != nil { + return fmt.Errorf("client secret create failed: %w", err) + } + } + + data := url.Values{} + data.Set("client_id", ah.conf.ClientID) + data.Set("client_secret", ah.conf.clientSecret) // JWT signed with Apple private key + data.Set("code", code) + data.Set("redirect_uri", redirectURI) // redirect URL can't refer to localhost and must have trusted certificate and https protocol + data.Set("grant_type", "authorization_code") + + client := http.Client{Timeout: time.Second * 5} + req, err := http.NewRequestWithContext(ctx, "POST", ah.endpoint.TokenURL, strings.NewReader(data.Encode())) + if err != nil { + return err + } + + req.Header.Add("content-type", appleRequestContentType) + req.Header.Add("accept", AcceptJSONHeader) + req.Header.Add("user-agent", defaultUserAgent) // apple requires a user agent + + res, err := client.Do(req) + if err != nil { + return err + } + + // Trying to decode (unmarshal json) data of response + err = json.NewDecoder(res.Body).Decode(result) + if err != nil { + return fmt.Errorf("unmarshalling data from apple service response failed: %w", err) + } + + defer func() { + if err = res.Body.Close(); err != nil { + ah.L.Logf("[ERROR] close request body failed when get access token: %v", err) + } + }() + + // If above operation done successfully checking a response code and error descriptions, if one exist. + // Apple service will response either 200 (OK) or 400 (any error). + if res.StatusCode != http.StatusOK || result.Error != "" { + return fmt.Errorf("apple token service error: %s", result.Error) + } + + return err +} + +// createClientSecret use for create the JWT client secret required to make requests to the Apple validation server. +// for more details go to link: https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048 +func (ah *AppleHandler) createClientSecret() (string, error) { + + if ah.conf.privateKey == nil { + return "", fmt.Errorf("private key can't be empty") + } + // Create a claims + now := time.Now() + exp := now.Add(time.Minute * 30).Unix() // default value + + claims := &jwt.StandardClaims{ + Issuer: ah.conf.TeamID, + IssuedAt: now.Unix(), + ExpiresAt: exp, + Audience: "https://appleid.apple.com", + Subject: ah.conf.ClientID, + } + + tkn := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + tkn.Header["alg"] = "ES256" + tkn.Header["kid"] = ah.conf.KeyID + + return tkn.SignedString(ah.conf.privateKey) +} + +func (ah *AppleHandler) parseUserData(user *token.User, jUser string) { + + type UserData struct { + Name struct { + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + } `json:"name"` + Email string `json:"email"` + } + + var userData UserData + + // Catch error for log only. No need break flow if user name doesn't exist + if err := json.Unmarshal([]byte(jUser), &userData); err != nil { + ah.L.Logf("[DEBUG] failed to parse user data %s: %v", user, err) + user.Name = "noname_" + user.ID[6:12] // paste noname if user name failed to parse + return + } + + user.Name = fmt.Sprintf("%s %s", userData.Name.FirstName, userData.Name.LastName) +} + +func (ah *AppleHandler) prepareLoginURL(state, path string) (string, error) { + + scopesList := strings.Join(ah.conf.scopes, " ") + + if scopesList != "" && ah.conf.ResponseMode != "form_post" { + return "", fmt.Errorf("response_mode must be form_post if scope is not empty") + } + + authURL, err := url.Parse(ah.endpoint.AuthURL) + if err != nil { + return "", err + } + + query := authURL.Query() + query.Set("state", state) + query.Set("response_type", "code") + query.Set("response_mode", ah.conf.ResponseMode) + query.Set("client_id", ah.conf.ClientID) + query.Set("scope", scopesList) + query.Set("redirect_uri", ah.makeRedirURL(path)) + authURL.RawQuery = query.Encode() + + return authURL.String(), nil + +} + +func (ah AppleHandler) makeRedirURL(path string) string { + elems := strings.Split(path, "/") + newPath := strings.Join(elems[:len(elems)-1], "/") + + return strings.TrimRight(ah.URL, "/") + strings.TrimSuffix(newPath, "/") + urlCallbackSuffix +} diff --git a/v2/provider/apple_pubkeys.go b/v2/provider/apple_pubkeys.go new file mode 100644 index 00000000..ce0ccde0 --- /dev/null +++ b/v2/provider/apple_pubkeys.go @@ -0,0 +1,178 @@ +package provider + +// This is implementation need for fetch and parse Apple public key to verify the ID token signature. +// Apple endpoint can return multiple keys, and the count of keys can vary over time. +// From this set of keys, select the key with the matching key identifier (kid) to verify the signature of any JSON Web Token (JWT) issued by Apple. +// For more details go to link https://developer.apple.com/documentation/sign_in_with_apple/fetch_apple_s_public_key_for_verifying_token_signature + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net/http" + "time" + + "github.com/golang-jwt/jwt" +) + +// appleKeysURL is the endpoint URL for fetch Apple’s public key +const appleKeysURL = "https://appleid.apple.com/auth/keys" + +// applePublicKey is the Apple public key object +// Apple public key is a data structure that represents a cryptographic key as JSON Web Key (JWK) +// based on RFC-7517 https://datatracker.ietf.org/doc/html/rfc7517 +type applePublicKey struct { + ID string `json:"id"` + KeyType string `json:"kty"` + Usage string `json:"use"` + Algorithm string `json:"alg"` + + publicKey *rsa.PublicKey +} + +// appleRawKey is raw json object +type appleRawKey struct { + KTY string `json:"kty"` + KID string `json:"kid"` + Use string `json:"use"` + Alg string `json:"alg"` + N string `json:"n"` + E string `json:"e"` +} + +// fetchAppleJWK to make web request to Apple service for get Apple public keys (JWK) +func fetchAppleJWK(ctx context.Context, keyURL string) (set appleKeySet, err error) { + client := http.Client{Timeout: time.Second * 5} + + if keyURL == "" { + keyURL = appleKeysURL + } + + req, err := http.NewRequestWithContext(ctx, "GET", keyURL, http.NoBody) + + if err != nil { + return set, fmt.Errorf("failed to prepare new request for fetch Apple public keys: %w", err) + } + + req.Header.Add("accept", AcceptJSONHeader) + req.Header.Add("user-agent", defaultUserAgent) // apple requires a user agent + + res, err := client.Do(req) + if err != nil { + return set, fmt.Errorf("failed to fetch Apple public keys: %w", err) + } + + data, err := io.ReadAll(res.Body) + if err != nil { + return set, fmt.Errorf("failed read data after Apple public key fetched: %w", err) + } + defer func() { _ = res.Body.Close() }() + + set, err = parseAppleJWK(data) + if err != nil { + return set, fmt.Errorf("get set of apple public key failed: %w", err) + } + + return set, nil +} + +// parseAppleJWK try parse keys data for return set of Apple public keys, if no errors +func parseAppleJWK(keyData []byte) (set appleKeySet, err error) { + + var rawKeys struct { + Keys []appleRawKey `json:"keys"` + } + + set = appleKeySet{} // init key sets + keys := make(map[string]*applePublicKey) + + if err = json.Unmarshal(keyData, &rawKeys); err != nil { + return set, fmt.Errorf("parse json data with Apple keys failed: %w", err) + } + for _, rawKey := range rawKeys.Keys { + key, err := parseApplePublicKey(rawKey) + if err != nil { + return set, err // no idea to continue iterate keys if at least one return error, need will check all public keys + } + keys[key.ID] = key + } + + set.keys = keys + return set, nil +} + +// parseApplePublicKey to make parse JWK data for create an Apple public key +func parseApplePublicKey(rawKey appleRawKey) (key *applePublicKey, err error) { + + key = &applePublicKey{ + KeyType: rawKey.KTY, + ID: rawKey.KID, + Usage: rawKey.Use, + Algorithm: rawKey.Alg, + } + + // parse and create public key + if err := key.createApplePublicKey(rawKey.N, rawKey.E); err != nil { + return nil, err + } + + return key, nil +} + +// createApplePublicKey need to decodes a base64-encoded larger integer from Apple's key format. +func (apk *applePublicKey) createApplePublicKey(n, e string) error { + + bufferN, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(n) // decode modulus + if err != nil { + return fmt.Errorf("failed to decode Apple public key modulus (n): %w", err) + } + + bufferE, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(e) // decode exponent + if err != nil { + return fmt.Errorf("failed to decode Apple public key exponent (e): %w", err) + } + + // create rsa public key from JWK data + apk.publicKey = &rsa.PublicKey{ + N: big.NewInt(0).SetBytes(bufferN), + E: int(big.NewInt(0).SetBytes(bufferE).Int64()), + } + return nil +} + +// appleKeySet is a set of Apple public keys +type appleKeySet struct { + keys map[string]*applePublicKey +} + +// get return Apple public key with specific KeyID (kid) +func (aks *appleKeySet) get(kid string) (keys *applePublicKey, err error) { + if aks.keys == nil || len(aks.keys) == 0 { + return nil, fmt.Errorf("failed to get key in appleKeySet, key set is nil or empty") + } + + if val, ok := aks.keys[kid]; ok { + return val, nil + } + return nil, fmt.Errorf("key with ID %s not found", kid) +} + +// keyFunc use for JWT verify with specific public key +func (aks *appleKeySet) keyFunc(token *jwt.Token) (interface{}, error) { + + keyID, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("get JWT kid header not found") + } + key, err := aks.get(keyID) + + if err != nil { + return nil, err + } + + return key.publicKey, nil +} diff --git a/v2/provider/apple_pubkeys_test.go b/v2/provider/apple_pubkeys_test.go new file mode 100644 index 00000000..3b6e810d --- /dev/null +++ b/v2/provider/apple_pubkeys_test.go @@ -0,0 +1,226 @@ +package provider + +import ( + "context" + "fmt" + "log" + "net/http" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplePublicKey_Fetch(t *testing.T) { + teardown := prepareAppleKeysTestServer(t, 8982) + + defer teardown() + + // valid response checking + ctx := context.Background() + url := fmt.Sprintf("http://127.0.0.1:%d/keys", 8982) + set, err := fetchAppleJWK(ctx, url) + assert.NoError(t, err) + assert.NotEqual(t, appleKeySet{}, set) + + // check service response error + url = fmt.Sprintf("http://127.0.0.1:%d/error", 8982) + _, err = fetchAppleJWK(ctx, url) + assert.Error(t, err) + + url = fmt.Sprintf("http://127.0.0.1:%d/no-answer", 8982) + ctx, cancelFunc := context.WithTimeout(ctx, time.Second*2) + _, err = fetchAppleJWK(ctx, url) + defer cancelFunc() + assert.Error(t, err) + +} + +func TestParseAppleJWK(t *testing.T) { + testKeys := `{ + "keys": [ + { + "kty": "RSA", + "kid": "86D88Kf", + "use": "sig", + "alg": "RS256", + "n": "iGaLqP6y-SJCCBq5Hv6pGDbG_SQ11MNjH7rWHcCFYz4hGwHC4lcSurTlV8u3avoVNM8jXevG1Iu1SY11qInqUvjJur--hghr1b56OPJu6H1iKulSxGjEIyDP6c5BdE1uwprYyr4IO9th8fOwCPygjLFrh44XEGbDIFeImwvBAGOhmMB2AD1n1KviyNsH0bEB7phQtiLk-ILjv1bORSRl8AK677-1T8isGfHKXGZ_ZGtStDe7Lu0Ihp8zoUt59kx2o9uWpROkzF56ypresiIl4WprClRCjz8x6cPZXU2qNWhu71TQvUFwvIvbkE1oYaJMb0jcOTmBRZA2QuYw-zHLwQ", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "eXaunmL", + "use": "sig", + "alg": "RS256", + "n": "4dGQ7bQK8LgILOdLsYzfZjkEAoQeVC_aqyc8GC6RX7dq_KvRAQAWPvkam8VQv4GK5T4ogklEKEvj5ISBamdDNq1n52TpxQwI2EqxSk7I9fKPKhRt4F8-2yETlYvye-2s6NeWJim0KBtOVrk0gWvEDgd6WOqJl_yt5WBISvILNyVg1qAAM8JeX6dRPosahRVDjA52G2X-Tip84wqwyRpUlq2ybzcLh3zyhCitBOebiRWDQfG26EH9lTlJhll-p_Dg8vAXxJLIJ4SNLcqgFeZe4OfHLgdzMvxXZJnPp_VgmkcpUdRotazKZumj6dBPcXI_XID4Z4Z3OM1KrZPJNdUhxw", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "YuyXoY", + "use": "sig", + "alg": "RS256", + "n": "1JiU4l3YCeT4o0gVmxGTEK1IXR-Ghdg5Bzka12tzmtdCxU00ChH66aV-4HRBjF1t95IsaeHeDFRgmF0lJbTDTqa6_VZo2hc0zTiUAsGLacN6slePvDcR1IMucQGtPP5tGhIbU-HKabsKOFdD4VQ5PCXifjpN9R-1qOR571BxCAl4u1kUUIePAAJcBcqGRFSI_I1j_jbN3gflK_8ZNmgnPrXA0kZXzj1I7ZHgekGbZoxmDrzYm2zmja1MsE5A_JX7itBYnlR41LOtvLRCNtw7K3EFlbfB6hkPL-Swk5XNGbWZdTROmaTNzJhV-lWT0gGm6V1qWAK2qOZoIDa_3Ud0Gw", + "e": "AQAB" + } + ] + }` + testKeySet, err := parseAppleJWK([]byte(testKeys)) + assert.NoError(t, err) + + key, err := testKeySet.get("YuyXoY") + assert.NoError(t, err) + assert.Equal(t, key.ID, "YuyXoY") + + testKeySet = appleKeySet{} // reset previous value + testKeySet, err = parseAppleJWK([]byte(`{"keys":[]}`)) + assert.NoError(t, err) + assert.Equal(t, 0, len(testKeySet.keys)) + + testKeySet = appleKeySet{} // reset previous value + testKeySet, err = parseAppleJWK([]byte(`{"keys":[{ + "kty": "RSA", + "kid": "86D88Kf", + "use": "sig", + "alg": "RS256", + "n": "invalid-value", + "e": "invalid-value" + }]}`)) + assert.Error(t, err, fmt.Errorf("failed to decode Apple public key modulus (n)")) + + testKeySet, err = parseAppleJWK([]byte(`{"keys":[{ + "kty": "RSA", + "kid": "86D88Kf", + "use": "sig", + "alg": "RS256", + "n": "1JiU4l3YCeT4o0gVmxGTEK1IXR-Ghdg5Bzka12tzmtdCxU00ChH66aV-4HRBjF1t95IsaeHeDFRgmF0lJbTDTqa6_VZo2hc0zTiUAsGLacN6slePvDcR1IMucQGtPP5tGhIbU-HKabsKOFdD4VQ5PCXifjpN9R-1qOR571BxCAl4u1kUUIePAAJcBcqGRFSI_I1j_jbN3gflK_8ZNmgnPrXA0kZXzj1I7ZHgekGbZoxmDrzYm2zmja1MsE5A_JX7itBYnlR41LOtvLRCNtw7K3EFlbfB6hkPL-Swk5XNGbWZdTROmaTNzJhV-lWT0gGm6V1qWAK2qOZoIDa_3Ud0Gw", + "e": "invalid-value" + }]}`)) + + assert.Error(t, err, fmt.Errorf("failed to decode Apple public key modulus (e)")) + testKeySet, err = parseAppleJWK([]byte(`{invalid-json}`)) + assert.Error(t, err) +} + +func TestAppleKeySet_Get(t *testing.T) { + testKeySet := appleKeySet{} + _, err := testKeySet.get("some-kid") + assert.Error(t, err, "failed to get key in appleKeySet, key set is nil or empty") + + testKeySet, err = parseAppleJWK([]byte(`{"keys":[{ + "kty": "RSA", + "kid": "86D88Kf", + "use": "sig", + "alg": "RS256", + "n": "iGaLqP6y-SJCCBq5Hv6pGDbG_SQ11MNjH7rWHcCFYz4hGwHC4lcSurTlV8u3avoVNM8jXevG1Iu1SY11qInqUvjJur--hghr1b56OPJu6H1iKulSxGjEIyDP6c5BdE1uwprYyr4IO9th8fOwCPygjLFrh44XEGbDIFeImwvBAGOhmMB2AD1n1KviyNsH0bEB7phQtiLk-ILjv1bORSRl8AK677-1T8isGfHKXGZ_ZGtStDe7Lu0Ihp8zoUt59kx2o9uWpROkzF56ypresiIl4WprClRCjz8x6cPZXU2qNWhu71TQvUFwvIvbkE1oYaJMb0jcOTmBRZA2QuYw-zHLwQ", + "e": "AQAB" + }]}`)) + require.Nil(t, err) + + apk, err := testKeySet.get("86D88Kf") + assert.NoError(t, err) + assert.Equal(t, apk.ID, "86D88Kf") + + _, err = testKeySet.get("not-found-kid") + assert.Error(t, err, "key with ID some-kid not found") + +} + +func TestAppleKeySet_KeyFunc(t *testing.T) { + + tokenHdr := map[string]interface{}{"kid": "86D88Kf"} + validToken := jwt.Token{Header: tokenHdr} + testKeySet, err := parseAppleJWK([]byte(`{"keys":[{ + "kty": "RSA", + "kid": "86D88Kf", + "use": "sig", + "alg": "RS256", + "n": "iGaLqP6y-SJCCBq5Hv6pGDbG_SQ11MNjH7rWHcCFYz4hGwHC4lcSurTlV8u3avoVNM8jXevG1Iu1SY11qInqUvjJur--hghr1b56OPJu6H1iKulSxGjEIyDP6c5BdE1uwprYyr4IO9th8fOwCPygjLFrh44XEGbDIFeImwvBAGOhmMB2AD1n1KviyNsH0bEB7phQtiLk-ILjv1bORSRl8AK677-1T8isGfHKXGZ_ZGtStDe7Lu0Ihp8zoUt59kx2o9uWpROkzF56ypresiIl4WprClRCjz8x6cPZXU2qNWhu71TQvUFwvIvbkE1oYaJMb0jcOTmBRZA2QuYw-zHLwQ", + "e": "AQAB" + }]}`)) + require.Nil(t, err) + assert.IsType(t, appleKeySet{}, testKeySet) + _, err = testKeySet.keyFunc(&validToken) + assert.NoError(t, err) + + testKeySet, err = parseAppleJWK([]byte(`{"keys":[{ + "kty": "RSA", + "kid": "eXaunmL", + "use": "sig", + "alg": "RS256", + "n": "4dGQ7bQK8LgILOdLsYzfZjkEAoQeVC_aqyc8GC6RX7dq_KvRAQAWPvkam8VQv4GK5T4ogklEKEvj5ISBamdDNq1n52TpxQwI2EqxSk7I9fKPKhRt4F8-2yETlYvye-2s6NeWJim0KBtOVrk0gWvEDgd6WOqJl_yt5WBISvILNyVg1qAAM8JeX6dRPosahRVDjA52G2X-Tip84wqwyRpUlq2ybzcLh3zyhCitBOebiRWDQfG26EH9lTlJhll-p_Dg8vAXxJLIJ4SNLcqgFeZe4OfHLgdzMvxXZJnPp_VgmkcpUdRotazKZumj6dBPcXI_XID4Z4Z3OM1KrZPJNdUhxw", + "e": "AQAB" + }]}`)) + require.NoError(t, err) + assert.NotNil(t, testKeySet) + + _, err = testKeySet.keyFunc(&validToken) + assert.Error(t, err, "key with ID 86D88Kf not found") + + _, err = testKeySet.keyFunc(&jwt.Token{}) + assert.Error(t, err, "get JWT kid header not found") +} + +//nolint:gosec //this is a test, we don't care about ReadHeaderTimeout +func prepareAppleKeysTestServer(t *testing.T, authPort int) func() { + ts := &http.Server{ + Addr: fmt.Sprintf(":%d", authPort), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("[MOCK APPLE KEYS SERVER] request %s %s %+v", r.Method, r.URL, r.Header) + switch { + case strings.HasPrefix(r.URL.Path, "/keys"): + + testKeys := `{ + "keys": [ + { + "kty": "RSA", + "kid": "86D88Kf", + "use": "sig", + "alg": "RS256", + "n": "iGaLqP6y-SJCCBq5Hv6pGDbG_SQ11MNjH7rWHcCFYz4hGwHC4lcSurTlV8u3avoVNM8jXevG1Iu1SY11qInqUvjJur--hghr1b56OPJu6H1iKulSxGjEIyDP6c5BdE1uwprYyr4IO9th8fOwCPygjLFrh44XEGbDIFeImwvBAGOhmMB2AD1n1KviyNsH0bEB7phQtiLk-ILjv1bORSRl8AK677-1T8isGfHKXGZ_ZGtStDe7Lu0Ihp8zoUt59kx2o9uWpROkzF56ypresiIl4WprClRCjz8x6cPZXU2qNWhu71TQvUFwvIvbkE1oYaJMb0jcOTmBRZA2QuYw-zHLwQ", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "eXaunmL", + "use": "sig", + "alg": "RS256", + "n": "4dGQ7bQK8LgILOdLsYzfZjkEAoQeVC_aqyc8GC6RX7dq_KvRAQAWPvkam8VQv4GK5T4ogklEKEvj5ISBamdDNq1n52TpxQwI2EqxSk7I9fKPKhRt4F8-2yETlYvye-2s6NeWJim0KBtOVrk0gWvEDgd6WOqJl_yt5WBISvILNyVg1qAAM8JeX6dRPosahRVDjA52G2X-Tip84wqwyRpUlq2ybzcLh3zyhCitBOebiRWDQfG26EH9lTlJhll-p_Dg8vAXxJLIJ4SNLcqgFeZe4OfHLgdzMvxXZJnPp_VgmkcpUdRotazKZumj6dBPcXI_XID4Z4Z3OM1KrZPJNdUhxw", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "YuyXoY", + "use": "sig", + "alg": "RS256", + "n": "1JiU4l3YCeT4o0gVmxGTEK1IXR-Ghdg5Bzka12tzmtdCxU00ChH66aV-4HRBjF1t95IsaeHeDFRgmF0lJbTDTqa6_VZo2hc0zTiUAsGLacN6slePvDcR1IMucQGtPP5tGhIbU-HKabsKOFdD4VQ5PCXifjpN9R-1qOR571BxCAl4u1kUUIePAAJcBcqGRFSI_I1j_jbN3gflK_8ZNmgnPrXA0kZXzj1I7ZHgekGbZoxmDrzYm2zmja1MsE5A_JX7itBYnlR41LOtvLRCNtw7K3EFlbfB6hkPL-Swk5XNGbWZdTROmaTNzJhV-lWT0gGm6V1qWAK2qOZoIDa_3Ud0Gw", + "e": "AQAB" + } + ] + }` + w.Header().Set("Content-Type", "application/json; charset=utf-8") + _, err := w.Write([]byte(testKeys)) + assert.NoError(t, err) + case strings.HasPrefix(r.URL.Path, "/error"): + _, err := w.Write([]byte("test error")) + assert.NoError(t, err) + case strings.HasPrefix(r.URL.Path, "/no-answer"): + time.Sleep(time.Second * 3) + return + default: + t.Fatalf("unexpected oauth request %s %s", r.Method, r.URL) + } + }), + } + + go func() { _ = ts.ListenAndServe() }() + + time.Sleep(time.Millisecond * 100) // let them start + + return func() { + assert.NoError(t, ts.Close()) + } +} diff --git a/v2/provider/apple_test.go b/v2/provider/apple_test.go new file mode 100644 index 00000000..f448ae77 --- /dev/null +++ b/v2/provider/apple_test.go @@ -0,0 +1,681 @@ +package provider + +import ( + "context" + "crypto/rsa" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/http/cookiejar" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +type customLoader struct{} // implement custom private key loader interface + +func TestAppleHandler_NewApple(t *testing.T) { + + testIDToken := `eyJhbGciOiJIUzI1NiJ9.eyJhdWQiOiJ0ZXN0LmF1dGguZXhhbXBsZS5jb20iLCJzdWIiOiIwMDExMjIuNzg5M2Y3NmViZWRjNDExOGE3OTE3ZGFiOWE4YTllYTkuMTEyMiIsImlzcyI6Imh0dHBzOi8vYXBwbGVpZC5hcHBsZS5jb20iLCJleHAiOiIxOTIwNjQ3MTgyIiwiaWF0IjoiMTYyMDYzNzE4MiIsImVtYWlsIjoidGVzdEBlbWFpbC5jb20ifQ.CQCPa7ov-IdZ5bEKfhhnxEXafMAM_t6mj5OAnaoyy0A` // #nosec + p := Params{ + URL: "http://localhost", + Issuer: "test-issuer", + Cid: "cid", + Csecret: "cs", + } + + aCfg := AppleConfig{ + ClientID: "auth.example.com", + TeamID: "AA11BB22CC", + KeyID: "BS2A79VCTT", + } + cl := customLoader{} + + ah, err := NewApple(p, aCfg, cl) + assert.NoError(t, err) + assert.IsType(t, &AppleHandler{}, ah) + assert.Equal(t, ah.name, "apple") + assert.Equal(t, ah.conf.ClientID, aCfg.ClientID) + assert.NotEmpty(t, ah.conf.privateKey) + assert.NotEmpty(t, ah.conf.clientSecret) + + testTokenClaims := jwt.MapClaims{} + _, err = jwt.ParseWithClaims(testIDToken, testTokenClaims, ah.tokenKeyFunc) + assert.Error(t, err) // no need check token is valid for test token + + // testing mapUser + u := ah.mapUser(testTokenClaims) + t.Logf("%+v", u) + assert.Equal(t, u.ID, "apple_"+token.HashID(sha1.New(), "001122.7893f76ebedc4118a7917dab9a8a9ea9.1122")) + + _, err = NewApple(p, aCfg, nil) + require.Error(t, err) + + // check empty params + aCfg.ClientID = "" + _, err = NewApple(p, aCfg, cl) + assert.Error(t, err, "required params missed: ClientID") + aCfg.TeamID = "" + _, err = NewApple(p, aCfg, cl) + assert.Error(t, err, "required params missed: ClientID, TeamID") + aCfg.KeyID = "" + _, err = NewApple(p, aCfg, cl) + assert.Error(t, err, "required params missed: ClientID, TeamID, KeyID") +} + +// TestAppleHandler_LoadPrivateKey need for testing pre-defined loader from local file +func TestAppleHandler_LoadPrivateKey(t *testing.T) { + testValidKey := `-----BEGIN PRIVATE KEY----- +MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgTxaHXzyuM85Znw7y +SJ9XeeC8gqcpE/VLhZHGsnPPiPagCgYIKoZIzj0DAQehRANCAATnwlOv7I6eC3Ec +/+GeYXT+hbcmhEVveDqLmNcHiXCR9XxJZXtpMRlcRfY8eaJpUdig27dfsbvpnfX5 +Ivx5tHkv +-----END PRIVATE KEY-----` // #nosec + testInvalidKey := `-----BEGIN PRIVATE KEY----- +MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAKNwapOQ6rQJHetP +HRlJBIh1OsOsUBiXb3rXXE3xpWAxAha0MH+UPRblOko+5T2JqIb+xKf9Vi3oTM3t +KvffaOPtzKXZauscjq6NGzA3LgeiMy6q19pvkUUOlGYK6+Xfl+B7Xw6+hBMkQuGE +nUS8nkpR5mK4ne7djIyfHFfMu4ptAgMBAAECgYA+s0PPtMq1osG9oi4xoxeAGikf +JB3eMUptP+2DYW7mRibc+ueYKhB9lhcUoKhlQUhL8bUUFVZYakP8xD21thmQqnC4 +f63asad0ycteJMLb3r+z26LHuCyOdPg1pyLk3oQ32lVQHBCYathRMcVznxOG16VK +I8BFfstJTaJu0lK/wQJBANYFGusBiZsJQ3utrQMVPpKmloO2++4q1v6ZR4puDQHx +TjLjAIgrkYfwTJBLBRZxec0E7TmuVQ9uJ+wMu/+7zaUCQQDDf2xMnQqYknJoKGq+ +oAnyC66UqWC5xAnQS32mlnJ632JXA0pf9pb1SXAYExB1p9Dfqd3VAwQDwBsDDgP6 +HD8pAkEA0lscNQZC2TaGtKZk2hXkdcH1SKru/g3vWTkRHxfCAznJUaza1fx0wzdG +GcES1Bdez0tbW4llI5By/skZc2eE3QJAFl6fOskBbGHde3Oce0F+wdZ6XIJhEgCP +iukIcKZoZQzoiMJUoVRrA5gqnmaYDI5uRRl/y57zt6YksR3KcLUIuQJAd242M/WF +6YAZat3q/wEeETeQq1wrooew+8lHl05/Nt0cCpV48RGEhJ83pzBm3mnwHf8lTBJH +x6XroMXsmbnsEw== +-----END PRIVATE KEY-----` + testPrivKeyFileName := "privKeyTest.tmp" + testBadPrivKeyFileName := "privKeyBadTest.tmp" + + dir, err := os.MkdirTemp(os.TempDir(), testPrivKeyFileName) + require.NoError(t, err) + + defer func() { + require.NoError(t, os.RemoveAll(dir)) + }() + + tmpFn := filepath.Join(dir, testPrivKeyFileName) + err = os.WriteFile(tmpFn, []byte(testValidKey), 0o600) + require.NoError(t, err) + badTmpFn := filepath.Join(dir, testBadPrivKeyFileName) + err = os.WriteFile(badTmpFn, []byte(testInvalidKey), 0o600) + require.NoError(t, err) + p := Params{ + URL: "http://localhost", + Issuer: "test-issuer", + Cid: "cid", + Csecret: "cs", + } + + aCfg := AppleConfig{ + ClientID: "auth.example.com", + TeamID: "AA11BB22CC", + KeyID: "BS2A79VCTT", + } + + // test good scenario + ah, err := NewApple(p, aCfg, LoadApplePrivateKeyFromFile(tmpFn)) + assert.NoError(t, err) + assert.IsType(t, &AppleHandler{}, ah) + assert.Equal(t, ah.name, "apple") + assert.Equal(t, ah.conf.ClientID, aCfg.ClientID) + assert.NotEmpty(t, ah.conf.privateKey) + assert.NotEmpty(t, ah.conf.publicKey) + assert.NotEmpty(t, ah.conf.clientSecret) + + // test bad scenario, should not panic + ah, err = NewApple(p, aCfg, LoadApplePrivateKeyFromFile(badTmpFn)) + assert.Error(t, err) + assert.IsType(t, &AppleHandler{}, ah) + assert.Empty(t, ah.conf.clientSecret, "client secret was not loaded") + assert.Empty(t, ah.conf.publicKey, "public key was not loaded") + assert.Equal(t, ah.name, "apple") + assert.Equal(t, ah.conf.ClientID, aCfg.ClientID) + assert.NotEmpty(t, ah.conf.privateKey) +} + +func TestAppleHandlerCreateClientSecret(t *testing.T) { + ah := &AppleHandler{} + tkn, err := ah.createClientSecret() + assert.Error(t, err) + assert.Empty(t, tkn) + + ah, err = prepareAppleHandlerTest("", []string{}) + assert.NoError(t, err) + assert.IsType(t, &AppleHandler{}, ah) + + tkn, err = ah.createClientSecret() + assert.NoError(t, err) + assert.NotEmpty(t, tkn) + + testClaims := jwt.MapClaims{} + + parsedToken, err := jwt.ParseWithClaims(tkn, testClaims, ah.tokenKeyFunc) + require.NoError(t, err) + require.NotNil(t, parsedToken) + assert.True(t, parsedToken.Valid) + + assert.Equal(t, "auth.example.com", testClaims["sub"]) +} + +func TestAppleParseUserData(t *testing.T) { + + ah := AppleHandler{Params: Params{L: logger.NoOp}} + + userNameClaimTest := `{"name":{"firstName":"test","lastName":"user"}}` + testUser := &token.User{ID: "", Email: "user@example.com"} + shaID := "apple_" + token.HashID(sha1.New(), testUser.ID) + + testUser.ID = shaID + testCheckUser := &token.User{ID: shaID, Name: "test user", Email: "user@example.com"} + + ah.parseUserData(testUser, userNameClaimTest) + assert.Equal(t, testUser, testCheckUser) + + testCheckUser.Name = "noname_" + shaID[6:12] + ah.parseUserData(testUser, "") + assert.Equal(t, testUser, testCheckUser) +} + +func TestPrepareLoginURL(t *testing.T) { + ah, err := prepareAppleHandlerTest("", []string{}) + assert.NoError(t, err) + assert.IsType(t, &AppleHandler{}, ah) + + lURL, err := ah.prepareLoginURL("1112233", "apple-test/login") + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(lURL, ah.endpoint.AuthURL)) + + checkURL, err := url.Parse(lURL) + assert.NoError(t, err) + q := checkURL.Query() + assert.Equal(t, q.Get("state"), "1112233") + assert.Equal(t, q.Get("response_type"), "code") + assert.Equal(t, q.Get("response_mode"), "form_post") + assert.Equal(t, q.Get("client_id"), ah.conf.ClientID) +} + +func TestPrepareLoginURLWithCustomResponseMode(t *testing.T) { + ah, err := prepareAppleHandlerTest("query", []string{}) + assert.NoError(t, err) + assert.IsType(t, &AppleHandler{}, ah) + ah.conf.scopes = []string{""} + lURL, err := ah.prepareLoginURL("1112233", "apple-test/login") + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(lURL, ah.endpoint.AuthURL)) + + checkURL, err := url.Parse(lURL) + assert.NoError(t, err) + q := checkURL.Query() + assert.Equal(t, q.Get("state"), "1112233") + assert.Equal(t, q.Get("response_type"), "code") + assert.Equal(t, q.Get("response_mode"), "query") + assert.Equal(t, q.Get("client_id"), ah.conf.ClientID) +} + +func TestThrowsWhenNotEmptyScopeAndWrongResponseMode(t *testing.T) { + ah, err := prepareAppleHandlerTest("query", []string{"email"}) + assert.NoError(t, err) + assert.IsType(t, &AppleHandler{}, ah) + + lURL, err := ah.prepareLoginURL("1112233", "apple-test/login") + assert.Equal(t, "", lURL) + assert.Error(t, err) +} + +func TestAppleHandlerMakeRedirURL(t *testing.T) { + cases := []struct{ rootURL, route, out string }{ + {"localhost:8080/", "/my/auth/path/apple", "localhost:8080/my/auth/path/callback"}, + {"localhost:8080", "/auth/apple", "localhost:8080/auth/callback"}, + {"localhost:8080/", "/auth/apple", "localhost:8080/auth/callback"}, + {"localhost:8080", "/", "localhost:8080/callback"}, + {"localhost:8080/", "/", "localhost:8080/callback"}, + {"mysite.com", "", "mysite.com/callback"}, + } + + ah, err := prepareAppleHandlerTest("", []string{}) + assert.NoError(t, err) + assert.IsType(t, &AppleHandler{}, ah) + + for i := range cases { + c := cases[i] + ah.URL = c.rootURL + assert.Equal(t, c.out, ah.makeRedirURL(c.route)) + } +} + +func TestAppleHandler_LoginHandler(t *testing.T) { + + teardown := prepareAppleOauthTest(t, 8981, 8982, nil) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + resp, err := client.Get("http://localhost:8981/login?site=remark") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 2678400, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + u := token.User{} + err = json.Unmarshal(body, &u) + assert.NoError(t, err) + testHashID := token.HashID(sha1.New(), "userid1") + testUserID := "apple_" + testHashID + testUserName := "noname_" + testUserID[6:12] + assert.Equal(t, token.User{ID: testUserID, Name: testUserName}, u) + + tk := resp.Cookies()[0].Value + jwtSvc := token.NewService(token.Opts{SecretReader: token.SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31}) + + claims, err := jwtSvc.Parse(tk) + require.NoError(t, err) + t.Log(claims) + assert.Equal(t, "go-pkgz/auth", claims.Issuer) + assert.Equal(t, "remark", claims.Audience) + +} + +func TestAppleHandler_LogoutHandler(t *testing.T) { + + teardown := prepareAppleOauthTest(t, 8691, 8692, nil) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + req, err := http.NewRequest("GET", "http://localhost:8691/logout", http.NoBody) + require.Nil(t, err) + resp, err := client.Do(req) + require.Nil(t, err) + assert.Equal(t, 403, resp.StatusCode, "user not lagged in") + + req, err = http.NewRequest("GET", "http://localhost:8691/logout", http.NoBody) + require.NoError(t, err) + expiration := int(365 * 24 * time.Hour.Seconds()) //nolint + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + resp, err = client.Do(req) + require.Nil(t, err) + require.Equal(t, 200, resp.StatusCode) + + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name, "token cookie cleared") + assert.Equal(t, "", resp.Cookies()[0].Value) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name, "xsrf cookie cleared") + assert.Equal(t, "", resp.Cookies()[1].Value) + +} + +func TestAppleHandler_Exchange(t *testing.T) { + var testResponseToken string + teardown := prepareAppleOauthTest(t, 8981, 8982, &testResponseToken) + defer teardown() + + ah, err := prepareAppleHandlerTest("", []string{}) + require.Nil(t, err) + + ah.endpoint = oauth2.Endpoint{ + AuthURL: fmt.Sprintf("http://localhost:%d/login/oauth/authorize", 8981), + TokenURL: fmt.Sprintf("http://localhost:%d/login/oauth/access_token", 8982), + } + + testAppleResponse := appleVerificationResponse{} + err = ah.exchange(context.Background(), "1122334455", "url/callback", &testAppleResponse) + assert.NoError(t, err) + assert.Equal(t, &appleVerificationResponse{ + AccessToken: "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3", + ExpiresIn: 3600, + TokenType: "bearer", + RefreshToken: "IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk", + IDToken: testResponseToken, + }, &testAppleResponse) + + testAppleResponse = appleVerificationResponse{} // clear response for next checking + err = ah.exchange(context.Background(), "test-error", "url/callback", &testAppleResponse) + assert.Error(t, err) + assert.Equal(t, &appleVerificationResponse{ + Error: "test error occurred", + }, &testAppleResponse) + + err = ah.exchange(context.Background(), "test-json-error", "url/callback", &testAppleResponse) + assert.Error(t, err) + assert.EqualError(t, err, "unmarshalling data from apple service response failed: invalid character 'i' looking for beginning of value") +} + +func (cl customLoader) LoadPrivateKey() ([]byte, error) { + + // valid p8 key + testValidKey := []byte(`-----BEGIN PRIVATE KEY----- +MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgTxaHXzyuM85Znw7y +SJ9XeeC8gqcpE/VLhZHGsnPPiPagCgYIKoZIzj0DAQehRANCAATnwlOv7I6eC3Ec +/+GeYXT+hbcmhEVveDqLmNcHiXCR9XxJZXtpMRlcRfY8eaJpUdig27dfsbvpnfX5 +Ivx5tHkv +-----END PRIVATE KEY-----`) + + return testValidKey, nil +} + +func prepareTestPrivateKey(t *testing.T) (filePath string, cancelFunc context.CancelFunc) { + testValidKey := `-----BEGIN PRIVATE KEY----- +MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgTxaHXzyuM85Znw7y +SJ9XeeC8gqcpE/VLhZHGsnPPiPagCgYIKoZIzj0DAQehRANCAATnwlOv7I6eC3Ec +/+GeYXT+hbcmhEVveDqLmNcHiXCR9XxJZXtpMRlcRfY8eaJpUdig27dfsbvpnfX5 +Ivx5tHkv +-----END PRIVATE KEY-----` + testPrivKeyFileName := "privKeyTest.tmp" + + dir, err := os.MkdirTemp(os.TempDir(), testPrivKeyFileName) + assert.NoError(t, err) + assert.NotNil(t, dir) + if err != nil { + log.Fatal(err) + return "", nil + } + + filePath = filepath.Join(dir, testPrivKeyFileName) + err = os.WriteFile(filePath, []byte(testValidKey), 0o600) + require.NoError(t, err) + ctx, cancelCtx := context.WithTimeout(context.Background(), time.Second*60) + + go func() { + <-ctx.Done() + require.NoError(t, os.RemoveAll(dir)) + }() + return filePath, cancelCtx +} + +func prepareAppleHandlerTest(responseMode string, scopes []string) (*AppleHandler, error) { + p := Params{ + URL: "http://localhost", + Issuer: "test-issuer", + Cid: "cid", + Csecret: "cs", + } + + aCfg := AppleConfig{ + ClientID: "auth.example.com", + TeamID: "AA11BB22CC", + KeyID: "BS2A79VCTT", + ResponseMode: responseMode, + scopes: scopes, + } + + cl := customLoader{} + return NewApple(p, aCfg, cl) +} + +func prepareAppleOauthTest(t *testing.T, loginPort, authPort int, testToken *string) func() { + signKey, testJWK := createTestSignKeyPairs(t) + provider, err := prepareAppleHandlerTest("", []string{}) + assert.NoError(t, err) + assert.IsType(t, &AppleHandler{}, provider) + + filePath, cancelCtx := prepareTestPrivateKey(t) + if cancelCtx == nil { + t.Fatal(fmt.Errorf("failed to create test private key file")) + return nil + } + + provider.name = "mock" + provider.endpoint = oauth2.Endpoint{ + AuthURL: fmt.Sprintf("http://localhost:%d/login/oauth/authorize", authPort), + TokenURL: fmt.Sprintf("http://localhost:%d/login/oauth/access_token", authPort), + } + provider.conf.jwkURL = fmt.Sprintf("http://localhost:%d/keys", authPort) + + provider.PrivateKeyLoader = LoadApplePrivateKeyFromFile(filePath) + require.NoError(t, err) + + // create self-signed JWT + testResponseToken, err := createTestResponseToken(signKey) + require.NoError(t, err) + require.NotEmpty(t, testResponseToken) + if testToken != nil { + *testToken = testResponseToken + } + + jwtService := token.NewService(token.Opts{ + SecretReader: token.SecretFunc(mockKeyStore), SecureCookies: false, TokenDuration: time.Hour, CookieDuration: days31, + ClaimsUpd: token.ClaimsUpdFunc(func(claims token.Claims) token.Claims { + if claims.User != nil { + switch claims.User.ID { + case "mock_myuser2": + claims.User.SetBoolAttr("admin", true) + case "mock_myuser1": + claims.User.Picture = "http://example.com/custom.png" + } + } + return claims + }), + }) + + params := Params{URL: "url", Cid: "cid", Csecret: "csecret", JwtService: jwtService, + Issuer: "go-pkgz/auth", L: logger.Std} + provider.Params = params + + svc := Service{Provider: provider} + + ts := &http.Server{Addr: fmt.Sprintf(":%d", loginPort), Handler: http.HandlerFunc(svc.Handler)} //nolint:gosec + + count := 0 + useIDs := []string{"myuser1", "myuser2"} // user for first ans second calls + + oauth := &http.Server{ //nolint:gosec + Addr: fmt.Sprintf(":%d", authPort), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("[MOCK OAUTH] request %s %s %+v", r.Method, r.URL, r.Header) + switch { + case strings.HasPrefix(r.URL.Path, "/login/oauth/authorize"): + state := r.URL.Query().Get("state") + w.Header().Add("Location", fmt.Sprintf("http://localhost:%d/callback?state=%s", loginPort, state)) + w.WriteHeader(302) + case strings.HasPrefix(r.URL.Path, "/login/oauth/access_token"): + err := r.ParseForm() + assert.NoError(t, err) + + res := fmt.Sprintf(`{ + "access_token":"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3", + "token_type":"bearer", + "expires_in":3600, + "refresh_token":"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk", + "id_token":"%s" + }`, testResponseToken) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + + switch r.Form.Get("code") { + case "test-error": + + res = `{ + "error":"test error occurred" + }` + w.WriteHeader(http.StatusBadRequest) + _, e := w.Write([]byte(res)) + assert.NoError(t, e) + return + + case "test-json-error": + res = `invalid json data` + w.WriteHeader(http.StatusBadRequest) + _, e := w.Write([]byte(res)) + assert.NoError(t, e) + return + } + + w.WriteHeader(200) + _, err = w.Write([]byte(res)) + assert.NoError(t, err) + case strings.HasPrefix(r.URL.Path, "/user"): + res := fmt.Sprintf(`{ + "id": "%s", + "name":"blah", + "picture":"http://exmple.com/pic1.png" + }`, useIDs[count]) + count++ + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + _, err := w.Write([]byte(res)) + assert.NoError(t, err) + case strings.HasPrefix(r.URL.Path, "/keys"): + + testKeys := fmt.Sprintf(`{ + "keys": [ + %s, + { + "kty": "RSA", + "kid": "eXaunmL", + "use": "sig", + "alg": "RS256", + "n": "4dGQ7bQK8LgILOdLsYzfZjkEAoQeVC_aqyc8GC6RX7dq_KvRAQAWPvkam8VQv4GK5T4ogklEKEvj5ISBamdDNq1n52TpxQwI2EqxSk7I9fKPKhRt4F8-2yETlYvye-2s6NeWJim0KBtOVrk0gWvEDgd6WOqJl_yt5WBISvILNyVg1qAAM8JeX6dRPosahRVDjA52G2X-Tip84wqwyRpUlq2ybzcLh3zyhCitBOebiRWDQfG26EH9lTlJhll-p_Dg8vAXxJLIJ4SNLcqgFeZe4OfHLgdzMvxXZJnPp_VgmkcpUdRotazKZumj6dBPcXI_XID4Z4Z3OM1KrZPJNdUhxw", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "YuyXoY", + "use": "sig", + "alg": "RS256", + "n": "1JiU4l3YCeT4o0gVmxGTEK1IXR-Ghdg5Bzka12tzmtdCxU00ChH66aV-4HRBjF1t95IsaeHeDFRgmF0lJbTDTqa6_VZo2hc0zTiUAsGLacN6slePvDcR1IMucQGtPP5tGhIbU-HKabsKOFdD4VQ5PCXifjpN9R-1qOR571BxCAl4u1kUUIePAAJcBcqGRFSI_I1j_jbN3gflK_8ZNmgnPrXA0kZXzj1I7ZHgekGbZoxmDrzYm2zmja1MsE5A_JX7itBYnlR41LOtvLRCNtw7K3EFlbfB6hkPL-Swk5XNGbWZdTROmaTNzJhV-lWT0gGm6V1qWAK2qOZoIDa_3Ud0Gw", + "e": "AQAB" + } + ] + }`, testJWK) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + _, err := w.Write([]byte(testKeys)) + assert.NoError(t, err) + default: + t.Fatalf("unexpected oauth request %s %s", r.Method, r.URL) + } + }), + } + + go func() { _ = oauth.ListenAndServe() }() + go func() { _ = ts.ListenAndServe() }() + + time.Sleep(time.Millisecond * 100) // let them start + + return func() { + + assert.NoError(t, ts.Close()) + assert.NoError(t, oauth.Close()) + cancelCtx() // delete test private key file + } +} + +func createTestResponseToken(privKey interface{}) (string, error) { + claims := &jwt.MapClaims{ + "iss": "http://go.localhost.test", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Second * 30).Unix(), + "aud": "go-pkgz/auth", + "sub": "userid1", + "email": "test@example.go", + } + + tkn := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tkn.Header["alg"] = "RS256" + tkn.Header["kid"] = "112233" + + return tkn.SignedString(privKey) +} + +func createTestSignKeyPairs(t *testing.T) (privKey *rsa.PrivateKey, jwk string) { + //nolint:gosec // test example and not a real key + privateStr := `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA4f5wg5l2hKsTeNem/V41fGnJm6gOdrj8ym3rFkEU/wT8RDtn +SgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0i +cqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhC +PUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR+1DcKJzQBSTAGnpYVaqpsAR +ap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKA +Rdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7wIDAQABAoIBAQCwia1k7+2oZ2d3 +n6agCAbqIE1QXfCmh41ZqJHbOY3oRQG3X1wpcGH4Gk+O+zDVTV2JszdcOt7E5dAy +MaomETAhRxB7hlIOnEN7WKm+dGNrKRvV0wDU5ReFMRHg31/Lnu8c+5BvGjZX+ky9 +POIhFFYJqwCRlopGSUIxmVj5rSgtzk3iWOQXr+ah1bjEXvlxDOWkHN6YfpV5ThdE +KdBIPGEVqa63r9n2h+qazKrtiRqJqGnOrHzOECYbRFYhexsNFz7YT02xdfSHn7gM +IvabDDP/Qp0PjE1jdouiMaFHYnLBbgvlnZW9yuVf/rpXTUq/njxIXMmvmEyyvSDn +FcFikB8pAoGBAPF77hK4m3/rdGT7X8a/gwvZ2R121aBcdPwEaUhvj/36dx596zvY +mEOjrWfZhF083/nYWE2kVquj2wjs+otCLfifEEgXcVPTnEOPO9Zg3uNSL0nNQghj +FuD3iGLTUBCtM66oTe0jLSslHe8gLGEQqyMzHOzYxNqibxcOZIe8Qt0NAoGBAO+U +I5+XWjWEgDmvyC3TrOSf/KCGjtu0TSv30ipv27bDLMrpvPmD/5lpptTFwcxvVhCs +2b+chCjlghFSWFbBULBrfci2FtliClOVMYrlNBdUSJhf3aYSG2Doe6Bgt1n2CpNn +/iu37Y3NfemZBJA7hNl4dYe+f+uzM87cdQ214+jrAoGAXA0XxX8ll2+ToOLJsaNT +OvNB9h9Uc5qK5X5w+7G7O998BN2PC/MWp8H+2fVqpXgNENpNXttkRm1hk1dych86 +EunfdPuqsX+as44oCyJGFHVBnWpm33eWQw9YqANRI+pCJzP08I5WK3osnPiwshd+ +hR54yjgfYhBFNI7B95PmEQkCgYBzFSz7h1+s34Ycr8SvxsOBWxymG5zaCsUbPsL0 +4aCgLScCHb9J+E86aVbbVFdglYa5Id7DPTL61ixhl7WZjujspeXZGSbmq0Kcnckb +mDgqkLECiOJW2NHP/j0McAkDLL4tysF8TLDO8gvuvzNC+WQ6drO2ThrypLVZQ+ry +eBIPmwKBgEZxhqa0gVvHQG/7Od69KWj4eJP28kq13RhKay8JOoN0vPmspXJo1HY3 +CKuHRG+AP579dncdUnOMvfXOtkdM4vk0+hWASBQzM9xzVcztCa+koAugjVaLS9A+ +9uQoqEeVNTckxx0S2bYevRy7hGQmUJTyQm3j1zEUR5jpdbL83Fbq +-----END RSA PRIVATE KEY-----` + publicStr := `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4f5wg5l2hKsTeNem/V41 +fGnJm6gOdrj8ym3rFkEU/wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7 +mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBp +HssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2 +XrHhR+1DcKJzQBSTAGnpYVaqpsARap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3b +ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy +7wIDAQAB +-----END PUBLIC KEY-----` + + signKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateStr)) + require.NoError(t, err) + + publicKey, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicStr)) + require.NoError(t, err) + + // convert modulus + n := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(publicKey.N.Bytes()) + + // convert exponent + eBuff := make([]byte, 4) + binary.LittleEndian.PutUint32(eBuff, uint32(publicKey.E)) + e := base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(eBuff) + + JWK := struct { + Alg string `json:"alg"` + Kty string `json:"kty"` + Use string `json:"use"` + Kid string `json:"kid"` + E string `json:"e"` + N string `json:"n"` + }{Alg: "RS256", Kty: "RSA", Use: "sig", Kid: "112233", N: n, E: e[:4]} + + var buffJwk []byte + buffJwk, err = json.Marshal(JWK) + require.NoError(t, err) + jwk = string(buffJwk) + + return signKey, jwk +} diff --git a/v2/provider/custom_server.go b/v2/provider/custom_server.go new file mode 100644 index 00000000..0ca145e3 --- /dev/null +++ b/v2/provider/custom_server.go @@ -0,0 +1,359 @@ +package provider + +import ( + "context" + "encoding/json" + "fmt" + "html/template" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + goauth2 "github.com/go-oauth2/oauth2/v4/server" + "golang.org/x/oauth2" + + "github.com/go-pkgz/auth/avatar" + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +// CustomHandlerOpt are options to initialize a handler for oauth2 server +type CustomHandlerOpt struct { + Endpoint oauth2.Endpoint + InfoURL string + MapUserFn func(UserData, []byte) token.User + BearerTokenHookFn BearerTokenHook + Scopes []string +} + +// CustomServerOpt are options to initialize a custom go-oauth2/oauth2 server +type CustomServerOpt struct { + logger.L + URL string + WithLoginPage bool + LoginPageHandler func(w http.ResponseWriter, r *http.Request) +} + +// NewCustomServer is helper function to initiate a customer server and prefill +// options needed for provider registration (see Service.AddCustomProvider) +func NewCustomServer(srv *goauth2.Server, sopts CustomServerOpt) *CustomServer { + copts := CustomHandlerOpt{ + Endpoint: oauth2.Endpoint{ + AuthURL: sopts.URL + "/authorize", + TokenURL: sopts.URL + "/access_token", + }, + InfoURL: sopts.URL + "/user", + MapUserFn: defaultMapUserFn, + } + + return &CustomServer{ + L: sopts.L, + URL: sopts.URL, + WithLoginPage: sopts.WithLoginPage, + LoginPageHandler: sopts.LoginPageHandler, + OauthServer: srv, + HandlerOpt: copts, + } +} + +// CustomServer is a wrapper over go-oauth2/oauth2 server running on its own port +type CustomServer struct { + logger.L + URL string // root url for custom oauth2 server + WithLoginPage bool // redirect to login html page if true + LoginPageHandler func(w http.ResponseWriter, r *http.Request) // handler for user-defined login page + OauthServer *goauth2.Server // an instance of go-oauth2/oauth2 server + HandlerOpt CustomHandlerOpt + httpServer *http.Server + lock sync.Mutex +} + +// Run starts serving on port from c.URL +func (c *CustomServer) Run(ctx context.Context) { + c.Logf("[INFO] run local go-oauth2/oauth2 server on %s", c.URL) + c.lock.Lock() + + u, err := url.Parse(c.URL) + if err != nil { + c.Logf("[ERROR] failed to parse service base URL=%s", c.URL) + return + } + + _, port, err := net.SplitHostPort(u.Host) + if err != nil { + c.Logf("[ERROR] failed to get port from URL=%s", c.URL) + return + } + + c.httpServer = &http.Server{ + Addr: fmt.Sprintf(":%s", port), + ReadHeaderTimeout: 5 * time.Second, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/authorize"): + c.handleAuthorize(w, r) + case strings.HasSuffix(r.URL.Path, "/access_token"): + if err = c.OauthServer.HandleTokenRequest(w, r); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + case strings.HasPrefix(r.URL.Path, "/user"): + c.handleUserInfo(w, r) + case strings.HasPrefix(r.URL.Path, "/avatar"): + c.handleAvatar(w, r) + default: + w.WriteHeader(http.StatusBadRequest) + return + } + }), + } + c.lock.Unlock() + + go func() { + <-ctx.Done() + c.Logf("[DEBUG] cancellation via context, %v", ctx.Err()) + c.Shutdown() + }() + + err = c.httpServer.ListenAndServe() + c.Logf("[WARN] go-oauth2/oauth2 server terminated, %s", err) +} + +func (c *CustomServer) handleAuthorize(w http.ResponseWriter, r *http.Request) { + // called for first time, ask for username + if c.WithLoginPage || c.LoginPageHandler != nil { + if r.ParseForm() != nil || r.Form.Get("username") == "" { + // show default template if user-defined function not specified + if c.LoginPageHandler != nil { + c.LoginPageHandler(w, r) + return + } + userLoginTmpl, err := template.New("page").Parse(defaultLoginTmpl) + if err != nil { + c.Logf("[ERROR] can't parse user login template, %s", err) + return + } + + formData := struct{ Query template.URL }{Query: template.URL(r.URL.RawQuery)} //nolint:gosec // query is safe + + if err := userLoginTmpl.Execute(w, formData); err != nil { + c.Logf("[WARN] can't write, %s", err) + } + return + } + } + + err := c.OauthServer.HandleAuthorizeRequest(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } +} + +func (c *CustomServer) handleUserInfo(w http.ResponseWriter, r *http.Request) { + ti, err := c.OauthServer.ValidationBearerToken(r) + if err != nil { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + userID := ti.GetUserID() + + user := token.User{ + ID: userID, + Name: userID, + Picture: fmt.Sprintf(c.URL+"/avatar?user=%s", url.QueryEscape(userID)), + } + res, err := json.Marshal(user) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + if _, err := w.Write(res); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +func (c *CustomServer) handleAvatar(w http.ResponseWriter, r *http.Request) { + user := r.URL.Query().Get("user") + b, err := avatar.GenerateAvatar(user) + if err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + if _, err = w.Write(b); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} + +// Shutdown go-oauth2/oauth2 server +func (c *CustomServer) Shutdown() { + c.Logf("[WARN] shutdown go-oauth2/oauth2 server") + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + c.lock.Lock() + if c.httpServer != nil { + if err := c.httpServer.Shutdown(ctx); err != nil { + c.Logf("[DEBUG] go-oauth2/oauth2 shutdown error, %s", err) + } + } + c.Logf("[DEBUG] shutdown go-oauth2/oauth2 server completed") + c.lock.Unlock() +} + +// NewCustom creates a handler for go-oauth2/oauth2 server +func NewCustom(name string, p Params, copts CustomHandlerOpt) Oauth2Handler { + return initOauth2Handler(p, Oauth2Handler{ + name: name, + endpoint: copts.Endpoint, + scopes: copts.Scopes, + infoURL: copts.InfoURL, + mapUser: copts.MapUserFn, + bearerTokenHook: copts.BearerTokenHookFn, + }) +} + +func defaultMapUserFn(data UserData, _ []byte) token.User { + userInfo := token.User{ + ID: data.Value("id"), + Name: data.Value("name"), + Picture: data.Value("picture"), + } + return userInfo +} + +var defaultLoginTmpl = ` + + + Dev OAuth + + + +
+
+

go-oauth2/oauth2

+

Golang OAuth 2.0 Server

+
+ + + + + +

+
+ + + +` diff --git a/v2/provider/custom_server_test.go b/v2/provider/custom_server_test.go new file mode 100644 index 00000000..3083920f --- /dev/null +++ b/v2/provider/custom_server_test.go @@ -0,0 +1,234 @@ +package provider + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/http/cookiejar" + "net/url" + "strings" + "testing" + "time" + + "github.com/go-oauth2/oauth2/v4/errors" + "github.com/go-oauth2/oauth2/v4/generates" + "github.com/go-oauth2/oauth2/v4/manage" + "github.com/go-oauth2/oauth2/v4/models" + goauth2 "github.com/go-oauth2/oauth2/v4/server" + "github.com/go-oauth2/oauth2/v4/store" + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +func TestCustomProvider(t *testing.T) { + srv := initGoauth2Srv(t) + + params := Params{ + URL: "http://127.0.0.1:8080", + JwtService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + DisableIAT: true, + }), + Issuer: "test-issuer", + Cid: "cid", + Csecret: "csecret", + L: logger.Std, + } + + var loginUsername string + var capturedUser token.User + + sopts := CustomServerOpt{ + URL: "http://127.0.0.1:9096", + L: logger.Std, + WithLoginPage: true, + LoginPageHandler: func(w http.ResponseWriter, r *http.Request) { + // // Simulate POST from login page + u, err := url.Parse("http://127.0.0.1:9096/authorize?" + r.URL.RawQuery) + if err != nil { + assert.Fail(t, "failed to parse url") + } + + jar, err := cookiejar.New(nil) + if err != nil { + assert.Fail(t, "failed initialize cookiesjar") + } + jar.SetCookies(u, r.Cookies()) + + form := url.Values{} + form.Add("username", loginUsername) + form.Add("password", "pwd1234") + + req, err := http.NewRequest("POST", "", strings.NewReader(form.Encode())) + if err != nil { + assert.Fail(t, "failed to simulate POST request in login page ,%s", err) + } + req.URL = u + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + c := &http.Client{Jar: jar, Timeout: time.Second * 10} + resp, err := c.Do(req) + if err != nil { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 2678400, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + claims, err := params.JwtService.Parse(resp.Cookies()[0].Value) + assert.NoError(t, err) + + capturedUser = *claims.User + }, + } + + prov := NewCustomServer(srv, sopts) + + h := NewCustom("myprov", params, prov.HandlerOpt) + s := Service{Provider: h} + + router := http.NewServeMux() + router.Handle("/auth/customprov/", http.HandlerFunc(s.Handler)) + ts := &http.Server{Addr: fmt.Sprintf("127.0.0.1:%d", 8080), Handler: router} //nolint:gosec + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go prov.Run(ctx) + go ts.ListenAndServe() + + defer func() { + prov.Shutdown() + _ = ts.Shutdown(context.TODO()) + }() + + time.Sleep(400 * time.Millisecond) + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: time.Second * 10} + + // check non-admin, permanent + loginUsername = "admin" + resp, err := client.Get("http://127.0.0.1:8080/auth/customprov/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + assert.Equal(t, token.User{Name: "admin", ID: "admin", + Picture: "http://127.0.0.1:9096/avatar?user=admin", IP: ""}, capturedUser) + + // check avatar + resp, err = client.Get("http://127.0.0.1:9096/avatar?user=dev_user") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, 960, len(body)) + t.Logf("headers: %+v", resp.Header) + + // check malicious user ID + loginUsername = "attack" + resp, err = client.Get("http://127.0.0.1:8080/auth/customprov/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + // user ID in picture URL is encoded + assert.Equal(t, "http://127.0.0.1:9096/avatar?user=none%26attack%3Dvalue%22%3E%3Cscript%3Enasty+stuff%3C%2Fscript%3E", capturedUser.Picture) + + // check default login page + prov.LoginPageHandler = nil + resp, err = client.Get("http://127.0.0.1:8080/auth/customprov/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) +} + +func TestCustomProviderCancel(t *testing.T) { + srv := initGoauth2Srv(t) + prov := CustomServer{ + OauthServer: srv, + URL: "http://127.0.0.1:9096", + L: logger.Std, + WithLoginPage: true, + } + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan bool) + go func() { + prov.Run(ctx) + done <- true + }() + cancel() + + select { + case <-time.After(time.Second): + t.Fail() + case <-done: + } +} + +func initGoauth2Srv(t *testing.T) *goauth2.Server { + manager := manage.NewDefaultManager() + manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) + + // token store + manager.MustTokenStorage(store.NewMemoryTokenStore()) + + // generate jwt access token + manager.MapAccessGenerate(generates.NewJWTAccessGenerate("", []byte("00000000"), jwt.SigningMethodHS512)) + + // client memory store + clientStore := store.NewClientStore() + err := clientStore.Set("cid", &models.Client{ + ID: "cid", + Secret: "csecret", + Domain: "http://127.0.0.1:8080", + }) + if err != nil { + assert.Fail(t, "failed to set up a client store for go-oauth2/oauth2 server, %s", err) + } + manager.MapClientStorage(clientStore) + + srv := goauth2.NewServer(goauth2.NewConfig(), manager) + + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) { + if r.ParseForm() != nil { + return "", fmt.Errorf("no username and password in request") + } + if r.Form.Get("username") == "admin" && r.Form.Get("password") == "pwd1234" { + return "admin", nil + } + if r.Form.Get("username") == "attack" && r.Form.Get("password") == "pwd1234" { + return "none&attack=value\">", nil + } + return "", fmt.Errorf("wrong creds") + }) + + srv.SetInternalErrorHandler(func(err error) (re *errors.Response) { + log.Println("Internal Error:", err.Error()) + return + }) + + srv.SetResponseErrorHandler(func(re *errors.Response) { + log.Println("Response Error:", re.Error.Error()) + }) + + return srv +} diff --git a/v2/provider/dev_provider.go b/v2/provider/dev_provider.go new file mode 100644 index 00000000..e93bb0e6 --- /dev/null +++ b/v2/provider/dev_provider.go @@ -0,0 +1,318 @@ +package provider + +import ( + "context" + "fmt" + "html/template" + "net/http" + "strings" + "sync" + "time" + + "golang.org/x/oauth2" + + "github.com/go-pkgz/auth/avatar" + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +const ( + defDevAuthPort = 8084 + defDevAuthHost = "127.0.0.1" +) + +// DevAuthServer is a fake oauth server for development +// it provides stand-alone server running on its own port and pretending to be the real oauth2. It also provides +// Dev Provider the same way as normal providers do, i.e. like github, google and others. +// can run in interactive and non-interactive mode. In interactive mode login attempts will show login form to select +// desired user name, this is the mode used for development. Non-interactive mode for tests only. +type DevAuthServer struct { + logger.L + Provider Oauth2Handler + Automatic bool + GetEmailFn func(string) string + username string // unsafe, but fine for dev + httpServer *http.Server + lock sync.Mutex +} + +// Run oauth2 dev server on port devAuthPort +func (d *DevAuthServer) Run(ctx context.Context) { // nolint (gocyclo) + if d.Provider.Port == 0 { + d.Provider.Port = defDevAuthPort + } + if d.Provider.Host == "" { + d.Provider.Host = defDevAuthHost + } + + d.username = "dev_user" + d.Logf("[INFO] run local oauth2 dev server on %d, redirect url=%s", d.Provider.Port, d.Provider.conf.RedirectURL) + d.lock.Lock() + var err error + + userFormTmpl, err := template.New("page").Parse(devUserFormTmpl) + if err != nil { + d.Logf("[WARN] can't parse user form template, %s", err) + return + } + + d.httpServer = &http.Server{ + Addr: fmt.Sprintf(":%d", d.Provider.Port), + ReadHeaderTimeout: 5 * time.Second, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + d.Logf("[DEBUG] dev oauth request %s %s %+v", r.Method, r.URL, r.Header) + switch { + + case strings.HasPrefix(r.URL.Path, "/login/oauth/authorize"): + + // first time it will be called without username and will ask for one + if !d.Automatic && (r.ParseForm() != nil || r.Form.Get("username") == "") { + formData := struct{ Query template.URL }{Query: template.URL(r.URL.RawQuery)} //nolint:gosec // query is safe + if err = userFormTmpl.Execute(w, formData); err != nil { + d.Logf("[WARN] can't write, %s", err) + } + return + } + + if !d.Automatic { + d.username = r.Form.Get("username") + } + + state := r.URL.Query().Get("state") + callbackURL := fmt.Sprintf("%s?code=g0ZGZmNjVmOWI&state=%s", d.Provider.conf.RedirectURL, state) + d.Logf("[DEBUG] callback url=%s", callbackURL) + w.Header().Add("Location", callbackURL) + w.WriteHeader(http.StatusFound) + + case strings.HasPrefix(r.URL.Path, "/login/oauth/access_token"): + res := `{ + "access_token":"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3", + "token_type":"bearer", + "expires_in":3600, + "refresh_token":"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk", + "scope":"create", + "state":"12345678" + }` + w.Header().Set("Content-Type", "application/json; charset=utf-8") + if _, err = w.Write([]byte(res)); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + case strings.HasPrefix(r.URL.Path, "/user"): + ava := fmt.Sprintf("http://%s:%d/avatar?user=%s", d.Provider.Host, d.Provider.Port, d.username) + res := fmt.Sprintf(`{ + "id": "%s", + "name":"%s", + "picture":"%s" + }`, d.username, d.username, ava) + + if d.GetEmailFn != nil { + email := d.GetEmailFn(d.username) + res = fmt.Sprintf(`{ + "id": "%s", + "name":"%s", + "picture":"%s", + "email": "%s" + }`, d.username, d.username, ava, email) + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + if _, err = w.Write([]byte(res)); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + case strings.HasPrefix(r.URL.Path, "/avatar"): + user := r.URL.Query().Get("user") + b, e := avatar.GenerateAvatar(user) + if e != nil { + w.WriteHeader(http.StatusNotFound) + return + } + if _, err = w.Write(b); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + default: + w.WriteHeader(http.StatusBadRequest) + } + }), + } + d.lock.Unlock() + + go func() { + <-ctx.Done() + d.Logf("[DEBUG] cancellation via context, %v", ctx.Err()) + d.Shutdown() + }() + + err = d.httpServer.ListenAndServe() + d.Logf("[WARN] dev oauth2 server terminated, %s", err) +} + +// Shutdown oauth2 dev server +func (d *DevAuthServer) Shutdown() { + d.Logf("[WARN] shutdown oauth2 dev server") + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d.lock.Lock() + if d.httpServer != nil { + if err := d.httpServer.Shutdown(ctx); err != nil { + d.Logf("[DEBUG] oauth2 dev shutdown error, %s", err) + } + } + d.Logf("[DEBUG] shutdown dev oauth2 server completed") + d.lock.Unlock() +} + +// NewDev makes dev oauth2 provider for admin user +func NewDev(p Params) Oauth2Handler { + if p.Port == 0 { + p.Port = defDevAuthPort + } + if p.Host == "" { + p.Host = defDevAuthHost + } + oh := initOauth2Handler(p, Oauth2Handler{ + name: "dev", + endpoint: oauth2.Endpoint{ + AuthURL: fmt.Sprintf("http://%s:%d/login/oauth/authorize", p.Host, p.Port), + TokenURL: fmt.Sprintf("http://%s:%d/login/oauth/access_token", p.Host, p.Port), + }, + scopes: []string{"user:email"}, + infoURL: fmt.Sprintf("http://%s:%d/user", p.Host, p.Port), + mapUser: func(data UserData, _ []byte) token.User { + userInfo := token.User{ + ID: data.Value("id"), + Name: data.Value("name"), + Picture: data.Value("picture"), + Email: data.Value("email"), + } + return userInfo + }, + }) + + oh.conf.RedirectURL = p.URL + "/auth/dev/callback" + return oh +} + +var devUserFormTmpl = ` + + + Dev OAuth + + + +
+
+

GO-PKGZ/AUTH

+

Dev Provider

+
+ + +

Not for production use

+
+ + + +` diff --git a/v2/provider/dev_provider_test.go b/v2/provider/dev_provider_test.go new file mode 100644 index 00000000..f1fb8f88 --- /dev/null +++ b/v2/provider/dev_provider_test.go @@ -0,0 +1,111 @@ +package provider + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/cookiejar" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +func TestDevProvider(t *testing.T) { + params := Params{Cid: "cid", Csecret: "csecret", URL: "http://127.0.0.1:8080", L: logger.Std, Port: 18084, + JwtService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + DisableIAT: true, + }), + } + + devProvider := NewDev(params) + s := Service{Provider: devProvider} + devOauth2Srv := DevAuthServer{Provider: devProvider, Automatic: true, username: "dev_user", L: logger.Std} + devOauth2Srv.GetEmailFn = func(username string) string { + return username + "@example.com" + } + + router := http.NewServeMux() + router.Handle("/auth/dev/", http.HandlerFunc(s.Handler)) + + ts := &http.Server{Addr: fmt.Sprintf("127.0.0.1:%d", 8080), Handler: router} //nolint:gosec + go devOauth2Srv.Run(context.TODO()) + go ts.ListenAndServe() + defer func() { + devOauth2Srv.Shutdown() + _ = ts.Shutdown(context.TODO()) + }() + + time.Sleep(200 * time.Millisecond) + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + // check non-admin, permanent + resp, err := client.Get("http://127.0.0.1:8080/auth/dev/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 2678400, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + claims, err := params.JwtService.Parse(resp.Cookies()[0].Value) + assert.NoError(t, err) + + assert.Equal(t, token.User{Name: "dev_user", ID: "dev_user", + Picture: "http://127.0.0.1:18084/avatar?user=dev_user", IP: "", Email: "dev_user@example.com"}, *claims.User) + + // check avatar + resp, err = client.Get("http://127.0.0.1:18084/avatar?user=dev_user") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, 960, len(body)) + t.Logf("headers: %+v", resp.Header) +} + +func TestDevProviderCancel(t *testing.T) { + params := Params{Cid: "cid", Csecret: "csecret", URL: "http://127.0.0.1:8080", L: logger.Std, + JwtService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + DisableIAT: true, + }), + } + + devProvider := NewDev(params) + devOauth2Srv := DevAuthServer{Provider: devProvider, Automatic: true, username: "dev_user", L: logger.Std} + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan bool) + go func() { + devOauth2Srv.Run(ctx) + done <- true + }() + cancel() + + select { + case <-time.After(time.Second): + t.Fail() + case <-done: + } +} diff --git a/v2/provider/direct.go b/v2/provider/direct.go new file mode 100644 index 00000000..742ebd5a --- /dev/null +++ b/v2/provider/direct.go @@ -0,0 +1,193 @@ +package provider + +import ( + "crypto/sha1" //nolint + "encoding/json" + "fmt" + "mime" + "net/http" + "time" + + "github.com/go-pkgz/rest" + "github.com/golang-jwt/jwt" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +const ( + // MaxHTTPBodySize defines max http body size + MaxHTTPBodySize = 1024 * 1024 +) + +// DirectHandler implements non-oauth2 provider authorizing user in traditional way with storage +// with users and hashes +type DirectHandler struct { + logger.L + CredChecker CredChecker + ProviderName string + TokenService TokenService + Issuer string + AvatarSaver AvatarSaver + UserIDFunc UserIDFunc +} + +// CredChecker defines interface to check credentials +type CredChecker interface { + Check(user, password string) (ok bool, err error) +} + +// UserIDFunc allows to provide custom func making userID instead of the default based on user's name hash +type UserIDFunc func(user string, r *http.Request) string + +// CredCheckerFunc type is an adapter to allow the use of ordinary functions as CredsChecker. +type CredCheckerFunc func(user, password string) (ok bool, err error) + +// Check calls f(user,passwd) +func (f CredCheckerFunc) Check(user, password string) (ok bool, err error) { + return f(user, password) +} + +// credentials holds user credentials +type credentials struct { + User string `json:"user"` + Password string `json:"passwd"` + Audience string `json:"aud"` +} + +// Name of the handler +func (p DirectHandler) Name() string { return p.ProviderName } + +// LoginHandler checks "user" and "passwd" against data store and makes jwt if all passed. +// +// GET /something?user=name&passwd=xyz&aud=bar&sess=[0|1] +// +// POST /something?sess[0|1] +// Accepts application/x-www-form-urlencoded or application/json encoded requests. +// +// application/x-www-form-urlencoded body example: +// user=name&passwd=xyz&aud=bar +// +// application/json body example: +// +// { +// "user": "name", +// "passwd": "xyz", +// "aud": "bar", +// } +func (p DirectHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { + creds, err := p.getCredentials(w, r) + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusBadRequest, err, "failed to parse credentials") + return + } + sessOnly := r.URL.Query().Get("sess") == "1" + if p.CredChecker == nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, + fmt.Errorf("no credential checker"), "no credential checker") + return + } + ok, err := p.CredChecker.Check(creds.User, creds.Password) + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to check user credentials") + return + } + if !ok { + rest.SendErrorJSON(w, r, p.L, http.StatusForbidden, nil, "incorrect user or password") + return + } + + userID := p.ProviderName + "_" + token.HashID(sha1.New(), creds.User) + if p.UserIDFunc != nil { + userID = p.ProviderName + "_" + token.HashID(sha1.New(), p.UserIDFunc(creds.User, r)) + } + + u := token.User{ + Name: creds.User, + ID: userID, + } + u, err = setAvatar(p.AvatarSaver, u, &http.Client{Timeout: 5 * time.Second}) + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to save avatar to proxy") + return + } + + cid, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "can't make token id") + return + } + + claims := token.Claims{ + User: &u, + StandardClaims: jwt.StandardClaims{ + Id: cid, + Issuer: p.Issuer, + Audience: creds.Audience, + }, + SessionOnly: sessOnly, + } + + if _, err = p.TokenService.Set(w, claims); err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to set token") + return + } + rest.RenderJSON(w, claims.User) +} + +// getCredentials extracts user and password from request +func (p DirectHandler) getCredentials(w http.ResponseWriter, r *http.Request) (credentials, error) { + + // GET /something?user=name&passwd=xyz&aud=bar + if r.Method == "GET" { + return credentials{ + User: r.URL.Query().Get("user"), + Password: r.URL.Query().Get("passwd"), + Audience: r.URL.Query().Get("aud"), + }, nil + } + + if r.Method != "POST" { + return credentials{}, fmt.Errorf("method %s not supported", r.Method) + } + + if r.Body != nil { + r.Body = http.MaxBytesReader(w, r.Body, MaxHTTPBodySize) + } + contentType := r.Header.Get("Content-Type") + if contentType != "" { + mt, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + return credentials{}, err + } + contentType = mt + } + + // POST with json body + if contentType == "application/json" { + var creds credentials + if err := json.NewDecoder(r.Body).Decode(&creds); err != nil { + return credentials{}, fmt.Errorf("failed to parse request body: %w", err) + } + return creds, nil + } + + // POST with form + if err := r.ParseForm(); err != nil { + return credentials{}, fmt.Errorf("failed to parse request: %w", err) + } + + return credentials{ + User: r.Form.Get("user"), + Password: r.Form.Get("passwd"), + Audience: r.Form.Get("aud"), + }, nil +} + +// AuthHandler doesn't do anything for direct login as it has no callbacks +func (p DirectHandler) AuthHandler(http.ResponseWriter, *http.Request) {} + +// LogoutHandler - GET /logout +func (p DirectHandler) LogoutHandler(w http.ResponseWriter, _ *http.Request) { + p.TokenService.Reset(w) +} diff --git a/v2/provider/direct_test.go b/v2/provider/direct_test.go new file mode 100644 index 00000000..334ceeee --- /dev/null +++ b/v2/provider/direct_test.go @@ -0,0 +1,277 @@ +package provider + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +func TestDirect_LoginHandler(t *testing.T) { + testCases := map[string]struct { + makeRequest func(t *testing.T) *http.Request + }{ + "GET": { + makeRequest: func(t *testing.T) *http.Request { + req, err := http.NewRequest("GET", "/login?user=myuser&passwd=pppp&aud=xyz123&from=http://example.com", http.NoBody) + require.NoError(t, err) + return req + }, + }, + "POST application/x-www-form-urlencoded": { + makeRequest: func(t *testing.T) *http.Request { + form := url.Values{ + "user": {"myuser"}, + "passwd": {"pppp"}, + "aud": {"xyz123"}, + } + req, err := http.NewRequest("POST", "/login?from=http://example.com", strings.NewReader(form.Encode())) + require.NoError(t, err) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + return req + }, + }, + "POST application/json": { + makeRequest: func(t *testing.T) *http.Request { + jsonBody := `{"user":"myuser", "passwd":"pppp", "aud":"xyz123"}` + req, err := http.NewRequest("POST", "/login?from=http://example.com", strings.NewReader(jsonBody)) + require.NoError(t, err) + req.Header.Add("Content-Type", "application/json") + return req + }, + }, + "POST application/json; charset=utf-8": { + makeRequest: func(t *testing.T) *http.Request { + jsonBody := `{"user":"myuser", "passwd":"pppp", "aud":"xyz123"}` + req, err := http.NewRequest("POST", "/login?from=http://example.com", strings.NewReader(jsonBody)) + require.NoError(t, err) + req.Header.Add("Content-Type", "application/json") + return req + }, + }, + } + + for name, test := range testCases { + test := test + t.Run(name, func(t *testing.T) { + d := DirectHandler{ + ProviderName: "test", + CredChecker: &mockCredsChecker{ok: true}, + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + } + + assert.Equal(t, "test", d.Name()) + handler := http.HandlerFunc(d.LoginHandler) + + rr := httptest.NewRecorder() + req := test.makeRequest(t) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) + assert.Equal(t, `{"name":"myuser","id":"test_ed6307123e30cc7682328522d1d090d9c7525b32","picture":""}`+"\n", rr.Body.String()) + + request := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}} + c, err := request.Cookie("JWT") + require.NoError(t, err) + claims, err := d.TokenService.Parse(c.Value) + require.NoError(t, err) + t.Logf("%+v", claims) + assert.Equal(t, "xyz123", claims.Audience) + assert.Equal(t, "iss-test", claims.Issuer) + assert.True(t, claims.ExpiresAt > time.Now().Unix()) + assert.Equal(t, "myuser", claims.User.Name) + }) + } +} + +func TestDirect_LoginHandlerCustomUserID(t *testing.T) { + d := DirectHandler{ + ProviderName: "test", + CredChecker: &mockCredsChecker{ok: true}, + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + UserIDFunc: func(user string, r *http.Request) string { + return user + "_custom_id" + }, + } + + assert.Equal(t, "test", d.Name()) + handler := http.HandlerFunc(d.LoginHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/login?user=myuser&passwd=pppp&aud=xyz123&from=http://example.com", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) + assert.Equal(t, `{"name":"myuser","id":"test_18c4eec1ecbe23902609e999c4d3da997e7ac10f","picture":""}`+"\n", rr.Body.String()) +} + +func TestDirect_LoginHandlerFailed(t *testing.T) { + testCases := map[string]struct { + makeRequest func(t *testing.T) *http.Request + credChecker CredChecker + wantCode int + wantBody string + }{ + "no credential checker": { + makeRequest: func(t *testing.T) *http.Request { + req, err := http.NewRequest("GET", "/login?user=myuser&passwd=pppp&aud=xyz123", http.NoBody) + require.NoError(t, err) + return req + }, + credChecker: nil, + wantCode: http.StatusInternalServerError, + wantBody: `{"error":"no credential checker"}`, + }, + "failed to check user credentials": { + makeRequest: func(t *testing.T) *http.Request { + req, err := http.NewRequest("GET", "/login?user=myuser&passwd=pppp&aud=xyz123", http.NoBody) + require.NoError(t, err) + return req + }, + credChecker: &mockCredsChecker{err: fmt.Errorf("some err"), ok: false}, + wantCode: http.StatusInternalServerError, + wantBody: `{"error":"failed to check user credentials"}`, + }, + "incorrect user or password": { + makeRequest: func(t *testing.T) *http.Request { + req, err := http.NewRequest("GET", "/login?user=myuser&passwd=pppp&aud=xyz123", http.NoBody) + require.NoError(t, err) + return req + }, + credChecker: &mockCredsChecker{err: nil, ok: false}, + wantCode: http.StatusForbidden, + wantBody: `{"error":"incorrect user or password"}`, + }, + "malformed json body": { + makeRequest: func(t *testing.T) *http.Request { + jsonBody := `{"user":"myuser"` + req, err := http.NewRequest("POST", "/login?from=http://example.com", strings.NewReader(jsonBody)) + require.NoError(t, err) + req.Header.Add("Content-Type", "application/json") + return req + }, + credChecker: &mockCredsChecker{err: nil, ok: true}, + wantCode: http.StatusBadRequest, + wantBody: `{"error":"failed to parse credentials"}`, + }, + "malformed application/x-www-form-urlencoded body": { + makeRequest: func(t *testing.T) *http.Request { + req, err := http.NewRequest("POST", "/login?from=http://example.com", nil) //nolint + require.NoError(t, err) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + return req + }, + credChecker: &mockCredsChecker{err: nil, ok: true}, + wantCode: http.StatusBadRequest, + wantBody: `{"error":"failed to parse credentials"}`, + }, + } + + for name, test := range testCases { + test := test + t.Run(name, func(t *testing.T) { + d := DirectHandler{ + ProviderName: "test", + CredChecker: test.credChecker, + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + } + + handler := http.HandlerFunc(d.LoginHandler) + rr := httptest.NewRecorder() + req := test.makeRequest(t) + handler.ServeHTTP(rr, req) + assert.Equal(t, test.wantCode, rr.Code) + assert.Equal(t, test.wantBody+"\n", rr.Body.String()) + }) + } +} + +func TestDirect_Logout(t *testing.T) { + d := DirectHandler{ + ProviderName: "test", + CredChecker: &mockCredsChecker{ok: true}, + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + } + + handler := http.HandlerFunc(d.LogoutHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/logout", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) + assert.Equal(t, 2, len(rr.Header()["Set-Cookie"])) + + request := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}} + c, err := request.Cookie("JWT") + require.NoError(t, err) + assert.Equal(t, time.Time{}, c.Expires) + + c, err = request.Cookie("XSRF-TOKEN") + require.NoError(t, err) + assert.Equal(t, time.Time{}, c.Expires) +} + +func TestDirect_AuthHandler(t *testing.T) { + d := DirectHandler{} + handler := http.HandlerFunc(d.AuthHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/callback", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) +} + +func TestDirect_CredChecker(t *testing.T) { + ch := CredCheckerFunc(func(user string, password string) (ok bool, err error) { + if user == "dev" && password == "password" { + return true, nil + } + return false, nil + }) + + ok, err := ch.Check("user", "blah") + assert.NoError(t, err) + assert.False(t, ok) + + ok, err = ch.Check("dev", "password") + assert.NoError(t, err) + assert.True(t, ok) +} + +type mockCredsChecker struct { + ok bool + err error +} + +func (m *mockCredsChecker) Check(string, string) (ok bool, err error) { return m.ok, m.err } diff --git a/v2/provider/oauth1.go b/v2/provider/oauth1.go new file mode 100644 index 00000000..4aec7f56 --- /dev/null +++ b/v2/provider/oauth1.go @@ -0,0 +1,195 @@ +package provider + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/dghubble/oauth1" + "github.com/go-pkgz/rest" + "github.com/golang-jwt/jwt" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +// Oauth1Handler implements /login, /callback and /logout handlers for oauth1 flow +type Oauth1Handler struct { + Params + name string + infoURL string + conf oauth1.Config + mapUser func(UserData, []byte) token.User // map info from InfoURL to User +} + +// Name returns provider name +func (h Oauth1Handler) Name() string { return h.name } + +// LoginHandler - GET /login?from=redirect-back-url&site=siteID&session=1 +func (h Oauth1Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { + h.Logf("[DEBUG] login with %s", h.Name()) + + // setting RedirectURL to {rootURL}/{routingPath}/{provider}/callback + // e.g. http://localhost:8080/auth/twitter/callback + h.conf.CallbackURL = h.makeRedirURL(r.URL.Path) + + requestToken, requestSecret, err := h.conf.RequestToken() + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to get request token") + return + } + + // use requestSecret as a state in oauth2 + cid, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to make claim's id") + return + } + + claims := token.Claims{ + Handshake: &token.Handshake{ + State: requestSecret, + From: r.URL.Query().Get("from"), + }, + SessionOnly: r.URL.Query().Get("session") != "" && r.URL.Query().Get("session") != "0", + StandardClaims: jwt.StandardClaims{ + Id: cid, + Audience: r.URL.Query().Get("site"), + ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), + NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + }, + } + + if _, err = h.JwtService.Set(w, claims); err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to set token") + return + } + + authURL, err := h.conf.AuthorizationURL(requestToken) + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to obtain oauth1 URL") + return + } + + http.Redirect(w, r, authURL.String(), http.StatusFound) +} + +// AuthHandler fills user info and redirects to "from" url. This is callback url redirected locally by browser +// GET /callback +func (h Oauth1Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { + oauthClaims, _, err := h.JwtService.Get(r) + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to get token") + return + } + + requestToken, verifier, err := oauth1.ParseAuthorizationCallback(r) + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to parse response from oauth1 server") + return + } + + accessToken, accessSecret, err := h.conf.AccessToken(requestToken, oauthClaims.Handshake.State, verifier) + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to get accessToken and accessSecret") + return + } + + tok := oauth1.NewToken(accessToken, accessSecret) + client := h.conf.Client(context.Background(), tok) + + uinfo, err := client.Get(h.infoURL) + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusServiceUnavailable, err, "failed to get client info") + return + } + + defer func() { + if e := uinfo.Body.Close(); e != nil { + h.Logf("[WARN] failed to close response body, %s", e) + } + }() + + data, err := io.ReadAll(uinfo.Body) + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to read user info") + return + } + + jData := map[string]interface{}{} + if e := json.Unmarshal(data, &jData); e != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to unmarshal user info") + return + } + h.Logf("[DEBUG] got raw user info %+v", jData) + + u := h.mapUser(jData, data) + u, err = setAvatar(h.AvatarSaver, u, &http.Client{Timeout: 5 * time.Second}) + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to save avatar to proxy") + return + } + + cid, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to make claim's id") + return + } + claims := token.Claims{ + User: &u, + StandardClaims: jwt.StandardClaims{ + Issuer: h.Issuer, + Id: cid, + Audience: oauthClaims.Audience, + }, + SessionOnly: oauthClaims.SessionOnly, + } + + if _, err = h.JwtService.Set(w, claims); err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusInternalServerError, err, "failed to set token") + return + } + + h.Logf("[DEBUG] user info %+v", u) + + // redirect to back url if presented in login query params + if oauthClaims.Handshake != nil && oauthClaims.Handshake.From != "" { + http.Redirect(w, r, oauthClaims.Handshake.From, http.StatusTemporaryRedirect) + return + } + rest.RenderJSON(w, &u) +} + +// LogoutHandler - GET /logout +func (h Oauth1Handler) LogoutHandler(w http.ResponseWriter, r *http.Request) { + if _, _, err := h.JwtService.Get(r); err != nil { + rest.SendErrorJSON(w, r, h.L, http.StatusForbidden, err, "logout not allowed") + return + } + h.JwtService.Reset(w) +} + +func (h Oauth1Handler) makeRedirURL(path string) string { + elems := strings.Split(path, "/") + newPath := strings.Join(elems[:len(elems)-1], "/") + + return strings.TrimSuffix(h.URL, "/") + strings.TrimSuffix(newPath, "/") + urlCallbackSuffix +} + +// initOauth2Handler makes oauth1 handler for given provider +func initOauth1Handler(p Params, service Oauth1Handler) Oauth1Handler { + if p.L == nil { + p.L = logger.NoOp + } + p.Logf("[INFO] init oauth1 service %s", service.name) + service.Params = p + service.conf.ConsumerKey = p.Cid + service.conf.ConsumerSecret = p.Csecret + + p.Logf("[DEBUG] created %s oauth2, id=%s, redir=%s, endpoint=%s", + service.name, service.Cid, service.makeRedirURL("/{route}/"+service.name+"/"), service.conf.Endpoint) + return service +} diff --git a/v2/provider/oauth1_test.go b/v2/provider/oauth1_test.go new file mode 100644 index 00000000..7a8d1182 --- /dev/null +++ b/v2/provider/oauth1_test.go @@ -0,0 +1,293 @@ +package provider + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/http/cookiejar" + "strings" + "testing" + "time" + + "github.com/dghubble/oauth1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +const ( + timeout = 100 + loginPort = 8983 + authPort = 8984 +) + +func TestOauth1Login(t *testing.T) { + teardown := prepOauth1Test(t, loginPort, authPort) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: timeout * time.Second} + + // check non-admin, permanent + resp, err := client.Get(fmt.Sprintf("http://localhost:%d/login?site=remark", loginPort)) + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 2678400, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + u := token.User{} + err = json.Unmarshal(body, &u) + assert.NoError(t, err) + assert.Equal(t, token.User{Name: "blah", ID: "mock_myuser1", Picture: "http://example.com/custom.png", IP: ""}, u) + + tk := resp.Cookies()[0].Value + jwtSvc := token.NewService(token.Opts{SecretReader: token.SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31}) + + claims, err := jwtSvc.Parse(tk) + require.NoError(t, err) + t.Log(claims) + assert.Equal(t, "remark42", claims.Issuer) + assert.Equal(t, "remark", claims.Audience) + + // check admin user + resp, err = client.Get(fmt.Sprintf("http://localhost:%d/login?site=remark", loginPort)) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + u = token.User{} + err = json.Unmarshal(body, &u) + assert.NoError(t, err) + assert.Equal(t, token.User{Name: "blah", ID: "mock_myuser2", Picture: "http://example.com/ava12345.png", + Attributes: map[string]interface{}{"admin": true}}, u) + +} + +func TestOauth1LoginSessionOnly(t *testing.T) { + + teardown := prepOauth1Test(t, loginPort, authPort) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: timeout * time.Second} + + // check non-admin, session + resp, err := client.Get(fmt.Sprintf("http://localhost:%d/login?site=remark&session=1", loginPort)) + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 0, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + req, err := http.NewRequest("GET", "http://example.com", http.NoBody) + require.Nil(t, err) + + req.AddCookie(resp.Cookies()[0]) + req.AddCookie(resp.Cookies()[1]) + req.Header.Add("X-XSRF-TOKEN", resp.Cookies()[1].Value) + + jwtService := token.NewService(token.Opts{SecretReader: token.SecretFunc(mockKeyStore)}) + res, _, err := jwtService.Get(req) + require.Nil(t, err) + assert.Equal(t, true, res.SessionOnly) + t.Logf("%+v", res) +} + +func TestOauth1Logout(t *testing.T) { + + teardown := prepOauth1Test(t, loginPort, authPort) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: timeout * time.Second} + + req, err := http.NewRequest("GET", fmt.Sprintf("http://localhost:%d/logout", loginPort), http.NoBody) + require.Nil(t, err) + resp, err := client.Do(req) + require.Nil(t, err) + assert.Equal(t, 403, resp.StatusCode, "user not lagged in") + + req, err = http.NewRequest("GET", fmt.Sprintf("http://localhost:%d/logout", loginPort), http.NoBody) + require.NoError(t, err) + expiration := int(365 * 24 * time.Hour.Seconds()) //nolint + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + resp, err = client.Do(req) + require.Nil(t, err) + require.Equal(t, 200, resp.StatusCode) + + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name, "token cookie cleared") + assert.Equal(t, "", resp.Cookies()[0].Value) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name, "xsrf cookie cleared") + assert.Equal(t, "", resp.Cookies()[1].Value) +} + +func TestOauth1InitProvider(t *testing.T) { + params := Params{URL: "url", Cid: "cid", Csecret: "csecret", Issuer: "app-test"} + provider := Oauth1Handler{name: "test"} + res := initOauth1Handler(params, provider) + assert.Equal(t, "cid", res.conf.ConsumerKey) + assert.Equal(t, "csecret", res.conf.ConsumerSecret) + assert.Equal(t, "test", res.name) + assert.Equal(t, "app-test", res.Issuer) +} + +func TestOauth1InvalidHandler(t *testing.T) { + teardown := prepOauth1Test(t, loginPort, authPort) + defer teardown() + + client := &http.Client{Timeout: timeout * time.Second} + resp, err := client.Get(fmt.Sprintf("http://localhost:%d/login_bad", loginPort)) + require.NoError(t, err) + assert.Equal(t, 404, resp.StatusCode) + + resp, err = client.Post(fmt.Sprintf("http://localhost:%d/login", loginPort), "", nil) + require.NoError(t, err) + assert.Equal(t, 500, resp.StatusCode) +} + +func TestOauth1MakeRedirURL(t *testing.T) { + cases := []struct{ rootURL, route, out string }{ + {"localhost:8080/", "/my/auth/path/google", "localhost:8080/my/auth/path/callback"}, + {"localhost:8080", "/auth/google", "localhost:8080/auth/callback"}, + {"localhost:8080/", "/auth/google", "localhost:8080/auth/callback"}, + {"localhost:8080", "/", "localhost:8080/callback"}, + {"localhost:8080/", "/", "localhost:8080/callback"}, + {"mysite.com", "", "mysite.com/callback"}, + } + + for i := range cases { + c := cases[i] + oh := initOauth1Handler(Params{URL: c.rootURL}, Oauth1Handler{}) + assert.Equal(t, c.out, oh.makeRedirURL(c.route)) + } +} + +func prepOauth1Test(t *testing.T, loginPort, authPort int) func() { //nolint + + provider := Oauth1Handler{ + name: "mock", + conf: oauth1.Config{ + Endpoint: oauth1.Endpoint{ + RequestTokenURL: fmt.Sprintf("http://localhost:%d/login/oauth/request_token", authPort), + AuthorizeURL: fmt.Sprintf("http://localhost:%d/login/oauth/authorize", authPort), + AccessTokenURL: fmt.Sprintf("http://localhost:%d/login/oauth/access_token", authPort), + }, + }, + infoURL: fmt.Sprintf("http://localhost:%d/user", authPort), + mapUser: func(data UserData, _ []byte) token.User { + userInfo := token.User{ + ID: "mock_" + data.Value("id"), + Name: data.Value("name"), + Picture: data.Value("picture"), + } + return userInfo + }, + } + + jwtService := token.NewService(token.Opts{ + SecretReader: token.SecretFunc(mockKeyStore), SecureCookies: false, TokenDuration: time.Hour, CookieDuration: days31, + ClaimsUpd: token.ClaimsUpdFunc(func(claims token.Claims) token.Claims { + if claims.User != nil { + switch claims.User.ID { + case "mock_myuser2": + claims.User.SetBoolAttr("admin", true) + case "mock_myuser1": + claims.User.Picture = "http://example.com/custom.png" + } + } + return claims + }), + }) + + params := Params{URL: "url", Cid: "aFdj12348sdja", Csecret: "Dwehsq2387akss", JwtService: jwtService, + Issuer: "remark42", AvatarSaver: &mockAvatarSaver{}, L: logger.Std} + + provider = initOauth1Handler(params, provider) + svc := Service{Provider: provider} + + ts := &http.Server{Addr: fmt.Sprintf(":%d", loginPort), Handler: http.HandlerFunc(svc.Handler)} //nolint:gosec + + count := 0 + useIDs := []string{"myuser1", "myuser2"} // user for first ans second calls + + //nolint + var ( + requestToken = "sdjasd09AfdkzztyRadrdR" + requestSecret = "asd34q129sjdklAJJAs" + verifier = "gsjad032ajjjOIU" + accessToken = "g0ZGZmNjVmOWI" + accessSecret = "qfr1239UJAkmpaf3l" + ) + + oauth := &http.Server{ //nolint:gosec + Addr: fmt.Sprintf(":%d", authPort), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("[MOCK OAUTH] request %s %s %+v", r.Method, r.URL, r.Header) + switch { + case strings.HasPrefix(r.URL.Path, "/login/oauth/request_token"): + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + _, err := fmt.Fprintf(w, `oauth_token=%s&oauth_token_secret=%s&oauth_callback_confirmed=true`, requestToken, requestSecret) + if err != nil { + w.WriteHeader(500) + return + } + case strings.HasPrefix(r.URL.Path, "/login/oauth/authorize"): + w.Header().Add("Location", fmt.Sprintf("http://localhost:%d/callback?oauth_token=%s&oauth_verifier=%s", + loginPort, requestToken, verifier)) + w.WriteHeader(302) + case strings.HasPrefix(r.URL.Path, "/login/oauth/access_token"): + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") + _, err := fmt.Fprintf(w, "oauth_token=%s&oauth_token_secret=%s", accessToken, accessSecret) + if err != nil { + w.WriteHeader(500) + return + } + w.WriteHeader(200) + case strings.HasPrefix(r.URL.Path, "/user"): + res := fmt.Sprintf(`{ + "id": "%s", + "name":"blah", + "picture":"http://exmple.com/pic1.png" + }`, useIDs[count]) + count++ + w.Header().Set("Content-Type", "application/json; charset=utf-8") + _, err := w.Write([]byte(res)) + assert.NoError(t, err) + default: + t.Fatalf("unexpected oauth request %s %s", r.Method, r.URL) + } + }), + } + + go func() { _ = oauth.ListenAndServe() }() + go func() { _ = ts.ListenAndServe() }() + + time.Sleep(time.Millisecond * 400) // let them start + + return func() { + assert.NoError(t, ts.Close()) + assert.NoError(t, oauth.Close()) + } +} diff --git a/v2/provider/oauth2.go b/v2/provider/oauth2.go new file mode 100644 index 00000000..6161357e --- /dev/null +++ b/v2/provider/oauth2.go @@ -0,0 +1,254 @@ +package provider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/go-pkgz/rest" + "github.com/golang-jwt/jwt" + "golang.org/x/oauth2" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +// Oauth2Handler implements /login, /callback and /logout handlers from aouth2 flow +type Oauth2Handler struct { + Params + + // all of these fields specific to particular oauth2 provider + name string + infoURL string + endpoint oauth2.Endpoint + scopes []string + mapUser func(UserData, []byte) token.User // map info from InfoURL to User + bearerTokenHook BearerTokenHook // a way to get a Bearer token received from oauth2-provider + conf oauth2.Config +} + +// Params to make initialized and ready to use provider +type Params struct { + logger.L + URL string + JwtService TokenService + Cid string + Csecret string + Issuer string + AvatarSaver AvatarSaver + UserAttributes UserAttributes + + Port int // relevant for providers supporting port customization, for example dev oauth2 + Host string // relevant for providers supporting host customization, for example dev oauth2 +} + +// UserData is type for user information returned from oauth2 providers /info API method +type UserData map[string]interface{} + +// Value returns value for key or empty string if not found +func (u UserData) Value(key string) string { + // json.Unmarshal converts json "null" value to go's "nil", in this case return empty string + if val, ok := u[key]; ok && val != nil { + return fmt.Sprintf("%v", val) + } + return "" +} + +// BearerTokenHook accepts provider name, user and token, received during oauth2 authentication +type BearerTokenHook func(provider string, user token.User, token oauth2.Token) + +// initOauth2Handler makes oauth2 handler for given provider +func initOauth2Handler(p Params, service Oauth2Handler) Oauth2Handler { + if p.L == nil { + p.L = logger.NoOp + } + p.Logf("[INFO] init oauth2 service %s", service.name) + service.Params = p + service.conf = oauth2.Config{ + ClientID: service.Cid, + ClientSecret: service.Csecret, + Scopes: service.scopes, + Endpoint: service.endpoint, + } + + p.Logf("[DEBUG] created %s oauth2, id=%s, redir=%s, endpoint=%s", + service.name, service.Cid, service.makeRedirURL("/{route}/"+service.name+"/"), service.endpoint) + return service +} + +// Name returns provider name +func (p Oauth2Handler) Name() string { return p.name } + +// LoginHandler - GET /login?from=redirect-back-url&[site|aud]=siteID&session=1&noava=1 +func (p Oauth2Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { + + p.Logf("[DEBUG] login with %s", p.Name()) + // make state (random) and store in session + state, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to make oauth2 state") + return + } + + cid, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to make claim's id") + return + } + + aud := r.URL.Query().Get("site") // legacy, for back compat + if aud == "" { + aud = r.URL.Query().Get("aud") + } + + claims := token.Claims{ + Handshake: &token.Handshake{ + State: state, + From: r.URL.Query().Get("from"), + }, + SessionOnly: r.URL.Query().Get("session") != "" && r.URL.Query().Get("session") != "0", + StandardClaims: jwt.StandardClaims{ + Id: cid, + Audience: aud, + ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), + NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + }, + NoAva: r.URL.Query().Get("noava") == "1", + } + + if _, err := p.JwtService.Set(w, claims); err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to set token") + return + } + + // setting RedirectURL to rootURL/routingPath/provider/callback + // e.g. http://localhost:8080/auth/github/callback + p.conf.RedirectURL = p.makeRedirURL(r.URL.Path) + + // return login url + loginURL := p.conf.AuthCodeURL(state) + p.Logf("[DEBUG] login url %s, claims=%+v", loginURL, claims) + + http.Redirect(w, r, loginURL, http.StatusFound) +} + +// AuthHandler fills user info and redirects to "from" url. This is callback url redirected locally by browser +// GET /callback +func (p Oauth2Handler) AuthHandler(w http.ResponseWriter, r *http.Request) { + oauthClaims, _, err := p.JwtService.Get(r) + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to get token") + return + } + + if oauthClaims.Handshake == nil { + rest.SendErrorJSON(w, r, p.L, http.StatusForbidden, nil, "invalid handshake token") + return + } + + retrievedState := oauthClaims.Handshake.State + if retrievedState == "" || retrievedState != r.URL.Query().Get("state") { + rest.SendErrorJSON(w, r, p.L, http.StatusForbidden, nil, "unexpected state") + return + } + + p.conf.RedirectURL = p.makeRedirURL(r.URL.Path) + + p.Logf("[DEBUG] token with state %s", retrievedState) + tok, err := p.conf.Exchange(context.Background(), r.URL.Query().Get("code")) + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "exchange failed") + return + } + + client := p.conf.Client(context.Background(), tok) + uinfo, err := client.Get(p.infoURL) + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusServiceUnavailable, err, "failed to get client info") + return + } + + defer func() { + if e := uinfo.Body.Close(); e != nil { + p.Logf("[WARN] failed to close response body, %s", e) + } + }() + + data, err := io.ReadAll(uinfo.Body) + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to read user info") + return + } + + jData := map[string]interface{}{} + if e := json.Unmarshal(data, &jData); e != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to unmarshal user info") + return + } + p.Logf("[DEBUG] got raw user info %+v", jData) + + u := p.mapUser(jData, data) + if oauthClaims.NoAva { + u.Picture = "" // reset picture on no avatar request + } + u, err = setAvatar(p.AvatarSaver, u, client) + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to save avatar to proxy") + return + } + + cid, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to make claim's id") + return + } + claims := token.Claims{ + User: &u, + StandardClaims: jwt.StandardClaims{ + Issuer: p.Issuer, + Id: cid, + Audience: oauthClaims.Audience, + }, + SessionOnly: oauthClaims.SessionOnly, + NoAva: oauthClaims.NoAva, + } + + if _, err = p.JwtService.Set(w, claims); err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusInternalServerError, err, "failed to set token") + return + } + + if p.bearerTokenHook != nil && tok != nil { + p.Logf("[DEBUG] pass bearer token %s, %s", p.Name(), tok.TokenType) + p.bearerTokenHook(p.Name(), u, *tok) + } + + p.Logf("[DEBUG] user info %+v", u) + + // redirect to back url if presented in login query params + if oauthClaims.Handshake != nil && oauthClaims.Handshake.From != "" { + http.Redirect(w, r, oauthClaims.Handshake.From, http.StatusTemporaryRedirect) + return + } + rest.RenderJSON(w, &u) +} + +// LogoutHandler - GET /logout +func (p Oauth2Handler) LogoutHandler(w http.ResponseWriter, r *http.Request) { + if _, _, err := p.JwtService.Get(r); err != nil { + rest.SendErrorJSON(w, r, p.L, http.StatusForbidden, err, "logout not allowed") + return + } + p.JwtService.Reset(w) +} + +func (p Oauth2Handler) makeRedirURL(path string) string { + elems := strings.Split(path, "/") + newPath := strings.Join(elems[:len(elems)-1], "/") + + return strings.TrimSuffix(p.URL, "/") + strings.TrimSuffix(newPath, "/") + urlCallbackSuffix +} diff --git a/v2/provider/oauth2_test.go b/v2/provider/oauth2_test.go new file mode 100644 index 00000000..26a51a8d --- /dev/null +++ b/v2/provider/oauth2_test.go @@ -0,0 +1,371 @@ +package provider + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/http/cookiejar" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +var testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19fQ.NN7TK-IbzpNgHMtld9-7BDypMGDZdMpwCmUMSfd31Zk" + +var days31 = time.Hour * 24 * 31 + +type rememberLastBearerTokenHook struct { + LastProviderName string `json:"lastProviderName"` + LastUser token.User `json:"lastUser"` + LastToken oauth2.Token `json:"lastToken"` +} + +func (h *rememberLastBearerTokenHook) hook(s string, user token.User, o oauth2.Token) { + h.LastProviderName = s + h.LastUser = user + h.LastToken = o +} + +func TestOauth2Login(t *testing.T) { + + teardown := prepOauth2Test(t, 8981, 8982, nil) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + // check non-admin, permanent + resp, err := client.Get("http://localhost:8981/login?site=remark") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 2678400, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + u := token.User{} + err = json.Unmarshal(body, &u) + assert.NoError(t, err) + assert.Equal(t, token.User{Name: "blah", ID: "mock_myuser1", Picture: "http://example.com/custom.png", IP: ""}, u) + + tk := resp.Cookies()[0].Value + jwtSvc := token.NewService(token.Opts{SecretReader: token.SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31}) + + claims, err := jwtSvc.Parse(tk) + require.NoError(t, err) + t.Log(claims) + assert.Equal(t, "remark42", claims.Issuer) + assert.Equal(t, "remark", claims.Audience) + + // check admin user + resp, err = client.Get("http://localhost:8981/login?site=remark") + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + u = token.User{} + err = json.Unmarshal(body, &u) + assert.NoError(t, err) + assert.Equal(t, token.User{Name: "blah", ID: "mock_myuser2", Picture: "http://example.com/ava12345.png", + Attributes: map[string]interface{}{"admin": true}}, u) +} + +func TestOauth2LoginBearerTokenHook(t *testing.T) { + + btHook := rememberLastBearerTokenHook{} + teardown := prepOauth2Test(t, 8981, 8982, btHook.hook) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + // check non-admin, permanent + resp, err := client.Get("http://localhost:8981/login?site=remark") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + + assert.Equal(t, "mock", btHook.LastProviderName) + assert.Equal(t, "mock_myuser1", btHook.LastUser.ID) + assert.Equal(t, "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3", btHook.LastToken.AccessToken) + + // check admin user + resp, err = client.Get("http://localhost:8981/login?site=remark") + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + assert.Equal(t, "mock", btHook.LastProviderName) + assert.Equal(t, "mock_myuser2", btHook.LastUser.ID) + assert.Equal(t, "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3", btHook.LastToken.AccessToken) +} + +func TestOauth2LoginSessionOnly(t *testing.T) { + + teardown := prepOauth2Test(t, 8981, 8982, nil) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + // check non-admin, session + resp, err := client.Get("http://localhost:8981/login?site=remark&session=1") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.Equal(t, 0, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + req, err := http.NewRequest("GET", "http://example.com", http.NoBody) + require.Nil(t, err) + + req.AddCookie(resp.Cookies()[0]) + req.AddCookie(resp.Cookies()[1]) + req.Header.Add("X-XSRF-TOKEN", resp.Cookies()[1].Value) + + jwtService := token.NewService(token.Opts{SecretReader: token.SecretFunc(mockKeyStore)}) + res, _, err := jwtService.Get(req) + require.Nil(t, err) + assert.Equal(t, true, res.SessionOnly) + t.Logf("%+v", res) +} + +func TestOauth2LoginNoAva(t *testing.T) { + + teardown := prepOauth2Test(t, 8981, 8982, nil) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + // check non-admin, session + resp, err := client.Get("http://localhost:8981/login?site=remark&noava=1") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name) + assert.NotEqual(t, "", resp.Cookies()[0].Value, "token set") + assert.NotEqual(t, 0, resp.Cookies()[0].MaxAge) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name) + assert.NotEqual(t, "", resp.Cookies()[1].Value, "xsrf cookie set") + + req, err := http.NewRequest("GET", "http://example.com", http.NoBody) + require.Nil(t, err) + + req.AddCookie(resp.Cookies()[0]) + req.AddCookie(resp.Cookies()[1]) + req.Header.Add("X-XSRF-TOKEN", resp.Cookies()[1].Value) + + jwtService := token.NewService(token.Opts{SecretReader: token.SecretFunc(mockKeyStore)}) + res, _, err := jwtService.Get(req) + require.Nil(t, err) + assert.Equal(t, "http://example.com/fake.png", res.User.Picture) + assert.Equal(t, true, res.NoAva) + t.Logf("%+v", res) +} + +func TestOauth2Logout(t *testing.T) { + + teardown := prepOauth2Test(t, 8691, 8692, nil) + defer teardown() + + jar, err := cookiejar.New(nil) + require.Nil(t, err) + client := &http.Client{Jar: jar, Timeout: 5 * time.Second} + + req, err := http.NewRequest("GET", "http://localhost:8691/logout", http.NoBody) + require.Nil(t, err) + resp, err := client.Do(req) + require.Nil(t, err) + assert.Equal(t, 403, resp.StatusCode, "user not lagged in") + + req, err = http.NewRequest("GET", "http://localhost:8691/logout", http.NoBody) + require.NoError(t, err) + expiration := int(365 * 24 * time.Hour.Seconds()) //nolint + req.AddCookie(&http.Cookie{Name: "JWT", Value: testJwtValid, HttpOnly: true, Path: "/", MaxAge: expiration, Secure: false}) + req.Header.Add("X-XSRF-TOKEN", "random id") + resp, err = client.Do(req) + require.Nil(t, err) + require.Equal(t, 200, resp.StatusCode) + + assert.Equal(t, 2, len(resp.Cookies())) + assert.Equal(t, "JWT", resp.Cookies()[0].Name, "token cookie cleared") + assert.Equal(t, "", resp.Cookies()[0].Value) + assert.Equal(t, "XSRF-TOKEN", resp.Cookies()[1].Name, "xsrf cookie cleared") + assert.Equal(t, "", resp.Cookies()[1].Value) +} + +func TestOauth2InitProvider(t *testing.T) { + params := Params{URL: "url", Cid: "cid", Csecret: "csecret", Issuer: "app-test"} + provider := Oauth2Handler{name: "test"} + res := initOauth2Handler(params, provider) + assert.Equal(t, "cid", res.conf.ClientID) + assert.Equal(t, "csecret", res.conf.ClientSecret) + assert.Equal(t, "test", res.name) + assert.Equal(t, "app-test", res.Issuer) +} + +func TestOauth2InvalidHandler(t *testing.T) { + teardown := prepOauth2Test(t, 8691, 8692, nil) + defer teardown() + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get("http://localhost:8691/login_bad") + require.Nil(t, err) + assert.Equal(t, 404, resp.StatusCode) + + resp, err = client.Post("http://localhost:8691/login", "", nil) + require.Nil(t, err) + assert.Equal(t, 500, resp.StatusCode) +} + +func TestMakeRedirURL(t *testing.T) { + cases := []struct{ rootURL, route, out string }{ + {"localhost:8080/", "/my/auth/path/google", "localhost:8080/my/auth/path/callback"}, + {"localhost:8080", "/auth/google", "localhost:8080/auth/callback"}, + {"localhost:8080/", "/auth/google", "localhost:8080/auth/callback"}, + {"localhost:8080", "/", "localhost:8080/callback"}, + {"localhost:8080/", "/", "localhost:8080/callback"}, + {"mysite.com", "", "mysite.com/callback"}, + } + + for i := range cases { + c := cases[i] + oh := initOauth2Handler(Params{URL: c.rootURL}, Oauth2Handler{}) + assert.Equal(t, c.out, oh.makeRedirURL(c.route)) + } +} + +func prepOauth2Test(t *testing.T, loginPort, authPort int, btHook BearerTokenHook) func() { + + provider := Oauth2Handler{ + name: "mock", + endpoint: oauth2.Endpoint{ + AuthURL: fmt.Sprintf("http://localhost:%d/login/oauth/authorize", authPort), + TokenURL: fmt.Sprintf("http://localhost:%d/login/oauth/access_token", authPort), + }, + scopes: []string{"user:email"}, + infoURL: fmt.Sprintf("http://localhost:%d/user", authPort), + mapUser: func(data UserData, _ []byte) token.User { + userInfo := token.User{ + ID: "mock_" + data.Value("id"), + Name: data.Value("name"), + Picture: data.Value("picture"), + } + return userInfo + }, + bearerTokenHook: btHook, + } + + jwtService := token.NewService(token.Opts{ + SecretReader: token.SecretFunc(mockKeyStore), SecureCookies: false, TokenDuration: time.Hour, CookieDuration: days31, + ClaimsUpd: token.ClaimsUpdFunc(func(claims token.Claims) token.Claims { + if claims.User != nil { + switch claims.User.ID { + case "mock_myuser2": + claims.User.SetBoolAttr("admin", true) + case "mock_myuser1": + if !claims.NoAva { + claims.User.Picture = "http://example.com/custom.png" + } + } + } + return claims + }), + }) + + params := Params{URL: "url", Cid: "cid", Csecret: "csecret", JwtService: jwtService, + Issuer: "remark42", AvatarSaver: &mockAvatarSaver{}, L: logger.Std} + + provider = initOauth2Handler(params, provider) + svc := Service{Provider: provider} + + ts := &http.Server{Addr: fmt.Sprintf(":%d", loginPort), Handler: http.HandlerFunc(svc.Handler)} //nolint:gosec + + count := 0 + useIDs := []string{"myuser1", "myuser2"} // user for first ans second calls + + oauth := &http.Server{ //nolint:gosec + Addr: fmt.Sprintf(":%d", authPort), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("[MOCK OAUTH] request %s %s %+v", r.Method, r.URL, r.Header) + switch { + case strings.HasPrefix(r.URL.Path, "/login/oauth/authorize"): + state := r.URL.Query().Get("state") + w.Header().Add("Location", fmt.Sprintf("http://localhost:%d/callback?code=g0ZGZmNjVmOWI&state=%s", + loginPort, state)) + w.WriteHeader(302) + case strings.HasPrefix(r.URL.Path, "/login/oauth/access_token"): + res := `{ + "access_token":"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3", + "token_type":"bearer", + "expires_in":3600, + "refresh_token":"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk", + "scope":"create", + "state":"12345678" + }` + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + _, err := w.Write([]byte(res)) + assert.NoError(t, err) + case strings.HasPrefix(r.URL.Path, "/user"): + res := fmt.Sprintf(`{ + "id": "%s", + "name":"blah", + "picture":"http://exmple.com/pic1.png" + }`, useIDs[count]) + count++ + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + _, err := w.Write([]byte(res)) + assert.NoError(t, err) + default: + t.Fatalf("unexpected oauth request %s %s", r.Method, r.URL) + } + }), + } + + go func() { _ = oauth.ListenAndServe() }() + go func() { _ = ts.ListenAndServe() }() + + time.Sleep(time.Millisecond * 100) // let them start + + return func() { + assert.NoError(t, ts.Close()) + assert.NoError(t, oauth.Close()) + } +} + +func mockKeyStore(string) (string, error) { return "12345", nil } + +type mockAvatarSaver struct{} + +func (m *mockAvatarSaver) Put(u token.User, _ *http.Client) (avatarURL string, err error) { + if u.Picture != "" { + return "http://example.com/ava12345.png", nil + } + return "http://example.com/fake.png", nil + +} diff --git a/v2/provider/providers.go b/v2/provider/providers.go new file mode 100644 index 00000000..4487b4cb --- /dev/null +++ b/v2/provider/providers.go @@ -0,0 +1,267 @@ +// Package provider implements all oauth2, oauth1 as well as custom and direct providers +package provider + +import ( + "crypto/sha1" //nolint + "encoding/json" + "fmt" + + "github.com/dghubble/oauth1" + "github.com/dghubble/oauth1/twitter" + "github.com/go-pkgz/auth/token" + "golang.org/x/oauth2" + "golang.org/x/oauth2/facebook" + "golang.org/x/oauth2/github" + "golang.org/x/oauth2/google" + "golang.org/x/oauth2/microsoft" + "golang.org/x/oauth2/yandex" +) + +// UserAttributes is the type that will be used to map user data from provider to token.User +type UserAttributes map[string]string + +// NewGoogle makes google oauth2 provider +func NewGoogle(p Params) Oauth2Handler { + return initOauth2Handler(p, Oauth2Handler{ + name: "google", + endpoint: google.Endpoint, + scopes: []string{"https://www.googleapis.com/auth/userinfo.profile"}, + infoURL: "https://www.googleapis.com/oauth2/v3/userinfo", + mapUser: func(data UserData, _ []byte) token.User { + userInfo := token.User{ + // encode email with provider name to avoid collision if same id returned by other provider + ID: "google_" + token.HashID(sha1.New(), data.Value("sub")), + Name: data.Value("name"), + Picture: data.Value("picture"), + } + if userInfo.Name == "" { + userInfo.Name = "noname_" + userInfo.ID[8:12] + } + for k, v := range p.UserAttributes { + userInfo.SetStrAttr(v, data.Value(k)) + } + return userInfo + }, + }) +} + +// NewGithub makes github oauth2 provider +func NewGithub(p Params) Oauth2Handler { + return initOauth2Handler(p, Oauth2Handler{ + name: "github", + endpoint: github.Endpoint, + scopes: []string{}, + infoURL: "https://api.github.com/user", + mapUser: func(data UserData, _ []byte) token.User { + userInfo := token.User{ + ID: "github_" + token.HashID(sha1.New(), data.Value("login")), + Name: data.Value("name"), + Picture: data.Value("avatar_url"), + } + // github may have no user name, use login in this case + if userInfo.Name == "" { + userInfo.Name = data.Value("login") + } + for k, v := range p.UserAttributes { + userInfo.SetStrAttr(v, data.Value(k)) + } + return userInfo + }, + }) +} + +// NewFacebook makes facebook oauth2 provider +func NewFacebook(p Params) Oauth2Handler { + + // response format for fb /me call + type uinfo struct { + ID string `json:"id"` + Name string `json:"name"` + Picture struct { + Data struct { + URL string `json:"url"` + } `json:"data"` + } `json:"picture"` + } + + return initOauth2Handler(p, Oauth2Handler{ + name: "facebook", + endpoint: facebook.Endpoint, + scopes: []string{"public_profile"}, + infoURL: "https://graph.facebook.com/me?fields=id,name,picture", + mapUser: func(data UserData, bdata []byte) token.User { + userInfo := token.User{ + ID: "facebook_" + token.HashID(sha1.New(), data.Value("id")), + Name: data.Value("name"), + } + if userInfo.Name == "" { + userInfo.Name = userInfo.ID[0:16] + } + + uinfoJSON := uinfo{} + if err := json.Unmarshal(bdata, &uinfoJSON); err == nil { + userInfo.Picture = uinfoJSON.Picture.Data.URL + } + for k, v := range p.UserAttributes { + userInfo.SetStrAttr(v, data.Value(k)) + } + return userInfo + }, + }) +} + +// NewYandex makes yandex oauth2 provider +func NewYandex(p Params) Oauth2Handler { + return initOauth2Handler(p, Oauth2Handler{ + name: "yandex", + endpoint: yandex.Endpoint, + scopes: []string{}, + // See https://tech.yandex.com/passport/doc/dg/reference/response-docpage/ + infoURL: "https://login.yandex.ru/info?format=json", + mapUser: func(data UserData, _ []byte) token.User { + userInfo := token.User{ + ID: "yandex_" + token.HashID(sha1.New(), data.Value("id")), + Name: data.Value("display_name"), // using Display Name by default + } + if userInfo.Name == "" { + userInfo.Name = data.Value("real_name") // using Real Name (== full name) if Display Name is empty + } + if userInfo.Name == "" { + userInfo.Name = data.Value("login") // otherwise using login + } + + if data.Value("default_avatar_id") != "" { + userInfo.Picture = fmt.Sprintf("https://avatars.yandex.net/get-yapic/%s/islands-200", data.Value("default_avatar_id")) + } + for k, v := range p.UserAttributes { + userInfo.SetStrAttr(v, data.Value(k)) + } + return userInfo + }, + }) +} + +// NewTwitter makes twitter oauth2 provider +func NewTwitter(p Params) Oauth1Handler { + return initOauth1Handler(p, Oauth1Handler{ + name: "twitter", + conf: oauth1.Config{ + Endpoint: twitter.AuthorizeEndpoint, + }, + infoURL: "https://api.twitter.com/1.1/account/verify_credentials.json", + mapUser: func(data UserData, _ []byte) token.User { + userInfo := token.User{ + ID: "twitter_" + token.HashID(sha1.New(), data.Value("id_str")), + Name: data.Value("screen_name"), + Picture: data.Value("profile_image_url_https"), + } + if userInfo.Name == "" { + userInfo.Name = data.Value("name") + } + for k, v := range p.UserAttributes { + userInfo.SetStrAttr(v, data.Value(k)) + } + return userInfo + }, + }) +} + +// NewBattlenet makes Battle.net oauth2 provider +func NewBattlenet(p Params) Oauth2Handler { + return initOauth2Handler(p, Oauth2Handler{ + name: "battlenet", + endpoint: oauth2.Endpoint{ + AuthURL: "https://eu.battle.net/oauth/authorize", + TokenURL: "https://eu.battle.net/oauth/token", + AuthStyle: oauth2.AuthStyleInParams, + }, + scopes: []string{}, + infoURL: "https://eu.battle.net/oauth/userinfo", + mapUser: func(data UserData, _ []byte) token.User { + userInfo := token.User{ + ID: "battlenet_" + token.HashID(sha1.New(), data.Value("id")), + Name: data.Value("battletag"), + } + for k, v := range p.UserAttributes { + userInfo.SetStrAttr(v, data.Value(k)) + } + return userInfo + }, + }) +} + +// NewMicrosoft makes microsoft azure oauth2 provider +func NewMicrosoft(p Params) Oauth2Handler { + return initOauth2Handler(p, Oauth2Handler{ + name: "microsoft", + endpoint: microsoft.AzureADEndpoint("common"), + scopes: []string{"User.Read"}, + infoURL: "https://graph.microsoft.com/v1.0/me", + // non-beta doesn't provide photo for consumers yet + // see https://github.com/microsoftgraph/microsoft-graph-docs/issues/3990 + mapUser: func(data UserData, _ []byte) token.User { + userInfo := token.User{ + ID: "microsoft_" + token.HashID(sha1.New(), data.Value("id")), + Name: data.Value("displayName"), + Picture: "https://graph.microsoft.com/beta/me/photo/$value", + } + for k, v := range p.UserAttributes { + userInfo.SetStrAttr(v, data.Value(k)) + } + return userInfo + }, + }) +} + +// NewPatreon makes patreon oauth2 provider +func NewPatreon(p Params) Oauth2Handler { + type uinfo struct { + Data struct { + Attributes struct { + FullName string `json:"full_name"` + ImageURL string `json:"image_url"` + } `json:"attributes"` + ID string `json:"id"` + Relationships struct { + Pledges struct { + Data []struct { + ID string `json:"id"` + Type string `json:"type"` + } `json:"data"` + } `json:"pledges"` + } `json:"relationships"` + } `json:"data"` + } + + return initOauth2Handler(p, Oauth2Handler{ + name: "patreon", + // see https://docs.patreon.com/?shell#oauth + endpoint: oauth2.Endpoint{ + AuthURL: "https://www.patreon.com/oauth2/authorize", + TokenURL: "https://api.patreon.com/oauth2/token", + AuthStyle: oauth2.AuthStyleInParams, + }, + scopes: []string{}, + infoURL: "https://www.patreon.com/api/oauth2/api/current_user", + + mapUser: func(data UserData, bdata []byte) token.User { + userInfo := token.User{} + + uinfoJSON := uinfo{} + if err := json.Unmarshal(bdata, &uinfoJSON); err == nil { + userInfo.ID = "patreon_" + token.HashID(sha1.New(), userInfo.ID) + userInfo.Name = uinfoJSON.Data.Attributes.FullName + userInfo.Picture = uinfoJSON.Data.Attributes.ImageURL + + // check if the user is your subscriber + if len(uinfoJSON.Data.Relationships.Pledges.Data) > 0 { + userInfo.SetPaidSub(true) + } + } + for k, v := range p.UserAttributes { + userInfo.SetStrAttr(v, data.Value(k)) + } + return userInfo + }, + }) +} diff --git a/v2/provider/providers_test.go b/v2/provider/providers_test.go new file mode 100644 index 00000000..88ea28ce --- /dev/null +++ b/v2/provider/providers_test.go @@ -0,0 +1,210 @@ +package provider + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/go-pkgz/auth/token" +) + +func TestProviders_NewGoogle(t *testing.T) { + r := NewGoogle(Params{URL: "http://demo.remark42.com", Cid: "cid", Csecret: "cs"}) + assert.Equal(t, "google", r.Name()) + + t.Run("with all data", func(t *testing.T) { + udata := UserData{"sub": "1234567890", "name": "test user", "picture": "http://demo.remark42.com/blah.png"} + user := r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "test user", ID: "google_01b307acba4f54f55aafc33bb06bbbf6ca803e9a", + Picture: "http://demo.remark42.com/blah.png", IP: ""}, user, "got %+v", user) + }) + + t.Run("with no name", func(t *testing.T) { + udata := UserData{"sub": "1234567890", "picture": "http://demo.remark42.com/blah.png"} + user := r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "noname_1b30", ID: "google_01b307acba4f54f55aafc33bb06bbbf6ca803e9a", + Picture: "http://demo.remark42.com/blah.png", IP: ""}, user, "got %+v", user) + }) + + t.Run("with extra scopes", func(t *testing.T) { + r := NewGoogle(Params{URL: "http://demo.remark42.com", Cid: "cid", Csecret: "cs", + UserAttributes: map[string]string{"email": "email"}}) + assert.Equal(t, "google", r.Name()) + udata := UserData{"sub": "1234567890", "name": "test user", "picture": "http://demo.remark42.com/blah.png", + "email": "test@email.com"} + user := r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "test user", ID: "google_01b307acba4f54f55aafc33bb06bbbf6ca803e9a", + Picture: "http://demo.remark42.com/blah.png", IP: "", Attributes: map[string]interface{}{"email": "test@email.com"}}, user, "got %+v", user) + }) +} + +func TestProviders_NewGithub(t *testing.T) { + r := NewGithub(Params{URL: "http://demo.remark42.com", Cid: "cid", Csecret: "cs"}) + assert.Equal(t, "github", r.Name()) + + t.Run("with all data", func(t *testing.T) { + udata := UserData{"login": "lll", "name": "test user", "avatar_url": "http://demo.remark42.com/blah.png"} + user := r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "test user", ID: "github_e80b2d2608711cbb3312db7c4727a46fbad9601a", + Picture: "http://demo.remark42.com/blah.png", IP: ""}, user, "got %+v", user) + }) + + t.Run("with no name", func(t *testing.T) { + // nil name in data (json response contains `"name": null`); using login, it's always required + udata := UserData{"login": "lll", "name": nil, "avatar_url": "http://demo.remark42.com/blah.png"} + user := r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "lll", ID: "github_e80b2d2608711cbb3312db7c4727a46fbad9601a", + Picture: "http://demo.remark42.com/blah.png", IP: ""}, user, "got %+v", user) + }) + + t.Run("with extra scopes", func(t *testing.T) { + r := NewGithub(Params{URL: "http://demo.remark42.com", Cid: "cid", Csecret: "cs", + UserAttributes: map[string]string{"email": "email"}}) + assert.Equal(t, "github", r.Name()) + udata := UserData{"login": "lll", "name": "test user", "avatar_url": "http://demo.remark42.com/blah.png", + "email": "test@email.com"} + user := r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "test user", ID: "github_e80b2d2608711cbb3312db7c4727a46fbad9601a", + Picture: "http://demo.remark42.com/blah.png", IP: "", Attributes: map[string]interface{}{"email": "test@email.com"}}, user, "got %+v", user) + }) +} + +func TestProviders_NewFacebook(t *testing.T) { + r := NewFacebook(Params{URL: "http://demo.remark42.com", Cid: "cid", Csecret: "cs"}) + assert.Equal(t, "facebook", r.Name()) + + t.Run("with all data", func(t *testing.T) { + udata := UserData{"id": "myid", "name": "test user"} + user := r.mapUser(udata, []byte(`{"picture": {"data": {"url": "http://demo.remark42.com/blah.png"} }}`)) + assert.Equal(t, token.User{Name: "test user", ID: "facebook_6e34471f84557e1713012d64a7477c71bfdac631", + Picture: "http://demo.remark42.com/blah.png", IP: ""}, user, "got %+v", user) + }) + + t.Run("with no name", func(t *testing.T) { + udata := UserData{"id": "myid", "name": ""} + user := r.mapUser(udata, []byte(`{"picture": {"data": {"url": "http://demo.remark42.com/blah.png"} }}`)) + assert.Equal(t, token.User{Name: "facebook_6e34471", ID: "facebook_6e34471f84557e1713012d64a7477c71bfdac631", + Picture: "http://demo.remark42.com/blah.png", IP: ""}, user, "got %+v", user) + }) + + t.Run("with extra scopes", func(t *testing.T) { + r := NewFacebook(Params{URL: "http://demo.remark42.com", Cid: "cid", Csecret: "cs", + UserAttributes: map[string]string{"email": "email"}}) + assert.Equal(t, "facebook", r.Name()) + udata := UserData{"id": "myid", "name": "test user", "email": "test@email.com"} + user := r.mapUser(udata, []byte(`{"picture": {"data": {"url": "http://demo.remark42.com/blah.png"} }}`)) + assert.Equal(t, token.User{Name: "test user", ID: "facebook_6e34471f84557e1713012d64a7477c71bfdac631", + Picture: "http://demo.remark42.com/blah.png", IP: "", Attributes: map[string]interface{}{"email": "test@email.com"}}, + user, "got %+v", user) + + }) +} + +func TestProviders_NewYandex(t *testing.T) { + r := NewYandex(Params{URL: "http://demo.remark42.com", Cid: "cid", Csecret: "cs"}) + assert.Equal(t, "yandex", r.Name()) + + udata := UserData{"id": "1234567890", "display_name": "Vasya P", "default_avatar_id": "131652443"} + user := r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "Vasya P", ID: "yandex_01b307acba4f54f55aafc33bb06bbbf6ca803e9a", + Picture: "https://avatars.yandex.net/get-yapic/131652443/islands-200", IP: ""}, user, "got %+v", user) + + // "display_name": null, "default_avatar_id": null + udata = UserData{"id": "1234567890", "login": "vasya", "display_name": nil, "real_name": "Vasya Pupkin", "default_avatar_id": nil} + user = r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "Vasya Pupkin", ID: "yandex_01b307acba4f54f55aafc33bb06bbbf6ca803e9a", + Picture: "", IP: ""}, user, "got %+v", user) + + // empty "display_name", empty "default_avatar_id", empty "real_name" + udata = UserData{"id": "1234567890", "login": "vasya", "display_name": "", "real_name": "", "default_avatar_id": ""} + user = r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "vasya", ID: "yandex_01b307acba4f54f55aafc33bb06bbbf6ca803e9a", + Picture: "", IP: ""}, user, "got %+v", user) + + // "real_name": null + udata = UserData{"id": "1234567890", "login": "vasya", "real_name": nil, "default_avatar_id": ""} + user = r.mapUser(udata, nil) + assert.Equal(t, token.User{Name: "vasya", ID: "yandex_01b307acba4f54f55aafc33bb06bbbf6ca803e9a", + Picture: "", IP: ""}, user, "got %+v", user) +} + +func TestProviders_NewTwitter(t *testing.T) { + r := NewTwitter(Params{URL: "http://demo.remark42.com", Cid: "cid", Csecret: "cs"}) + assert.Equal(t, "twitter", r.Name()) + + cases := []struct { + udata UserData + uopts []byte + expected token.User + }{ + {udata: UserData{"id_str": "myid", "name": "test user", "profile_image_url_https": "https://demo.remark42.com/blah.png"}, + uopts: []byte(``), + expected: token.User{Name: "test user", ID: "twitter_6e34471f84557e1713012d64a7477c71bfdac631", + Picture: "https://demo.remark42.com/blah.png", IP: ""}, + }, + {udata: UserData{"id_str": "124381237", "screen_name": "Bob", "name": "Robert Downey Jr.", "profile_image_url_https": ""}, + uopts: []byte(``), + expected: token.User{Name: "Bob", ID: "twitter_63a6b20b6e17fb5e17f6c58b6223e3b760ad510e", + Picture: "", IP: ""}, + }, + {udata: UserData{"id_str": "124381237", "name": "Robert Downey Jr.", "profile_image_url_https": "https://demo.remark42.com/blah.png"}, + uopts: []byte(``), + expected: token.User{Name: "Robert Downey Jr.", ID: "twitter_63a6b20b6e17fb5e17f6c58b6223e3b760ad510e", + Picture: "https://demo.remark42.com/blah.png", IP: ""}, + }, + } + + for i := range cases { + c := cases[i] + got := r.mapUser(c.udata, c.uopts) + assert.Equal(t, c.expected, got, "got %+v", got) + } + +} + +func TestProviders_NewPatreon(t *testing.T) { + r := NewPatreon(Params{URL: "http://demo.remark42.com", Cid: "cid", Csecret: "cs"}) + assert.Equal(t, "patreon", r.Name()) + + udata := UserData{} + user := r.mapUser(udata, []byte(`{ + "data": { + "attributes": { + "email": "corgi@example.com", + "full_name": "Corgi The Dev", + "image_url": "https://c8.patreon.com/2/400/0000000" + }, + "id": "0000000" + }}`)) + assert.Equal(t, token.User{Name: "Corgi The Dev", ID: "patreon_da39a3ee5e6b4b0d3255bfef95601890afd80709", + Picture: "https://c8.patreon.com/2/400/0000000", IP: ""}, user, "got %+v", user) + + udata = UserData{} + user = r.mapUser(udata, []byte(`{ + "data": { + "attributes": { + "email": "corgi@example.com", + "full_name": "Corgi The Dev", + "image_url": "https://c8.patreon.com/2/400/0000000" + }, + "id": "0000000", + "relationships": { + "pledges": { + "data": [ + { + "id": "0000000", + "type": "pledge" + } + ] + } + } + }}`)) + assert.Equal( + t, + token.User{Name: "Corgi The Dev", ID: "patreon_da39a3ee5e6b4b0d3255bfef95601890afd80709", + Picture: "https://c8.patreon.com/2/400/0000000", IP: "", Attributes: map[string]interface{}{"is_paid_sub": true}}, + user, + "got %+v", + user, + ) +} diff --git a/v2/provider/sender/email.go b/v2/provider/sender/email.go new file mode 100644 index 00000000..2c773702 --- /dev/null +++ b/v2/provider/sender/email.go @@ -0,0 +1,90 @@ +// Package sender provides email sender +package sender + +import ( + "time" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/email" +) + +// Email implements sender interface for VerifyHandler +// Uses common subject line and "from" for all messages +type Email struct { + EmailParams + logger.L + sender *email.Sender +} + +// EmailParams with all needed to make new Email client with smtp +type EmailParams struct { + Host string // SMTP host + Port int // SMTP port + From string // From email field + Subject string // Email subject + ContentType string // Content type + + TLS bool // TLS auth + StartTLS bool // StartTLS auth + InsecureSkipVerify bool // Skip certificate verification + Charset string // Character set + LoginAuth bool // LOGIN auth method instead of default PLAIN, needed for Office 365 and outlook.com + SMTPUserName string // username + SMTPPassword string // password + TimeOut time.Duration // TCP connection timeout +} + +// NewEmailClient creates email client +func NewEmailClient(emailParams EmailParams, l logger.L) *Email { + var opts []email.Option + + if emailParams.SMTPUserName != "" { + opts = append(opts, email.Auth(emailParams.SMTPUserName, emailParams.SMTPPassword)) + } + + if emailParams.ContentType != "" { + opts = append(opts, email.ContentType(emailParams.ContentType)) + } + + if emailParams.Charset != "" { + opts = append(opts, email.Charset(emailParams.Charset)) + } + + if emailParams.LoginAuth { + opts = append(opts, email.LoginAuth()) + } + + if emailParams.Port != 0 { + opts = append(opts, email.Port(emailParams.Port)) + } + + if emailParams.TimeOut != 0 { + opts = append(opts, email.TimeOut(emailParams.TimeOut)) + } + + if emailParams.TLS { + opts = append(opts, email.TLS(true)) + } + + if emailParams.StartTLS { + opts = append(opts, email.STARTTLS(true)) + } + + if emailParams.InsecureSkipVerify { + opts = append(opts, email.InsecureSkipVerify(true)) + } + + sender := email.NewSender(emailParams.Host, opts...) + + return &Email{EmailParams: emailParams, L: l, sender: sender} +} + +// Send email with given text +func (e *Email) Send(to, text string) error { + e.Logf("[DEBUG] send %q to %s", text, to) + return e.sender.Send(text, email.Params{ + From: e.From, + To: []string{to}, + Subject: e.Subject, + }) +} diff --git a/v2/provider/sender/email_test.go b/v2/provider/sender/email_test.go new file mode 100644 index 00000000..1ef73331 --- /dev/null +++ b/v2/provider/sender/email_test.go @@ -0,0 +1,70 @@ +package sender + +import ( + "os" + "testing" + "time" + + "github.com/go-pkgz/auth/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEmailSend(t *testing.T) { + if _, ok := os.LookupEnv("SEND_EMAIL_TEST"); !ok { + t.Skip() + } + p := EmailParams{ + From: "test@umputun.com", + ContentType: "text/html", + Host: "192.168.1.24", + Port: 25, + Subject: "test email", + } + client := NewEmailClient(p, logger.Std) + + msg := ` + + + +

rest

+
xyz
+ + +` + err := client.Send("sys@umputun.dev", msg) + assert.NoError(t, err) +} + +func TestEmail_New(t *testing.T) { + p := EmailParams{ + Host: "127.0.0.2", + From: "from@example.com", + SMTPUserName: "user", + SMTPPassword: "pass", + Subject: "subj", + ContentType: "text/html", + Charset: "UTF-8", + LoginAuth: true, + StartTLS: true, + TLS: true, + InsecureSkipVerify: true, + } + e := NewEmailClient(p, logger.Std) + assert.Equal(t, p, e.EmailParams) +} + +func TestEmail_SendFailed(t *testing.T) { + p := EmailParams{Host: "127.0.0.2", Port: 25, From: "from@example.com", + Subject: "subj", ContentType: "text/html", TimeOut: time.Millisecond * 200} + e := NewEmailClient(p, logger.Std) + assert.Equal(t, p, e.EmailParams) + err := e.Send("to@example.com", "some text") + require.NotNil(t, err, "failed to make smtp client") + + p = EmailParams{Host: "127.0.0.1", Port: 225, From: "from@example.com", Subject: "subj", ContentType: "text/html", + TLS: true} + e = NewEmailClient(p, logger.Std) + err = e.Send("to@example.com", "some text") + require.NotNil(t, err) +} diff --git a/v2/provider/service.go b/v2/provider/service.go new file mode 100644 index 00000000..953d4167 --- /dev/null +++ b/v2/provider/service.go @@ -0,0 +1,95 @@ +package provider + +import ( + "crypto/rand" + "crypto/sha1" + "fmt" + "net/http" + "strings" + + "github.com/go-pkgz/auth/token" +) + +const ( + urlLoginSuffix = "/login" + urlCallbackSuffix = "/callback" + urlLogoutSuffix = "/logout" +) + +// Service represents oauth2 provider. Adds Handler method multiplexing login, auth and logout requests +type Service struct { + Provider +} + +// NewService makes service for given provider +func NewService(p Provider) Service { + return Service{Provider: p} +} + +// AvatarSaver defines minimal interface to save avatar +type AvatarSaver interface { + Put(u token.User, client *http.Client) (avatarURL string, err error) +} + +// TokenService defines interface accessing tokens +type TokenService interface { + Parse(tokenString string) (claims token.Claims, err error) + Set(w http.ResponseWriter, claims token.Claims) (token.Claims, error) + Get(r *http.Request) (claims token.Claims, token string, err error) + Reset(w http.ResponseWriter) +} + +// Provider defines interface for auth handler +type Provider interface { + Name() string + LoginHandler(w http.ResponseWriter, r *http.Request) + AuthHandler(w http.ResponseWriter, r *http.Request) + LogoutHandler(w http.ResponseWriter, r *http.Request) +} + +// Handler returns auth routes for given provider +func (p Service) Handler(w http.ResponseWriter, r *http.Request) { + + if r.Method != http.MethodGet && r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if strings.HasSuffix(r.URL.Path, urlLoginSuffix) { + p.LoginHandler(w, r) + return + } + if strings.HasSuffix(r.URL.Path, urlCallbackSuffix) { + p.AuthHandler(w, r) + return + } + if strings.HasSuffix(r.URL.Path, urlLogoutSuffix) { + p.LogoutHandler(w, r) + return + } + w.WriteHeader(http.StatusNotFound) +} + +// setAvatar saves avatar and puts proxied URL to u.Picture +func setAvatar(ava AvatarSaver, u token.User, client *http.Client) (token.User, error) { + if ava != nil { + avatarURL, e := ava.Put(u, client) + if e != nil { + return u, fmt.Errorf("failed to save avatar for: %w", e) + } + u.Picture = avatarURL + return u, nil + } + return u, nil // empty AvatarSaver ok, just skipped +} + +func randToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("can't get random: %w", err) + } + s := sha1.New() + if _, err := s.Write(b); err != nil { + return "", fmt.Errorf("can't write randoms to sha1: %w", err) + } + return fmt.Sprintf("%x", s.Sum(nil)), nil +} diff --git a/v2/provider/service_test.go b/v2/provider/service_test.go new file mode 100644 index 00000000..7ad14d72 --- /dev/null +++ b/v2/provider/service_test.go @@ -0,0 +1,98 @@ +package provider + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/token" +) + +func TestHandler(t *testing.T) { + + tbl := []struct { + method string + url string + code int + resp string + }{ + {"GET", "/login", 200, "login"}, + {"POST", "/login", 200, "login"}, + {"GET", "/callback", 200, "callback"}, + {"GET", "/logout", 200, "logout"}, + {"GET", "/blah", 404, ""}, + {"PUT", "/login", 405, ""}, + } + svc := NewService(&mockHandler{}) + handler := http.HandlerFunc(svc.Handler) + + for n, tt := range tbl { + tt := tt + t.Run(fmt.Sprintf("check-%d", n), func(t *testing.T) { + rr := httptest.NewRecorder() + req, err := http.NewRequest(tt.method, tt.url, http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, tt.code, rr.Code) + assert.Equal(t, tt.resp, rr.Body.String()) + }) + } + +} + +func TestRandToken(t *testing.T) { + s1, err := randToken() + assert.NoError(t, err) + assert.NotEqual(t, "", s1) + t.Log(s1) + + s2, err := randToken() + assert.NoError(t, err) + assert.NotEqual(t, "", s2) + assert.NotEqual(t, s2, s1) + t.Log(s2) +} + +func TestSetAvatar(t *testing.T) { + client := &http.Client{Timeout: time.Second} + u, err := setAvatar(nil, token.User{Picture: "http://example.com/pic1.png"}, client) + assert.NoError(t, err, "nil ava allowed") + assert.Equal(t, token.User{Picture: "http://example.com/pic1.png"}, u) + + u, err = setAvatar(mockAva{true, "http://example.com/pic1px.png"}, token.User{Picture: "http://example.com/pic1.png"}, client) + assert.NoError(t, err) + assert.Equal(t, token.User{Picture: "http://example.com/pic1px.png"}, u) + + _, err = setAvatar(mockAva{false, ""}, token.User{Picture: "http://example.com/pic1.png"}, client) + assert.Error(t, err, "some error") +} + +type mockAva struct { + ok bool + res string +} + +func (m mockAva) Put(token.User, *http.Client) (avatarURL string, err error) { + if !m.ok { + return "", fmt.Errorf("some error") + } + return m.res, nil +} + +type mockHandler struct{} + +func (n *mockHandler) Name() string { return "mock-handler" } +func (n *mockHandler) LoginHandler(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("login")) +} +func (n *mockHandler) AuthHandler(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("callback")) +} +func (n *mockHandler) LogoutHandler(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("logout")) +} diff --git a/v2/provider/telegram.go b/v2/provider/telegram.go new file mode 100644 index 00000000..177c1c02 --- /dev/null +++ b/v2/provider/telegram.go @@ -0,0 +1,484 @@ +package provider + +//go:generate moq --out telegram_moq_test.go . TelegramAPI + +import ( + "context" + "crypto/sha1" + "encoding/json" + "fmt" + "io" + "net/http" + neturl "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/go-pkgz/repeater" + "github.com/go-pkgz/rest" + "github.com/golang-jwt/jwt" + + "github.com/go-pkgz/auth/logger" + authtoken "github.com/go-pkgz/auth/token" +) + +// TelegramHandler implements login via telegram +type TelegramHandler struct { + logger.L + + ProviderName string + ErrorMsg, SuccessMsg string + + TokenService TokenService + AvatarSaver AvatarSaver + Telegram TelegramAPI + + run int32 // non-zero if Run goroutine has started + username string // bot username + requests struct { + sync.RWMutex + data map[string]tgAuthRequest + } +} + +type tgAuthRequest struct { + confirmed bool // whether login request has been confirmed and user info set + expires time.Time + user *authtoken.User +} + +// TelegramAPI is used for interacting with telegram API +type TelegramAPI interface { + GetUpdates(ctx context.Context) (*telegramUpdate, error) + Avatar(ctx context.Context, userID int) (string, error) + Send(ctx context.Context, id int, text string) error + BotInfo(ctx context.Context) (*botInfo, error) +} + +// changed in tests +var apiPollInterval = time.Second * 5 // interval to check updates from Telegram API and answer to users +var expiredCleanupInterval = time.Minute * 5 // interval to check and clean up expired notification requests + +// Run starts processing login requests sent in Telegram +// Blocks caller +func (th *TelegramHandler) Run(ctx context.Context) error { + // Initialization + atomic.AddInt32(&th.run, 1) + info, err := th.Telegram.BotInfo(ctx) + if err != nil { + return fmt.Errorf("failed to fetch bot info: %w", err) + } + th.username = info.Username + + th.requests.Lock() + th.requests.data = make(map[string]tgAuthRequest) + th.requests.Unlock() + + processUpdatedTicker := time.NewTicker(apiPollInterval) + cleanupTicker := time.NewTicker(expiredCleanupInterval) + + for { + select { + case <-ctx.Done(): + processUpdatedTicker.Stop() + cleanupTicker.Stop() + atomic.AddInt32(&th.run, -1) + return ctx.Err() + case <-processUpdatedTicker.C: + updates, err := th.Telegram.GetUpdates(ctx) + if err != nil { + th.Logf("Error while getting telegram updates: %v", err) + continue + } + th.processUpdates(ctx, updates) + case <-cleanupTicker.C: + now := time.Now() + th.requests.Lock() + for key, req := range th.requests.data { + if now.After(req.expires) { + delete(th.requests.data, key) + } + } + th.requests.Unlock() + } + } +} + +// telegramUpdate contains update information, which is used from whole telegram API response +type telegramUpdate struct { + Result []struct { + UpdateID int `json:"update_id"` + Message struct { + Chat struct { + ID int `json:"id"` + Name string `json:"first_name"` + Type string `json:"type"` + } `json:"chat"` + Text string `json:"text"` + } `json:"message"` + } `json:"result"` +} + +// ProcessUpdate is alternative to Run, it processes provided plain text update from Telegram +// so that caller could get updates and send it not only there but to multiple sources +func (th *TelegramHandler) ProcessUpdate(ctx context.Context, textUpdate string) error { + if atomic.LoadInt32(&th.run) != 0 { + return fmt.Errorf("Run goroutine should not be used with ProcessUpdate") + } + defer func() { + // as Run goroutine is not running, clean up old requests on each update + // even if we hit json decode error + now := time.Now() + th.requests.Lock() + for key, req := range th.requests.data { + if now.After(req.expires) { + delete(th.requests.data, key) + } + } + th.requests.Unlock() + }() + // initialize requests.data as usually it's initialized in Run + th.requests.Lock() + if th.requests.data == nil { + th.requests.data = make(map[string]tgAuthRequest) + } + th.requests.Unlock() + var updates telegramUpdate + if err := json.Unmarshal([]byte(textUpdate), &updates); err != nil { + return fmt.Errorf("failed to decode provided telegram update: %w", err) + } + th.processUpdates(ctx, &updates) + return nil +} + +// processUpdates processes a batch of updates from telegram servers +// Returns offset for subsequent calls +func (th *TelegramHandler) processUpdates(ctx context.Context, updates *telegramUpdate) { + for _, update := range updates.Result { + if update.Message.Chat.Type != "private" { + continue + } + + if !strings.HasPrefix(update.Message.Text, "/start ") { + continue + } + + token := strings.TrimPrefix(update.Message.Text, "/start ") + + th.requests.RLock() + authRequest, ok := th.requests.data[token] + if !ok { // No such token + th.requests.RUnlock() + err := th.Telegram.Send(ctx, update.Message.Chat.ID, th.ErrorMsg) + if err != nil { + th.Logf("failed to notify telegram peer: %v", err) + } + continue + } + th.requests.RUnlock() + + avatarURL, err := th.Telegram.Avatar(ctx, update.Message.Chat.ID) + if err != nil { + th.Logf("failed to get user avatar: %v", err) + continue + } + + id := th.ProviderName + "_" + authtoken.HashID(sha1.New(), fmt.Sprint(update.Message.Chat.ID)) + + authRequest.confirmed = true + authRequest.user = &authtoken.User{ + ID: id, + Name: update.Message.Chat.Name, + Picture: avatarURL, + } + + th.requests.Lock() + th.requests.data[token] = authRequest + th.requests.Unlock() + + err = th.Telegram.Send(ctx, update.Message.Chat.ID, th.SuccessMsg) + if err != nil { + th.Logf("failed to notify telegram peer: %v", err) + } + } +} + +// addToken adds token +func (th *TelegramHandler) addToken(token string, expires time.Time) error { + th.requests.Lock() + if th.requests.data == nil { + th.requests.Unlock() + return fmt.Errorf("run goroutine is not running") + } + th.requests.data[token] = tgAuthRequest{ + expires: expires, + } + th.requests.Unlock() + return nil +} + +// checkToken verifies incoming token, returns the user address if it's confirmed and empty string otherwise +func (th *TelegramHandler) checkToken(token string) (*authtoken.User, error) { + th.requests.RLock() + authRequest, ok := th.requests.data[token] + th.requests.RUnlock() + + if !ok { + return nil, fmt.Errorf("request is not found") + } + + if time.Now().After(authRequest.expires) { + th.requests.Lock() + delete(th.requests.data, token) + th.requests.Unlock() + return nil, fmt.Errorf("request expired") + } + + if !authRequest.confirmed { + return nil, fmt.Errorf("request is not verified yet") + } + + return authRequest.user, nil +} + +// Name of the provider +func (th *TelegramHandler) Name() string { return th.ProviderName } + +// String representation of the provider +func (th *TelegramHandler) String() string { return th.Name() } + +// Default token lifetime. Changed in tests +var tgAuthRequestLifetime = time.Minute * 10 + +// LoginHandler generates and verifies login requests +func (th *TelegramHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { + queryToken := r.URL.Query().Get("token") + if queryToken == "" { + // GET /login (No token supplied) + // Generate and send token + token, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, th.L, http.StatusInternalServerError, err, "failed to generate code") + return + } + + err = th.addToken(token, time.Now().Add(tgAuthRequestLifetime)) + if err != nil { + rest.SendErrorJSON(w, r, th.L, http.StatusInternalServerError, err, "failed to process login request") + return + } + + // verify that we have a username, which is not set if Run was not used + if th.username == "" { + info, err := th.Telegram.BotInfo(r.Context()) + if err != nil { + rest.SendErrorJSON(w, r, th.L, http.StatusInternalServerError, err, "failed to fetch bot username") + return + } + th.username = info.Username + } + + rest.RenderJSON(w, struct { + Token string `json:"token"` + Bot string `json:"bot"` + }{token, th.username}) + + return + } + + // GET /login?token=blah + authUser, err := th.checkToken(queryToken) + if err != nil { + rest.SendErrorJSON(w, r, nil, http.StatusNotFound, err, err.Error()) + return + } + + u, err := setAvatar(th.AvatarSaver, *authUser, &http.Client{Timeout: 5 * time.Second}) + if err != nil { + rest.SendErrorJSON(w, r, th.L, http.StatusInternalServerError, err, "failed to save avatar to proxy") + return + } + + claims := authtoken.Claims{ + User: &u, + StandardClaims: jwt.StandardClaims{ + Audience: r.URL.Query().Get("site"), + Id: queryToken, + Issuer: th.ProviderName, + ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), + NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + }, + SessionOnly: false, // TODO review? + } + + if _, err := th.TokenService.Set(w, claims); err != nil { + rest.SendErrorJSON(w, r, th.L, http.StatusInternalServerError, err, "failed to set token") + return + } + + rest.RenderJSON(w, claims.User) + + // Delete request + th.requests.Lock() + defer th.requests.Unlock() + delete(th.requests.data, queryToken) +} + +// AuthHandler does nothing since we don't have any callbacks +func (th *TelegramHandler) AuthHandler(_ http.ResponseWriter, _ *http.Request) {} + +// LogoutHandler - GET /logout +func (th *TelegramHandler) LogoutHandler(w http.ResponseWriter, _ *http.Request) { + th.TokenService.Reset(w) +} + +// tgAPI implements TelegramAPI +type tgAPI struct { + logger.L + token string + client *http.Client + + // Identifier of the first update to be requested. + // Should be equal to LastSeenUpdateID + 1 + // See https://core.telegram.org/bots/api#getupdates + updateOffset int +} + +// NewTelegramAPI returns initialized TelegramAPI implementation +func NewTelegramAPI(token string, client *http.Client) TelegramAPI { + return &tgAPI{ + client: client, + token: token, + } +} + +// GetUpdates fetches incoming updates +func (tg *tgAPI) GetUpdates(ctx context.Context) (*telegramUpdate, error) { + url := `getUpdates?allowed_updates=["message"]` + if tg.updateOffset != 0 { + url += fmt.Sprintf("&offset=%d", tg.updateOffset) + } + + var result telegramUpdate + + err := tg.request(ctx, url, &result) + if err != nil { + return nil, fmt.Errorf("failed to fetch updates: %w", err) + } + + for _, u := range result.Result { + if u.UpdateID >= tg.updateOffset { + tg.updateOffset = u.UpdateID + 1 + } + } + + return &result, err +} + +// Send sends a message to telegram peer +func (tg *tgAPI) Send(ctx context.Context, id int, msg string) error { + url := fmt.Sprintf("sendMessage?chat_id=%d&text=%s", id, neturl.PathEscape(msg)) + return tg.request(ctx, url, &struct{}{}) +} + +// Avatar returns URL to user avatar +func (tg *tgAPI) Avatar(ctx context.Context, id int) (string, error) { + // Get profile pictures + url := fmt.Sprintf(`getUserProfilePhotos?user_id=%d`, id) + + var profilePhotos = struct { + Result struct { + Photos [][]struct { + ID string `json:"file_id"` + } `json:"photos"` + } `json:"result"` + }{} + + if err := tg.request(ctx, url, &profilePhotos); err != nil { + return "", err + } + + // User does not have profile picture set or it is hidden in privacy settings + if len(profilePhotos.Result.Photos) == 0 || len(profilePhotos.Result.Photos[0]) == 0 { + return "", nil + } + + // Get max possible picture size + last := len(profilePhotos.Result.Photos[0]) - 1 + fileID := profilePhotos.Result.Photos[0][last].ID + url = fmt.Sprintf(`getFile?file_id=%s`, fileID) + + var fileMetadata = struct { + Result struct { + Path string `json:"file_path"` + } `json:"result"` + }{} + + if err := tg.request(ctx, url, &fileMetadata); err != nil { + return "", err + } + + avatarURL := fmt.Sprintf("https://api.telegram.org/file/bot%s/%s", tg.token, fileMetadata.Result.Path) + + return avatarURL, nil +} + +// botInfo structure contains information about telegram bot, which is used from whole telegram API response +type botInfo struct { + Username string `json:"username"` +} + +// BotInfo returns info about configured bot +func (tg *tgAPI) BotInfo(ctx context.Context) (*botInfo, error) { + var resp = struct { + Result *botInfo `json:"result"` + }{} + + err := tg.request(ctx, "getMe", &resp) + if err != nil { + return nil, err + } + if resp.Result == nil { + return nil, fmt.Errorf("received empty result") + } + + return resp.Result, nil +} + +func (tg *tgAPI) request(ctx context.Context, method string, data interface{}) error { + return repeater.NewDefault(3, time.Millisecond*50).Do(ctx, func() error { + url := fmt.Sprintf("https://api.telegram.org/bot%s/%s", tg.token, method) + + req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + resp, err := tg.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() //nolint gosec // we don't care about response body + + if resp.StatusCode != http.StatusOK { + return tg.parseError(resp.Body, resp.StatusCode) + } + + if err = json.NewDecoder(resp.Body).Decode(data); err != nil { + return fmt.Errorf("failed to decode json response: %w", err) + } + + return nil + }) +} + +func (tg *tgAPI) parseError(r io.Reader, statusCode int) error { + tgErr := struct { + Description string `json:"description"` + }{} + if err := json.NewDecoder(r).Decode(&tgErr); err != nil { + return fmt.Errorf("unexpected telegram API status code %d", statusCode) + } + return fmt.Errorf("unexpected telegram API status code %d, error: %q", statusCode, tgErr.Description) +} diff --git a/v2/provider/telegram_moq_test.go b/v2/provider/telegram_moq_test.go new file mode 100644 index 00000000..8148e4b9 --- /dev/null +++ b/v2/provider/telegram_moq_test.go @@ -0,0 +1,225 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package provider + +import ( + "context" + "sync" +) + +// Ensure, that TelegramAPIMock does implement TelegramAPI. +// If this is not the case, regenerate this file with moq. +var _ TelegramAPI = &TelegramAPIMock{} + +// TelegramAPIMock is a mock implementation of TelegramAPI. +// +// func TestSomethingThatUsesTelegramAPI(t *testing.T) { +// +// // make and configure a mocked TelegramAPI +// mockedTelegramAPI := &TelegramAPIMock{ +// AvatarFunc: func(ctx context.Context, userID int) (string, error) { +// panic("mock out the Avatar method") +// }, +// BotInfoFunc: func(ctx context.Context) (*botInfo, error) { +// panic("mock out the BotInfo method") +// }, +// GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { +// panic("mock out the GetUpdates method") +// }, +// SendFunc: func(ctx context.Context, id int, text string) error { +// panic("mock out the Send method") +// }, +// } +// +// // use mockedTelegramAPI in code that requires TelegramAPI +// // and then make assertions. +// +// } +type TelegramAPIMock struct { + // AvatarFunc mocks the Avatar method. + AvatarFunc func(ctx context.Context, userID int) (string, error) + + // BotInfoFunc mocks the BotInfo method. + BotInfoFunc func(ctx context.Context) (*botInfo, error) + + // GetUpdatesFunc mocks the GetUpdates method. + GetUpdatesFunc func(ctx context.Context) (*telegramUpdate, error) + + // SendFunc mocks the Send method. + SendFunc func(ctx context.Context, id int, text string) error + + // calls tracks calls to the methods. + calls struct { + // Avatar holds details about calls to the Avatar method. + Avatar []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // UserID is the userID argument value. + UserID int + } + // BotInfo holds details about calls to the BotInfo method. + BotInfo []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // GetUpdates holds details about calls to the GetUpdates method. + GetUpdates []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // Send holds details about calls to the Send method. + Send []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // ID is the id argument value. + ID int + // Text is the text argument value. + Text string + } + } + lockAvatar sync.RWMutex + lockBotInfo sync.RWMutex + lockGetUpdates sync.RWMutex + lockSend sync.RWMutex +} + +// Avatar calls AvatarFunc. +func (mock *TelegramAPIMock) Avatar(ctx context.Context, userID int) (string, error) { + if mock.AvatarFunc == nil { + panic("TelegramAPIMock.AvatarFunc: method is nil but TelegramAPI.Avatar was just called") + } + callInfo := struct { + Ctx context.Context + UserID int + }{ + Ctx: ctx, + UserID: userID, + } + mock.lockAvatar.Lock() + mock.calls.Avatar = append(mock.calls.Avatar, callInfo) + mock.lockAvatar.Unlock() + return mock.AvatarFunc(ctx, userID) +} + +// AvatarCalls gets all the calls that were made to Avatar. +// Check the length with: +// +// len(mockedTelegramAPI.AvatarCalls()) +func (mock *TelegramAPIMock) AvatarCalls() []struct { + Ctx context.Context + UserID int +} { + var calls []struct { + Ctx context.Context + UserID int + } + mock.lockAvatar.RLock() + calls = mock.calls.Avatar + mock.lockAvatar.RUnlock() + return calls +} + +// BotInfo calls BotInfoFunc. +func (mock *TelegramAPIMock) BotInfo(ctx context.Context) (*botInfo, error) { + if mock.BotInfoFunc == nil { + panic("TelegramAPIMock.BotInfoFunc: method is nil but TelegramAPI.BotInfo was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockBotInfo.Lock() + mock.calls.BotInfo = append(mock.calls.BotInfo, callInfo) + mock.lockBotInfo.Unlock() + return mock.BotInfoFunc(ctx) +} + +// BotInfoCalls gets all the calls that were made to BotInfo. +// Check the length with: +// +// len(mockedTelegramAPI.BotInfoCalls()) +func (mock *TelegramAPIMock) BotInfoCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockBotInfo.RLock() + calls = mock.calls.BotInfo + mock.lockBotInfo.RUnlock() + return calls +} + +// GetUpdates calls GetUpdatesFunc. +func (mock *TelegramAPIMock) GetUpdates(ctx context.Context) (*telegramUpdate, error) { + if mock.GetUpdatesFunc == nil { + panic("TelegramAPIMock.GetUpdatesFunc: method is nil but TelegramAPI.GetUpdates was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetUpdates.Lock() + mock.calls.GetUpdates = append(mock.calls.GetUpdates, callInfo) + mock.lockGetUpdates.Unlock() + return mock.GetUpdatesFunc(ctx) +} + +// GetUpdatesCalls gets all the calls that were made to GetUpdates. +// Check the length with: +// +// len(mockedTelegramAPI.GetUpdatesCalls()) +func (mock *TelegramAPIMock) GetUpdatesCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetUpdates.RLock() + calls = mock.calls.GetUpdates + mock.lockGetUpdates.RUnlock() + return calls +} + +// Send calls SendFunc. +func (mock *TelegramAPIMock) Send(ctx context.Context, id int, text string) error { + if mock.SendFunc == nil { + panic("TelegramAPIMock.SendFunc: method is nil but TelegramAPI.Send was just called") + } + callInfo := struct { + Ctx context.Context + ID int + Text string + }{ + Ctx: ctx, + ID: id, + Text: text, + } + mock.lockSend.Lock() + mock.calls.Send = append(mock.calls.Send, callInfo) + mock.lockSend.Unlock() + return mock.SendFunc(ctx, id, text) +} + +// SendCalls gets all the calls that were made to Send. +// Check the length with: +// +// len(mockedTelegramAPI.SendCalls()) +func (mock *TelegramAPIMock) SendCalls() []struct { + Ctx context.Context + ID int + Text string +} { + var calls []struct { + Ctx context.Context + ID int + Text string + } + mock.lockSend.RLock() + calls = mock.calls.Send + mock.lockSend.RUnlock() + return calls +} diff --git a/v2/provider/telegram_test.go b/v2/provider/telegram_test.go new file mode 100644 index 00000000..5121c7cd --- /dev/null +++ b/v2/provider/telegram_test.go @@ -0,0 +1,585 @@ +package provider + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + authtoken "github.com/go-pkgz/auth/token" +) + +// same across all tests +var botInfoFunc = func(ctx context.Context) (*botInfo, error) { + return &botInfo{Username: "my_auth_bot"}, nil +} + +func TestTgLoginHandlerErrors(t *testing.T) { + tg := TelegramHandler{Telegram: NewTelegramAPI("test", http.DefaultClient)} + + r := httptest.NewRequest("GET", "/login?site=remark", nil) + w := httptest.NewRecorder() + tg.LoginHandler(w, r) + assert.Equal(t, http.StatusInternalServerError, w.Code, "request should fail") + + var resp = struct { + Error string `json:"error"` + }{} + + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.NoError(t, err) + assert.Equal(t, "failed to process login request", resp.Error) +} + +func TestTelegramUnconfirmedRequest(t *testing.T) { + m := &TelegramAPIMock{ + GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { + return &telegramUpdate{}, nil + }, + BotInfoFunc: botInfoFunc, + } + + tg, cleanup := setupHandler(t, m) + defer cleanup() + + // Get token + r := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + tg.LoginHandler(w, r) + + assert.Equal(t, http.StatusOK, w.Code, "request should succeed") + + var resp = struct { + Token string `json:"token"` + Bot string `json:"bot"` + }{} + + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.NoError(t, err) + + assert.Equal(t, "my_auth_bot", resp.Bot) + token := resp.Token + + // Make sure we get error without first confirming auth request + r = httptest.NewRequest("GET", fmt.Sprintf("/?token=%s", token), nil) + w = httptest.NewRecorder() + tg.LoginHandler(w, r) + + assert.Equal(t, http.StatusNotFound, w.Code, "response code should be 404") + assert.Equal(t, `{"error":"request is not verified yet"}`+"\n", w.Body.String()) + + time.Sleep(tgAuthRequestLifetime) + + // Confirm auth request expired + r = httptest.NewRequest("GET", fmt.Sprintf("/?token=%s", token), nil) + w = httptest.NewRecorder() + tg.LoginHandler(w, r) + + assert.Equal(t, http.StatusNotFound, w.Code, "response code should be 404") + assert.Equal(t, `{"error":"request expired"}`+"\n", w.Body.String()) +} + +func TestTelegramConfirmedRequest(t *testing.T) { + var servedToken string + var mu sync.Mutex + + m := &TelegramAPIMock{ + GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { + var upd telegramUpdate + + mu.Lock() + defer mu.Unlock() + if servedToken != "" { + resp := fmt.Sprintf(getUpdatesResp, servedToken) + + err := json.Unmarshal([]byte(resp), &upd) + if err != nil { + t.Fatal(err) + } + } + return &upd, nil + }, + AvatarFunc: func(ctx context.Context, userID int) (string, error) { + assert.Equal(t, 313131313, userID) + return "http://t.me/avatar.png", nil + }, + SendFunc: func(ctx context.Context, id int, text string) error { + assert.Equal(t, 313131313, id) + assert.Equal(t, "success", text) + return nil + }, + BotInfoFunc: botInfoFunc, + } + + tg, cleanup := setupHandler(t, m) + defer cleanup() + + // Get token + r := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + tg.LoginHandler(w, r) + + assert.Equal(t, http.StatusOK, w.Code, "request should succeed") + + var resp = struct { + Token string `json:"token"` + Bot string `json:"bot"` + }{} + + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.NoError(t, err) + assert.Equal(t, "my_auth_bot", resp.Bot) + + mu.Lock() + servedToken = resp.Token + mu.Unlock() + + // Check the token confirmation + assert.Eventually(t, func() bool { + r = httptest.NewRequest("GET", fmt.Sprintf("/?token=%s", resp.Token), nil) + w = httptest.NewRecorder() + tg.LoginHandler(w, r) + return w.Code == http.StatusOK + }, apiPollInterval*10, apiPollInterval, "response code should be 200") + + info := struct { + Name string `name:"name"` + ID string `id:"id"` + Picture string `json:"picture"` + }{} + err = json.NewDecoder(w.Body).Decode(&info) + assert.NoError(t, err) + + assert.Equal(t, "Joe", info.Name) + assert.Contains(t, info.ID, "telegram_") + assert.Equal(t, "http://example.com/ava12345.png", info.Picture) + + // Test request has been invalidated + r = httptest.NewRequest("GET", fmt.Sprintf("/?token=%s", resp.Token), nil) + w = httptest.NewRecorder() + tg.LoginHandler(w, r) + + assert.Equal(t, http.StatusNotFound, w.Code, "request should get revoked") + assert.Equal(t, `{"error":"request is not found"}`+"\n", w.Body.String()) +} + +func TestTelegramLogout(t *testing.T) { + m := &TelegramAPIMock{ + GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { + return &telegramUpdate{}, nil + }, + BotInfoFunc: botInfoFunc, + } + + tg, cleanup := setupHandler(t, m) + defer cleanup() + + // Same TestVerifyHandler_Logout + handler := http.HandlerFunc(tg.LogoutHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/logout", http.NoBody) + assert.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, 2, len(rr.Header()["Set-Cookie"])) + + request := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}} + c, err := request.Cookie("JWT") + assert.NoError(t, err) + assert.Equal(t, time.Time{}, c.Expires) + + c, err = request.Cookie("XSRF-TOKEN") + assert.NoError(t, err) + assert.Equal(t, time.Time{}, c.Expires) +} + +func TestTelegramHandler_Name(t *testing.T) { + tg := &TelegramHandler{ProviderName: "test telegram"} + assert.Equal(t, "test telegram", tg.Name()) + assert.Equal(t, "test telegram", tg.String()) +} + +func TestTelegram_ProcessUpdateFlow(t *testing.T) { + m := &TelegramAPIMock{ + GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { + return &telegramUpdate{}, nil + }, + SendFunc: func(ctx context.Context, id int, text string) error { + assert.Equal(t, 313131313, id) + return nil + }, + AvatarFunc: func(ctx context.Context, userID int) (string, error) { + assert.Equal(t, 313131313, userID) + return "http://t.me/avatar.png", nil + }, + BotInfoFunc: botInfoFunc, + } + + tg := &TelegramHandler{ + ProviderName: "telegram", + ErrorMsg: "error", + SuccessMsg: "success", + + L: t, + TokenService: authtoken.NewService(authtoken.Opts{ + SecretReader: authtoken.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + AvatarSaver: &mockAvatarSaver{}, + Telegram: m, + } + // we can't call addToken unless requests.data initialized either in Run or ProcessUpdate + assert.EqualError(t, tg.ProcessUpdate(context.Background(), ""), "failed to decode provided telegram update: unexpected end of JSON input") + + assert.NoError(t, tg.addToken("token", time.Now().Add(time.Minute))) + assert.NoError(t, tg.addToken("expired token", time.Now().Add(-time.Minute))) + assert.Len(t, tg.requests.data, 2) + _, err := tg.checkToken("token") + assert.Error(t, err) + assert.NoError(t, tg.ProcessUpdate(context.Background(), fmt.Sprintf(getUpdatesResp, "token"))) + assert.Len(t, tg.requests.data, 1, "expired token was cleaned up") + tgUser, err := tg.checkToken("token") + assert.NoError(t, err) + assert.NotNil(t, tgUser) + assert.Equal(t, "Joe", tgUser.Name) + assert.Len(t, tg.requests.data, 1) + + assert.NoError(t, tg.addToken("expired token", time.Now().Add(-time.Minute))) + assert.Len(t, tg.requests.data, 2) + assert.EqualError(t, tg.ProcessUpdate(context.Background(), ""), "failed to decode provided telegram update: unexpected end of JSON input") + assert.Len(t, tg.requests.data, 1, "expired token should be cleaned up despite the error") + + // Verify that get token will return bot name + r := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + tg.LoginHandler(w, r) + + assert.Equal(t, http.StatusOK, w.Code, "request should succeed") + + var resp = struct { + Token string `json:"token"` + Bot string `json:"bot"` + }{} + + err = json.Unmarshal(w.Body.Bytes(), &resp) + assert.NoError(t, err) + assert.Equal(t, "my_auth_bot", resp.Bot) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go tg.Run(ctx) + assert.Eventually(t, func() bool { + return tg.ProcessUpdate(ctx, "").Error() == "Run goroutine should not be used with ProcessUpdate" + }, time.Millisecond*100, time.Millisecond*10, "ProcessUpdate should not work same time as Run") +} + +func TestTelegram_TokenVerification(t *testing.T) { + m := &TelegramAPIMock{ + GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { + return &telegramUpdate{}, nil + }, + BotInfoFunc: botInfoFunc, + } + + tg, cleanup := setupHandler(t, m) + cleanup() // we don't need tg.Run goroutine + assert.NotNil(t, tg) + tg.requests.data = make(map[string]tgAuthRequest) // usually done in Run() + err := tg.addToken("token", time.Now().Add(time.Minute)) + assert.NoError(t, err) + assert.Len(t, tg.requests.data, 1) + + // wrong token + tgID, err := tg.checkToken("unknown token") + assert.Empty(t, tgID) + assert.EqualError(t, err, "request is not found") + + // right token, not verified yet + tgID, err = tg.checkToken("token") + assert.Empty(t, tgID) + assert.EqualError(t, err, "request is not verified yet") + + // confirm request + authRequest, ok := tg.requests.data["token"] + assert.True(t, ok) + authRequest.confirmed = true + authRequest.user = &authtoken.User{ + Name: "telegram user name", + } + tg.requests.data["token"] = authRequest + + // successful check + tgID, err = tg.checkToken("token") + assert.NoError(t, err) + assert.Equal(t, &authtoken.User{Name: "telegram user name"}, tgID) + + // expired token + err = tg.addToken("expired token", time.Now().Add(-time.Minute)) + assert.NoError(t, err) + tgID, err = tg.checkToken("expired token") + assert.Empty(t, tgID) + assert.EqualError(t, err, "request expired") + assert.Len(t, tg.requests.data, 1) + + // expired token, cleaned up by the cleanup + apiPollInterval = time.Hour + expiredCleanupInterval = time.Millisecond * 10 + ctx, cancel := context.WithCancel(context.Background()) + go tg.Run(ctx) + // that sleep is needed because Run() will create new requests.data map, and we need to be sure that + // it's created by the time addToken is called. + time.Sleep(expiredCleanupInterval) + err = tg.addToken("expired token", time.Now().Add(-time.Minute)) + assert.NoError(t, err) + tg.requests.RLock() + assert.Len(t, tg.requests.data, 1) + tg.requests.RUnlock() + time.Sleep(expiredCleanupInterval * 2) + tg.requests.RLock() + assert.Len(t, tg.requests.data, 0) + tg.requests.RUnlock() + cancel() + // give enough time for Run() to finish + time.Sleep(expiredCleanupInterval) +} + +func setupHandler(t *testing.T, m TelegramAPI) (tg *TelegramHandler, cleanup func()) { + apiPollInterval = time.Millisecond * 10 + tgAuthRequestLifetime = time.Millisecond * 100 + + tg = &TelegramHandler{ + ProviderName: "telegram", + ErrorMsg: "error", + SuccessMsg: "success", + + L: t, + TokenService: authtoken.NewService(authtoken.Opts{ + SecretReader: authtoken.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + AvatarSaver: &mockAvatarSaver{}, + Telegram: m, + } + + assert.Equal(t, "telegram", tg.Name()) + + ctx, cleanup := context.WithCancel(context.Background()) + go func() { + err := tg.Run(ctx) + if err != context.Canceled { + t.Errorf("Unexpected error: %v", err) + } + }() + time.Sleep(20 * time.Millisecond) + + return tg, cleanup +} + +const getUpdatesResp = `{ + "ok": true, + "result": [ + { + "update_id": 1000, + "message": { + "message_id": 4, + "from": { + "id": 313131313, + "is_bot": false, + "first_name": "Joe", + "username": "joe123", + "language_code": "en" + }, + "chat": { + "id": 313131313, + "first_name": "Joe", + "username": "joe123", + "type": "private" + }, + "date": 1601665548, + "text": "/start %s", + "entities": [ + { + "offset": 0, + "length": 6, + "type": "bot_command" + } + ] + } + } + ] +}` + +func TestTgAPI_GetUpdates(t *testing.T) { + first := true + tg, cleanup := prepareTgAPI(t, func(w http.ResponseWriter, r *http.Request) { + if first { + assert.Equal(t, "", r.URL.Query().Get("offset")) + first = false + } else { + assert.Equal(t, "1001", r.URL.Query().Get("offset")) + } + _, _ = fmt.Fprintf(w, getUpdatesResp, "token") + }) + defer cleanup() + + // send request with no offset + upd, err := tg.GetUpdates(context.Background()) + assert.NoError(t, err) + + assert.Len(t, upd.Result, 1) + + assert.Equal(t, 1001, tg.updateOffset) + assert.Equal(t, "/start token", upd.Result[len(upd.Result)-1].Message.Text) + + // send request with offset + _, err = tg.GetUpdates(context.Background()) + assert.NoError(t, err) +} + +const sendMessageResp = `{ + "ok": true, + "result": { + "message_id": 100, + "from": { + "id": 666666666, + "is_bot": true, + "first_name": "Test auth bot", + "username": "TestAuthBot" + }, + "chat": { + "id": 313131313, + "first_name": "Joe", + "username": "joe123", + "type": "private" + }, + "date": 1602430546, + "text": "123" + } +}` + +func TestTgAPI_Send(t *testing.T) { + tg, cleanup := prepareTgAPI(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "123", r.URL.Query().Get("chat_id")) + assert.Equal(t, "hello there", r.URL.Query().Get("text")) + _, _ = w.Write([]byte(sendMessageResp)) + })) + defer cleanup() + + err := tg.Send(context.Background(), 123, "hello there") + assert.NoError(t, err) +} + +const profilePhotosResp = `{ + "ok": true, + "result": { + "total_count": 1, + "photos": [ + [ + { + "file_id": "1", + "file_unique_id": "A", + "file_size": 8900, + "width": 200, + "height": 200 + } + ] + ] + } +}` + +const getFileResp = `{ + "ok": true, + "result": { + "file_id": "1", + "file_unique_id": "A", + "file_size": 8900, + "file_path": "photos/file_0.jpg" + } +}` + +func TestTgAPI_Avatar(t *testing.T) { + tg, cleanup := prepareTgAPI(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.String(), "getUserProfilePhotos") { + assert.Equal(t, "123", r.URL.Query().Get("user_id")) + _, _ = w.Write([]byte(profilePhotosResp)) + return + } + + assert.Equal(t, "1", r.URL.Query().Get("file_id")) + _, _ = w.Write([]byte(getFileResp)) + + })) + defer cleanup() + + avatarURL, err := tg.Avatar(context.Background(), 123) + assert.NoError(t, err) + + expected := fmt.Sprintf("https://api.telegram.org/file/bot%s/photos/file_0.jpg", tg.token) + assert.Equal(t, expected, avatarURL) +} + +const errorResp = `{"ok":false,"error_code":400,"description":"Very bad request"}` + +func TestTgAPI_Error(t *testing.T) { + tg, cleanup := prepareTgAPI(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(errorResp)) + })) + defer cleanup() + + _, err := tg.GetUpdates(context.Background()) + assert.EqualError(t, err, "failed to fetch updates: unexpected telegram API status code 400, error: \"Very bad request\"") +} + +// mockRoundTripper redirects all incoming requests to mock url +type mockRoundTripper struct{ url string } + +func (m mockRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + u, _ := url.Parse(m.url) + r.URL.Host = u.Host + r.URL.Scheme = u.Scheme + return http.DefaultClient.Do(r) +} + +const getMeResp = `{ + "ok": true, + "result": { + "id": 123456789, + "is_bot": true, + "first_name": "Test auth bot", + "username": "RemarkAuthBot", + "can_join_groups": true, + "can_read_all_group_messages": false, + "supports_inline_queries": false + } +} +` + +func prepareTgAPI(t *testing.T, h http.HandlerFunc) (tg *tgAPI, cleanup func()) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.String(), "xxxsupersecretxxx") + + if strings.Contains(r.URL.String(), "getMe") { + _, _ = w.Write([]byte(getMeResp)) + return + } + + h(w, r) + })) + + client := &http.Client{ + Transport: mockRoundTripper{srv.URL}, + } + + return NewTelegramAPI("xxxsupersecretxxx", client).(*tgAPI), srv.Close +} diff --git a/v2/provider/verify.go b/v2/provider/verify.go new file mode 100644 index 00000000..8b0a03dd --- /dev/null +++ b/v2/provider/verify.go @@ -0,0 +1,219 @@ +package provider + +import ( + "bytes" + "crypto/sha1" + "fmt" + "html/template" + "net/http" + "strings" + "time" + + "github.com/go-pkgz/rest" + "github.com/golang-jwt/jwt" + + "github.com/go-pkgz/auth/avatar" + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +// VerifyHandler implements non-oauth2 provider authorizing users with some confirmation. +// can be email, IM or anything else implementing Sender interface +type VerifyHandler struct { + logger.L + ProviderName string + TokenService VerifTokenService + Issuer string + AvatarSaver AvatarSaver + Sender Sender + Template string + UseGravatar bool +} + +// Sender defines interface to send emails +type Sender interface { + Send(address, text string) error +} + +// SenderFunc type is an adapter to allow the use of ordinary functions as Sender. +type SenderFunc func(address, text string) error + +// Send calls f(address,text) to implement Sender interface +func (f SenderFunc) Send(address, text string) error { + return f(address, text) +} + +// VerifTokenService defines interface accessing tokens +type VerifTokenService interface { + Token(claims token.Claims) (string, error) + Parse(tokenString string) (claims token.Claims, err error) + IsExpired(claims token.Claims) bool + Set(w http.ResponseWriter, claims token.Claims) (token.Claims, error) + Reset(w http.ResponseWriter) +} + +// Name of the handler +func (e VerifyHandler) Name() string { return e.ProviderName } + +// LoginHandler gets name and address from query, makes confirmation token and sends it to user. +// In case if confirmation token presented in the query uses it to create auth token +func (e VerifyHandler) LoginHandler(w http.ResponseWriter, r *http.Request) { + + // GET /login?site=site&user=name&address=someone@example.com + tkn := r.URL.Query().Get("token") + if tkn == "" { // no token, ask confirmation via email + e.sendConfirmation(w, r) + return + } + + // confirmation token presented + // GET /login?token=confirmation-jwt&sess=1 + confClaims, err := e.TokenService.Parse(tkn) + if err != nil { + rest.SendErrorJSON(w, r, e.L, http.StatusForbidden, err, "failed to verify confirmation token") + return + } + + if e.TokenService.IsExpired(confClaims) { + rest.SendErrorJSON(w, r, e.L, http.StatusForbidden, fmt.Errorf("expired"), "failed to verify confirmation token") + return + } + + elems := strings.Split(confClaims.Handshake.ID, "::") + if len(elems) != 2 { + rest.SendErrorJSON(w, r, e.L, http.StatusBadRequest, fmt.Errorf("%s", confClaims.Handshake.ID), "invalid handshake token") + return + } + user, address := elems[0], elems[1] + sessOnly := r.URL.Query().Get("sess") == "1" + + u := token.User{ + Name: user, + ID: e.ProviderName + "_" + token.HashID(sha1.New(), address), + } + // try to get gravatar for email + if e.UseGravatar && strings.Contains(address, "@") { // TODO: better email check to avoid silly hits to gravatar api + if picURL, e := avatar.GetGravatarURL(address); e == nil { + u.Picture = picURL + } + } + + if u, err = setAvatar(e.AvatarSaver, u, &http.Client{Timeout: 5 * time.Second}); err != nil { + rest.SendErrorJSON(w, r, e.L, http.StatusInternalServerError, err, "failed to save avatar to proxy") + return + } + + cid, err := randToken() + if err != nil { + rest.SendErrorJSON(w, r, e.L, http.StatusInternalServerError, err, "can't make token id") + return + } + + claims := token.Claims{ + User: &u, + StandardClaims: jwt.StandardClaims{ + Id: cid, + Issuer: e.Issuer, + Audience: confClaims.Audience, + }, + SessionOnly: sessOnly, + } + + if _, err = e.TokenService.Set(w, claims); err != nil { + rest.SendErrorJSON(w, r, e.L, http.StatusInternalServerError, err, "failed to set token") + return + } + if confClaims.Handshake != nil && confClaims.Handshake.From != "" { + http.Redirect(w, r, confClaims.Handshake.From, http.StatusTemporaryRedirect) + return + } + rest.RenderJSON(w, claims.User) +} + +// GET /login?site=site&user=name&address=someone@example.com +func (e VerifyHandler) sendConfirmation(w http.ResponseWriter, r *http.Request) { + + user, address, site := r.URL.Query().Get("user"), r.URL.Query().Get("address"), r.URL.Query().Get("site") + + if user == "" || address == "" { + rest.SendErrorJSON(w, r, e.L, http.StatusBadRequest, fmt.Errorf("wrong request"), "can't get user and address") + return + } + + claims := token.Claims{ + Handshake: &token.Handshake{ + State: "", + ID: user + "::" + address, + }, + SessionOnly: r.URL.Query().Get("session") != "" && r.URL.Query().Get("session") != "0", + StandardClaims: jwt.StandardClaims{ + Audience: site, + ExpiresAt: time.Now().Add(30 * time.Minute).Unix(), + NotBefore: time.Now().Add(-1 * time.Minute).Unix(), + Issuer: e.Issuer, + }, + } + + tkn, err := e.TokenService.Token(claims) + if err != nil { + rest.SendErrorJSON(w, r, e.L, http.StatusForbidden, err, "failed to make login token") + return + } + + tmpl := msgTemplate + if e.Template != "" { + tmpl = e.Template + } + emailTmpl, err := template.New("confirm").Parse(tmpl) + if err != nil { + rest.SendErrorJSON(w, r, e.L, http.StatusInternalServerError, err, "can't parse confirmation template") + return + } + + tmplData := struct { + User string + Address string + Token string + Site string + }{ + User: trim(user), + Address: trim(address), + Token: tkn, + Site: site, + } + buf := bytes.Buffer{} + if err = emailTmpl.Execute(&buf, tmplData); err != nil { + rest.SendErrorJSON(w, r, e.L, http.StatusInternalServerError, err, "can't execute confirmation template") + return + } + + if err := e.Sender.Send(address, buf.String()); err != nil { + rest.SendErrorJSON(w, r, e.L, http.StatusInternalServerError, err, "failed to send confirmation") + return + } + + rest.RenderJSON(w, rest.JSON{"user": user, "address": address}) +} + +// AuthHandler doesn't do anything for direct login as it has no callbacks +func (e VerifyHandler) AuthHandler(http.ResponseWriter, *http.Request) {} + +// LogoutHandler - GET /logout +func (e VerifyHandler) LogoutHandler(w http.ResponseWriter, _ *http.Request) { + e.TokenService.Reset(w) +} + +var msgTemplate = ` +Confirmation for {{.User}} {{.Address}}, site {{.Site}} + +Token: {{.Token}} +` + +func trim(inp string) string { + res := strings.ReplaceAll(inp, "\n", "") + res = strings.TrimSpace(res) + if len(res) > 128 { + return res[:128] + } + return res +} diff --git a/v2/provider/verify_test.go b/v2/provider/verify_test.go new file mode 100644 index 00000000..291c8928 --- /dev/null +++ b/v2/provider/verify_test.go @@ -0,0 +1,329 @@ +package provider + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/token" +) + +// nolint +var ( + testConfirmedToken = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJyZW1hcms0MiIsImV4cCI6MTg2MDMwNzQxMiwibmJmIjoxNTYwMzA1NTUyLCJoYW5kc2hha2UiOnsiaWQiOiJ0ZXN0MTIzOjpibGFoQHVzZXIuY29tIn19.D8AvAunK7Tj-P6P56VyaoZ-hyA6U8duZ9HV8-ACEya8` + testConfirmedBadIDToken = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJyZW1hcms0MiIsImV4cCI6MTg2MDMwNzQxMiwibmJmIjoxNTYwMzA1NTUyLCJoYW5kc2hha2UiOnsiaWQiOiJibGFoQHVzZXIuY29tIn19.hB91-kyY9-Q2Ln6IJGR9StQi-QQiXYu8SV31YhOoTbc` + testConfirmedGravatar = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJyZW1hcms0MiIsImV4cCI6MTg2MDMwNzQxMiwibmJmIjoxNTYwMzA1NTUyLCJoYW5kc2hha2UiOnsiaWQiOiJncmF2YTo6ZWVmcmV0c291bEBnbWFpbC5jb20ifX0.yQTtG7neX3YjLZ-SGeiiNmwNfJWA7nR50KAxDw834XE` + testConfirmedExpired = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJyZW1hcms0MiIsImV4cCI6MTU2MDMwNzQxMiwibmJmIjoxNTYwMzA1NTUyLCJoYW5kc2hha2UiOnsiaWQiOiJ0ZXN0MTIzOjpibGFoQHVzZXIuY29tIn19.bCFMAwCg1_l4yuEzFYzd0q9PstY-auHe2rwLqltffqo` +) + +func TestVerifyHandler_LoginSendConfirm(t *testing.T) { + + emailer := mockSender{} + e := VerifyHandler{ + ProviderName: "test", + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + Sender: SenderFunc(emailer.Send), + Template: "{{.User}} {{.Address}} {{.Site}} token:{{.Token}}", + } + + handler := http.HandlerFunc(e.LoginHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/login?address=blah@user.com&user=test123&site=remark42", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) + assert.Equal(t, "blah@user.com", emailer.to) + assert.Contains(t, emailer.text, "test123 blah@user.com remark42 token:") + + tknStr := strings.Split(emailer.text, " token:")[1] + tkn, err := e.TokenService.Parse(tknStr) + assert.NoError(t, err) + t.Logf("%s %+v", tknStr, tkn) + assert.Equal(t, "test123::blah@user.com", tkn.Handshake.ID) + assert.Equal(t, "remark42", tkn.Audience) + assert.True(t, tkn.ExpiresAt > tkn.NotBefore) + + assert.Equal(t, "test", e.Name()) +} + +func TestVerifyHandler_LoginSendConfirmEscapesBadInput(t *testing.T) { + + emailer := mockSender{} + e := VerifyHandler{ + ProviderName: "test", + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + Sender: SenderFunc(emailer.Send), + Template: "{{.User}} {{.Address}} {{.Site}} token:{{.Token}}", + } + + handler := http.HandlerFunc(e.LoginHandler) + rr := httptest.NewRecorder() + badData := "<escaped>" + req, err := http.NewRequest("GET", "/login?address=blah@user.com&user="+url.QueryEscape(badData)+"&site="+url.QueryEscape(badData), http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) + assert.Equal(t, "blah@user.com", emailer.to) + expectedEscaped := "<html><script>nasty stuff</script>&lt;escaped&gt;</html>" + assert.Contains(t, emailer.text, expectedEscaped+" blah@user.com "+expectedEscaped+" token:") + + tknStr := strings.Split(emailer.text, " token:")[1] + tkn, err := e.TokenService.Parse(tknStr) + assert.NoError(t, err) + t.Logf("%s %+v", tknStr, tkn) + // not escaped in these fields as they are not rendered as HTML + assert.Equal(t, badData+"::blah@user.com", tkn.Handshake.ID) + assert.Equal(t, badData, tkn.Audience) + assert.True(t, tkn.ExpiresAt > tkn.NotBefore) + + assert.Equal(t, "test", e.Name()) +} + +func TestVerifyHandler_LoginAcceptConfirm(t *testing.T) { + e := VerifyHandler{ + ProviderName: "test", + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + } + + handler := http.HandlerFunc(e.LoginHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", fmt.Sprintf("/login?token=%s&sess=1", testConfirmedToken), http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) + assert.Equal(t, `{"name":"test123","id":"test_63c1017838e567a526800790805eae4dc975402b","picture":""}`+"\n", rr.Body.String()) + + request := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}} + c, err := request.Cookie("JWT") + require.NoError(t, err) + claims, err := e.TokenService.Parse(c.Value) + require.NoError(t, err) + t.Logf("%+v", claims) + assert.Equal(t, "remark42", claims.Audience) + assert.Equal(t, "iss-test", claims.Issuer) + assert.True(t, claims.ExpiresAt > time.Now().Unix()) + assert.Equal(t, "test123", claims.User.Name) + assert.Equal(t, true, claims.SessionOnly) +} + +func TestVerifyHandler_LoginAcceptConfirmWithAvatar(t *testing.T) { + e := VerifyHandler{ + ProviderName: "test", + UseGravatar: true, + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + } + + handler := http.HandlerFunc(e.LoginHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", fmt.Sprintf("/login?token=%s&sess=1", testConfirmedGravatar), http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) + assert.Equal(t, `{"name":"grava","id":"test_47dbf92d92954b1297cae73a864c159b4d847b9f","picture":"https://www.gravatar.com/avatar/c82739de14cf64affaf30856ca95b851"}`+"\n", rr.Body.String()) +} + +func TestVerifyHandler_LoginAcceptConfirmWithGrAvatarDisabled(t *testing.T) { + e := VerifyHandler{ + ProviderName: "test", + UseGravatar: false, + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + } + + handler := http.HandlerFunc(e.LoginHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", fmt.Sprintf("/login?token=%s&sess=1", testConfirmedGravatar), http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) + assert.Equal(t, `{"name":"grava","id":"test_47dbf92d92954b1297cae73a864c159b4d847b9f","picture":""}`+"\n", rr.Body.String()) +} + +func TestVerifyHandler_LoginHandlerFailed(t *testing.T) { + emailer := mockSender{} + d := VerifyHandler{ + ProviderName: "test", + Sender: &emailer, + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + } + + handler := http.HandlerFunc(d.LoginHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/login?user=myuser&aud=xyz123", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 400, rr.Code) + assert.Equal(t, `{"error":"can't get user and address"}`+"\n", rr.Body.String()) + + d.Sender = &mockSender{err: fmt.Errorf("some err")} + handler = d.LoginHandler + rr = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/login?user=myuser&address=pppp&aud=xyz123", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 500, rr.Code) + assert.Equal(t, `{"error":"failed to send confirmation"}`+"\n", rr.Body.String()) + + rr = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/login?token=bad", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Equal(t, `{"error":"failed to verify confirmation token"}`+"\n", rr.Body.String()) + + rr = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/login?token="+testConfirmedBadIDToken, http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Equal(t, `{"error":"invalid handshake token"}`+"\n", rr.Body.String()) + + rr = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/login?token="+testConfirmedExpired, http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Equal(t, `{"error":"failed to verify confirmation token"}`+"\n", rr.Body.String()) + + d.Template = `{{.Blah}}` + d.Sender = &mockSender{} + handler = d.LoginHandler + rr = httptest.NewRecorder() + req, err = http.NewRequest("GET", "/login?user=myuser&address=pppp&aud=xyz123", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Equal(t, `{"error":"can't execute confirmation template"}`+"\n", rr.Body.String()) +} + +func TestVerifyHandler_LoginHandlerAvatarFailed(t *testing.T) { + emailer := mockSender{} + d := VerifyHandler{ + ProviderName: "test", + Sender: &emailer, + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + AvatarSaver: mockAvatarSaverVerif{err: fmt.Errorf("avatar save error")}, + } + + handler := http.HandlerFunc(d.LoginHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/login?token="+testConfirmedToken, http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 500, rr.Code) + assert.Equal(t, `{"error":"failed to save avatar to proxy"}`+"\n", rr.Body.String()) +} + +func TestVerifyHandler_AuthHandler(t *testing.T) { + d := VerifyHandler{} + handler := http.HandlerFunc(d.AuthHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/callback", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) +} + +func TestVerifyHandler_Logout(t *testing.T) { + d := VerifyHandler{ + ProviderName: "test", + TokenService: token.NewService(token.Opts{ + SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), + TokenDuration: time.Hour, + CookieDuration: time.Hour * 24 * 31, + }), + Issuer: "iss-test", + L: logger.Std, + } + + handler := http.HandlerFunc(d.LogoutHandler) + rr := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/logout", http.NoBody) + require.NoError(t, err) + handler.ServeHTTP(rr, req) + assert.Equal(t, 200, rr.Code) + assert.Equal(t, 2, len(rr.Header()["Set-Cookie"])) + + request := &http.Request{Header: http.Header{"Cookie": rr.Header()["Set-Cookie"]}} + c, err := request.Cookie("JWT") + require.NoError(t, err) + assert.Equal(t, time.Time{}, c.Expires) + + c, err = request.Cookie("XSRF-TOKEN") + require.NoError(t, err) + assert.Equal(t, time.Time{}, c.Expires) +} + +type mockSender struct { + err error + + to string + text string +} + +func (m *mockSender) Send(to, text string) error { + if m.err != nil { + return m.err + } + m.to = to + m.text = text + return nil +} + +type mockAvatarSaverVerif struct { + err error + url string +} + +func (a mockAvatarSaverVerif) Put(token.User, *http.Client) (avatarURL string, err error) { + return a.url, a.err +} diff --git a/v2/token/jwt.go b/v2/token/jwt.go new file mode 100644 index 00000000..72fc7e89 --- /dev/null +++ b/v2/token/jwt.go @@ -0,0 +1,405 @@ +// Package token wraps jwt-go library and provides higher level abstraction to work with JWT. +package token + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt" +) + +// Service wraps jwt operations +// supports both header and cookie tokens +type Service struct { + Opts +} + +// Claims stores user info for token and state & from from login +type Claims struct { + jwt.StandardClaims + User *User `json:"user,omitempty"` // user info + SessionOnly bool `json:"sess_only,omitempty"` + Handshake *Handshake `json:"handshake,omitempty"` // used for oauth handshake + NoAva bool `json:"no-ava,omitempty"` // disable avatar, always use identicon +} + +// Handshake used for oauth handshake +type Handshake struct { + State string `json:"state,omitempty"` + From string `json:"from,omitempty"` + ID string `json:"id,omitempty"` +} + +const ( + // default names for cookies and headers + defaultJWTCookieName = "JWT" + defaultJWTCookieDomain = "" + defaultJWTHeaderKey = "X-JWT" + defaultXSRFCookieName = "XSRF-TOKEN" + defaultXSRFHeaderKey = "X-XSRF-TOKEN" + + defaultIssuer = "go-pkgz/auth" + + defaultTokenDuration = time.Minute * 15 + defaultCookieDuration = time.Hour * 24 * 31 + + defaultTokenQuery = "token" +) + +// Opts holds constructor params +type Opts struct { + SecretReader Secret + ClaimsUpd ClaimsUpdater + SecureCookies bool + TokenDuration time.Duration + CookieDuration time.Duration + DisableXSRF bool + DisableIAT bool // disable IssuedAt claim + // optional (custom) names for cookies and headers + JWTCookieName string + JWTCookieDomain string + JWTHeaderKey string + XSRFCookieName string + XSRFHeaderKey string + JWTQuery string + AudienceReader Audience // allowed aud values + Issuer string // optional value for iss claim, usually application name + AudSecrets bool // uses different secret for differed auds. important: adds pre-parsing of unverified token + SendJWTHeader bool // if enabled send JWT as a header instead of cookie + SameSite http.SameSite // define a cookie attribute making it impossible for the browser to send this cookie cross-site +} + +// NewService makes JWT service +func NewService(opts Opts) *Service { + res := Service{Opts: opts} + + setDefault := func(fld *string, def string) { + if *fld == "" { + *fld = def + } + } + + setDefault(&res.JWTCookieName, defaultJWTCookieName) + setDefault(&res.JWTHeaderKey, defaultJWTHeaderKey) + setDefault(&res.XSRFCookieName, defaultXSRFCookieName) + setDefault(&res.XSRFHeaderKey, defaultXSRFHeaderKey) + setDefault(&res.JWTQuery, defaultTokenQuery) + setDefault(&res.Issuer, defaultIssuer) + setDefault(&res.JWTCookieDomain, defaultJWTCookieDomain) + + if opts.TokenDuration == 0 { + res.TokenDuration = defaultTokenDuration + } + + if opts.CookieDuration == 0 { + res.CookieDuration = defaultCookieDuration + } + + return &res +} + +// Token makes token with claims +func (j *Service) Token(claims Claims) (string, error) { + + // make token for allowed aud values only, rejects others + + // update claims with ClaimsUpdFunc defined by consumer + if j.ClaimsUpd != nil { + claims = j.ClaimsUpd.Update(claims) + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + if j.SecretReader == nil { + return "", fmt.Errorf("secret reader not defined") + } + + if err := j.checkAuds(&claims, j.AudienceReader); err != nil { + return "", fmt.Errorf("aud rejected: %w", err) + } + + secret, err := j.SecretReader.Get(claims.Audience) // get secret via consumer defined SecretReader + if err != nil { + return "", fmt.Errorf("can't get secret: %w", err) + } + + tokenString, err := token.SignedString([]byte(secret)) + if err != nil { + return "", fmt.Errorf("can't sign token: %w", err) + } + return tokenString, nil +} + +// Parse token string and verify. Not checking for expiration +func (j *Service) Parse(tokenString string) (Claims, error) { + parser := jwt.Parser{SkipClaimsValidation: true} // allow parsing of expired tokens + + if j.SecretReader == nil { + return Claims{}, fmt.Errorf("secret reader not defined") + } + + aud := "ignore" + if j.AudSecrets { + var err error + aud, err = j.aud(tokenString) + if err != nil { + return Claims{}, fmt.Errorf("can't retrieve audience from the token") + } + } + + secret, err := j.SecretReader.Get(aud) + if err != nil { + return Claims{}, fmt.Errorf("can't get secret: %w", err) + } + + token, err := parser.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(secret), nil + }) + if err != nil { + return Claims{}, fmt.Errorf("can't parse token: %w", err) + } + + claims, ok := token.Claims.(*Claims) + if !ok { + return Claims{}, fmt.Errorf("invalid token") + } + + if err = j.checkAuds(claims, j.AudienceReader); err != nil { + return Claims{}, fmt.Errorf("aud rejected: %w", err) + } + return *claims, j.validate(claims) +} + +// aud pre-parse token and extracts aud from the claim +// important! this step ignores token verification, should not be used for any validations +func (j *Service) aud(tokenString string) (string, error) { + parser := jwt.Parser{} + token, _, err := parser.ParseUnverified(tokenString, &Claims{}) + if err != nil { + return "", fmt.Errorf("can't pre-parse token: %w", err) + } + claims, ok := token.Claims.(*Claims) + if !ok { + return "", fmt.Errorf("invalid token") + } + if strings.TrimSpace(claims.Audience) == "" { + return "", fmt.Errorf("empty aud") + } + return claims.Audience, nil +} + +func (j *Service) validate(claims *Claims) error { + cerr := claims.Valid() + + if cerr == nil { + return nil + } + + if e, ok := cerr.(*jwt.ValidationError); ok { + if e.Errors == jwt.ValidationErrorExpired { + return nil // allow expired tokens + } + } + + return cerr +} + +// Set creates token cookie with xsrf cookie and put it to ResponseWriter +// accepts claims and sets expiration if none defined. permanent flag means long-living cookie, +// false makes it session only. +func (j *Service) Set(w http.ResponseWriter, claims Claims) (Claims, error) { + if claims.ExpiresAt == 0 { + claims.ExpiresAt = time.Now().Add(j.TokenDuration).Unix() + } + + if claims.Issuer == "" { + claims.Issuer = j.Issuer + } + + if !j.DisableIAT { + claims.IssuedAt = time.Now().Unix() + } + + tokenString, err := j.Token(claims) + if err != nil { + return Claims{}, fmt.Errorf("failed to make token token: %w", err) + } + + if j.SendJWTHeader { + w.Header().Set(j.JWTHeaderKey, tokenString) + return claims, nil + } + + cookieExpiration := 0 // session cookie + if !claims.SessionOnly && claims.Handshake == nil { + cookieExpiration = int(j.CookieDuration.Seconds()) + } + + jwtCookie := http.Cookie{Name: j.JWTCookieName, Value: tokenString, HttpOnly: true, Path: "/", Domain: j.JWTCookieDomain, + MaxAge: cookieExpiration, Secure: j.SecureCookies, SameSite: j.SameSite} + http.SetCookie(w, &jwtCookie) + + xsrfCookie := http.Cookie{Name: j.XSRFCookieName, Value: claims.Id, HttpOnly: false, Path: "/", Domain: j.JWTCookieDomain, + MaxAge: cookieExpiration, Secure: j.SecureCookies, SameSite: j.SameSite} + http.SetCookie(w, &xsrfCookie) + + return claims, nil +} + +// Get token from url, header or cookie +// if cookie used, verify xsrf token to match +func (j *Service) Get(r *http.Request) (Claims, string, error) { + + fromCookie := false + tokenString := "" + + // try to get from "token" query param + if tkQuery := r.URL.Query().Get(j.JWTQuery); tkQuery != "" { + tokenString = tkQuery + } + + // try to get from JWT header + if tokenHeader := r.Header.Get(j.JWTHeaderKey); tokenHeader != "" && tokenString == "" { + tokenString = tokenHeader + } + + // try to get from JWT cookie + if tokenString == "" { + fromCookie = true + jc, err := r.Cookie(j.JWTCookieName) + if err != nil { + return Claims{}, "", fmt.Errorf("token cookie was not presented: %w", err) + } + tokenString = jc.Value + } + + claims, err := j.Parse(tokenString) + if err != nil { + return Claims{}, "", fmt.Errorf("failed to get token: %w", err) + } + + // promote claim's aud to User.Audience + if claims.User != nil { + claims.User.Audience = claims.Audience + } + + if !fromCookie && j.IsExpired(claims) { + return Claims{}, "", fmt.Errorf("token expired") + } + + if j.DisableXSRF { + return claims, tokenString, nil + } + + if fromCookie && claims.User != nil { + xsrf := r.Header.Get(j.XSRFHeaderKey) + if claims.Id != xsrf { + return Claims{}, "", fmt.Errorf("xsrf mismatch") + } + } + + return claims, tokenString, nil +} + +// IsExpired returns true if claims expired +func (j *Service) IsExpired(claims Claims) bool { + return !claims.VerifyExpiresAt(time.Now().Unix(), true) +} + +// Reset token's cookies +func (j *Service) Reset(w http.ResponseWriter) { + jwtCookie := http.Cookie{Name: j.JWTCookieName, Value: "", HttpOnly: false, Path: "/", Domain: j.JWTCookieDomain, + MaxAge: -1, Expires: time.Unix(0, 0), Secure: j.SecureCookies, SameSite: j.SameSite} + http.SetCookie(w, &jwtCookie) + + xsrfCookie := http.Cookie{Name: j.XSRFCookieName, Value: "", HttpOnly: false, Path: "/", Domain: j.JWTCookieDomain, + MaxAge: -1, Expires: time.Unix(0, 0), Secure: j.SecureCookies, SameSite: j.SameSite} + http.SetCookie(w, &xsrfCookie) +} + +// checkAuds verifies if claims.Audience in the list of allowed by audReader +func (j *Service) checkAuds(claims *Claims, audReader Audience) error { + if audReader == nil { // lack of any allowed means any + return nil + } + auds, err := audReader.Get() + if err != nil { + return fmt.Errorf("failed to get auds: %w", err) + } + for _, a := range auds { + if strings.EqualFold(a, claims.Audience) { + return nil + } + } + return fmt.Errorf("aud %q not allowed", claims.Audience) +} + +func (c Claims) String() string { + b, err := json.Marshal(c) + if err != nil { + return fmt.Sprintf("%+v %+v", c.StandardClaims, c.User) + } + return string(b) +} + +// Secret defines interface returning secret key for given id (aud) +type Secret interface { + Get(aud string) (string, error) // aud matching is optional. Implementation may decide if supported or ignored +} + +// SecretFunc type is an adapter to allow the use of ordinary functions as Secret. If f is a function +// with the appropriate signature, SecretFunc(f) is a Handler that calls f. +type SecretFunc func(aud string) (string, error) + +// Get calls f() +func (f SecretFunc) Get(aud string) (string, error) { + return f(aud) +} + +// ClaimsUpdater defines interface adding extras to claims +type ClaimsUpdater interface { + Update(claims Claims) Claims +} + +// ClaimsUpdFunc type is an adapter to allow the use of ordinary functions as ClaimsUpdater. If f is a function +// with the appropriate signature, ClaimsUpdFunc(f) is a Handler that calls f. +type ClaimsUpdFunc func(claims Claims) Claims + +// Update calls f(id) +func (f ClaimsUpdFunc) Update(claims Claims) Claims { + return f(claims) +} + +// Validator defines interface to accept o reject claims with consumer defined logic +// It works with valid token and allows to reject some, based on token match or user's fields +type Validator interface { + Validate(token string, claims Claims) bool +} + +// ValidatorFunc type is an adapter to allow the use of ordinary functions as Validator. If f is a function +// with the appropriate signature, ValidatorFunc(f) is a Validator that calls f. +type ValidatorFunc func(token string, claims Claims) bool + +// Validate calls f(id) +func (f ValidatorFunc) Validate(token string, claims Claims) bool { + return f(token, claims) +} + +// Audience defines interface returning list of allowed audiences +type Audience interface { + Get() ([]string, error) +} + +// AudienceFunc type is an adapter to allow the use of ordinary functions as Audience. +type AudienceFunc func() ([]string, error) + +// Get calls f() +func (f AudienceFunc) Get() ([]string, error) { + return f() +} diff --git a/v2/token/jwt_test.go b/v2/token/jwt_test.go new file mode 100644 index 00000000..e30ee137 --- /dev/null +++ b/v2/token/jwt_test.go @@ -0,0 +1,639 @@ +package token + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// To generate a token, enter one of the tokens here into https://jwt.io, change the secret to one you're using in your test +// ("secret" in most cases here, "xyz 12345" in makeTestAuth), and alter the fields you want to be changed. + +var ( + testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJoYW5kc2hha2UiOnsic3RhdGUiOiIxMjM0NTYiLCJmcm9tIjoiZnJvbSIsImlkIjoibXlpZC0xMjM0NTYifX0._2X1cAEoxjLA7XuN8xW8V9r7rYfP_m9lSRz_9_UFzac" + testJwtValidNoHandshake = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19fQ.OWPdibrSSSHuOV3DzzLH5soO6kUcERELL7_GLf7Ja_E" + testJwtValidSess = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJzZXNzX29ubHkiOnRydWV9.SjPlVgca_bijC2wbaite2_eNHk66VXgsxUKLy7eqlXM" + testJwtExpired = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE1MjY4ODc4MjIsImp0aSI6InJhbmRvbSBpZCIs" + + "ImlzcyI6InJlbWFyazQyIiwibmJmIjoxNTI2ODg0MjIyLCJ1c2VyIjp7Im5hbWUiOiJuYW1lMSIsImlkIjoiaWQxIiwicGljdHVyZSI6IiI" + + "sImFkbWluIjpmYWxzZX0sInN0YXRlIjoiMTIzNDU2IiwiZnJvbSI6ImZyb20ifQ.4_dCrY9ihyfZIedz-kZwBTxmxU1a52V7IqeJrOqTzE4" + testJwtBadSign = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJoYW5kc2hha2UiOnsic3RhdGUiOiIxMjM0NTYiLCJmcm9tIjoiZnJvbSIsImlkIjoibXlpZC0xMjM0NTYifX0.PRuys_Ez2QWhAMp3on4Xpdc5rebKcL7-HGncvYsdYns" + testJwtNbf = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjk5OTk4ODQyMjUsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19LCJoYW5kc2hha2UiOnsic3RhdGUiOiIxMjM0NTYiLCJmcm9tIjoiZnJvbSIsImlkIjoibXlpZC0xMjM0NTYifX0.T-rdC9_-6iuh7iKTW3wN8rFezDPJhz5y2bYnXVjw3nk" + testJwtNoneAlg = "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJpc3MiOiJodHRwczovL2p3dC1pZHAuZXhhbXBsZS5jb20iLCJzdWIiOiJtYWlsdG86bWlrZUBleGFtcGxlLmNvbSIsIm5iZiI6MTU0Njc0MzcxMSwiZXhwIjoxNTQ2NzQ3MzExLCJpYXQiOjE1NDY3NDM3MTEsImp0aSI6ImlkMTIzNDU2IiwidHlwIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS9yZWdpc3RlciJ9." + testJwtNoAud = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjI3ODkxOTE4MjIsImp0aSI6InJhbmRvbSBpZCIsImlzcyI6InJlbWFyazQyIiwibmJmIjoxNTI2ODg0MjIyLCJ1c2VyIjp7Im5hbWUiOiJuYW1lMSIsImlkIjoiaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fSwiaGFuZHNoYWtlIjp7InN0YXRlIjoiMTIzNDU2IiwiZnJvbSI6ImZyb20iLCJpZCI6Im15aWQtMTIzNDU2In19.pzRsCcZjH7MItUvnBmyGv74Qg3qx8vCGmsZP6lF_Z9A" + testJwtValidAud = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X2F1ZF9vbmx5IiwiZXhwIjoyNzg5MTkxODIyLCJqdGkiOiJyYW5kb20gaWQiLCJpc3MiOiJyZW1hcms0MiIsIm5iZiI6MTUyNjg4NDIyMiwidXNlciI6eyJuYW1lIjoibmFtZTEiLCJpZCI6ImlkMSIsInBpY3R1cmUiOiJodHRwOi8vZXhhbXBsZS5jb20vcGljLnBuZyIsImlwIjoiMTI3LjAuMC4xIiwiZW1haWwiOiJtZUBleGFtcGxlLmNvbSIsImF0dHJzIjp7ImJvb2xhIjp0cnVlLCJzdHJhIjoic3RyYS12YWwifX0sImhhbmRzaGFrZSI6eyJzdGF0ZSI6IjEyMzQ1NiIsImZyb20iOiJmcm9tIiwiaWQiOiJteWlkLTEyMzQ1NiJ9fQ.Ll3uS2jvj_yYZms43_w6zJOdkDR305M4AiFVLXnSd7Y" + testJwtNonAudSign = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X2F1ZF9vbmx5IiwiZXhwIjoyNzg5MTkxODIyLCJqdGkiOiJyYW5kb20gaWQiLCJpc3MiOiJyZW1hcms0MiIsIm5iZiI6MTUyNjg4NDIyMiwidXNlciI6eyJuYW1lIjoibmFtZTEiLCJpZCI6ImlkMSIsInBpY3R1cmUiOiJodHRwOi8vZXhhbXBsZS5jb20vcGljLnBuZyIsImlwIjoiMTI3LjAuMC4xIiwiZW1haWwiOiJtZUBleGFtcGxlLmNvbSIsImF0dHJzIjp7ImJvb2xhIjp0cnVlLCJzdHJhIjoic3RyYS12YWwifX0sImhhbmRzaGFrZSI6eyJzdGF0ZSI6IjEyMzQ1NiIsImZyb20iOiJmcm9tIiwiaWQiOiJteWlkLTEyMzQ1NiJ9fQ.kJc-U970h3j9riUhFLR9vN_YCUQwZ66tjk7zdC9OiUg" +) + +var days31 = time.Hour * 24 * 31 + +const ( + jwtCustomCookieName = "jc1" + jwtCustomHeaderKey = "jh1" + xsrfCustomCookieName = "xc1" + xsrfCustomHeaderKey = "xh1" +) + +func mockKeyStore(aud string) (string, error) { + if aud == "test_aud_only" { + return "audsecret", nil + } + return "xyz 12345", nil +} + +func TestJWT_NewDefault(t *testing.T) { + j := NewService(Opts{}) + assert.Equal(t, defaultJWTCookieName, j.JWTCookieName) + assert.Equal(t, defaultJWTCookieDomain, j.JWTCookieDomain) + assert.Equal(t, defaultJWTHeaderKey, j.JWTHeaderKey) + assert.Equal(t, defaultXSRFCookieName, j.XSRFCookieName) + assert.Equal(t, defaultXSRFHeaderKey, j.XSRFHeaderKey) + assert.Equal(t, defaultIssuer, j.Issuer) +} + +func TestJWT_NewNotDefault(t *testing.T) { + j := NewService(Opts{JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, JWTCookieDomain: "blah.com", + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, Issuer: "i1", + }) + assert.Equal(t, jwtCustomCookieName, j.JWTCookieName) + assert.Equal(t, jwtCustomHeaderKey, j.JWTHeaderKey) + assert.Equal(t, xsrfCustomCookieName, j.XSRFCookieName) + assert.Equal(t, xsrfCustomHeaderKey, j.XSRFHeaderKey) + assert.Equal(t, "i1", j.Issuer) + assert.Equal(t, "blah.com", j.JWTCookieDomain) +} + +func TestJWT_Token(t *testing.T) { + + j := NewService(Opts{ + SecretReader: SecretFunc(mockKeyStore), + SecureCookies: false, + TokenDuration: time.Hour, + CookieDuration: days31, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + }) + + claims := testClaims + res, err := j.Token(claims) + assert.NoError(t, err) + assert.Equal(t, testJwtValid, res) + + j.SecretReader = nil + _, err = j.Token(claims) + assert.EqualError(t, err, "secret reader not defined") + + j.SecretReader = SecretFunc(func(string) (string, error) { return "", fmt.Errorf("err blah") }) + _, err = j.Token(claims) + assert.EqualError(t, err, "can't get secret: err blah") + + j.SecretReader = SecretFunc(mockKeyStore) + j.AudienceReader = AudienceFunc(func() ([]string, error) { return []string{"a1", "aa2"}, nil }) + _, err = j.Token(claims) + assert.EqualError(t, err, `aud rejected: aud "test_sys" not allowed`) + + j.AudienceReader = AudienceFunc(func() ([]string, error) { return []string{"a1", "test_sys", "aa2"}, nil }) + _, err = j.Token(claims) + assert.NoError(t, err, "aud test_sys allowed") + +} + +func TestJWT_Parse(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore)}) + claims, err := j.Parse(testJwtValid) + assert.NoError(t, err) + assert.False(t, j.IsExpired(claims)) + assert.Equal(t, &User{Name: "name1", ID: "id1", Picture: "http://example.com/pic.png", IP: "127.0.0.1", + Email: "me@example.com", Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}}, claims.User) + + claims, err = j.Parse(testJwtExpired) + assert.NoError(t, err) + assert.True(t, j.IsExpired(claims)) + + _, err = j.Parse(testJwtNbf) + assert.EqualError(t, err, "token is not valid yet") + + _, err = j.Parse("bad") + assert.Error(t, err, "bad token") + + _, err = j.Parse(testJwtBadSign) + assert.EqualError(t, err, "can't parse token: signature is invalid") + + _, err = j.Parse(testJwtNoneAlg) + assert.EqualError(t, err, "can't parse token: unexpected signing method: none") + + j = NewService(Opts{ + SecretReader: SecretFunc(func(string) (string, error) { return "bad 12345", nil }), + }) + _, err = j.Parse(testJwtValid) + assert.Error(t, err, "bad token", "valid token parsed with wrong secret") + + j = NewService(Opts{ + SecretReader: SecretFunc(func(string) (string, error) { return "", fmt.Errorf("err blah") }), + }) + _, err = j.Parse(testJwtValid) + assert.EqualError(t, err, "can't get secret: err blah") + +} + +func TestJWT_Set(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, Issuer: "remark42", + JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + DisableIAT: true, + }) + + claims := testClaims + claims.Handshake = nil + + rr := httptest.NewRecorder() + c, err := j.Set(rr, claims) + assert.NoError(t, err) + assert.Equal(t, claims, c) + cookies := rr.Result().Cookies() + t.Log(cookies) + require.Equal(t, 2, len(cookies)) + assert.Equal(t, jwtCustomCookieName, cookies[0].Name) + assert.Equal(t, testJwtValidNoHandshake, cookies[0].Value) + assert.Equal(t, 31*24*3600, cookies[0].MaxAge) + assert.Equal(t, xsrfCustomCookieName, cookies[1].Name) + assert.Equal(t, "random id", cookies[1].Value) + + claims.SessionOnly = true + rr = httptest.NewRecorder() + _, err = j.Set(rr, claims) + assert.NoError(t, err) + cookies = rr.Result().Cookies() + t.Log(cookies) + require.Equal(t, 2, len(cookies)) + assert.Equal(t, jwtCustomCookieName, cookies[0].Name) + assert.Equal(t, testJwtValidSess, cookies[0].Value) + assert.Equal(t, 0, cookies[0].MaxAge) + assert.Equal(t, xsrfCustomCookieName, cookies[1].Name) + assert.Equal(t, "random id", cookies[1].Value) + assert.Equal(t, "", cookies[0].Domain) + + j.DisableIAT = false + rr = httptest.NewRecorder() + _, err = j.Set(rr, claims) + assert.NoError(t, err) + cookies = rr.Result().Cookies() + t.Log(cookies) + require.Equal(t, 2, len(cookies)) + assert.Equal(t, jwtCustomCookieName, cookies[0].Name) + assert.NotEqual(t, testJwtValidSess, cookies[0].Value, "iat changed the token") + assert.Equal(t, "", rr.Result().Header.Get(jwtCustomHeaderKey), "no JWT header set") +} + +func TestJWT_SetWithDomain(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, Issuer: "remark42", + JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, JWTCookieDomain: "example.com", + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + DisableIAT: true, + }) + + claims := testClaims + claims.Handshake = nil + + rr := httptest.NewRecorder() + c, err := j.Set(rr, claims) + assert.NoError(t, err) + assert.Equal(t, claims, c) + cookies := rr.Result().Cookies() + t.Log(cookies) + require.Equal(t, 2, len(cookies)) + assert.Equal(t, jwtCustomCookieName, cookies[0].Name) + assert.Equal(t, "example.com", cookies[0].Domain) + assert.Equal(t, testJwtValidNoHandshake, cookies[0].Value) + assert.Equal(t, 31*24*3600, cookies[0].MaxAge) + assert.Equal(t, xsrfCustomCookieName, cookies[1].Name) + assert.Equal(t, "random id", cookies[1].Value) + +} + +func TestJWT_SendJWTHeader(t *testing.T) { + + j := NewService(Opts{ + SecretReader: SecretFunc(mockKeyStore), + SecureCookies: false, + TokenDuration: time.Hour, + CookieDuration: days31, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + DisableIAT: true, + SendJWTHeader: true, + }) + + rr := httptest.NewRecorder() + _, err := j.Set(rr, testClaims) + assert.NoError(t, err) + cookies := rr.Result().Cookies() + t.Log(cookies) + require.Equal(t, 0, len(cookies), "no cookies set") + assert.Equal(t, testJwtValid, rr.Result().Header.Get("X-JWT")) +} + +func TestJWT_SetProlonged(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, Issuer: "remark42", + JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + }) + + claims := testClaims + claims.Handshake = nil + claims.ExpiresAt = 0 + + rr := httptest.NewRecorder() + _, err := j.Set(rr, claims) + assert.NoError(t, err) + cookies := rr.Result().Cookies() + t.Log(cookies) + assert.Equal(t, jwtCustomCookieName, cookies[0].Name) + + cc, err := j.Parse(cookies[0].Value) + assert.NoError(t, err) + assert.True(t, cc.ExpiresAt > time.Now().Unix()) +} + +func TestJWT_NoIssuer(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, Issuer: "xyz", + JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + }) + + claims := testClaims + claims.Handshake = nil + claims.Issuer = "" + + rr := httptest.NewRecorder() + _, err := j.Set(rr, claims) + assert.NoError(t, err) + cookies := rr.Result().Cookies() + t.Log(cookies) + assert.Equal(t, jwtCustomCookieName, cookies[0].Name) + + cc, err := j.Parse(cookies[0].Value) + assert.NoError(t, err) + assert.Equal(t, "xyz", cc.Issuer) +} + +func TestJWT_GetFromHeader(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, + JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + }) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Add(jwtCustomHeaderKey, testJwtValid) + claims, token, err := j.Get(req) + assert.NoError(t, err) + assert.Equal(t, testJwtValid, token) + assert.False(t, j.IsExpired(claims)) + assert.Equal(t, &User{Name: "name1", ID: "id1", Picture: "http://example.com/pic.png", IP: "127.0.0.1", + Email: "me@example.com", Audience: "test_sys", + Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}}, claims.User) + assert.Equal(t, "remark42", claims.Issuer) + + req = httptest.NewRequest("GET", "/", nil) + req.Header.Add(jwtCustomHeaderKey, testJwtExpired) + _, _, err = j.Get(req) + assert.Error(t, err) + + req = httptest.NewRequest("GET", "/", nil) + req.Header.Add(jwtCustomHeaderKey, "bad bad token") + _, _, err = j.Get(req) + require.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "failed to get token: can't parse token: token contains an invalid number of segments"), err.Error()) +} + +func TestJWT_GetFromQuery(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + }) + + req := httptest.NewRequest("GET", "/blah?token="+testJwtValid, nil) + claims, token, err := j.Get(req) + assert.NoError(t, err) + assert.Equal(t, testJwtValid, token) + assert.False(t, j.IsExpired(claims)) + assert.Equal(t, &User{Name: "name1", ID: "id1", Picture: "http://example.com/pic.png", IP: "127.0.0.1", + Email: "me@example.com", Audience: "test_sys", + Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}}, claims.User) + assert.Equal(t, "remark42", claims.Issuer) + + req = httptest.NewRequest("GET", "/blah?token="+testJwtExpired, nil) + _, _, err = j.Get(req) + assert.Error(t, err) + + req = httptest.NewRequest("GET", "/blah?token=blah", nil) + _, _, err = j.Get(req) + require.NotNil(t, err) + assert.True(t, strings.Contains(err.Error(), "failed to get token: can't parse token: token contains an invalid number of segments"), err.Error()) +} + +func TestJWT_GetFailed(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false}) + req := httptest.NewRequest("GET", "/", nil) + _, _, err := j.Get(req) + assert.Error(t, err, "token cookie was not presented") +} + +func TestJWT_SetAndGetWithCookies(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, + JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + }) + + claims := testClaims + claims.SessionOnly = true + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/valid" { + _, e := j.Set(w, claims) + require.NoError(t, e) + w.WriteHeader(200) + } + })) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/valid") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + + req := httptest.NewRequest("GET", "/valid", nil) + req.AddCookie(resp.Cookies()[0]) + req.Header.Add(xsrfCustomHeaderKey, "random id") + r, _, err := j.Get(req) + assert.NoError(t, err) + assert.Equal(t, &User{Name: "name1", ID: "id1", Picture: "http://example.com/pic.png", IP: "127.0.0.1", + Email: "me@example.com", Audience: "test_sys", + Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}}, r.User) + assert.Equal(t, "remark42", claims.Issuer) + assert.Equal(t, true, claims.SessionOnly) + t.Log(resp.Cookies()) +} + +func TestJWT_SetAndGetWithXsrfMismatch(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, + JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + Issuer: "remark42", + DisableIAT: true, + }) + + claims := testClaims + claims.SessionOnly = true + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/valid" { + _, e := j.Set(w, claims) + require.NoError(t, e) + w.WriteHeader(200) + } + })) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/valid") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + + req := httptest.NewRequest("GET", "/valid", nil) + req.AddCookie(resp.Cookies()[0]) + req.Header.Add(xsrfCustomHeaderKey, "random id wrong") + _, _, err = j.Get(req) + assert.EqualError(t, err, "xsrf mismatch") + + j.DisableXSRF = true + req = httptest.NewRequest("GET", "/valid", nil) + req.AddCookie(resp.Cookies()[0]) + req.Header.Add(xsrfCustomHeaderKey, "random id wrong") + c, _, err := j.Get(req) + require.NoError(t, err, "xsrf mismatch, but ignored") + claims.User.Audience = c.Audience // set aud to user because we don't do the normal Get call + assert.Equal(t, claims, c) +} + +func TestJWT_SetAndGetWithCookiesExpired(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, + JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, + ClaimsUpd: ClaimsUpdFunc(func(claims Claims) Claims { + claims.User.SetStrAttr("stra", "stra-val") + claims.User.SetBoolAttr("boola", true) + return claims + }), + DisableIAT: true, + }) + + claims := testClaims + claims.StandardClaims.ExpiresAt = time.Date(2018, 5, 21, 1, 35, 22, 0, time.Local).Unix() + claims.StandardClaims.NotBefore = time.Date(2018, 5, 21, 1, 30, 22, 0, time.Local).Unix() + claims.SessionOnly = true + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/expired" { + _, e := j.Set(w, claims) + require.NoError(t, e) + w.WriteHeader(200) + } + })) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/expired") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + + req := httptest.NewRequest("GET", "/expired", nil) + req.AddCookie(resp.Cookies()[0]) + req.Header.Add(xsrfCustomHeaderKey, "random id") + r, _, err := j.Get(req) + assert.NoError(t, err) + assert.True(t, j.IsExpired(r)) +} + +func TestJWT_Reset(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + JWTCookieName: jwtCustomCookieName, JWTHeaderKey: jwtCustomHeaderKey, + XSRFCookieName: xsrfCustomCookieName, XSRFHeaderKey: xsrfCustomHeaderKey, + TokenDuration: time.Hour, CookieDuration: days31, + }) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/valid" { + j.Reset(w) + w.WriteHeader(200) + } + })) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/valid") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + + assert.Equal(t, "jc1=; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0", resp.Header.Get("Set-Cookie")) + assert.Equal(t, "0", resp.Header.Get("Content-Length")) +} + +func TestJWT_Validator(t *testing.T) { + ch := ValidatorFunc(func(token string, claims Claims) bool { + return token == "good" + }) + assert.True(t, ch.Validate("good", Claims{})) + assert.False(t, ch.Validate("bad", Claims{})) +} + +func TestClaims_String(t *testing.T) { + s := testClaims.String() + assert.True(t, strings.Contains(s, `"aud":"test_sys"`)) + assert.True(t, strings.Contains(s, `"exp":2789191822`)) + assert.True(t, strings.Contains(s, `"jti":"random id"`)) + assert.True(t, strings.Contains(s, `"iss":"remark42"`)) + assert.True(t, strings.Contains(s, `"nbf":1526884222`)) + assert.True(t, strings.Contains(s, `"user":`)) + assert.True(t, strings.Contains(s, `"name":"name1"`)) + assert.True(t, strings.Contains(s, `"picture":"http://example.com/pic.png"`)) +} + +func TestAudience(t *testing.T) { + + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, + }) + + c := Claims{ + StandardClaims: jwt.StandardClaims{ + Audience: "au1", + Issuer: "test iss", + }, + } + + assert.NoError(t, j.checkAuds(&c, nil), "any aud allowed") + + err := j.checkAuds(&c, AudienceFunc(func() ([]string, error) { return []string{"xxx", "yyy"}, nil })) + assert.EqualError(t, err, `aud "au1" not allowed`) + + err = j.checkAuds(&c, AudienceFunc(func() ([]string, error) { return []string{"xxx", "yyy", "au1"}, nil })) + assert.NoError(t, err, `au1 allowed`) +} + +func TestAudReader(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, AudSecrets: true, + }) + + a, err := j.aud(testJwtValid) + require.NoError(t, err) + assert.Equal(t, "test_sys", a) + + a, err = j.aud(testJwtBadSign) + require.NoError(t, err) + assert.Equal(t, "test_sys", a) + + _, err = j.aud(testJwtNoAud) + assert.EqualError(t, err, "empty aud") + + _, err = j.aud("blah bad bad") + assert.EqualError(t, err, "can't pre-parse token: token contains an invalid number of segments") +} + +func TestParseWithAud(t *testing.T) { + j := NewService(Opts{SecretReader: SecretFunc(mockKeyStore), SecureCookies: false, + TokenDuration: time.Hour, CookieDuration: days31, AudSecrets: true, + }) + + claims, err := j.Parse(testJwtValid) + assert.NoError(t, err) + assert.False(t, j.IsExpired(claims)) + assert.Equal(t, &User{Name: "name1", ID: "id1", Picture: "http://example.com/pic.png", IP: "127.0.0.1", + Email: "me@example.com", Attributes: map[string]interface{}{"boola": true, "stra": "stra-val"}}, claims.User) + + claims, err = j.Parse(testJwtValidAud) + assert.NoError(t, err) + assert.Equal(t, "test_aud_only", claims.Audience) + + claims, err = j.Parse(testJwtNonAudSign) + assert.EqualError(t, err, "can't parse token: signature is invalid") +} + +var testClaims = Claims{ + StandardClaims: jwt.StandardClaims{ + Id: "random id", + Issuer: "remark42", + Audience: "test_sys", + ExpiresAt: time.Date(2058, 5, 21, 7, 30, 22, 0, time.UTC).Unix(), + NotBefore: time.Date(2018, 5, 21, 6, 30, 22, 0, time.UTC).Unix(), + }, + + User: &User{ + ID: "id1", + Name: "name1", + IP: "127.0.0.1", + Email: "me@example.com", + Picture: "http://example.com/pic.png", + }, + + Handshake: &Handshake{ + From: "from", + State: "123456", + ID: "myid-123456", + }, +} diff --git a/v2/token/user.go b/v2/token/user.go new file mode 100644 index 00000000..cc9875f7 --- /dev/null +++ b/v2/token/user.go @@ -0,0 +1,169 @@ +package token + +import ( + "context" + "encoding/hex" + "fmt" + "hash" + "hash/crc64" + "io" + "net/http" + "regexp" +) + +var reValidSha = regexp.MustCompile("^[a-fA-F0-9]{40}$") +var reValidCrc64 = regexp.MustCompile("^[a-fA-F0-9]{16}$") + +const ( + adminAttr = "admin" // predefined attribute key for bool isAdmin status + paidSubscriberAttr = "is_paid_sub" // predefined attribute key for bool paid subscriptions status +) + +// User is the basic part of oauth data provided by service +type User struct { + // set by service + Name string `json:"name"` + ID string `json:"id"` + Picture string `json:"picture"` + Audience string `json:"aud,omitempty"` + + // set by client + IP string `json:"ip,omitempty"` + Email string `json:"email,omitempty"` + Attributes map[string]interface{} `json:"attrs,omitempty"` + Role string `json:"role,omitempty"` +} + +// SetBoolAttr sets boolean attribute +func (u *User) SetBoolAttr(key string, val bool) { + if u.Attributes == nil { + u.Attributes = map[string]interface{}{} + } + u.Attributes[key] = val +} + +// SetStrAttr sets string attribute +func (u *User) SetStrAttr(key, val string) { + if u.Attributes == nil { + u.Attributes = map[string]interface{}{} + } + u.Attributes[key] = val +} + +// BoolAttr gets boolean attribute +func (u *User) BoolAttr(key string) bool { + r, ok := u.Attributes[key].(bool) + if !ok { + return false + } + return r +} + +// StrAttr gets string attribute +func (u *User) StrAttr(key string) string { + r, ok := u.Attributes[key].(string) + if !ok { + return "" + } + return r +} + +// SetAdmin is a shortcut to set "admin" attribute +func (u *User) SetAdmin(val bool) { + u.SetBoolAttr(adminAttr, val) +} + +// IsAdmin is a shortcut to get admin attribute +func (u *User) IsAdmin() bool { + return u.BoolAttr(adminAttr) +} + +// SetPaidSub is a shortcut to set "paidSubscriberAttr" attribute +func (u *User) SetPaidSub(val bool) { + u.SetBoolAttr(paidSubscriberAttr, val) +} + +// IsPaidSub is a shortcut to get "paidSubscriberAttr" attribute +func (u *User) IsPaidSub() bool { + return u.BoolAttr(paidSubscriberAttr) +} + +// SliceAttr gets slice attribute +func (u *User) SliceAttr(key string) []string { + r, ok := u.Attributes[key].([]string) + if !ok { + return []string{} + } + return r +} + +// SetSliceAttr sets slice attribute for given key +func (u *User) SetSliceAttr(key string, val []string) { + if u.Attributes == nil { + u.Attributes = map[string]interface{}{} + } + u.Attributes[key] = val +} + +// HashID tries to hash val with hash.Hash and fallback to crc if needed +func HashID(h hash.Hash, val string) string { + + if reValidSha.MatchString(val) { + return val // already hashed or empty + } + + if _, err := io.WriteString(h, val); err != nil { + // fail back to crc64 + if val == "" { + val = "!empty string!" + } + if reValidCrc64.MatchString(val) { + return val // already crced + } + return fmt.Sprintf("%x", crc64.Checksum([]byte(val), crc64.MakeTable(crc64.ECMA))) + } + return hex.EncodeToString(h.Sum(nil)) +} + +type contextKey string + +// MustGetUserInfo gets user info and panics if can't extract it from the request. +// should be called from authenticated controllers only +func MustGetUserInfo(r *http.Request) User { + user, err := GetUserInfo(r) + if err != nil { + panic(err) + } + return user +} + +// GetUserInfo returns user info from request context +func GetUserInfo(r *http.Request) (user User, err error) { + + ctx := r.Context() + if ctx == nil { + return User{}, fmt.Errorf("no info about user") + } + if u, ok := ctx.Value(contextKey("user")).(User); ok { + return u, nil + } + + return User{}, fmt.Errorf("user can't be parsed") +} + +// SetUserInfo sets user into request context +func SetUserInfo(r *http.Request, user User) *http.Request { + ctx := r.Context() + ctx = context.WithValue(ctx, contextKey("user"), user) + return r.WithContext(ctx) +} + +// SetRole sets user role for RBAC +func (u *User) SetRole(role string) { + u.Role = role +} + +// GetRole gets user role +func (u *User) GetRole() string { + return u.Role +} diff --git a/v2/token/user_test.go b/v2/token/user_test.go new file mode 100644 index 00000000..ee063214 --- /dev/null +++ b/v2/token/user_test.go @@ -0,0 +1,129 @@ +package token + +import ( + "crypto/sha1" //nolint + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUser_HashID(t *testing.T) { + tbl := []struct { + id string + hash string + }{ + {"myid", "6e34471f84557e1713012d64a7477c71bfdac631"}, + {"", "da39a3ee5e6b4b0d3255bfef95601890afd80709"}, + {"blah blah", "135a1e01bae742c4a576b20fd41a683f6483ca43"}, + {"da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709"}, + } + + for i, tt := range tbl { + hh := sha1.New() + assert.Equal(t, tt.hash, HashID(hh, tt.id), "case #%d", i) + } +} + +type mockBadHasher struct{} + +func (m *mockBadHasher) Write([]byte) (n int, err error) { return 0, fmt.Errorf("err") } +func (m *mockBadHasher) Sum([]byte) []byte { return nil } +func (m *mockBadHasher) Reset() {} +func (m *mockBadHasher) Size() int { return 0 } +func (m *mockBadHasher) BlockSize() int { return 0 } + +func TestUser_HashIDWithCRC(t *testing.T) { + tbl := []struct { + id string + hash string + }{ + {"myid", "e337514486e387ed"}, + {"", "914cd8098b8a2128"}, + {"blah blah", "a9d6c06bfd811649"}, + {"a9d6c06bfd811649", "a9d6c06bfd811649"}, + } + + for i, tt := range tbl { + hh := &mockBadHasher{} + assert.Equal(t, tt.hash, HashID(hh, tt.id), "case #%d", i) + } +} + +func TestUser_Attrs(t *testing.T) { + u := User{Name: "test", IP: "127.0.0.1"} + + u.SetBoolAttr("k1", true) + v := u.BoolAttr("k1") + assert.True(t, v) + + u.SetBoolAttr("k1", false) + v = u.BoolAttr("k1") + assert.False(t, v) + err := u.StrAttr("k1") + assert.NotNil(t, err) + + u.SetStrAttr("k2", "v2") + vs := u.StrAttr("k2") + assert.Equal(t, "v2", vs) + + u.SetStrAttr("k2", "v22") + vs = u.StrAttr("k2") + assert.Equal(t, "v22", vs) + + vb := u.BoolAttr("k2") + assert.False(t, vb) + + u.SetSliceAttr("ks", []string{"ss1", "ss2", "blah"}) + assert.Equal(t, []string{"ss1", "ss2", "blah"}, u.SliceAttr("ks")) + assert.Equal(t, []string{}, u.SliceAttr("k2"), "not a slice") +} + +func TestUser_Admin(t *testing.T) { + u := User{Name: "test", IP: "127.0.0.1"} + assert.False(t, u.IsAdmin()) + u.SetAdmin(true) + assert.True(t, u.IsAdmin()) + u.SetAdmin(false) + assert.False(t, u.IsAdmin()) +} + +func TestUser_PaidSubscriber(t *testing.T) { + u := User{Name: "test"} + assert.False(t, u.IsPaidSub()) + u.SetPaidSub(true) + assert.True(t, u.IsPaidSub()) + u.SetPaidSub(false) + assert.False(t, u.IsPaidSub()) +} + +func TestUser_GetUserInfo(t *testing.T) { + r, err := http.NewRequest("GET", "http://blah.com", http.NoBody) + assert.NoError(t, err) + _, err = GetUserInfo(r) + assert.EqualError(t, err, "user can't be parsed") + + r = SetUserInfo(r, User{Name: "test", ID: "id"}) + u, err := GetUserInfo(r) + assert.NoError(t, err) + assert.Equal(t, User{Name: "test", ID: "id"}, u) +} + +func TestUser_MustGetUserInfo(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Log("recovered from panic") + } + }() + + r, err := http.NewRequest("GET", "http://blah.com", http.NoBody) + assert.NoError(t, err) + _ = MustGetUserInfo(r) + assert.Fail(t, "should panic") + + r = SetUserInfo(r, User{Name: "test", ID: "id"}) + u := MustGetUserInfo(r) + assert.NoError(t, err) + assert.Equal(t, User{Name: "test", ID: "id"}, u) +} From 4e32f30d2628c11f9104b1467c8d82cf968228dd Mon Sep 17 00:00:00 2001 From: Dmitry Verkhoturov Date: Thu, 4 Apr 2024 11:01:31 +0200 Subject: [PATCH 2/3] switch v2 directory to have v2 naming, add GitHub CI setup --- .github/workflows/ci-v2.yml | 60 +++++++++++++++++++++++++ .github/workflows/ci.yml | 8 +++- v2/auth.go | 10 ++--- v2/auth_test.go | 8 ++-- v2/avatar/avatar.go | 4 +- v2/avatar/avatar_test.go | 4 +- v2/avatar/store.go | 2 +- v2/go.mod | 2 +- v2/go.sum | 72 ++++++++++++++++++++++++++++++ v2/middleware/auth.go | 6 +-- v2/middleware/auth_test.go | 6 +-- v2/middleware/user_updater.go | 2 +- v2/middleware/user_updater_test.go | 2 +- v2/provider/apple.go | 6 +-- v2/provider/apple_test.go | 4 +- v2/provider/custom_server.go | 6 +-- v2/provider/custom_server_test.go | 4 +- v2/provider/dev_provider.go | 6 +-- v2/provider/dev_provider_test.go | 4 +- v2/provider/direct.go | 4 +- v2/provider/direct_test.go | 4 +- v2/provider/oauth1.go | 4 +- v2/provider/oauth1_test.go | 4 +- v2/provider/oauth2.go | 4 +- v2/provider/oauth2_test.go | 4 +- v2/provider/providers.go | 2 +- v2/provider/providers_test.go | 2 +- v2/provider/sender/email.go | 2 +- v2/provider/sender/email_test.go | 2 +- v2/provider/service.go | 2 +- v2/provider/service_test.go | 2 +- v2/provider/telegram.go | 4 +- v2/provider/telegram_test.go | 2 +- v2/provider/verify.go | 6 +-- v2/provider/verify_test.go | 4 +- 35 files changed, 203 insertions(+), 65 deletions(-) create mode 100644 .github/workflows/ci-v2.yml diff --git a/.github/workflows/ci-v2.yml b/.github/workflows/ci-v2.yml new file mode 100644 index 00000000..f82400ff --- /dev/null +++ b/.github/workflows/ci-v2.yml @@ -0,0 +1,60 @@ +name: build-v2 + +on: + push: + branches: + tags: + paths: + - ".github/workflows/ci-v2.yml" + - "v2/**" + pull_request: + paths: + - ".github/workflows/ci-v2.yml" + - "v2/**" + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: disable and stop mono-xsp4.service (wtf?) + run: | + sudo systemctl stop mono-xsp4.service || true + sudo systemctl disable mono-xsp4.service || true + + - name: set up go + uses: actions/setup-go@v5 + with: + go-version: "1.21" + id: go + + - name: launch mongodb + uses: wbari/start-mongoDB@v0.2 + with: + mongoDBVersion: "6.0" + + - name: checkout + uses: actions/checkout@v4 + + - name: build and test + run: | + go test -timeout=60s -v -race -p 1 -covermode=atomic -coverprofile=$GITHUB_WORKSPACE/profile.cov ./... + go build -race + working-directory: v2 + env: + TZ: "America/Chicago" + ENABLE_MONGO_TESTS: "true" + + - name: golangci-lint + uses: golangci/golangci-lint-action@v4 + with: + version: latest + args: --config ../.golangci.yml + working-directory: v2 + + - name: submit coverage + run: | + go install github.com/mattn/goveralls@latest + goveralls -service="github" -coverprofile=$GITHUB_WORKSPACE/profile.cov + env: + COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index faa3c00b..c103bfd5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,13 @@ on: push: branches: tags: + paths-ignore: + - ".github/workflows/ci-v2.yml" + - "v2/**" pull_request: + paths-ignore: + - ".github/workflows/ci-v2.yml" + - "v2/**" jobs: build: @@ -16,7 +22,7 @@ jobs: sudo systemctl stop mono-xsp4.service || true sudo systemctl disable mono-xsp4.service || true - - name: set up go 1.21 + - name: set up go uses: actions/setup-go@v5 with: go-version: "1.21" diff --git a/v2/auth.go b/v2/auth.go index 91faa7e4..0760ca62 100644 --- a/v2/auth.go +++ b/v2/auth.go @@ -9,11 +9,11 @@ import ( "github.com/go-pkgz/rest" - "github.com/go-pkgz/auth/avatar" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/middleware" - "github.com/go-pkgz/auth/provider" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/avatar" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/middleware" + "github.com/go-pkgz/auth/v2/provider" + "github.com/go-pkgz/auth/v2/token" ) // Client is a type of auth client diff --git a/v2/auth_test.go b/v2/auth_test.go index 15b9efbe..29432256 100644 --- a/v2/auth_test.go +++ b/v2/auth_test.go @@ -17,10 +17,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/avatar" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/provider" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/avatar" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/provider" + "github.com/go-pkgz/auth/v2/token" ) func TestNewService(t *testing.T) { diff --git a/v2/avatar/avatar.go b/v2/avatar/avatar.go index ed691fa8..3e9b5c2a 100644 --- a/v2/avatar/avatar.go +++ b/v2/avatar/avatar.go @@ -19,8 +19,8 @@ import ( "github.com/rrivera/identicon" "golang.org/x/image/draw" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) // Proxy provides http handler for avatars from avatar.Store diff --git a/v2/avatar/avatar_test.go b/v2/avatar/avatar_test.go index a5e0bc2b..937f8866 100644 --- a/v2/avatar/avatar_test.go +++ b/v2/avatar/avatar_test.go @@ -17,8 +17,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) func TestAvatar_Put(t *testing.T) { diff --git a/v2/avatar/store.go b/v2/avatar/store.go index ab51a1b8..cd2a3815 100644 --- a/v2/avatar/store.go +++ b/v2/avatar/store.go @@ -19,7 +19,7 @@ import ( "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/token" ) // imgSfx for avatars diff --git a/v2/go.mod b/v2/go.mod index 2b6075d4..a15c84ce 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -1,4 +1,4 @@ -module github.com/go-pkgz/auth +module github.com/go-pkgz/auth/v2 go 1.21 diff --git a/v2/go.sum b/v2/go.sum index 97892557..4d432554 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -1,22 +1,35 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go/compute v1.25.1 h1:ZRpHJedLtTpKgr3RV1Fx23NuaAEN1Zfx9hw1u4aJdjU= cloud.google.com/go/compute v1.25.1/go.mod h1:oopOIR53ly6viBYxaDhBfJwzUAxf1zE//uf3IB011ls= +cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= +github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dghubble/oauth1 v0.7.3 h1:EkEM/zMDMp3zOsX2DC/ZQ2vnEX3ELK0/l9kb+vs4ptE= github.com/dghubble/oauth1 v0.7.3/go.mod h1:oxTe+az9NSMIucDPDCCtzJGsPhciJV33xocHfcR2sVY= github.com/fasthttp-contrib/websocket v0.0.0-20160511215533-1f3b11f56072/go.mod h1:duJ4Jxv5lDcvg4QuQr0oowTf7dz4/CR8NtyCooz9HL8= +github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/gavv/httpexpect v2.0.0+incompatible h1:1X9kcRshkSKEjNJJxX9Y9mQ5BRfbxU5kORdjhlA1yX8= github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= +github.com/go-oauth2/oauth2/v4 v4.5.2 h1:CuZhD3lhGuI6aNLyUbRHXsgG2RwGRBOuCBfd4WQKqBQ= github.com/go-oauth2/oauth2/v4 v4.5.2/go.mod h1:wk/2uLImWIa9VVQDgxz99H2GDbhmfi/9/Xr+GvkSUSQ= +github.com/go-pkgz/email v0.5.0 h1:fdtMDGJ8NwyBACLR0LYHaCIK/OeUwZHMhH7Q0+oty9U= github.com/go-pkgz/email v0.5.0/go.mod h1:BdxglsQnymzhfdbnncEE72a6DrucZHy6I+42LK2jLEc= +github.com/go-pkgz/repeater v1.1.3 h1:q6+JQF14ESSy28Dd7F+wRelY4F+41HJ0LEy/szNnMiE= github.com/go-pkgz/repeater v1.1.3/go.mod h1:hVTavuO5x3Gxnu8zW7d6sQBfAneKV8X2FjU48kGfpKw= +github.com/go-pkgz/rest v1.19.0 h1:FNMi5QX5dDIkuC+/e0r+CWsTuOTwUiWMRSA16Ou+9+A= github.com/go-pkgz/rest v1.19.0/go.mod h1:Po+W6zQzpMPP6XDGLdAN2aW7UKk1IyrLSb48Lp1N3oQ= github.com/go-session/session v3.1.2+incompatible/go.mod h1:8B3iivBQjrz/JtC68Np2T1yBBLxTan3mn/3OM0CyRt0= github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -28,29 +41,44 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk= github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs= github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -58,57 +86,91 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108 github.com/onsi/ginkgo v1.13.0/go.mod h1:+REjRxOmWfHCjfv9TTWB1jD1Frx4XydAD3zm1lskyM0= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rrivera/identicon v0.0.0-20240116195454-d5ba35832c0d h1:l3+2LWCbVxn5itfvXAfH9n4YL9jh8l1g5zcncbIc1cs= github.com/rrivera/identicon v0.0.0-20240116195454-d5ba35832c0d/go.mod h1:TbpErkob6SY7cyozRVSGoB3OlO2qOAgVN8O3KAJ4fMI= github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw= +github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/assert v0.1.0 h1:aWcKyRBUAdLoVebxo95N7+YZVTFF/ASTr7BN4sLP6XI= +github.com/tidwall/assert v0.1.0/go.mod h1:QLYtGyeqse53vuELQheYl9dngGCJQ+mTtlxcktb+Kj8= github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tidwall/buntdb v1.1.2/go.mod h1:xAzi36Hir4FarpSHyfuZ6JzPJdjRZ8QlLZSntE2mqlI= +github.com/tidwall/buntdb v1.3.0 h1:gdhWO+/YwoB2qZMeAU9JcWWsHSYU3OvcieYgFRS0zwA= github.com/tidwall/buntdb v1.3.0/go.mod h1:lZZrZUWzlyDJKlLQ6DKAy53LnG7m5kHyrEHvvcDmBpU= github.com/tidwall/gjson v1.3.4/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/grect v0.0.0-20161006141115-ba9a043346eb/go.mod h1:lKYYLFIr9OIgdgrtgkZ9zgRxRdvPYsExnYBsEAd8W5M= +github.com/tidwall/grect v0.1.4 h1:dA3oIgNgWdSspFzn1kS4S/RDpZFLrIxAZOdJKjYapOg= github.com/tidwall/grect v0.1.4/go.mod h1:9FBsaYRaR0Tcy4UwefBX/UDcDcDy9V5jUcxHzv2jd5Q= +github.com/tidwall/lotsa v1.0.2 h1:dNVBH5MErdaQ/xd9s769R31/n2dXavsQ0Yf4TMEHHw8= +github.com/tidwall/lotsa v1.0.2/go.mod h1:X6NiU+4yHA3fE3Puvpnn1XMDrFZrE9JO2/w+UMuqgR8= github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/rtred v0.1.2 h1:exmoQtOLvDoO8ud++6LwVsAMTu0KPzLTUrMln8u1yu8= github.com/tidwall/rtred v0.1.2/go.mod h1:hd69WNXQ5RP9vHd7dqekAz+RIdtfBogmglkZSRxCHFQ= github.com/tidwall/rtree v0.0.0-20180113144539-6cd427091e0e/go.mod h1:/h+UnNGt0IhNNJLkGikcdcJqm66zGD/uJGMRxK/9+Ao= github.com/tidwall/tinyqueue v0.0.0-20180302190814-1e39f5511563/go.mod h1:mLqSmt7Dv/CNneF2wfcChfN1rvapyQr01LGKnKex0DQ= +github.com/tidwall/tinyqueue v0.1.1 h1:SpNEvEggbpyN5DIReaJ2/1ndroY8iyEGxPYxoSaymYE= github.com/tidwall/tinyqueue v0.1.1/go.mod h1:O/QNHwrnjqr6IHItYrzoHAKYhBkLI67Q096fQP5zMYw= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4= github.com/valyala/fasthttp v1.34.0/go.mod h1:epZA5N+7pY6ZaEKRmstzOuYJx9HI8DI1oaCGZpdH4h0= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0 h1:6fRhSjgLCkTD3JnJxvaJ4Sj+TYblw757bqYgZaOq5ZY= github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= +github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4= +github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA= github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= +github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= github.com/yudai/pp v2.0.1+incompatible/go.mod h1:PuxR/8QJ7cyCkFp/aUDS+JY727OFEZkTdatxwunjIkc= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI= go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= +go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -123,12 +185,16 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI= golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -147,6 +213,7 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -156,6 +223,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= @@ -165,6 +233,7 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= @@ -174,8 +243,10 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= @@ -184,4 +255,5 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/v2/middleware/auth.go b/v2/middleware/auth.go index 64d6c7ce..c1b297bb 100644 --- a/v2/middleware/auth.go +++ b/v2/middleware/auth.go @@ -10,9 +10,9 @@ import ( "net/http" "strings" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/provider" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/provider" + "github.com/go-pkgz/auth/v2/token" ) // Authenticator is top level auth object providing middlewares diff --git a/v2/middleware/auth_test.go b/v2/middleware/auth_test.go index 18208aee..2a438f15 100644 --- a/v2/middleware/auth_test.go +++ b/v2/middleware/auth_test.go @@ -15,9 +15,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/provider" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/provider" + "github.com/go-pkgz/auth/v2/token" ) var testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJwcm92aWRlcjFfaWQxIiwicGljdHVyZSI6Imh0dHA6Ly9leGFtcGxlLmNvbS9waWMucG5nIiwiaXAiOiIxMjcuMC4wLjEiLCJlbWFpbCI6Im1lQGV4YW1wbGUuY29tIiwiYXR0cnMiOnsiYm9vbGEiOnRydWUsInN0cmEiOiJzdHJhLXZhbCJ9fX0.orBYt_pVA4uvCCw0JMQLla3DA0mpjRTl_U9vT_wtI30" diff --git a/v2/middleware/user_updater.go b/v2/middleware/user_updater.go index 34dc5dcd..1103703b 100644 --- a/v2/middleware/user_updater.go +++ b/v2/middleware/user_updater.go @@ -3,7 +3,7 @@ package middleware import ( "net/http" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/token" ) // UserUpdater defines interface adding extras or modifying UserInfo in request context diff --git a/v2/middleware/user_updater_test.go b/v2/middleware/user_updater_test.go index 3797ab81..324490c7 100644 --- a/v2/middleware/user_updater_test.go +++ b/v2/middleware/user_updater_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/token" ) func TestUserUpdate(t *testing.T) { diff --git a/v2/provider/apple.go b/v2/provider/apple.go index 2663cf11..f83168ab 100644 --- a/v2/provider/apple.go +++ b/v2/provider/apple.go @@ -26,8 +26,8 @@ import ( "github.com/go-pkgz/rest" "github.com/golang-jwt/jwt" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) const ( @@ -41,7 +41,7 @@ const ( appleRequestContentType = "application/x-www-form-urlencoded" // UserAgent required to every request to Apple REST API - defaultUserAgent = "github.com/go-pkgz/auth" + defaultUserAgent = "github.com/go-pkgz/auth/v2" // AcceptJSONHeader is the content to accept from response AcceptJSONHeader = "application/json" diff --git a/v2/provider/apple_test.go b/v2/provider/apple_test.go index f448ae77..1b47deb1 100644 --- a/v2/provider/apple_test.go +++ b/v2/provider/apple_test.go @@ -24,8 +24,8 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) type customLoader struct{} // implement custom private key loader interface diff --git a/v2/provider/custom_server.go b/v2/provider/custom_server.go index 0ca145e3..d77a5941 100644 --- a/v2/provider/custom_server.go +++ b/v2/provider/custom_server.go @@ -15,9 +15,9 @@ import ( goauth2 "github.com/go-oauth2/oauth2/v4/server" "golang.org/x/oauth2" - "github.com/go-pkgz/auth/avatar" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/avatar" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) // CustomHandlerOpt are options to initialize a handler for oauth2 server diff --git a/v2/provider/custom_server_test.go b/v2/provider/custom_server_test.go index 3083920f..945e7728 100644 --- a/v2/provider/custom_server_test.go +++ b/v2/provider/custom_server_test.go @@ -22,8 +22,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) func TestCustomProvider(t *testing.T) { diff --git a/v2/provider/dev_provider.go b/v2/provider/dev_provider.go index e93bb0e6..136877a2 100644 --- a/v2/provider/dev_provider.go +++ b/v2/provider/dev_provider.go @@ -11,9 +11,9 @@ import ( "golang.org/x/oauth2" - "github.com/go-pkgz/auth/avatar" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/avatar" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) const ( diff --git a/v2/provider/dev_provider_test.go b/v2/provider/dev_provider_test.go index f1fb8f88..1964af4c 100644 --- a/v2/provider/dev_provider_test.go +++ b/v2/provider/dev_provider_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) func TestDevProvider(t *testing.T) { diff --git a/v2/provider/direct.go b/v2/provider/direct.go index 742ebd5a..19a965e2 100644 --- a/v2/provider/direct.go +++ b/v2/provider/direct.go @@ -11,8 +11,8 @@ import ( "github.com/go-pkgz/rest" "github.com/golang-jwt/jwt" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) const ( diff --git a/v2/provider/direct_test.go b/v2/provider/direct_test.go index 334ceeee..0e8fcdf1 100644 --- a/v2/provider/direct_test.go +++ b/v2/provider/direct_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) func TestDirect_LoginHandler(t *testing.T) { diff --git a/v2/provider/oauth1.go b/v2/provider/oauth1.go index 4aec7f56..b00e1e1c 100644 --- a/v2/provider/oauth1.go +++ b/v2/provider/oauth1.go @@ -12,8 +12,8 @@ import ( "github.com/go-pkgz/rest" "github.com/golang-jwt/jwt" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) // Oauth1Handler implements /login, /callback and /logout handlers for oauth1 flow diff --git a/v2/provider/oauth1_test.go b/v2/provider/oauth1_test.go index 7a8d1182..2565869f 100644 --- a/v2/provider/oauth1_test.go +++ b/v2/provider/oauth1_test.go @@ -15,8 +15,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) const ( diff --git a/v2/provider/oauth2.go b/v2/provider/oauth2.go index 6161357e..23cb605c 100644 --- a/v2/provider/oauth2.go +++ b/v2/provider/oauth2.go @@ -13,8 +13,8 @@ import ( "github.com/golang-jwt/jwt" "golang.org/x/oauth2" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) // Oauth2Handler implements /login, /callback and /logout handlers from aouth2 flow diff --git a/v2/provider/oauth2_test.go b/v2/provider/oauth2_test.go index 26a51a8d..5448564b 100644 --- a/v2/provider/oauth2_test.go +++ b/v2/provider/oauth2_test.go @@ -15,8 +15,8 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) var testJwtValid = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJ0ZXN0X3N5cyIsImV4cCI6Mjc4OTE5MTgyMiwianRpIjoicmFuZG9tIGlkIiwiaXNzIjoicmVtYXJrNDIiLCJuYmYiOjE1MjY4ODQyMjIsInVzZXIiOnsibmFtZSI6Im5hbWUxIiwiaWQiOiJpZDEiLCJwaWN0dXJlIjoiaHR0cDovL2V4YW1wbGUuY29tL3BpYy5wbmciLCJpcCI6IjEyNy4wLjAuMSIsImVtYWlsIjoibWVAZXhhbXBsZS5jb20iLCJhdHRycyI6eyJib29sYSI6dHJ1ZSwic3RyYSI6InN0cmEtdmFsIn19fQ.NN7TK-IbzpNgHMtld9-7BDypMGDZdMpwCmUMSfd31Zk" diff --git a/v2/provider/providers.go b/v2/provider/providers.go index 4487b4cb..5f8758e0 100644 --- a/v2/provider/providers.go +++ b/v2/provider/providers.go @@ -8,7 +8,7 @@ import ( "github.com/dghubble/oauth1" "github.com/dghubble/oauth1/twitter" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/token" "golang.org/x/oauth2" "golang.org/x/oauth2/facebook" "golang.org/x/oauth2/github" diff --git a/v2/provider/providers_test.go b/v2/provider/providers_test.go index 88ea28ce..7d3b8fed 100644 --- a/v2/provider/providers_test.go +++ b/v2/provider/providers_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/token" ) func TestProviders_NewGoogle(t *testing.T) { diff --git a/v2/provider/sender/email.go b/v2/provider/sender/email.go index 2c773702..b70a71d4 100644 --- a/v2/provider/sender/email.go +++ b/v2/provider/sender/email.go @@ -4,7 +4,7 @@ package sender import ( "time" - "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/v2/logger" "github.com/go-pkgz/email" ) diff --git a/v2/provider/sender/email_test.go b/v2/provider/sender/email_test.go index 1ef73331..536139be 100644 --- a/v2/provider/sender/email_test.go +++ b/v2/provider/sender/email_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/go-pkgz/auth/logger" + "github.com/go-pkgz/auth/v2/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/v2/provider/service.go b/v2/provider/service.go index 953d4167..f085b5f2 100644 --- a/v2/provider/service.go +++ b/v2/provider/service.go @@ -7,7 +7,7 @@ import ( "net/http" "strings" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/token" ) const ( diff --git a/v2/provider/service_test.go b/v2/provider/service_test.go index 7ad14d72..3c1ed38f 100644 --- a/v2/provider/service_test.go +++ b/v2/provider/service_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/token" ) func TestHandler(t *testing.T) { diff --git a/v2/provider/telegram.go b/v2/provider/telegram.go index 177c1c02..8ecc2ef9 100644 --- a/v2/provider/telegram.go +++ b/v2/provider/telegram.go @@ -19,8 +19,8 @@ import ( "github.com/go-pkgz/rest" "github.com/golang-jwt/jwt" - "github.com/go-pkgz/auth/logger" - authtoken "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + authtoken "github.com/go-pkgz/auth/v2/token" ) // TelegramHandler implements login via telegram diff --git a/v2/provider/telegram_test.go b/v2/provider/telegram_test.go index 5121c7cd..cadaf3fd 100644 --- a/v2/provider/telegram_test.go +++ b/v2/provider/telegram_test.go @@ -14,7 +14,7 @@ import ( "github.com/stretchr/testify/assert" - authtoken "github.com/go-pkgz/auth/token" + authtoken "github.com/go-pkgz/auth/v2/token" ) // same across all tests diff --git a/v2/provider/verify.go b/v2/provider/verify.go index 8b0a03dd..a1376af3 100644 --- a/v2/provider/verify.go +++ b/v2/provider/verify.go @@ -12,9 +12,9 @@ import ( "github.com/go-pkgz/rest" "github.com/golang-jwt/jwt" - "github.com/go-pkgz/auth/avatar" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/avatar" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) // VerifyHandler implements non-oauth2 provider authorizing users with some confirmation. diff --git a/v2/provider/verify_test.go b/v2/provider/verify_test.go index 291c8928..2c298ed6 100644 --- a/v2/provider/verify_test.go +++ b/v2/provider/verify_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/go-pkgz/auth/logger" - "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/auth/v2/logger" + "github.com/go-pkgz/auth/v2/token" ) // nolint From 7321017f6d988f7e80f55a12536a478569c255f1 Mon Sep 17 00:00:00 2001 From: Dmitry Verkhoturov Date: Thu, 4 Apr 2024 11:06:42 +0200 Subject: [PATCH 3/3] make RefreshCache types explicit instead of interface{} --- v2/middleware/auth.go | 6 +++--- v2/middleware/auth_test.go | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/v2/middleware/auth.go b/v2/middleware/auth.go index c1b297bb..c62d1686 100644 --- a/v2/middleware/auth.go +++ b/v2/middleware/auth.go @@ -28,8 +28,8 @@ type Authenticator struct { // RefreshCache defines interface storing and retrieving refreshed tokens type RefreshCache interface { - Get(key interface{}) (value interface{}, ok bool) - Set(key, value interface{}) + Get(key string) (value token.Claims, ok bool) + Set(key string, value token.Claims) } // TokenService defines interface accessing tokens @@ -173,7 +173,7 @@ func (a *Authenticator) refreshExpiredToken(w http.ResponseWriter, claims token. if a.RefreshCache != nil { if c, ok := a.RefreshCache.Get(tkn); ok { // already in cache - return c.(token.Claims), nil + return c, nil } } diff --git a/v2/middleware/auth_test.go b/v2/middleware/auth_test.go index 2a438f15..518380c1 100644 --- a/v2/middleware/auth_test.go +++ b/v2/middleware/auth_test.go @@ -543,16 +543,16 @@ func makeTestAuth(_ *testing.T) Authenticator { } type testRefreshCache struct { - data map[interface{}]interface{} + data map[string]token.Claims sync.RWMutex hits, misses int32 } func newTestRefreshCache() *testRefreshCache { - return &testRefreshCache{data: make(map[interface{}]interface{})} + return &testRefreshCache{data: make(map[string]token.Claims)} } -func (c *testRefreshCache) Get(key interface{}) (value interface{}, ok bool) { +func (c *testRefreshCache) Get(key string) (value token.Claims, ok bool) { c.RLock() defer c.RUnlock() value, ok = c.data[key] @@ -564,7 +564,7 @@ func (c *testRefreshCache) Get(key interface{}) (value interface{}, ok bool) { return value, ok } -func (c *testRefreshCache) Set(key, value interface{}) { +func (c *testRefreshCache) Set(key string, value token.Claims) { c.Lock() defer c.Unlock() c.data[key] = value