Skip to content

Commit

Permalink
Merge pull request #1274 from nats-io/fix_1265
Browse files Browse the repository at this point in the history
[CHANGED] Reject channels with different case (Foo vs foo)
  • Loading branch information
kozlovic authored Oct 11, 2022
2 parents 432ad7c + 497a6a0 commit 069cc96
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 6 deletions.
1 change: 1 addition & 0 deletions server/clustering.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ func (r *raftFSM) lookupOrCreateChannel(name string, id uint64) (*channel, error
return nil, err
}
delete(cs.channels, name)
delete(s.channels.channelsLC, strings.ToLower(name))
}
// Channel does exist or has been deleted. Create now with given ID.
return cs.createChannelLocked(s, name, id)
Expand Down
32 changes: 26 additions & 6 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,18 @@ func (state State) String() string {

type channelStore struct {
sync.RWMutex
channels map[string]*channel
store stores.Store
stan *StanServer
channels map[string]*channel
channelsLC map[string]*channel
store stores.Store
stan *StanServer
}

func newChannelStore(srv *StanServer, s stores.Store) *channelStore {
cs := &channelStore{
channels: make(map[string]*channel),
store: s,
stan: srv,
channels: make(map[string]*channel),
channelsLC: make(map[string]*channel),
store: s,
stan: srv,
}
return cs
}
Expand Down Expand Up @@ -313,6 +315,13 @@ func (cs *channelStore) createChannel(s *StanServer, name string) (*channel, err
return c, err
}

func (cs *channelStore) checkCase(name string) error {
if c := cs.channelsLC[strings.ToLower(name)]; c != nil {
return fmt.Errorf("rejecting channel %q because channel %q alreay exists (different cases not allowed)", name, c.name)
}
return nil
}

func (cs *channelStore) createChannelLocked(s *StanServer, name string, id uint64) (retChan *channel, retErr error) {
defer func() {
if retErr != nil {
Expand All @@ -329,6 +338,8 @@ func (cs *channelStore) createChannelLocked(s *StanServer, name string, id uint6
return nil, ErrChanDelInProgress
}
return c, nil
} else if err := cs.checkCase(name); err != nil {
return nil, err
}
if s.isClustered {
if s.isLeader() && id == 0 {
Expand Down Expand Up @@ -370,6 +381,7 @@ func (cs *channelStore) create(s *StanServer, name string, sc *stores.Channel) (
}
c.nextSequence = lastSequence + 1
cs.channels[name] = c
cs.channelsLC[strings.ToLower(name)] = c
cl := cs.store.GetChannelLimits(name)
if cl.MaxInactivity > 0 {
c.activity = &channelActivity{maxInactivity: cl.MaxInactivity}
Expand Down Expand Up @@ -900,6 +912,9 @@ func (s *StanServer) lookupOrCreateChannel(name string) (*channel, error) {
}
cs.RUnlock()
return c, nil
} else if err := cs.checkCase(name); err != nil {
cs.RUnlock()
return nil, err
}
cs.RUnlock()
return cs.createChannel(s, name)
Expand All @@ -914,6 +929,9 @@ func (s *StanServer) lookupOrCreateChannelPreventDelete(name string) (*channel,
cs.Unlock()
return nil, false, ErrChanDelInProgress
}
} else if err := cs.checkCase(name); err != nil {
cs.Unlock()
return nil, false, err
} else {
var err error
c, err = cs.createChannelLocked(s, name, 0)
Expand Down Expand Up @@ -3127,6 +3145,7 @@ func (s *StanServer) processDeleteChannel(channel string) {
return
}
delete(s.channels.channels, channel)
delete(s.channels.channelsLC, strings.ToLower(channel))
s.log.Noticef("Channel %q has been deleted", channel)
}

Expand Down Expand Up @@ -5287,6 +5306,7 @@ func (s *StanServer) processSubscriptionRequest(m *nats.Msg) {
}
}
if err != nil {
s.log.Errorf("Unable to create subscription on %q: %v", sr.Subject, err)
s.channels.turnOffPreventDelete(c)
s.channels.maybeStartChannelDeleteTimer(sr.Subject, c)
s.sendSubscriptionResponseErr(m.Reply, err)
Expand Down
146 changes: 146 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1703,3 +1703,149 @@ func TestInternalSubsLimits(t *testing.T) {
})
}
}

func TestChannelNameRejectedIfAlreadyExistsWithDifferentCase(t *testing.T) {
for _, tinfo := range []struct {
name string
st string
restart bool
}{
{"memory", stores.TypeMemory, false},
{"file", stores.TypeFile, true},
{"sql", stores.TypeSQL, true},
{"clustered", stores.TypeFile, true},
} {
t.Run(tinfo.name, func(t *testing.T) {
var o *Options
if tinfo.st == stores.TypeSQL {
if !doSQL {
t.SkipNow()
}
}
// Force persistent store to be the tinfo.st for this test.
orgps := persistentStoreType
persistentStoreType = tinfo.st
defer func() { persistentStoreType = orgps }()

if tinfo.st == stores.TypeSQL || tinfo.st == stores.TypeFile {
o = getTestDefaultOptsForPersistentStore()
} else if tinfo.st == stores.TypeMemory {
o = GetDefaultOptions()
}

cleanupDatastore(t)
defer cleanupDatastore(t)
cleanupRaftLog(t)
defer cleanupRaftLog(t)

var servers []*StanServer
if tinfo.name == "clustered" {
ns := natsdTest.RunDefaultServer()
defer ns.Shutdown()

o1 := getTestDefaultOptsForClustering("a", true)
servers = append(servers, runServerWithOpts(t, o1, nil))
o2 := getTestDefaultOptsForClustering("b", false)
servers = append(servers, runServerWithOpts(t, o2, nil))
o3 := getTestDefaultOptsForClustering("c", false)
servers = append(servers, runServerWithOpts(t, o3, nil))
} else {
servers = append(servers, runServerWithOpts(t, o, nil))
}
for _, s := range servers {
defer s.Shutdown()
}

sc := NewDefaultConnection(t)
defer sc.Close()

sendOK := func(channel, content string) {
t.Helper()
if err := sc.Publish(channel, []byte(content)); err != nil {
t.Fatalf("Error on send: %v", err)
}
}
sendFail := func(channel, content string) {
t.Helper()
err := sc.Publish(channel, []byte(content))
if err == nil || !strings.Contains(err.Error(), "exists") {
t.Fatalf("Expected error that channel already exists, got: %v", err)
}
}
sendOK("Foo", "1")
sendOK("Foo", "2")
sendOK("Foo", "3")
sendOK("Foo", "4")
// Change channel name case
sendFail("foo", "1")
sendFail("foo", "2")
// Back to "Foo"
sendOK("Foo", "5")
sendOK("Foo", "6")

recvOK := func(channel string) {
t.Helper()

ch := make(chan *stan.Msg, 6)
sub, err := sc.Subscribe(channel, func(m *stan.Msg) {
ch <- m
}, stan.DeliverAllAvailable())
if err != nil {
t.Fatalf("Error on subscribe: %v", err)
}
defer sub.Unsubscribe()

// We want to get all 6 messages
for i := 0; i < 6; i++ {
select {
case m := <-ch:
if v, err := strconv.ParseInt(string(m.Data), 10, 64); err != nil || int(v) != i+1 {
t.Fatalf("Invalid message %v: %s", i+1, m.Data)
}
case <-time.After(time.Second):
t.Fatalf("Failed receiving message %v", i+1)
}
}
}
recvFail := func(channel string) {
t.Helper()

_, err := sc.Subscribe(channel, func(m *stan.Msg) {}, stan.DeliverAllAvailable())
if err == nil || !strings.Contains(err.Error(), "exists") {
t.Fatalf("Expected error that channel already exists, got %v", err)
}
}
recvOK("Foo")
recvFail("foo")
recvFail("FoO")

if !tinfo.restart {
return
}

sc.Close()

for i, s := range servers {
s.Shutdown()
s.mu.RLock()
opts := s.opts
s.mu.RUnlock()
s = runServerWithOpts(t, opts, nil)
defer s.Shutdown()
servers[i] = s
}

if tinfo.name == "clustered" {
getLeader(t, 10*time.Second, servers...)
}

sc = NewDefaultConnection(t)
defer sc.Close()

// Try to receive again, but change the order...
recvFail("foo")
recvOK("Foo")
recvFail("FoO")
})
}
}

0 comments on commit 069cc96

Please sign in to comment.