diff --git a/server/clustering.go b/server/clustering.go index 003a7303..9f93bf3e 100644 --- a/server/clustering.go +++ b/server/clustering.go @@ -738,6 +738,7 @@ func (r *raftFSM) lookupOrCreateChannel(name string, id uint64) (*channel, error return nil, err } delete(cs.channels, name) + delete(s.channels.channelsLC, strings.ToLower(name)) } // Channel does exist or has been deleted. Create now with given ID. return cs.createChannelLocked(s, name, id) diff --git a/server/server.go b/server/server.go index 46f97f07..66c611f4 100644 --- a/server/server.go +++ b/server/server.go @@ -274,16 +274,18 @@ func (state State) String() string { type channelStore struct { sync.RWMutex - channels map[string]*channel - store stores.Store - stan *StanServer + channels map[string]*channel + channelsLC map[string]*channel + store stores.Store + stan *StanServer } func newChannelStore(srv *StanServer, s stores.Store) *channelStore { cs := &channelStore{ - channels: make(map[string]*channel), - store: s, - stan: srv, + channels: make(map[string]*channel), + channelsLC: make(map[string]*channel), + store: s, + stan: srv, } return cs } @@ -313,6 +315,13 @@ func (cs *channelStore) createChannel(s *StanServer, name string) (*channel, err return c, err } +func (cs *channelStore) checkCase(name string) error { + if c := cs.channelsLC[strings.ToLower(name)]; c != nil { + return fmt.Errorf("rejecting channel %q because channel %q alreay exists (different cases not allowed)", name, c.name) + } + return nil +} + func (cs *channelStore) createChannelLocked(s *StanServer, name string, id uint64) (retChan *channel, retErr error) { defer func() { if retErr != nil { @@ -329,6 +338,8 @@ func (cs *channelStore) createChannelLocked(s *StanServer, name string, id uint6 return nil, ErrChanDelInProgress } return c, nil + } else if err := cs.checkCase(name); err != nil { + return nil, err } if s.isClustered { if s.isLeader() && id == 0 { @@ -370,6 +381,7 @@ func (cs *channelStore) create(s *StanServer, name string, sc *stores.Channel) ( } c.nextSequence = lastSequence + 1 cs.channels[name] = c + cs.channelsLC[strings.ToLower(name)] = c cl := cs.store.GetChannelLimits(name) if cl.MaxInactivity > 0 { c.activity = &channelActivity{maxInactivity: cl.MaxInactivity} @@ -900,6 +912,9 @@ func (s *StanServer) lookupOrCreateChannel(name string) (*channel, error) { } cs.RUnlock() return c, nil + } else if err := cs.checkCase(name); err != nil { + cs.RUnlock() + return nil, err } cs.RUnlock() return cs.createChannel(s, name) @@ -914,6 +929,9 @@ func (s *StanServer) lookupOrCreateChannelPreventDelete(name string) (*channel, cs.Unlock() return nil, false, ErrChanDelInProgress } + } else if err := cs.checkCase(name); err != nil { + cs.Unlock() + return nil, false, err } else { var err error c, err = cs.createChannelLocked(s, name, 0) @@ -3127,6 +3145,7 @@ func (s *StanServer) processDeleteChannel(channel string) { return } delete(s.channels.channels, channel) + delete(s.channels.channelsLC, strings.ToLower(channel)) s.log.Noticef("Channel %q has been deleted", channel) } @@ -5287,6 +5306,7 @@ func (s *StanServer) processSubscriptionRequest(m *nats.Msg) { } } if err != nil { + s.log.Errorf("Unable to create subscription on %q: %v", sr.Subject, err) s.channels.turnOffPreventDelete(c) s.channels.maybeStartChannelDeleteTimer(sr.Subject, c) s.sendSubscriptionResponseErr(m.Reply, err) diff --git a/server/server_test.go b/server/server_test.go index fc9997e5..299f4a04 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1703,3 +1703,149 @@ func TestInternalSubsLimits(t *testing.T) { }) } } + +func TestChannelNameRejectedIfAlreadyExistsWithDifferentCase(t *testing.T) { + for _, tinfo := range []struct { + name string + st string + restart bool + }{ + {"memory", stores.TypeMemory, false}, + {"file", stores.TypeFile, true}, + {"sql", stores.TypeSQL, true}, + {"clustered", stores.TypeFile, true}, + } { + t.Run(tinfo.name, func(t *testing.T) { + var o *Options + if tinfo.st == stores.TypeSQL { + if !doSQL { + t.SkipNow() + } + } + // Force persistent store to be the tinfo.st for this test. + orgps := persistentStoreType + persistentStoreType = tinfo.st + defer func() { persistentStoreType = orgps }() + + if tinfo.st == stores.TypeSQL || tinfo.st == stores.TypeFile { + o = getTestDefaultOptsForPersistentStore() + } else if tinfo.st == stores.TypeMemory { + o = GetDefaultOptions() + } + + cleanupDatastore(t) + defer cleanupDatastore(t) + cleanupRaftLog(t) + defer cleanupRaftLog(t) + + var servers []*StanServer + if tinfo.name == "clustered" { + ns := natsdTest.RunDefaultServer() + defer ns.Shutdown() + + o1 := getTestDefaultOptsForClustering("a", true) + servers = append(servers, runServerWithOpts(t, o1, nil)) + o2 := getTestDefaultOptsForClustering("b", false) + servers = append(servers, runServerWithOpts(t, o2, nil)) + o3 := getTestDefaultOptsForClustering("c", false) + servers = append(servers, runServerWithOpts(t, o3, nil)) + } else { + servers = append(servers, runServerWithOpts(t, o, nil)) + } + for _, s := range servers { + defer s.Shutdown() + } + + sc := NewDefaultConnection(t) + defer sc.Close() + + sendOK := func(channel, content string) { + t.Helper() + if err := sc.Publish(channel, []byte(content)); err != nil { + t.Fatalf("Error on send: %v", err) + } + } + sendFail := func(channel, content string) { + t.Helper() + err := sc.Publish(channel, []byte(content)) + if err == nil || !strings.Contains(err.Error(), "exists") { + t.Fatalf("Expected error that channel already exists, got: %v", err) + } + } + sendOK("Foo", "1") + sendOK("Foo", "2") + sendOK("Foo", "3") + sendOK("Foo", "4") + // Change channel name case + sendFail("foo", "1") + sendFail("foo", "2") + // Back to "Foo" + sendOK("Foo", "5") + sendOK("Foo", "6") + + recvOK := func(channel string) { + t.Helper() + + ch := make(chan *stan.Msg, 6) + sub, err := sc.Subscribe(channel, func(m *stan.Msg) { + ch <- m + }, stan.DeliverAllAvailable()) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer sub.Unsubscribe() + + // We want to get all 6 messages + for i := 0; i < 6; i++ { + select { + case m := <-ch: + if v, err := strconv.ParseInt(string(m.Data), 10, 64); err != nil || int(v) != i+1 { + t.Fatalf("Invalid message %v: %s", i+1, m.Data) + } + case <-time.After(time.Second): + t.Fatalf("Failed receiving message %v", i+1) + } + } + } + recvFail := func(channel string) { + t.Helper() + + _, err := sc.Subscribe(channel, func(m *stan.Msg) {}, stan.DeliverAllAvailable()) + if err == nil || !strings.Contains(err.Error(), "exists") { + t.Fatalf("Expected error that channel already exists, got %v", err) + } + } + recvOK("Foo") + recvFail("foo") + recvFail("FoO") + + if !tinfo.restart { + return + } + + sc.Close() + + for i, s := range servers { + s.Shutdown() + s.mu.RLock() + opts := s.opts + s.mu.RUnlock() + s = runServerWithOpts(t, opts, nil) + defer s.Shutdown() + servers[i] = s + } + + if tinfo.name == "clustered" { + getLeader(t, 10*time.Second, servers...) + } + + sc = NewDefaultConnection(t) + defer sc.Close() + + // Try to receive again, but change the order... + recvFail("foo") + recvOK("Foo") + recvFail("FoO") + }) + } +}