diff --git a/pkg/messaging/mqtt/pubsub.go b/pkg/messaging/mqtt/pubsub.go index bc158f4bff..a479062979 100644 --- a/pkg/messaging/mqtt/pubsub.go +++ b/pkg/messaging/mqtt/pubsub.go @@ -25,7 +25,6 @@ var ( errSubscribeTimeout = errors.New("failed to subscribe due to timeout reached") errUnsubscribeTimeout = errors.New("failed to unsubscribe due to timeout reached") errUnsubscribeDeleteTopic = errors.New("failed to unsubscribe due to deletion of topic") - errAlreadySubscribed = errors.New("already subscribed to topic") errNotSubscribed = errors.New("not subscribed") errEmptyTopic = errors.New("empty topic") errEmptyID = errors.New("empty ID") @@ -47,6 +46,7 @@ type pubsub struct { subscriptions map[string]subscription } +// NewPubSub returns MQTT message publisher/subscriber. func NewPubSub(url, queue string, timeout time.Duration, logger log.Logger) (messaging.PubSub, error) { client, err := newClient(url, "mqtt-publisher", timeout) if err != nil { @@ -80,7 +80,23 @@ func (ps pubsub) Subscribe(id, topic string, handler messaging.MessageHandler) e case true: // Check topic if ok = s.contains(topic); ok { - return errAlreadySubscribed + // Unlocking, so that Unsubscribe() can access ps.subscriptions + ps.mu.Unlock() + err := ps.Unsubscribe(id, topic) + ps.mu.Lock() // Lock so that deferred unlock handle it + if err != nil { + return err + } + if len(ps.subscriptions) == 0 { + client, err := newClient(ps.address, id, ps.timeout) + if err != nil { + return err + } + s = subscription{ + client: client, + topics: []string{topic}, + } + } } s.topics = append(s.topics, topic) default: @@ -93,6 +109,7 @@ func (ps pubsub) Subscribe(id, topic string, handler messaging.MessageHandler) e topics: []string{topic}, } } + token := s.client.Subscribe(topic, qos, ps.mqttHandler(handler)) if token.Error() != nil { return token.Error() diff --git a/pkg/messaging/nats/pubsub.go b/pkg/messaging/nats/pubsub.go index c7b406584e..397ce88069 100644 --- a/pkg/messaging/nats/pubsub.go +++ b/pkg/messaging/nats/pubsub.go @@ -22,6 +22,7 @@ var ( ErrNotSubscribed = errors.New("not subscribed") ErrEmptyTopic = errors.New("empty topic") ErrEmptyID = errors.New("empty id") + ErrFailed = errors.New("failed") ) var _ messaging.PubSub = (*pubsub)(nil) @@ -69,20 +70,30 @@ func (ps *pubsub) Subscribe(id, topic string, handler messaging.MessageHandler) if topic == "" { return ErrEmptyTopic } + ps.mu.Lock() - defer ps.mu.Unlock() // Check topic s, ok := ps.subscriptions[topic] - switch ok { - case true: - // Check topic ID + if ok { + // Check client ID if _, ok := s[id]; ok { - return ErrAlreadySubscribed + // Unlocking, so that Unsubscribe() can access ps.subscriptions + ps.mu.Unlock() + if err := ps.Unsubscribe(id, topic); err != nil { + return err + } + + ps.mu.Lock() + // value of s can be changed while ps.mu is unlocked + s = ps.subscriptions[topic] } - default: + } + defer ps.mu.Unlock() + if s == nil { s = make(map[string]subscription) ps.subscriptions[topic] = s } + nh := ps.natsHandler(handler) if ps.queue != "" { @@ -104,6 +115,7 @@ func (ps *pubsub) Subscribe(id, topic string, handler messaging.MessageHandler) Subscription: sub, cancel: handler.Cancel, } + return nil } diff --git a/pkg/messaging/nats/pubsub_test.go b/pkg/messaging/nats/pubsub_test.go index 4b830e4c1d..428bb19721 100644 --- a/pkg/messaging/nats/pubsub_test.go +++ b/pkg/messaging/nats/pubsub_test.go @@ -88,6 +88,7 @@ func TestPubsub(t *testing.T) { clientID string errorMessage error pubsub bool //true for subscribe and false for unsubscribe + handler messaging.MessageHandler }{ { desc: "Subscribe to a topic with an ID", @@ -95,6 +96,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Subscribe to the same topic with a different ID", @@ -102,13 +104,15 @@ func TestPubsub(t *testing.T) { clientID: "clientid2", errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Subscribe to an already subscribed topic with an ID", topic: fmt.Sprintf("%s.%s", chansPrefix, topic), clientID: "clientid1", - errorMessage: nats.ErrAlreadySubscribed, + errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Unsubscribe from a topic with an ID", @@ -116,6 +120,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: nil, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from a non-existent topic with an ID", @@ -123,6 +128,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: nats.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from the same topic with a different ID", @@ -130,6 +136,7 @@ func TestPubsub(t *testing.T) { clientID: "clientidd2", errorMessage: nats.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from the same topic with a different ID not subscribed", @@ -137,6 +144,7 @@ func TestPubsub(t *testing.T) { clientID: "clientidd3", errorMessage: nats.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from an already unsubscribed topic with an ID", @@ -144,6 +152,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: nats.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Subscribe to a topic with a subtopic with an ID", @@ -151,13 +160,15 @@ func TestPubsub(t *testing.T) { clientID: "clientidd1", errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Subscribe to an already subscribed topic with a subtopic with an ID", topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic), clientID: "clientidd1", - errorMessage: nats.ErrAlreadySubscribed, + errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Unsubscribe from a topic with a subtopic with an ID", @@ -165,6 +176,7 @@ func TestPubsub(t *testing.T) { clientID: "clientidd1", errorMessage: nil, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID", @@ -172,6 +184,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: nats.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Subscribe to an empty topic with an ID", @@ -179,6 +192,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: nats.ErrEmptyTopic, pubsub: true, + handler: handler{false}, }, { desc: "Unsubscribe from an empty topic with an ID", @@ -186,6 +200,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: nats.ErrEmptyTopic, pubsub: false, + handler: handler{false}, }, { desc: "Subscribe to a topic with empty id", @@ -193,6 +208,7 @@ func TestPubsub(t *testing.T) { clientID: "", errorMessage: nats.ErrEmptyID, pubsub: true, + handler: handler{false}, }, { desc: "Unsubscribe from a topic with empty id", @@ -200,12 +216,45 @@ func TestPubsub(t *testing.T) { clientID: "", errorMessage: nats.ErrEmptyID, pubsub: false, + handler: handler{false}, + }, + { + desc: "Subscribe to another topic with an ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"1"), + clientID: "clientid3", + errorMessage: nil, + pubsub: true, + handler: handler{true}, + }, + { + desc: "Subscribe to another already subscribed topic with an ID with Unsubscribe failing", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"1"), + clientID: "clientid3", + errorMessage: nats.ErrFailed, + pubsub: true, + handler: handler{true}, + }, + { + desc: "Subscribe to a new topic with an ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"), + clientID: "clientid4", + errorMessage: nil, + pubsub: true, + handler: handler{true}, + }, + { + desc: "Unsubscribe from a topic with an ID with failing handler", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"), + clientID: "clientid4", + errorMessage: nats.ErrFailed, + pubsub: false, + handler: handler{true}, }, } for _, pc := range subcases { if pc.pubsub == true { - err := pubsub.Subscribe(pc.clientID, pc.topic, handler{}) + err := pubsub.Subscribe(pc.clientID, pc.topic, pc.handler) if pc.errorMessage == nil { require.Nil(t, err, fmt.Sprintf("%s got unexpected error: %s", pc.desc, err)) } else { @@ -222,7 +271,9 @@ func TestPubsub(t *testing.T) { } } -type handler struct{} +type handler struct { + fail bool +} func (h handler) Handle(msg messaging.Message) error { msgChan <- msg @@ -230,5 +281,8 @@ func (h handler) Handle(msg messaging.Message) error { } func (h handler) Cancel() error { + if h.fail { + return nats.ErrFailed + } return nil } diff --git a/pkg/messaging/rabbitmq/pubsub.go b/pkg/messaging/rabbitmq/pubsub.go index 632f2f52a1..3543646f2f 100644 --- a/pkg/messaging/rabbitmq/pubsub.go +++ b/pkg/messaging/rabbitmq/pubsub.go @@ -26,6 +26,7 @@ var ( ErrNotSubscribed = errors.New("not subscribed") ErrEmptyTopic = errors.New("empty topic") ErrEmptyID = errors.New("empty id") + ErrFailed = errors.New("failed") ) var _ messaging.PubSub = (*pubsub)(nil) @@ -72,16 +73,24 @@ func (ps *pubsub) Subscribe(id, topic string, handler messaging.MessageHandler) return ErrEmptyTopic } ps.mu.Lock() - defer ps.mu.Unlock() // Check topic s, ok := ps.subscriptions[topic] - switch ok { - case true: - // Check topic ID + if ok { + // Check client ID if _, ok := s[id]; ok { - return ErrAlreadySubscribed + // Unlocking, so that Unsubscribe() can access ps.subscriptions + ps.mu.Unlock() + if err := ps.Unsubscribe(id, topic); err != nil { + return err + } + + ps.mu.Lock() + // value of s can be changed while ps.mu is unlocked + s = ps.subscriptions[topic] } - default: + } + defer ps.mu.Unlock() + if s == nil { s = make(map[string]subscription) ps.subscriptions[topic] = s } diff --git a/pkg/messaging/rabbitmq/pubsub_test.go b/pkg/messaging/rabbitmq/pubsub_test.go index 35d95e909e..f44057c1eb 100644 --- a/pkg/messaging/rabbitmq/pubsub_test.go +++ b/pkg/messaging/rabbitmq/pubsub_test.go @@ -85,6 +85,7 @@ func TestPubsub(t *testing.T) { clientID string errorMessage error pubsub bool //true for subscribe and false for unsubscribe + handler messaging.MessageHandler }{ { desc: "Subscribe to a topic with an ID", @@ -92,6 +93,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Subscribe to the same topic with a different ID", @@ -99,13 +101,15 @@ func TestPubsub(t *testing.T) { clientID: "clientid2", errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Subscribe to an already subscribed topic with an ID", topic: fmt.Sprintf("%s.%s", chansPrefix, topic), clientID: "clientid1", - errorMessage: rabbitmq.ErrAlreadySubscribed, + errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Unsubscribe from a topic with an ID", @@ -113,6 +117,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: nil, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from a non-existent topic with an ID", @@ -120,6 +125,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: rabbitmq.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from the same topic with a different ID", @@ -127,6 +133,7 @@ func TestPubsub(t *testing.T) { clientID: "clientidd2", errorMessage: rabbitmq.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from the same topic with a different ID not subscribed", @@ -134,6 +141,7 @@ func TestPubsub(t *testing.T) { clientID: "clientidd3", errorMessage: rabbitmq.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from an already unsubscribed topic with an ID", @@ -141,6 +149,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: rabbitmq.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Subscribe to a topic with a subtopic with an ID", @@ -148,13 +157,15 @@ func TestPubsub(t *testing.T) { clientID: "clientidd1", errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Subscribe to an already subscribed topic with a subtopic with an ID", topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic), clientID: "clientidd1", - errorMessage: rabbitmq.ErrAlreadySubscribed, + errorMessage: nil, pubsub: true, + handler: handler{false}, }, { desc: "Unsubscribe from a topic with a subtopic with an ID", @@ -162,6 +173,7 @@ func TestPubsub(t *testing.T) { clientID: "clientidd1", errorMessage: nil, pubsub: false, + handler: handler{false}, }, { desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID", @@ -169,6 +181,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: rabbitmq.ErrNotSubscribed, pubsub: false, + handler: handler{false}, }, { desc: "Subscribe to an empty topic with an ID", @@ -176,6 +189,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: rabbitmq.ErrEmptyTopic, pubsub: true, + handler: handler{false}, }, { desc: "Unsubscribe from an empty topic with an ID", @@ -183,6 +197,7 @@ func TestPubsub(t *testing.T) { clientID: "clientid1", errorMessage: rabbitmq.ErrEmptyTopic, pubsub: false, + handler: handler{false}, }, { desc: "Subscribe to a topic with empty id", @@ -190,6 +205,7 @@ func TestPubsub(t *testing.T) { clientID: "", errorMessage: rabbitmq.ErrEmptyID, pubsub: true, + handler: handler{false}, }, { desc: "Unsubscribe from a topic with empty id", @@ -197,12 +213,45 @@ func TestPubsub(t *testing.T) { clientID: "", errorMessage: rabbitmq.ErrEmptyID, pubsub: false, + handler: handler{false}, + }, + { + desc: "Subscribe to another topic with an ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"1"), + clientID: "clientid3", + errorMessage: nil, + pubsub: true, + handler: handler{true}, + }, + { + desc: "Subscribe to another already subscribed topic with an ID with Unsubscribe failing", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"1"), + clientID: "clientid3", + errorMessage: rabbitmq.ErrFailed, + pubsub: true, + handler: handler{true}, + }, + { + desc: "Subscribe to a new topic with an ID", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"), + clientID: "clientid4", + errorMessage: nil, + pubsub: true, + handler: handler{true}, + }, + { + desc: "Unsubscribe from a topic with an ID with failing handler", + topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"), + clientID: "clientid4", + errorMessage: rabbitmq.ErrFailed, + pubsub: false, + handler: handler{true}, }, } for _, pc := range subcases { if pc.pubsub == true { - err := pubsub.Subscribe(pc.clientID, pc.topic, handler{}) + err := pubsub.Subscribe(pc.clientID, pc.topic, pc.handler) if pc.errorMessage == nil { require.Nil(t, err, fmt.Sprintf("%s got unexpected error: %s", pc.desc, err)) } else { @@ -219,7 +268,9 @@ func TestPubsub(t *testing.T) { } } -type handler struct{} +type handler struct { + fail bool +} func (h handler) Handle(msg messaging.Message) error { msgChan <- msg @@ -227,5 +278,8 @@ func (h handler) Handle(msg messaging.Message) error { } func (h handler) Cancel() error { + if h.fail { + return rabbitmq.ErrFailed + } return nil }