Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIXED] Make sure tmpAccounts does not accidentally have duplicates. #5588

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions server/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2889,9 +2889,12 @@ func (a *Account) isIssuerClaimTrusted(claims *jwt.ActivationClaims) bool {
// check is done with the account's name, not the pointer. This is used
// during config reload where we are comparing current and new config
// in which pointers are different.
// No lock is acquired in this function, so it is assumed that the
// import maps are not changed while this executes.
// Acquires `a` read lock, but `b` is assumed to not be accessed
// by anyone but the caller (`b` is not registered anywhere).
func (a *Account) checkStreamImportsEqual(b *Account) bool {
a.mu.RLock()
defer a.mu.RUnlock()

if len(a.imports.streams) != len(b.imports.streams) {
return false
}
Expand Down Expand Up @@ -3264,6 +3267,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
a.nameTag = ac.Name
a.tags = ac.Tags

// Grab trace label under lock.
tl := a.traceLabel()

var td string
var tds int
if ac.Trace != nil {
Expand Down Expand Up @@ -3297,10 +3303,10 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}
if a.imports.services != nil {
old.imports.services = make(map[string]*serviceImport, len(a.imports.services))
}
for k, v := range a.imports.services {
old.imports.services[k] = v
delete(a.imports.services, k)
for k, v := range a.imports.services {
old.imports.services[k] = v
delete(a.imports.services, k)
}
}

alteredScope := map[string]struct{}{}
Expand Down Expand Up @@ -3370,13 +3376,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
for _, e := range ac.Exports {
switch e.Type {
case jwt.Stream:
s.Debugf("Adding stream export %q for %s", e.Subject, a.traceLabel())
s.Debugf("Adding stream export %q for %s", e.Subject, tl)
if err := a.addStreamExportWithAccountPos(
string(e.Subject), authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil {
s.Debugf("Error adding stream export to account [%s]: %v", a.traceLabel(), err.Error())
s.Debugf("Error adding stream export to account [%s]: %v", tl, err.Error())
}
case jwt.Service:
s.Debugf("Adding service export %q for %s", e.Subject, a.traceLabel())
s.Debugf("Adding service export %q for %s", e.Subject, tl)
rt := Singleton
switch e.ResponseType {
case jwt.ResponseTypeStream:
Expand All @@ -3386,7 +3392,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}
if err := a.addServiceExportWithResponseAndAccountPos(
string(e.Subject), rt, authAccounts(e.TokenReq), e.AccountTokenPosition); err != nil {
s.Debugf("Error adding service export to account [%s]: %v", a.traceLabel(), err)
s.Debugf("Error adding service export to account [%s]: %v", tl, err)
continue
}
sub := string(e.Subject)
Expand All @@ -3396,13 +3402,13 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
if e.Latency.Sampling == jwt.Headers {
hdrNote = " (using headers)"
}
s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, a.traceLabel(), err)
s.Debugf("Error adding latency tracking%s for service export to account [%s]: %v", hdrNote, tl, err)
}
}
if e.ResponseThreshold != 0 {
// Response threshold was set in options.
if err := a.SetServiceExportResponseThreshold(sub, e.ResponseThreshold); err != nil {
s.Debugf("Error adding service export response threshold for [%s]: %v", a.traceLabel(), err)
s.Debugf("Error adding service export response threshold for [%s]: %v", tl, err)
}
}
if err := a.SetServiceExportAllowTrace(sub, e.AllowTrace); err != nil {
Expand Down Expand Up @@ -3450,44 +3456,41 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
}
var incompleteImports []*jwt.Import
for _, i := range ac.Imports {
// check tmpAccounts with priority
var acc *Account
var err error
if v, ok := s.tmpAccounts.Load(i.Account); ok {
acc = v.(*Account)
} else {
acc, err = s.lookupAccount(i.Account)
}
acc, err := s.lookupAccount(i.Account)
if acc == nil || err != nil {
s.Errorf("Can't locate account [%s] for import of [%v] %s (err=%v)", i.Account, i.Subject, i.Type, err)
incompleteImports = append(incompleteImports, i)
continue
}
from := string(i.Subject)
to := i.GetTo()
// Capture trace labels.
acc.mu.RLock()
atl := acc.traceLabel()
acc.mu.RUnlock()
// Grab from and to
from, to := string(i.Subject), i.GetTo()
switch i.Type {
case jwt.Stream:
if i.LocalSubject != _EMPTY_ {
// set local subject implies to is empty
to = string(i.LocalSubject)
s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to)
s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to)
err = a.AddMappedStreamImportWithClaim(acc, from, to, i)
} else {
s.Debugf("Adding stream import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to)
s.Debugf("Adding stream import %s:%q for %s:%q", atl, from, tl, to)
err = a.AddStreamImportWithClaim(acc, from, to, i)
}
if err != nil {
s.Debugf("Error adding stream import to account [%s]: %v", a.traceLabel(), err.Error())
s.Debugf("Error adding stream import to account [%s]: %v", tl, err.Error())
incompleteImports = append(incompleteImports, i)
}
case jwt.Service:
if i.LocalSubject != _EMPTY_ {
from = string(i.LocalSubject)
to = string(i.Subject)
}
s.Debugf("Adding service import %s:%q for %s:%q", acc.traceLabel(), from, a.traceLabel(), to)
s.Debugf("Adding service import %s:%q for %s:%q", atl, from, tl, to)
if err := a.AddServiceImportWithClaim(acc, from, to, i); err != nil {
s.Debugf("Error adding service import to account [%s]: %v", a.traceLabel(), err.Error())
s.Debugf("Error adding service import to account [%s]: %v", tl, err.Error())
incompleteImports = append(incompleteImports, i)
}
}
Expand Down Expand Up @@ -3663,7 +3666,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
// regardless of enabled or disabled. It handles both cases.
if jsEnabled {
if err := s.configJetStream(a); err != nil {
s.Errorf("Error configuring jetstream for account [%s]: %v", a.traceLabel(), err.Error())
s.Errorf("Error configuring jetstream for account [%s]: %v", tl, err.Error())
a.mu.Lock()
// Absent reload of js server cfg, this is going to be broken until js is disabled
a.incomplete = true
Expand Down Expand Up @@ -3802,8 +3805,13 @@ func (s *Server) buildInternalAccount(ac *jwt.AccountClaims) *Account {
// We don't want to register an account that is in the process of
// being built, however, to solve circular import dependencies, we
// need to store it here.
s.tmpAccounts.Store(ac.Subject, acc)
if v, loaded := s.tmpAccounts.LoadOrStore(ac.Subject, acc); loaded {
return v.(*Account)
}

// Update based on claims.
s.UpdateAccountClaims(acc, ac)

return acc
}

Expand Down
5 changes: 4 additions & 1 deletion server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2959,8 +2959,11 @@ func (c *client) addShadowSubscriptions(acc *Account, sub *subscription, enact b

// Add in the shadow subscription.
func (c *client) addShadowSub(sub *subscription, ime *ime, enact bool) (*subscription, error) {
im := ime.im
c.mu.Lock()
nsub := *sub // copy
c.mu.Unlock()

im := ime.im
nsub.im = im

if !im.usePub && ime.dyn && im.tr != nil {
Expand Down
6 changes: 3 additions & 3 deletions server/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1992,9 +1992,9 @@ func TestJWTAccountURLResolverPermanentFetchFailure(t *testing.T) {
importErrCnt++
}
case <-tmr.C:
// connecting and updating, each cause 3 traces (2 + 1 on iteration)
if importErrCnt != 6 {
t.Fatalf("Expected 6 debug traces, got %d", importErrCnt)
// connecting and updating, each cause 3 traces (2 + 1 on iteration) + 1 xtra fetch
if importErrCnt != 7 {
t.Fatalf("Expected 7 debug traces, got %d", importErrCnt)
}
return
}
Expand Down
4 changes: 2 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1140,11 +1140,11 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error)
if reloading && acc.Name != globalAccountName {
if ai, ok := s.accounts.Load(acc.Name); ok {
a = ai.(*Account)
a.mu.Lock()
// Before updating the account, check if stream imports have changed.
if !a.checkStreamImportsEqual(acc) {
awcsti[acc.Name] = struct{}{}
}
a.mu.Lock()
// Collect the sids for the service imports since we are going to
// replace with new ones.
var sids [][]byte
Expand Down Expand Up @@ -2107,7 +2107,6 @@ func (s *Server) fetchAccount(name string) (*Account, error) {
return nil, err
}
acc := s.buildInternalAccount(accClaims)
acc.claimJWT = claimJWT
// Due to possible race, if registerAccount() returns a non
// nil account, it means the same account was already
// registered and we should use this one.
Expand All @@ -2123,6 +2122,7 @@ func (s *Server) fetchAccount(name string) (*Account, error) {
var needImportSubs bool

acc.mu.Lock()
acc.claimJWT = claimJWT
if len(acc.imports.services) > 0 {
if acc.ic == nil {
acc.ic = s.createInternalAccountClient()
Expand Down