From 30656d62d0d2f96c49a8cea5163848c0794c637d Mon Sep 17 00:00:00 2001 From: cyb3r4nt <104218001+cyb3r4nt@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:59:54 +0300 Subject: [PATCH 1/2] Fix registration of dev provider in Service.authMiddleware.Providers Now it is possible to have a configuration, where only one single dev provider is enabled. Providers were not registered into Service.authMiddleware.Provicers slice in the Service.AddDevProvider() and Service.AddAppleProvider() methods before. --- v2/auth.go | 54 ++++++++++++++++++++---------------------- v2/auth_test.go | 63 ++++++++++++++++++++++++++----------------------- 2 files changed, 59 insertions(+), 58 deletions(-) diff --git a/v2/auth.go b/v2/auth.go index 0760ca6..efcd552 100644 --- a/v2/auth.go +++ b/v2/auth.go @@ -234,39 +234,44 @@ func (s *Service) AddProviderWithUserAttributes(name, cid, csecret string, userA L: s.logger, UserAttributes: userAttributes, } - s.addProvider(name, p) + s.addProviderByName(name, p) } -func (s *Service) addProvider(name string, p provider.Params) { +func (s *Service) addProviderByName(name string, p provider.Params) { + var prov provider.Provider switch strings.ToLower(name) { case "github": - s.providers = append(s.providers, provider.NewService(provider.NewGithub(p))) + prov = provider.NewGithub(p) case "google": - s.providers = append(s.providers, provider.NewService(provider.NewGoogle(p))) + prov = provider.NewGoogle(p) case "facebook": - s.providers = append(s.providers, provider.NewService(provider.NewFacebook(p))) + prov = provider.NewFacebook(p) case "yandex": - s.providers = append(s.providers, provider.NewService(provider.NewYandex(p))) + prov = provider.NewYandex(p) case "battlenet": - s.providers = append(s.providers, provider.NewService(provider.NewBattlenet(p))) + prov = provider.NewBattlenet(p) case "microsoft": - s.providers = append(s.providers, provider.NewService(provider.NewMicrosoft(p))) + prov = provider.NewMicrosoft(p) case "twitter": - s.providers = append(s.providers, provider.NewService(provider.NewTwitter(p))) + prov = provider.NewTwitter(p) case "patreon": - s.providers = append(s.providers, provider.NewService(provider.NewPatreon(p))) + prov = provider.NewPatreon(p) case "dev": - s.providers = append(s.providers, provider.NewService(provider.NewDev(p))) + prov = provider.NewDev(p) default: return } + s.addProvider(prov) +} + +func (s *Service) addProvider(prov provider.Provider) { + s.providers = append(s.providers, provider.NewService(prov)) s.authMiddleware.Providers = s.providers } // AddProvider adds provider for given name func (s *Service) AddProvider(name, cid, csecret string) { - p := provider.Params{ URL: s.opts.URL, JwtService: s.jwtService, @@ -277,8 +282,7 @@ func (s *Service) AddProvider(name, cid, csecret string) { L: s.logger, UserAttributes: map[string]string{}, } - - s.addProvider(name, p) + s.addProviderByName(name, p) } // AddDevProvider with a custom host and port @@ -292,7 +296,7 @@ func (s *Service) AddDevProvider(host string, port int) { Port: port, Host: host, } - s.providers = append(s.providers, provider.NewService(provider.NewDev(p))) + s.addProvider(provider.NewDev(p)) } // AddAppleProvider allow SignIn with Apple ID @@ -311,7 +315,7 @@ func (s *Service) AddAppleProvider(appleConfig provider.AppleConfig, privKeyLoad return fmt.Errorf("an AppleProvider creating failed: %w", err) } - s.providers = append(s.providers, provider.NewService(appleProvider)) + s.addProvider(appleProvider) return nil } @@ -326,9 +330,7 @@ func (s *Service) AddCustomProvider(name string, client Client, copts provider.C Csecret: client.Csecret, L: s.logger, } - - s.providers = append(s.providers, provider.NewService(provider.NewCustom(name, p, copts))) - s.authMiddleware.Providers = s.providers + s.addProvider(provider.NewCustom(name, p, copts)) } // AddDirectProvider adds provider with direct check against data store @@ -342,8 +344,7 @@ func (s *Service) AddDirectProvider(name string, credChecker provider.CredChecke CredChecker: credChecker, AvatarSaver: s.avatarProxy, } - s.providers = append(s.providers, provider.NewService(dh)) - s.authMiddleware.Providers = s.providers + s.addProvider(dh) } // AddDirectProviderWithUserIDFunc adds provider with direct check against data store and sets custom UserIDFunc allows @@ -359,8 +360,7 @@ func (s *Service) AddDirectProviderWithUserIDFunc(name string, credChecker provi AvatarSaver: s.avatarProxy, UserIDFunc: ufn, } - s.providers = append(s.providers, provider.NewService(dh)) - s.authMiddleware.Providers = s.providers + s.addProvider(dh) } // AddVerifProvider adds provider user's verification sent by sender @@ -375,14 +375,12 @@ func (s *Service) AddVerifProvider(name, msgTmpl string, sender provider.Sender) Template: msgTmpl, UseGravatar: s.useGravatar, } - s.providers = append(s.providers, provider.NewService(dh)) - s.authMiddleware.Providers = s.providers + s.addProvider(dh) } // AddCustomHandler adds user-defined self-implemented handler of auth provider -func (s *Service) AddCustomHandler(handler provider.Provider) { - s.providers = append(s.providers, provider.NewService(handler)) - s.authMiddleware.Providers = s.providers +func (s *Service) AddCustomHandler(p provider.Provider) { + s.addProvider(p) } // DevAuth makes dev oauth2 server, for testing and development only! diff --git a/v2/auth_test.go b/v2/auth_test.go index 2943225..f2934f1 100644 --- a/v2/auth_test.go +++ b/v2/auth_test.go @@ -227,7 +227,11 @@ func TestIntegrationAvatar(t *testing.T) { } func TestIntegrationList(t *testing.T) { - _, teardown := prepService(t) + _, teardown := prepService(t, func(svc *Service) { + svc.AddProvider("github", "cid", "csec") + // add go-oauth2/oauth2 provider + svc.AddCustomProvider("custom123", Client{"cid", "csecret"}, provider.CustomHandlerOpt{}) + }) defer teardown() resp, err := http.Get("http://127.0.0.1:8089/auth/list") @@ -237,7 +241,7 @@ func TestIntegrationList(t *testing.T) { b, err := io.ReadAll(resp.Body) require.NoError(t, err) - assert.Equal(t, `["dev","github","custom123","direct","direct_custom","email"]`+"\n", string(b)) + assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b)) } func TestIntegrationUserInfo(t *testing.T) { @@ -336,7 +340,11 @@ func TestBadRequests(t *testing.T) { } func TestDirectProvider(t *testing.T) { - _, teardown := prepService(t) + _, teardown := prepService(t, func(svc *Service) { + svc.AddDirectProvider("direct", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { + return user == "dev_direct" && password == "password", nil + })) + }) defer teardown() // login @@ -374,19 +382,28 @@ func TestDirectProvider(t *testing.T) { } func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { - _, teardown := prepService(t) + _, teardown := prepService(t, func(svc *Service) { + svc.AddDirectProviderWithUserIDFunc("directCustom", + provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { + return user == "dev_direct" && password == "password", nil + }), + func(user string, r *http.Request) string { + return "blah" + }, + ) + }) defer teardown() // login jar, err := cookiejar.New(nil) require.Nil(t, err) client := &http.Client{Jar: jar, Timeout: 5 * time.Second} - resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad") + resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 403, resp.StatusCode) - resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password") + resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password") require.Nil(t, err) defer resp.Body.Close() assert.Equal(t, 200, resp.StatusCode) @@ -396,7 +413,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { t.Logf("resp %s", string(body)) t.Logf("headers: %+v", resp.Header) - assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) + assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`) require.Equal(t, 2, len(resp.Cookies())) assert.Equal(t, "JWT", resp.Cookies()[0].Name) @@ -412,7 +429,9 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) { } func TestVerifProvider(t *testing.T) { - _, teardown := prepService(t) + _, teardown := prepService(t, func(svc *Service) { + svc.AddVerifProvider("email", "{{.Token}}", &sender) + }) defer teardown() // login @@ -488,7 +507,7 @@ func TestStatus(t *testing.T) { } -func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unparam +func prepService(t *testing.T, providerConfigFunctions ...func(svc *Service)) (svc *Service, teardown func()) { //nolint unparam options := Opts{ SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }), @@ -509,28 +528,12 @@ func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unpara } svc = NewService(options) - svc.AddDevProvider("localhost", 18084) // add dev provider on 18084 - svc.AddProvider("github", "cid", "csec") // add github provider - - // add go-oauth2/oauth2 provider - svc.AddCustomProvider("custom123", Client{"cid", "csecret"}, provider.CustomHandlerOpt{}) - // add direct provider - svc.AddDirectProvider("direct", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { - return user == "dev_direct" && password == "password", nil - })) - - // add direct provider with custom user id func - svc.AddDirectProviderWithUserIDFunc("direct_custom", - provider.CredCheckerFunc(func(user, password string) (ok bool, err error) { - return user == "dev_direct" && password == "password", nil - }), - func(user string, r *http.Request) string { - return "blah" - }, - ) + svc.AddDevProvider("localhost", 18084) // add dev provider on 18084 - svc.AddVerifProvider("email", "{{.Token}}", &sender) + for _, f := range providerConfigFunctions { + f(svc) + } // run dev/test oauth2 server on :18084 devAuth, err := svc.DevAuth() @@ -546,7 +549,7 @@ func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unpara _, _ = w.Write([]byte("open route, no token needed\n")) })) mux.Handle("/private", m.Auth(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // token required - _, _ = w.Write([]byte("open route, no token needed\n")) + _, _ = w.Write([]byte("protected route, authenticated with token\n")) }))) // setup auth routes From 523974c6eb32c42cda88f0bbd4065b6b61f9318b Mon Sep 17 00:00:00 2001 From: cyb3r4nt <104218001+cyb3r4nt@users.noreply.github.com> Date: Fri, 30 Aug 2024 20:00:14 +0300 Subject: [PATCH 2/2] fix race conditions in TestTelegramConfirmedRequest There were two different race conditions between logic in TestTelegramConfirmedRequest and TelegramAPIMock.GetUpdatesFunc and TelegramAPIMock.SendFunc: * GetUpdatesFunc may start before token was fetched, then it produces empty telegramUpdate response, which causes assertions in SendFunc to fail. * When token becomes used and removed from wait queue after successful login completion, then GetUpdatesFunc may be still called and new telegram update is created for same token. This breaks telegram update processing logic, and SendFunc gets called with the error parameter, which also breaks assertions. --- provider/telegram_test.go | 39 +++++++++++++++++++++++++----------- v2/provider/telegram_test.go | 38 ++++++++++++++++++++++++----------- 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/provider/telegram_test.go b/provider/telegram_test.go index 12afe07..f56839b 100644 --- a/provider/telegram_test.go +++ b/provider/telegram_test.go @@ -89,22 +89,37 @@ func TestTelegramUnconfirmedRequest(t *testing.T) { func TestTelegramConfirmedRequest(t *testing.T) { var servedToken string - var mu sync.Mutex + + // is set when token becomes used, + // no sync is required because only a single goroutine in TelegramHandler.Run() reads and writes it + var tokenAlreadyUsed bool + + var wgToken sync.WaitGroup + wgToken.Add(1) + defer func() { + if t.Failed() && servedToken == "" { + wgToken.Done() // for the case when test fails before token is generated + } + }() m := &TelegramAPIMock{ GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { - var upd telegramUpdate + wgToken.Wait() - mu.Lock() - defer mu.Unlock() - if servedToken != "" { - resp := fmt.Sprintf(getUpdatesResp, servedToken) + if tokenAlreadyUsed || t.Failed() { + return nil, fmt.Errorf("token %s has been already used", servedToken) + } - err := json.Unmarshal([]byte(resp), &upd) - if err != nil { - t.Fatal(err) - } + var upd telegramUpdate + resp := fmt.Sprintf(getUpdatesResp, servedToken) + err := json.Unmarshal([]byte(resp), &upd) + if err != nil { + t.Fatal(err) } + + // token is served only once + tokenAlreadyUsed = true + return &upd, nil }, AvatarFunc: func(ctx context.Context, userID int) (string, error) { @@ -147,10 +162,10 @@ func TestTelegramConfirmedRequest(t *testing.T) { err := json.Unmarshal(w.Body.Bytes(), &resp) assert.NoError(t, err) assert.Equal(t, "my_auth_bot", resp.Bot) + assert.NotEmpty(t, resp.Token) - mu.Lock() servedToken = resp.Token - mu.Unlock() + wgToken.Done() // Check the token confirmation assert.Eventually(t, func() bool { diff --git a/v2/provider/telegram_test.go b/v2/provider/telegram_test.go index 228ae23..7e0cbf0 100644 --- a/v2/provider/telegram_test.go +++ b/v2/provider/telegram_test.go @@ -89,22 +89,36 @@ func TestTelegramUnconfirmedRequest(t *testing.T) { func TestTelegramConfirmedRequest(t *testing.T) { var servedToken string - var mu sync.Mutex + // is set when token becomes used, + // no sync is required because only a single goroutine in TelegramHandler.Run() reads and writes it + var tokenAlreadyUsed bool + + var wgToken sync.WaitGroup + wgToken.Add(1) + defer func() { + if t.Failed() && servedToken == "" { + wgToken.Done() // for the case when test fails before token is generated + } + }() m := &TelegramAPIMock{ GetUpdatesFunc: func(ctx context.Context) (*telegramUpdate, error) { - var upd telegramUpdate + wgToken.Wait() - mu.Lock() - defer mu.Unlock() - if servedToken != "" { - resp := fmt.Sprintf(getUpdatesResp, servedToken) + if tokenAlreadyUsed || t.Failed() { + return nil, fmt.Errorf("token %s has been already used", servedToken) + } - err := json.Unmarshal([]byte(resp), &upd) - if err != nil { - t.Fatal(err) - } + var upd telegramUpdate + resp := fmt.Sprintf(getUpdatesResp, servedToken) + err := json.Unmarshal([]byte(resp), &upd) + if err != nil { + t.Fatal(err) } + + // token is served only once + tokenAlreadyUsed = true + return &upd, nil }, AvatarFunc: func(ctx context.Context, userID int) (string, error) { @@ -147,10 +161,10 @@ func TestTelegramConfirmedRequest(t *testing.T) { err := json.Unmarshal(w.Body.Bytes(), &resp) assert.NoError(t, err) assert.Equal(t, "my_auth_bot", resp.Bot) + assert.NotEmpty(t, resp.Token) - mu.Lock() servedToken = resp.Token - mu.Unlock() + wgToken.Done() // Check the token confirmation assert.Eventually(t, func() bool {