diff --git a/client/web/antrea-ui/src/api/auth.tsx b/client/web/antrea-ui/src/api/auth.tsx index 70a31a7b..b89f8349 100644 --- a/client/web/antrea-ui/src/api/auth.tsx +++ b/client/web/antrea-ui/src/api/auth.tsx @@ -36,7 +36,7 @@ api.defaults.withCredentials = true; export const authAPI = { login: async (username: string, password: string): Promise => { - return api.get(`auth/login`, { + return api.post(`auth/login`, {}, { headers: { "Authorization": "Basic " + encode(username + ":" + password), }, @@ -44,7 +44,7 @@ export const authAPI = { }, logout: async (): Promise => { - return api.get(`auth/logout`).then(_ => {}).catch((error) => handleError(error, "Error when trying to log out")); + return api.post(`auth/logout`, {}).then(_ => {}).catch((error) => handleError(error, "Error when trying to log out")); }, refreshToken: async (): Promise => { diff --git a/pkg/server/auth.go b/pkg/server/auth.go index b106753a..cb7bafb4 100644 --- a/pkg/server/auth.go +++ b/pkg/server/auth.go @@ -17,6 +17,7 @@ package server import ( "fmt" "net/http" + "strings" "time" "github.com/gin-gonic/gin" @@ -70,7 +71,7 @@ func (s *server) Login(c *gin.Context) { Value: refreshToken.Raw, Path: "/api/v1/auth", Domain: "", - MaxAge: int(refreshToken.ExpiresIn / time.Second), + MaxAge: 0, // make it a session cookie Secure: s.config.CookieSecure, HttpOnly: true, SameSite: http.SameSiteStrictMode, @@ -86,15 +87,30 @@ func (s *server) Login(c *gin.Context) { func (s *server) RefreshToken(c *gin.Context) { if sError := func() *serverError { - cookie, err := c.Request.Cookie("antrea-ui-refresh-token") - if err != nil { - return &serverError{ - code: http.StatusUnauthorized, - message: "Missing 'antrea-ui-refresh-token' cookie", - err: err, + // /refresh supports both the Authorization header and the token cookie, giving + // priority to the Authorization header + var refreshToken string + auth := c.GetHeader("Authorization") + if auth != "" { + t := strings.Split(auth, " ") + if len(t) != 2 || t[0] != "Bearer" { + return &serverError{ + code: http.StatusUnauthorized, + message: "Authorization header does not have valid format", + } + } + refreshToken = t[1] + } else { + cookie, err := c.Request.Cookie("antrea-ui-refresh-token") + if err != nil { + return &serverError{ + code: http.StatusUnauthorized, + message: "Missing 'antrea-ui-refresh-token' cookie", + err: err, + } } + refreshToken = cookie.Value } - refreshToken := cookie.Value if err := s.tokenManager.VerifyRefreshToken(refreshToken); err != nil { return &serverError{ code: http.StatusUnauthorized, @@ -157,7 +173,7 @@ func (s *server) AddAuthRoutes(r *gin.RouterGroup) { loginHandlers = append(loginHandlers, ratelimit.Middleware(loginRateLimiter)) } loginHandlers = append(loginHandlers, s.Login) - r.GET("/login", loginHandlers...) + r.POST("/login", loginHandlers...) r.GET("/refresh_token", s.RefreshToken) - r.GET("/logout", s.Logout) + r.POST("/logout", s.Logout) } diff --git a/pkg/server/auth_test.go b/pkg/server/auth_test.go index 701b6a94..8a2c3d23 100644 --- a/pkg/server/auth_test.go +++ b/pkg/server/auth_test.go @@ -45,7 +45,7 @@ func TestLogin(t *testing.T) { wrongPassword := "abc" sendRequest := func(ts *testServer, mutators ...func(req *http.Request)) *httptest.ResponseRecorder { - req := httptest.NewRequest("GET", "/api/v1/auth/login", nil) + req := httptest.NewRequest("POST", "/api/v1/auth/login", nil) for _, m := range mutators { m(req) } @@ -81,7 +81,7 @@ func TestLogin(t *testing.T) { assert.Equal(t, refreshToken.Raw, cookie.Value) assert.Equal(t, "/api/v1/auth", cookie.Path) assert.Equal(t, "", cookie.Domain) - assert.Equal(t, int(testTokenValidity/time.Second), cookie.MaxAge) + assert.Equal(t, 0, cookie.MaxAge) assert.True(t, cookie.HttpOnly) assert.Equal(t, http.SameSiteStrictMode, cookie.SameSite) }) @@ -132,56 +132,88 @@ func TestLogin(t *testing.T) { } func TestRefreshToken(t *testing.T) { - sendRequest := func(ts *testServer, refreshToken *string) *httptest.ResponseRecorder { + sendRequestWithAuthorizationHeader := func(ts *testServer, refreshToken string) *httptest.ResponseRecorder { req := httptest.NewRequest("GET", "/api/v1/auth/refresh_token", nil) - if refreshToken != nil { - req.AddCookie(&http.Cookie{ - Name: "antrea-ui-refresh-token", - Value: *refreshToken, - }) - } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", refreshToken)) rr := httptest.NewRecorder() ts.router.ServeHTTP(rr, req) return rr } - t.Run("valid refresh", func(t *testing.T) { - ts := newTestServer(t) - refreshToken := getTestToken() - token := getTestToken() - gomock.InOrder( - ts.tokenManager.EXPECT().VerifyRefreshToken(refreshToken.Raw), - ts.tokenManager.EXPECT().GetToken().Return(token, nil), - ) - rr := sendRequest(ts, &refreshToken.Raw) - assert.Equal(t, http.StatusOK, rr.Code) + sendRequestWithCookie := func(ts *testServer, refreshToken string) *httptest.ResponseRecorder { + req := httptest.NewRequest("GET", "/api/v1/auth/refresh_token", nil) + req.AddCookie(&http.Cookie{ + Name: "antrea-ui-refresh-token", + Value: refreshToken, + }) + rr := httptest.NewRecorder() + ts.router.ServeHTTP(rr, req) + return rr + } - // check body of response - var data apisv1alpha1.Token - require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &data)) - assert.Equal(t, token.Raw, data.AccessToken) - assert.Equal(t, "Bearer", data.TokenType) - assert.Equal(t, int64(testTokenValidity/time.Second), data.ExpiresIn) - }) + sendRequestNoAuth := func(ts *testServer) *httptest.ResponseRecorder { + req := httptest.NewRequest("GET", "/api/v1/auth/refresh_token", nil) + rr := httptest.NewRecorder() + ts.router.ServeHTTP(rr, req) + return rr + } - t.Run("missing cookie", func(t *testing.T) { + t.Run("no auth", func(t *testing.T) { ts := newTestServer(t) - rr := sendRequest(ts, nil) + rr := sendRequestNoAuth(ts) assert.Equal(t, http.StatusUnauthorized, rr.Code) }) - t.Run("wrong refresh token", func(t *testing.T) { - ts := newTestServer(t) - badToken := "bad" - ts.tokenManager.EXPECT().VerifyRefreshToken(badToken).Return(fmt.Errorf("bad token")) - rr := sendRequest(ts, &badToken) - assert.Equal(t, http.StatusUnauthorized, rr.Code) - }) + authMethods := []struct { + name string + requestFunc func(ts *testServer, refreshToken string) *httptest.ResponseRecorder + }{ + { + name: "auth header", + requestFunc: sendRequestWithAuthorizationHeader, + }, + { + name: "cookie", + requestFunc: sendRequestWithCookie, + }, + } + + for _, m := range authMethods { + t.Run(m.name, func(t *testing.T) { + t.Run("valid refresh", func(t *testing.T) { + ts := newTestServer(t) + refreshToken := getTestToken() + token := getTestToken() + gomock.InOrder( + ts.tokenManager.EXPECT().VerifyRefreshToken(refreshToken.Raw), + ts.tokenManager.EXPECT().GetToken().Return(token, nil), + ) + rr := m.requestFunc(ts, refreshToken.Raw) + assert.Equal(t, http.StatusOK, rr.Code) + + // check body of response + var data apisv1alpha1.Token + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &data)) + assert.Equal(t, token.Raw, data.AccessToken) + assert.Equal(t, "Bearer", data.TokenType) + assert.Equal(t, int64(testTokenValidity/time.Second), data.ExpiresIn) + }) + + t.Run("wrong refresh token", func(t *testing.T) { + ts := newTestServer(t) + badToken := "bad" + ts.tokenManager.EXPECT().VerifyRefreshToken(badToken).Return(fmt.Errorf("bad token")) + rr := m.requestFunc(ts, badToken) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + }) + } + } func TestLogout(t *testing.T) { sendRequest := func(ts *testServer, refreshToken *string) *httptest.ResponseRecorder { - req := httptest.NewRequest("GET", "/api/v1/auth/logout", nil) + req := httptest.NewRequest("POST", "/api/v1/auth/logout", nil) if refreshToken != nil { req.AddCookie(&http.Cookie{ Name: "antrea-ui-refresh-token", diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index efeef865..d058996a 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -120,9 +120,9 @@ func (ts *testServer) authorizeRequest(req *http.Request) { func TestAuthorization(t *testing.T) { unprotectedRoutes := map[string]bool{ "GET /healthz": true, - "GET /api/v1/auth/login": true, + "POST /api/v1/auth/login": true, "GET /api/v1/auth/refresh_token": true, - "GET /api/v1/auth/logout": true, + "POST /api/v1/auth/logout": true, "GET /api/v1/version": true, } ts := newTestServer(t) diff --git a/test/e2e/api_test.go b/test/e2e/api_test.go index 62eaaf58..79f4bdae 100644 --- a/test/e2e/api_test.go +++ b/test/e2e/api_test.go @@ -38,7 +38,7 @@ const ( func TestLoginRateLimiting(t *testing.T) { ctx := context.Background() badLogin := func() int { - resp, err := Request(ctx, host, "GET", "api/v1/auth/login", nil, func(req *http.Request) { + resp, err := Request(ctx, host, "POST", "api/v1/auth/login", nil, func(req *http.Request) { req.SetBasicAuth("admin", "bad") // invalid password }) require.NoError(t, err) diff --git a/test/e2e/client.go b/test/e2e/client.go index 7ac19e69..2687245c 100644 --- a/test/e2e/client.go +++ b/test/e2e/client.go @@ -32,7 +32,7 @@ type AuthProvider struct { func (p *AuthProvider) getAccessToken(ctx context.Context, host string) (string, error) { login := func(ctx context.Context) (*http.Response, error) { - return Request(ctx, host, "GET", "api/v1/auth/login", nil, func(req *http.Request) { + return Request(ctx, host, "POST", "api/v1/auth/login", nil, func(req *http.Request) { req.SetBasicAuth("admin", "admin") // default credentials }) }