From 51cb1d3615b863d0a50c68ad27072343e434d60f Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Mon, 24 Jun 2024 19:46:33 -0700 Subject: [PATCH] Make sure tmpAccounts does not accidentally have duplicates. With this change did not need lookup from tmpAccounts on import processing. Signed-off-by: Derek Collison --- server/accounts.go | 66 ++++++++++++++++++++++++++-------------------- server/client.go | 5 +++- server/jwt_test.go | 6 ++--- server/server.go | 4 +-- 4 files changed, 46 insertions(+), 35 deletions(-) diff --git a/server/accounts.go b/server/accounts.go index 6f5a1a7026d..aca96b5d24b 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -2889,9 +2889,12 @@ func (a *Account) isIssuerClaimTrusted(claims *jwt.ActivationClaims) bool { // check is done with the account's name, not the pointer. This is used // during config reload where we are comparing current and new config // in which pointers are different. -// No lock is acquired in this function, so it is assumed that the -// import maps are not changed while this executes. +// Acquires `a` read lock, but `b` is assumed to not be accessed +// by anyone but the caller (`b` is not registered anywhere). func (a *Account) checkStreamImportsEqual(b *Account) bool { + a.mu.RLock() + defer a.mu.RUnlock() + if len(a.imports.streams) != len(b.imports.streams) { return false } @@ -3264,6 +3267,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim a.nameTag = ac.Name a.tags = ac.Tags + // Grab trace label under lock. + tl := a.traceLabel() + var td string var tds int if ac.Trace != nil { @@ -3297,10 +3303,10 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } if a.imports.services != nil { old.imports.services = make(map[string]*serviceImport, len(a.imports.services)) - } - for k, v := range a.imports.services { - old.imports.services[k] = v - delete(a.imports.services, k) + for k, v := range a.imports.services { + old.imports.services[k] = v + delete(a.imports.services, k) + } } alteredScope := map[string]struct{}{} @@ -3370,13 +3376,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim for _, e := range ac.Exports { switch e.Type { case jwt.Stream: - s.Debugf("Adding stream export %q for %s", e.Subject, a.traceLabel()) + s.Debugf("Adding stream export %q for %s", e.Subject, tl) if err := a.addStreamExportWithAccountPos( string(e.Subject), authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil { - s.Debugf("Error adding stream export to account [%s]: %v", a.traceLabel(), err.Error()) + s.Debugf("Error adding stream export to account [%s]: %v", tl, err.Error()) } case jwt.Service: - s.Debugf("Adding service export %q for %s", e.Subject, a.traceLabel()) + s.Debugf("Adding service export %q for %s", e.Subject, tl) rt := Singleton switch e.ResponseType { case jwt.ResponseTypeStream: @@ -3386,7 +3392,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } if err := a.addServiceExportWithResponseAndAccountPos( string(e.Subject), rt, authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil { - s.Debugf("Error adding service export to account [%s]: %v", a.traceLabel(), err) + s.Debugf("Error adding service export to account [%s]: %v", tl, err) continue } sub := string(e.Subject) @@ -3396,13 +3402,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim if e.Latency.Sampling == jwt.Headers { hdrNote = " (using headers)" } - s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, a.traceLabel(), err) + s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, tl, err) } } if e.ResponseThreshold != 0 { // Response threshold was set in options. if err := a.SetServiceExportResponseThreshold(sub, e.ResponseThreshold); err != nil { - s.Debugf("Error adding service export response threshold for [%s]: %v", a.traceLabel(), err) + s.Debugf("Error adding service export response threshold for [%s]: %v", tl, err) } } if err := a.SetServiceExportAllowTrace(sub, e.AllowTrace); err != nil { @@ -3450,34 +3456,31 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } var incompleteImports []*jwt.Import for _, i := range ac.Imports { - // check tmpAccounts with priority - var acc *Account - var err error - if v, ok := s.tmpAccounts.Load(i.Account); ok { - acc = v.(*Account) - } else { - acc, err = s.lookupAccount(i.Account) - } + acc, err := s.lookupAccount(i.Account) if acc == nil || err != nil { s.Errorf("Can't locate account [%s] for import of [%v] %s (err=%v)", i.Account, i.Subject, i.Type, err) incompleteImports = append(incompleteImports, i) continue } - from := string(i.Subject) - to := i.GetTo() + // Capture trace labels. + acc.mu.RLock() + atl := acc.traceLabel() + acc.mu.RUnlock() + // Grab from and to + from, to := string(i.Subject), i.GetTo() switch i.Type { case jwt.Stream: if i.LocalSubject != _EMPTY_ { // set local subject implies to is empty to = string(i.LocalSubject) - s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to) + s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to) err = a.AddMappedStreamImportWithClaim(acc, from, to, i) } else { - s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to) + s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to) err = a.AddStreamImportWithClaim(acc, from, to, i) } if err != nil { - s.Debugf("Error adding stream import to account [%s]: %v", a.traceLabel(), err.Error()) + s.Debugf("Error adding stream import to account [%s]: %v", tl, err.Error()) incompleteImports = append(incompleteImports, i) } case jwt.Service: @@ -3485,9 +3488,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim from = string(i.LocalSubject) to = string(i.Subject) } - s.Debugf("Adding service import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to) + s.Debugf("Adding service import %s:%q for %s:%q", atl, from, tl, to) if err := a.AddServiceImportWithClaim(acc, from, to, i); err != nil { - s.Debugf("Error adding service import to account [%s]: %v", a.traceLabel(), err.Error()) + s.Debugf("Error adding service import to account [%s]: %v", tl, err.Error()) incompleteImports = append(incompleteImports, i) } } @@ -3663,7 +3666,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim // regardless of enabled or disabled. It handles both cases. if jsEnabled { if err := s.configJetStream(a); err != nil { - s.Errorf("Error configuring jetstream for account [%s]: %v", a.traceLabel(), err.Error()) + s.Errorf("Error configuring jetstream for account [%s]: %v", tl, err.Error()) a.mu.Lock() // Absent reload of js server cfg, this is going to be broken until js is disabled a.incomplete = true @@ -3802,8 +3805,13 @@ func (s *Server) buildInternalAccount(ac *jwt.AccountClaims) *Account { // We don't want to register an account that is in the process of // being built, however, to solve circular import dependencies, we // need to store it here. - s.tmpAccounts.Store(ac.Subject, acc) + if v, loaded := s.tmpAccounts.LoadOrStore(ac.Subject, acc); loaded { + return v.(*Account) + } + + // Update based on claims. s.UpdateAccountClaims(acc, ac) + return acc } diff --git a/server/client.go b/server/client.go index 5f41b4fb09e..b458d51b6fc 100644 --- a/server/client.go +++ b/server/client.go @@ -2959,8 +2959,11 @@ func (c *client) addShadowSubscriptions(acc *Account, sub *subscription, enact b // Add in the shadow subscription. func (c *client) addShadowSub(sub *subscription, ime *ime, enact bool) (*subscription, error) { - im := ime.im + c.mu.Lock() nsub := *sub // copy + c.mu.Unlock() + + im := ime.im nsub.im = im if !im.usePub && ime.dyn && im.tr != nil { diff --git a/server/jwt_test.go b/server/jwt_test.go index b6140813c24..d6cdd1dd7d9 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -1992,9 +1992,9 @@ func TestJWTAccountURLResolverPermanentFetchFailure(t *testing.T) { importErrCnt++ } case <-tmr.C: - // connecting and updating, each cause 3 traces (2 + 1 on iteration) - if importErrCnt != 6 { - t.Fatalf("Expected 6 debug traces, got %d", importErrCnt) + // connecting and updating, each cause 3 traces (2 + 1 on iteration) + 1 xtra fetch + if importErrCnt != 7 { + t.Fatalf("Expected 7 debug traces, got %d", importErrCnt) } return } diff --git a/server/server.go b/server/server.go index fc8c886381c..c0113a29400 100644 --- a/server/server.go +++ b/server/server.go @@ -1140,11 +1140,11 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error) if reloading && acc.Name != globalAccountName { if ai, ok := s.accounts.Load(acc.Name); ok { a = ai.(*Account) - a.mu.Lock() // Before updating the account, check if stream imports have changed. if !a.checkStreamImportsEqual(acc) { awcsti[acc.Name] = struct{}{} } + a.mu.Lock() // Collect the sids for the service imports since we are going to // replace with new ones. var sids [][]byte @@ -2107,7 +2107,6 @@ func (s *Server) fetchAccount(name string) (*Account, error) { return nil, err } acc := s.buildInternalAccount(accClaims) - acc.claimJWT = claimJWT // Due to possible race, if registerAccount() returns a non // nil account, it means the same account was already // registered and we should use this one. @@ -2123,6 +2122,7 @@ func (s *Server) fetchAccount(name string) (*Account, error) { var needImportSubs bool acc.mu.Lock() + acc.claimJWT = claimJWT if len(acc.imports.services) > 0 { if acc.ic == nil { acc.ic = s.createInternalAccountClient()