From bc78229deb56efe0678c41b017456fdcc56f21b8 Mon Sep 17 00:00:00 2001 From: Vivek Lohiya Date: Tue, 8 Oct 2024 12:09:54 +0530 Subject: [PATCH] Adding support for refresh Token api --- pkg/controller/controller.go | 2 +- pkg/controller/controller_test.go | 4 +- pkg/controller/postManager.go | 24 ++-- pkg/networkmanager/networkmanager.go | 10 +- pkg/networkmanager/networkmanager_test.go | 8 +- pkg/tokenmanager/tokenmanager.go | 165 ++++++++++++++++------ pkg/tokenmanager/tokenmanager_test.go | 130 +++++++++++++---- 7 files changed, 254 insertions(+), 89 deletions(-) diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 606cdca22..d92d6e4e7 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -56,7 +56,7 @@ func RunController(params Params) *Controller { ctlr.addInformers() // Start Sync CM token Manager - go ctlr.CMTokenManager.Start(make(chan struct{}), tokenmanager.CMAccessTokenExpiration) + go ctlr.CMTokenManager.Start(make(chan struct{}), tokenmanager.CMRefreshTokenExpiration) // start request handler ctlr.RequestHandler.startRequestHandler() diff --git a/pkg/controller/controller_test.go b/pkg/controller/controller_test.go index c3fd45eda..6e33b6b2b 100644 --- a/pkg/controller/controller_test.go +++ b/pkg/controller/controller_test.go @@ -80,7 +80,7 @@ var _ = Describe("New Controller", func() { }) It("should create, start and stop the controller", func() { statusCode = 200 - responseLogin := tokenmanager.TokenResponse{ + responseLogin := tokenmanager.AccessTokenResponse{ AccessToken: "test.token", } server.AppendHandlers( @@ -99,7 +99,7 @@ var _ = Describe("New Controller", func() { ctlr := NewController(params, mockStatusManager) Expect(ctlr).ToNot(BeNil()) time.Sleep(1 * time.Second) - token := ctlr.CMTokenManager.GetToken() + token := ctlr.CMTokenManager.GetAccessToken() Expect(token).To(BeEquivalentTo("test.token"), "Token should be empty") Expect(ctlr.CMTokenManager.CMVersion).To(Equal("20.1.0")) Expect(ctlr.RequestHandler).ToNot(BeNil()) diff --git a/pkg/controller/postManager.go b/pkg/controller/postManager.go index 5dc408b29..b837d46c5 100644 --- a/pkg/controller/postManager.go +++ b/pkg/controller/postManager.go @@ -195,7 +195,7 @@ func (postMgr *PostManager) postConfig(cfg *as3Config) { } log.Infof("%v[AS3]%v posting request to %v", getRequestPrefix(cfg.id), postMgr.postManagerPrefix, cfg.as3APIURL) // add authorization header to the req - req.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetToken()) + req.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetAccessToken()) // add content type header to the req req.Header.Add("Content-Type", "application/json") httpResp, responseMap := postMgr.httpPOST(req) @@ -268,17 +268,17 @@ func (postMgr *PostManager) postConfig(cfg *as3Config) { // return // } // // add authorization header to the req -// if postMgr.tokenManager.GetToken() == "" { +// if postMgr.tokenManager.GetAccessToken() == "" { // log.Debugf("[AS3] Waiting for max 5 seconds for token syncing..") // t := 0 // for t < 5 { // time.Sleep(1 * time.Second) -// if postMgr.tokenManager.GetToken() != "" { +// if postMgr.tokenManager.GetAccessToken() != "" { // log.Debugf("[AS3] Token is now available") // break // } // } -// if postMgr.tokenManager.GetToken() == "" { +// if postMgr.tokenManager.GetAccessToken() == "" { // log.Errorf("[AS3]%v Creating new HTTP request error: access token missing ", postMgr.postManagerPrefix) // return // } @@ -311,7 +311,7 @@ func (postMgr *PostManager) postConfig(cfg *as3Config) { // } // log.Debugf("[AS3]%v posting request to %v", postMgr.postManagerPrefix, cfg.as3APIURL+"/deployments") // // add authorization header to the req -// deployReq.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetToken()) +// deployReq.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetAccessToken()) // deployReq.Header.Add("Content-Type", "application/json") // httpDeployResp, deployResponseMap := postMgr.httpPOST(deployReq) @@ -343,7 +343,7 @@ func (postMgr *PostManager) postConfig(cfg *as3Config) { // } // log.Debugf("[AS3]%v posting request to %v", postMgr.postManagerPrefix, cfg.as3APIURL) // // add authorization header to the req -// declareReq.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetToken()) +// declareReq.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetAccessToken()) // httpDeclareResp, declareResponseMap := postMgr.httpPOST(declareReq) // if httpDeclareResp == nil || declareResponseMap == nil { @@ -378,7 +378,7 @@ func (postMgr *PostManager) postConfig(cfg *as3Config) { // } // log.Debugf("[AS3]%v posting update request to %v", postMgr.postManagerPrefix, cfg.as3APIURL+docID) // // add authorization header to the req -// updateReq.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetToken()) +// updateReq.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetAccessToken()) // updateReq.Header.Add("Content-Type", "application/json") // httpUpdateResp, updateResponseMap := postMgr.httpPOST(updateReq) @@ -411,7 +411,7 @@ func (postMgr *PostManager) postConfig(cfg *as3Config) { // } // log.Debugf("[AS3]%v posting request to %v", postMgr.postManagerPrefix, cfg.as3APIURL+docID) // // add authorization header to the req -// deleteReq.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetToken()) +// deleteReq.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetAccessToken()) // httpDeclareResp, declareResponseMap := postMgr.httpPOST(deleteReq) // if httpDeclareResp == nil || declareResponseMap == nil { @@ -570,7 +570,7 @@ func (postMgr *PostManager) getTenantConfigStatus(id string, cfg *as3Config) { } log.Debugf("[AS3]%v posting request with taskId to %v", postMgr.postManagerPrefix, url) // add authorization header to the req - req.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetToken()) + req.Header.Add("Authorization", "Bearer "+postMgr.tokenManager.GetAccessToken()) httpResp, responseMap := postMgr.httpPOST(req) if httpResp == nil || responseMap == nil { @@ -795,7 +795,7 @@ func (postMgr *PostManager) GetBigipAS3Version() (string, string, string, error) log.Debugf("[AS3]%v posting GET BIGIP AS3 Version request on %v", postMgr.postManagerPrefix, url) // add authorization header to the req - req.Header.Add("Authorization", postMgr.tokenManager.GetToken()) + req.Header.Add("Authorization", postMgr.tokenManager.GetAccessToken()) httpResp, responseMap := postMgr.httpReq(req) if httpResp == nil || responseMap == nil { @@ -833,7 +833,7 @@ func (postMgr *PostManager) GetBigipRegKey() (string, error) { log.Debugf("[AS3]%v Posting GET BIGIP Reg Key request on %v", postMgr.postManagerPrefix, url) // add authorization header to the req - req.Header.Add("Authorization", postMgr.tokenManager.GetToken()) + req.Header.Add("Authorization", postMgr.tokenManager.GetAccessToken()) httpResp, responseMap := postMgr.httpReq(req) if httpResp == nil || responseMap == nil { @@ -865,7 +865,7 @@ func (postMgr *PostManager) GetAS3DeclarationFromBigIP() (map[string]interface{} log.Debugf("[AS3]%v posting GET BIGIP AS3 declaration request on %v", postMgr.postManagerPrefix, url) // add authorization header to the req - req.Header.Add("Authorization", postMgr.tokenManager.GetToken()) + req.Header.Add("Authorization", postMgr.tokenManager.GetAccessToken()) httpResp, responseMap := postMgr.httpReq(req) if httpResp == nil || responseMap == nil { diff --git a/pkg/networkmanager/networkmanager.go b/pkg/networkmanager/networkmanager.go index f60f84374..6e90df26d 100644 --- a/pkg/networkmanager/networkmanager.go +++ b/pkg/networkmanager/networkmanager.go @@ -202,7 +202,7 @@ func (nm *NetworkManager) SetInstanceIds(bigIpConfigs []cisapiv1.BigIpConfig, co } // Set authorization header - req.Header.Set("Authorization", "Bearer "+nm.CMTokenManager.GetToken()) + req.Header.Set("Authorization", "Bearer "+nm.CMTokenManager.GetAccessToken()) // Perform request resp, err := nm.httpClient.Do(req) @@ -263,7 +263,7 @@ func (nm *NetworkManager) GetL3ForwardsFromInstance(instanceId string, controlle } // Set authorization header - req.Header.Set("Authorization", "Bearer "+nm.CMTokenManager.GetToken()) + req.Header.Set("Authorization", "Bearer "+nm.CMTokenManager.GetAccessToken()) // Perform request resp, err := nm.httpClient.Do(req) @@ -336,7 +336,7 @@ func (nm *NetworkManager) DeleteL3Forward(instanceId, l3ForwardID string) error } // Set authorization header - req.Header.Set("Authorization", "Bearer "+nm.CMTokenManager.GetToken()) + req.Header.Set("Authorization", "Bearer "+nm.CMTokenManager.GetAccessToken()) // Perform request resp, err := nm.httpClient.Do(req) @@ -387,7 +387,7 @@ func (nm *NetworkManager) GetTaskStatus(taskRef string) (string, string, error) return "", "", err } // Set authorization header - req.Header.Set("Authorization", "Bearer "+nm.CMTokenManager.GetToken()) + req.Header.Set("Authorization", "Bearer "+nm.CMTokenManager.GetAccessToken()) // Perform request resp, err := nm.httpClient.Do(req) @@ -576,7 +576,7 @@ func (nm *NetworkManager) HandleL3ForwardRequest(req *NetworkConfigRequest, l3Fo } // create the l3 forward - err := nm.PostL3Forward(nm.CMTokenManager.ServerURL+InstancesURI+req.BigIp.InstanceId+L3Forwards, nm.CMTokenManager.GetToken(), l3Forward) + err := nm.PostL3Forward(nm.CMTokenManager.ServerURL+InstancesURI+req.BigIp.InstanceId+L3Forwards, nm.CMTokenManager.GetAccessToken(), l3Forward) if err != nil { bigipStatus.L3Status = &cisapiv1.L3Status{ Message: Create + Failed, diff --git a/pkg/networkmanager/networkmanager_test.go b/pkg/networkmanager/networkmanager_test.go index cfbe2bbba..8c60f53d0 100644 --- a/pkg/networkmanager/networkmanager_test.go +++ b/pkg/networkmanager/networkmanager_test.go @@ -22,7 +22,7 @@ func stringToJson(s string) map[string]interface{} { var _ = Describe("Network Manager Tests", func() { var tokenManager *tokenmanager.TokenManager var server *ghttp.Server - var tokenResponse tokenmanager.TokenResponse + var tokenResponse tokenmanager.AccessTokenResponse var networkManager *NetworkManager var inventoryResponse string var l3ForwardResponse string @@ -66,7 +66,7 @@ var _ = Describe("Network Manager Tests", func() { Password: "admin", }, "", true, mockStatusManager) - tokenResponse = tokenmanager.TokenResponse{ + tokenResponse = tokenmanager.AccessTokenResponse{ AccessToken: "test.token", } server.AppendHandlers( @@ -218,7 +218,7 @@ var _ = Describe("Network Manager Tests", func() { Password: "admin", }, "", true, mockStatusManager) - tokenResponse = tokenmanager.TokenResponse{ + tokenResponse = tokenmanager.AccessTokenResponse{ AccessToken: "test.token", } server.AppendHandlers( @@ -401,7 +401,7 @@ var _ = Describe("Network Manager Tests", func() { Password: "admin", }, "", true, mockStatusManager) - tokenResponse = tokenmanager.TokenResponse{ + tokenResponse = tokenmanager.AccessTokenResponse{ AccessToken: "test.token", } server.AppendHandlers( diff --git a/pkg/tokenmanager/tokenmanager.go b/pkg/tokenmanager/tokenmanager.go index 2b7086240..791c67fae 100644 --- a/pkg/tokenmanager/tokenmanager.go +++ b/pkg/tokenmanager/tokenmanager.go @@ -19,25 +19,29 @@ import ( const ( //CM login url - CMLoginURL = "/api/login" - CMVersionURL = "/api/v1/system/infra/info" - CMAccessTokenExpiration = 5 * time.Minute - TokenFetchFailed = "Failed to fetch token" - Ok = "OK" - RetryInterval = time.Duration(10) + CMLoginURL = "/api/login" + CMVersionURL = "/api/v1/system/infra/info" + CMRefreshTokenURL = "/api/token-refresh" + CMRefreshTokenExpiration = 10 * time.Hour + CMAccessTokenExpiration = 5 * time.Minute + TokenFetchFailed = "Failed to fetch accessToken" + Ok = "OK" + RetryInterval = time.Duration(10) ) -// TokenManager is responsible for managing the authentication token. +// TokenManager is responsible for managing the authentication accessToken. type TokenManager struct { - mu sync.Mutex - token string - ServerURL string - credentials Credentials - SslInsecure bool - TrustedCerts string - httpClient *http.Client - CMVersion string - StatusManager statusmanager.StatusManagerInterface + mu sync.Mutex + accessToken string + accessTokenExpiry time.Time + refreshToken string + ServerURL string + credentials Credentials + SslInsecure bool + TrustedCerts string + httpClient *http.Client + CMVersion string + StatusManager statusmanager.StatusManagerInterface } // Credentials represent the username and password used for authentication. @@ -46,8 +50,13 @@ type Credentials struct { Password string `json:"password"` } -// TokenResponse represents the response received from the CM. -type TokenResponse struct { +// RefreshTokenResponse represents the response received from the CM. +type RefreshTokenResponse struct { + AccessToken string `json:"access_token"` +} + +// AccessTokenResponse represents the response received from the CM. +type AccessTokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` UserID string `json:"user_id"` @@ -65,14 +74,44 @@ func NewTokenManager(serverURL string, credentials Credentials, trustedCerts str } } -// GetToken returns the current valid saved token. -func (tm *TokenManager) GetToken() string { +// GetRefreshToken returns the current valid saved accessToken. +func (tm *TokenManager) GetRefreshToken() string { tm.mu.Lock() - token := tm.token + token := tm.refreshToken tm.mu.Unlock() return token } +// SetRefreshToken safely sets the accessToken in the TokenManager. +func (tm *TokenManager) SetRefreshToken(token string) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.refreshToken = token +} + +// GetAccessToken returns the current valid saved accessToken. +func (tm *TokenManager) GetAccessToken() string { + if time.Now().After(tm.accessTokenExpiry) { + if err := tm.RefreshAccessToken(); err == nil { + log.Debugf("[Token Manager] Successfully refreshed accessToken from Central Manager") + } else { + log.Errorf("[Token Manager] Failed to refresh accessToken from Central Manager: %v", err) + } + } + tm.mu.Lock() + token := tm.accessToken + tm.mu.Unlock() + return token +} + +// SetAccessToken safely sets the accessToken in the TokenManager. +func (tm *TokenManager) SetAccessToken(token string) { + tm.mu.Lock() + defer tm.mu.Unlock() + tm.accessToken = token + tm.accessTokenExpiry = time.Now().Add(CMAccessTokenExpiration) +} + func getHttpClient(trustedCerts string, sslInsecure bool) *http.Client { // Configure CA certificates rootCAs, _ := x509.SystemCertPool() @@ -98,7 +137,51 @@ func getHttpClient(trustedCerts string, sslInsecure bool) *http.Client { return &http.Client{Transport: tr} } -// SyncTokenWithoutRetry retrieves a new token from the CM. +// RefreshAccessToken retrieves a new accessToken from the CM. +func (tm *TokenManager) RefreshAccessToken() error { + + payload := []byte(`{"refresh_token":"` + tm.GetRefreshToken() + `"}`) + // Send POST request for accessToken + resp, err := tm.httpClient.Post(tm.ServerURL+CMRefreshTokenURL, "application/json", bytes.NewBuffer(payload)) + if err != nil { + return fmt.Errorf("unable to establish connection with Central Manager, Probable reasons might be: invalid custom-certs (or) custom-certs not provided using --trusted-certs-cfgmap flag") + } + defer resp.Body.Close() + + // Read the response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("unable to read response body %v. error: %v", resp.Body, err.Error()) + } + + // Check for successful response + if resp.StatusCode != http.StatusOK { + switch resp.StatusCode { + case http.StatusUnauthorized: + return fmt.Errorf("unauthorized to fetch accessToken from Central Manager. "+ + "Please check the credentials, status code: %d, response: %s", resp.StatusCode, body) + case http.StatusServiceUnavailable: + return fmt.Errorf("failed to get accessToken due to service unavailability, "+ + "status code: %d, response: %s", resp.StatusCode, body) + case http.StatusNotFound, http.StatusMovedPermanently: + return fmt.Errorf("requested page/api not found, status code: %d, response: %s", resp.StatusCode, body) + default: + return fmt.Errorf("failed to get accessToken, status code: %d, response: %s", resp.StatusCode, body) + } + } + + // Parse the accessToken and its expiration time from the response + tokenResponse := RefreshTokenResponse{} + err = json.Unmarshal(body, &tokenResponse) + if err != nil { + return fmt.Errorf("unmarshaling failed for refreshAccessToken response %v. error: %v", body, err.Error()) + } + // Keep the accessToken updated in the TokenManager + tm.SetAccessToken(tokenResponse.AccessToken) + return nil +} + +// SyncTokenWithoutRetry retrieves a new accessToken from the CM. func (tm *TokenManager) SyncTokenWithoutRetry() (err error, exit bool) { var errMessage error // Prepare the request payload @@ -108,7 +191,7 @@ func (tm *TokenManager) SyncTokenWithoutRetry() (err error, exit bool) { return errMessage, false } - // Send POST request for token + // Send POST request for accessToken resp, err := tm.httpClient.Post(tm.ServerURL+CMLoginURL, "application/json", bytes.NewBuffer(payload)) if err != nil { errMessage = fmt.Errorf("unable to establish connection with Central Manager, Probable reasons might be: invalid custom-certs (or) custom-certs not provided using --trusted-certs-cfgmap flag") @@ -127,61 +210,61 @@ func (tm *TokenManager) SyncTokenWithoutRetry() (err error, exit bool) { if resp.StatusCode != http.StatusOK { switch resp.StatusCode { case http.StatusUnauthorized: - errMessage = fmt.Errorf("unauthorized to fetch token from Central Manager. "+ + errMessage = fmt.Errorf("unauthorized to fetch accessToken from Central Manager. "+ "Please check the credentials, status code: %d, response: %s", resp.StatusCode, body) return errMessage, true case http.StatusServiceUnavailable: tm.StatusManager.AddRequest(statusmanager.DeployConfig, "", "", false, &cisapiv1.CMStatus{ Message: TokenFetchFailed, - Error: fmt.Sprintf("failed to get token due to service unavailability, "+ + Error: fmt.Sprintf("failed to get accessToken due to service unavailability, "+ "status code: %d, response: %s", resp.StatusCode, body), LastUpdated: metav1.Now(), }) - errMessage = fmt.Errorf("failed to get token due to service unavailability, "+ + errMessage = fmt.Errorf("failed to get accessToken due to service unavailability, "+ "status code: %d, response: %s", resp.StatusCode, body) return errMessage, false case http.StatusNotFound, http.StatusMovedPermanently: errMessage = fmt.Errorf("requested page/api not found, status code: %d, response: %s", resp.StatusCode, body) return errMessage, true default: - errMessage = fmt.Errorf("failed to get token, status code: %d, response: %s", resp.StatusCode, body) + errMessage = fmt.Errorf("failed to get accessToken, status code: %d, response: %s", resp.StatusCode, body) return errMessage, false } } - // Parse the token and its expiration time from the response - tokenResponse := TokenResponse{} + // Parse the accessToken and its expiration time from the response + tokenResponse := AccessTokenResponse{} err = json.Unmarshal(body, &tokenResponse) if err != nil { - errMessage = fmt.Errorf("unmarshaling failed for token response %v. error: %v", body, err.Error()) + errMessage = fmt.Errorf("unmarshaling failed for accessToken response %v. error: %v", body, err.Error()) return errMessage, false } - // Keep the token updated in the TokenManager - tm.mu.Lock() - tm.token = tokenResponse.AccessToken - tm.mu.Unlock() - log.Debugf("[Token Manager] Successfully fetched token from Central Manager") + // Keep the accessToken updated in the TokenManager + tm.SetAccessToken(tokenResponse.AccessToken) + // Keep the refreshToken updated in the TokenManager + tm.SetRefreshToken(tokenResponse.RefreshToken) + log.Debugf("[Token Manager] Successfully fetched accessToken from Central Manager") return nil, false } -// Start maintains valid token. It fetches a new token before expiry. +// Start maintains valid accessToken. It fetches a new accessToken before expiry. func (tm *TokenManager) Start(stopCh chan struct{}, duration time.Duration) { - // Set ticker to 1 minute less than token expiry time to ensure token is refreshed on time + // Set ticker to 1 minute less than refreshToken expiry time to ensure accessToken is refreshed on time tokenUpdateTicker := time.Tick(duration - 60*time.Second) for { select { case <-tokenUpdateTicker: tm.SyncToken() case <-stopCh: - log.Debug("[Token Manager] Stopping synchronizing token") + log.Debug("[Token Manager] Stopping synchronizing refreshToken") close(stopCh) return } } } -// SyncToken is a helper function to fetch token and retry on failure +// SyncToken is a helper function to fetch refreshToken and retry on failure func (tm *TokenManager) SyncToken() { for { err, exit := tm.SyncTokenWithoutRetry() @@ -192,7 +275,7 @@ func (tm *TokenManager) SyncToken() { LastUpdated: metav1.Now(), }) if !exit { - log.Debugf("[Token Manager] Retrying to fetch token in %d seconds", RetryInterval) + log.Debugf("[Token Manager] Retrying to fetch refreshToken in %d seconds", RetryInterval) time.Sleep(RetryInterval * time.Second) } } else { @@ -216,7 +299,7 @@ func (tm *TokenManager) GetCMVersion() (string, error) { log.Debugf("posting GET CM version request on %v", url) // add authorization header to the req - req.Header.Add("Authorization", "Bearer "+tm.GetToken()) + req.Header.Add("Authorization", "Bearer "+tm.GetAccessToken()) httpResp, err := tm.httpClient.Do(req) diff --git a/pkg/tokenmanager/tokenmanager_test.go b/pkg/tokenmanager/tokenmanager_test.go index 05a2217bf..032359fc7 100644 --- a/pkg/tokenmanager/tokenmanager_test.go +++ b/pkg/tokenmanager/tokenmanager_test.go @@ -12,12 +12,12 @@ var _ = Describe("Token Manager Tests", func() { var tokenManager *TokenManager var server *ghttp.Server var statusCode int - var response TokenResponse - - Describe("GetToken", func() { - Context("when token fetch is successful", func() { + var response AccessTokenResponse + var refreshResponse RefreshTokenResponse + Describe("GetAccessToken", func() { + Context("when accessToken is fetched during login", func() { BeforeEach(func() { - // Mock the token server + // Mock the accessToken server server = ghttp.NewServer() mockStatusManager := mockmanager.NewMockStatusManager() tokenManager = NewTokenManager(server.URL(), Credentials{ @@ -26,30 +26,37 @@ var _ = Describe("Token Manager Tests", func() { }, "", true, mockStatusManager) }) AfterEach(func() { - // Stop the mock token server + // Stop the mock accessToken server server.Close() }) - It("should return error while fetching token", func() { + It("should return error while fetching accessToken", func() { statusCode = 500 - response = TokenResponse{ - AccessToken: "test.token", + response = AccessTokenResponse{ + AccessToken: "test.accessToken", } server.AppendHandlers( ghttp.CombineHandlers( - ghttp.VerifyRequest("POST", "/api/login"), + ghttp.VerifyRequest("POST", CMLoginURL), ghttp.RespondWithJSONEncoded(statusCode, response), )) + server.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", CMRefreshTokenURL), + ghttp.RespondWithJSONEncoded(statusCode, RefreshTokenResponse{ + AccessToken: "test.accessToken", + }), + )) go tokenManager.SyncToken() time.Sleep(1 * time.Second) - token := tokenManager.GetToken() + token := tokenManager.GetAccessToken() Expect(token).To(BeEmpty(), "Token should be empty") }) - It("should return a valid token", func() { + It("should return a valid accessToken", func() { statusCode = 200 - response = TokenResponse{ - AccessToken: "test.token", + response = AccessTokenResponse{ + AccessToken: "test.accessToken", } server.AppendHandlers( ghttp.CombineHandlers( @@ -57,13 +64,13 @@ var _ = Describe("Token Manager Tests", func() { ghttp.RespondWithJSONEncoded(statusCode, response), )) tokenManager.SyncTokenWithoutRetry() - token := tokenManager.GetToken() + token := tokenManager.GetAccessToken() Expect(token).To(Equal(response.AccessToken), "Token should not be nil") }) It("error code 401", func() { statusCode = 401 - response = TokenResponse{ - AccessToken: "test.token", + response = AccessTokenResponse{ + AccessToken: "test.accessToken", } server.AppendHandlers( ghttp.CombineHandlers( @@ -72,12 +79,12 @@ var _ = Describe("Token Manager Tests", func() { )) err, _ := tokenManager.SyncTokenWithoutRetry() Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("unauthorized to fetch token")) + Expect(err.Error()).To(ContainSubstring("unauthorized to fetch accessToken")) }) It("error code 503", func() { statusCode = 503 - response = TokenResponse{ - AccessToken: "test.token", + response = AccessTokenResponse{ + AccessToken: "test.accessToken", } server.AppendHandlers( ghttp.CombineHandlers( @@ -86,12 +93,12 @@ var _ = Describe("Token Manager Tests", func() { )) err, _ := tokenManager.SyncTokenWithoutRetry() Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("failed to get token due to service unavailability")) + Expect(err.Error()).To(ContainSubstring("failed to get accessToken due to service unavailability")) }) It("error code 404", func() { statusCode = 404 - response = TokenResponse{ - AccessToken: "test.token", + response = AccessTokenResponse{ + AccessToken: "test.accessToken", } server.AppendHandlers( ghttp.CombineHandlers( @@ -103,6 +110,80 @@ var _ = Describe("Token Manager Tests", func() { Expect(err.Error()).To(ContainSubstring("requested page/api not found")) }) }) + Context("when accessToken is fetched during refresh", func() { + BeforeEach(func() { + // Mock the accessToken server + server = ghttp.NewServer() + mockStatusManager := mockmanager.NewMockStatusManager() + tokenManager = NewTokenManager(server.URL(), Credentials{ + Username: "admin", + Password: "admin", + }, "", true, mockStatusManager) + tokenManager.accessToken = "test.accessToken" + tokenManager.accessTokenExpiry = time.Now() + }) + AfterEach(func() { + // Stop the mock accessToken server + server.Close() + }) + + It("should return a valid accessToken on refresh", func() { + statusCode = 200 + refreshResponse = RefreshTokenResponse{ + AccessToken: "refreshed.accessToken", + } + server.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", CMRefreshTokenURL), + ghttp.RespondWithJSONEncoded(statusCode, refreshResponse), + )) + err := tokenManager.RefreshAccessToken() + Expect(err).NotTo(HaveOccurred()) + Expect(tokenManager.accessToken).To(Equal(refreshResponse.AccessToken), "Token should not be nil") + }) + It("error code 401 with refresh api", func() { + statusCode = 401 + refreshResponse = RefreshTokenResponse{ + AccessToken: "refreshed.accessToken", + } + server.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", CMRefreshTokenURL), + ghttp.RespondWithJSONEncoded(statusCode, refreshResponse), + )) + err := tokenManager.RefreshAccessToken() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unauthorized to fetch accessToken")) + }) + It("error code 503 with token refresh api", func() { + statusCode = 503 + refreshResponse = RefreshTokenResponse{ + AccessToken: "refreshed.accessToken", + } + server.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", CMRefreshTokenURL), + ghttp.RespondWithJSONEncoded(statusCode, refreshResponse), + )) + err := tokenManager.RefreshAccessToken() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to get accessToken due to service unavailability")) + }) + It("error code 404", func() { + statusCode = 404 + refreshResponse = RefreshTokenResponse{ + AccessToken: "refreshed.accessToken", + } + server.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", CMRefreshTokenURL), + ghttp.RespondWithJSONEncoded(statusCode, refreshResponse), + )) + err := tokenManager.RefreshAccessToken() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("requested page/api not found")) + }) + }) }) }) @@ -120,7 +201,8 @@ var _ = Describe("GetCMVersion", func() { Username: "admin", Password: "admin", }, "", true, mockStatusManager) - tm.token = "fake-token" + tm.accessToken = "fake-accessToken" + tm.accessTokenExpiry = time.Now().Add(5 * time.Minute) }) AfterEach(func() {