From 8323b6bc08766986fc885ada19abd1587c8ffe55 Mon Sep 17 00:00:00 2001 From: Jordi Puig Bou Date: Wed, 6 Apr 2022 10:48:32 +0200 Subject: [PATCH] fix(MAGNETO-7807): Fixed rabbitmq sonar issues (#86) * fix(MAGNETO-7807): Fixed rabbitmq sonar issues * fix(MAGNETO-7807): Session init fix * fix(MAGNETO-7807): Remeved wrong struct field * fix(MAGNETO-7807): Removed wrong empty line * fix(MAGNETO-7807): Fixed code smells --- go.mod | 2 +- steps/rabbit/amqp.go | 74 ++++ steps/rabbit/amqp_mock.go | 68 ++++ steps/rabbit/context.go | 2 +- steps/rabbit/session.go | 163 +++++---- steps/rabbit/session_test.go | 644 +++++++++++++++++++++++++++++++++++ steps/rabbit/steps.go | 4 +- 7 files changed, 869 insertions(+), 88 deletions(-) create mode 100644 steps/rabbit/amqp.go create mode 100644 steps/rabbit/amqp_mock.go create mode 100644 steps/rabbit/session_test.go diff --git a/go.mod b/go.mod index 71cda6d5..63c6b785 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/AdguardTeam/dnsproxy v0.41.2 github.com/aws/aws-sdk-go v1.38.71 github.com/cucumber/godog v0.12.0 - github.com/cucumber/messages-go/v16 v16.0.1 // indirect + github.com/cucumber/messages-go/v16 v16.0.1 github.com/elastic/go-elasticsearch/v7 v7.12.0 github.com/go-redis/redis/v8 v8.7.1 github.com/google/uuid v1.2.0 diff --git a/steps/rabbit/amqp.go b/steps/rabbit/amqp.go new file mode 100644 index 00000000..04b7064a --- /dev/null +++ b/steps/rabbit/amqp.go @@ -0,0 +1,74 @@ +// Copyright 2021 Telefonica Cybersecurity & Cloud Tech SL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rabbit + +import ( + "github.com/streadway/amqp" +) + +type AMQPServiceFunctions interface { + Dial(url string) (*amqp.Connection, error) + ConnectionChannel(c *amqp.Connection) (*amqp.Channel, error) + ChannelExchangeDeclare(channel *amqp.Channel, name string) error + ChannelQueueDeclare(channel *amqp.Channel) (amqp.Queue, error) + ChannelQueueBind(channel *amqp.Channel, name, exchange string) error + ChannelConsume(channel *amqp.Channel, queue string, + ) (<-chan amqp.Delivery, error) + ChannelClose(channel *amqp.Channel) error + ChannelPublish(channel *amqp.Channel, + exchange string, + msg amqp.Publishing, + ) error +} + +type AMQPService struct{} + +func NewAMQPService() *AMQPService { + return &AMQPService{} +} + +func (a AMQPService) Dial(url string) (*amqp.Connection, error) { + return amqp.Dial(url) +} + +func (a AMQPService) ConnectionChannel(connection *amqp.Connection) (*amqp.Channel, error) { + return connection.Channel() +} +func (a AMQPService) ChannelExchangeDeclare(channel *amqp.Channel, name string) error { + return channel.ExchangeDeclare(name, "fanout", true, false, false, false, nil) +} + +func (a AMQPService) ChannelQueueDeclare(channel *amqp.Channel) (amqp.Queue, error) { + return channel.QueueDeclare("", false, true, true, false, nil) +} +func (a AMQPService) ChannelQueueBind(channel *amqp.Channel, name, exchange string) error { + return channel.QueueBind(name, "", exchange, false, nil) +} + +func (a AMQPService) ChannelConsume(channel *amqp.Channel, queue string, +) (<-chan amqp.Delivery, error) { + return channel.Consume(queue, "", true, false, false, false, nil) +} + +func (a AMQPService) ChannelClose(channel *amqp.Channel) error { + return channel.Close() +} + +func (a AMQPService) ChannelPublish(channel *amqp.Channel, + exchange string, + msg amqp.Publishing, +) error { + return channel.Publish(exchange, "", false, false, msg) +} diff --git a/steps/rabbit/amqp_mock.go b/steps/rabbit/amqp_mock.go new file mode 100644 index 00000000..d7921777 --- /dev/null +++ b/steps/rabbit/amqp_mock.go @@ -0,0 +1,68 @@ +// Copyright 2021 Telefonica Cybersecurity & Cloud Tech SL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rabbit + +import ( + "github.com/streadway/amqp" +) + +var ( + DialError error + ConnectionChannelError error + ChannelExchangeDeclareError error + ChannelQueueDeclareError error + ChannelQueueBindError error + ChannelConsumeError error + MockSubCh <-chan amqp.Delivery + ChannelPublishError error +) + +type AMQPServiceFuncMock struct{} + +func (a AMQPServiceFuncMock) Dial(url string) (*amqp.Connection, error) { + return nil, DialError +} + +func (a AMQPServiceFuncMock) ConnectionChannel(c *amqp.Connection) (*amqp.Channel, error) { + return nil, ConnectionChannelError +} + +func (a AMQPServiceFuncMock) ChannelExchangeDeclare(channel *amqp.Channel, name string) error { + return ChannelExchangeDeclareError +} + +func (a AMQPServiceFuncMock) ChannelQueueDeclare(channel *amqp.Channel) (amqp.Queue, error) { + amqpQueue := amqp.Queue{} + return amqpQueue, ChannelQueueDeclareError +} + +func (a AMQPServiceFuncMock) ChannelQueueBind(channel *amqp.Channel, name, exchange string) error { + return ChannelQueueBindError +} + +func (a AMQPServiceFuncMock) ChannelConsume(channel *amqp.Channel, queue string, +) (<-chan amqp.Delivery, error) { + return MockSubCh, ChannelConsumeError +} +func (a AMQPServiceFuncMock) ChannelClose(channel *amqp.Channel) error { + return nil +} + +func (a AMQPServiceFuncMock) ChannelPublish(channel *amqp.Channel, + exchange string, + msg amqp.Publishing, +) error { + return ChannelPublishError +} diff --git a/steps/rabbit/context.go b/steps/rabbit/context.go index c7ae3d54..dfc430be 100755 --- a/steps/rabbit/context.go +++ b/steps/rabbit/context.go @@ -26,7 +26,7 @@ const contextKey ContextKey = "rabbitSession" // InitializeContext adds the rabbit session to the context. // The new context is returned because context is immutable. func InitializeContext(ctx context.Context) context.Context { - return context.WithValue(ctx, contextKey, &Session{}) + return context.WithValue(ctx, contextKey, &Session{AMQPService: *NewAMQPService()}) } // GetSession returns the rabbit session stored in context. diff --git a/steps/rabbit/session.go b/steps/rabbit/session.go index 445eb4e3..d55eb550 100644 --- a/steps/rabbit/session.go +++ b/steps/rabbit/session.go @@ -46,12 +46,14 @@ type Session struct { publishing amqp.Publishing // rabbit received delivery message msg amqp.Delivery + // ampq service + AMQPService AMQPServiceFunctions } // ConfigureConnection creates a rabbit connection based on the URI. func (s *Session) ConfigureConnection(ctx context.Context, uri string) error { var err error - s.Connection, err = amqp.Dial(uri) + s.Connection, err = s.AMQPService.Dial(uri) if err != nil { return fmt.Errorf("failed configuring connection '%s': %w", uri, err) } @@ -79,52 +81,23 @@ func (s *Session) ConfigureStandardProperties(ctx context.Context, props amqp.Pu func (s *Session) SubscribeTopic(ctx context.Context, topic string) error { GetLogger().LogSubscribedTopic(topic) var err error - s.channel, err = s.Connection.Channel() + s.channel, err = s.AMQPService.ConnectionChannel(s.Connection) if err != nil { return errors.Wrap(err, "failed to open a channel") } - err = s.channel.ExchangeDeclare( - topic, // name - "fanout", // type - true, // durable - false, // auto-deleted - false, // internal - false, // no-wait - nil, // arguments - ) + err = s.AMQPService.ChannelExchangeDeclare(s.channel, topic) if err != nil { return errors.Wrap(err, "failed to declare an exchange") } - q, err := s.channel.QueueDeclare( - "", // name - false, // durable - true, // delete when unused - true, // exclusive - false, // no-wait - nil, // arguments - ) + q, err := s.AMQPService.ChannelQueueDeclare(s.channel) if err != nil { return errors.Wrap(err, "failed to declare a queue") } - err = s.channel.QueueBind( - q.Name, // queue name - "", // routing key - topic, // exchange - false, - nil, - ) + err = s.AMQPService.ChannelQueueBind(s.channel, q.Name, topic) if err != nil { return errors.Wrap(err, "failed to bind a queue") } - s.subCh, err = s.channel.Consume( - q.Name, // queue - "", // consumer - true, // auto-ack - false, // exclusive - false, // no-local - false, // no-wait - nil, // args - ) + s.subCh, err = s.AMQPService.ChannelConsume(s.channel, q.Name) go func() { logrus.Debugf("Receiving messages from topic %s...", topic) for msg := range s.subCh { @@ -146,37 +119,23 @@ func (s *Session) Unsubscribe(ctx context.Context) error { if s.channel == nil { return nil } - return s.channel.Close() + return s.AMQPService.ChannelClose(s.channel) } // PublishTextMessage publishes a text message in a rabbit topic. func (s *Session) PublishTextMessage(ctx context.Context, topic, message string) error { GetLogger().LogPublishedMessage(message, topic, s.Correlator) var err error - s.channel, err = s.Connection.Channel() + s.channel, err = s.AMQPService.ConnectionChannel(s.Connection) if err != nil { return errors.Wrap(err, "failed to open a channel") } - err = s.channel.ExchangeDeclare( - topic, // name - "fanout", // type - true, // durable - false, // auto-deleted - false, // internal - false, // no-wait - nil, // arguments - ) + err = s.AMQPService.ChannelExchangeDeclare(s.channel, topic) if err != nil { return fmt.Errorf("failed to declare an exchange") } publishing := s.buildPublishingMessage([]byte(message)) - err = s.channel.Publish( - topic, // exchange - "", // routing key - false, // mandatory - false, // immediate - publishing, // publishing - ) + err = s.AMQPService.ChannelPublish(s.channel, topic, publishing) if err != nil { return fmt.Errorf("failed publishing the message '%s' to topic '%s': %w", message, topic, err) } @@ -195,30 +154,38 @@ func (s *Session) buildPublishingMessage(body []byte) amqp.Publishing { publishing.ContentType = "text/plain" } publishing.Headers = s.headers - publishing.Body = []byte(body) + publishing.Body = body return publishing } // PublishJSONMessage publishes a JSON message in a rabbit topic. -func (s *Session) PublishJSONMessage(ctx context.Context, topic string, props map[string]interface{}) error { +func (s *Session) PublishJSONMessage( + ctx context.Context, + topic string, + props map[string]interface{}, +) error { var json string var err error for key, value := range props { if json, err = sjson.Set(json, key, value); err != nil { - return fmt.Errorf("failed setting property '%s' with value '%s' in the message: %w", key, value, err) + return fmt.Errorf("failed setting property '%s' with value '%s' in the message: %w", + key, value, err) } } s.publishing.ContentType = "application/json" return s.PublishTextMessage(ctx, topic, json) } -// WaitForTextMessage waits up to timeout until the expected message is found in the received messages -// for this session. -func (s *Session) WaitForTextMessage(ctx context.Context, timeout time.Duration, expectedMsg string) error { +// WaitForTextMessage waits up to timeout until the expected message is found in +// the received messages for this session. +func (s *Session) WaitForTextMessage(ctx context.Context, + timeout time.Duration, + expectedMsg string, +) error { return waitUpTo(timeout, func() error { - for _, msg := range s.Messages { - if string(msg.Body) == expectedMsg { - s.msg = msg + for i := range s.Messages { + if string(s.Messages[i].Body) == expectedMsg { + s.msg = s.Messages[i] return nil } } @@ -228,12 +195,15 @@ func (s *Session) WaitForTextMessage(ctx context.Context, timeout time.Duration, // WaitForJSONMessageWithProperties waits up to timeout and verifies if there is a message received // in the topic with the requested properties. -func (s *Session) WaitForJSONMessageWithProperties(ctx context.Context, timeout time.Duration, props map[string]interface{}) error { +func (s *Session) WaitForJSONMessageWithProperties(ctx context.Context, + timeout time.Duration, + props map[string]interface{}, +) error { return waitUpTo(timeout, func() error { - for _, msg := range s.Messages { - logrus.Debugf("Checking message: %s", msg.Body) - if matchMessage(string(msg.Body), props) { - s.msg = msg + for i := range s.Messages { + logrus.Debugf("Checking message: %s", s.Messages[i].Body) + if matchMessage(string(s.Messages[i].Body), props) { + s.msg = s.Messages[i] return nil } } @@ -253,17 +223,23 @@ func matchMessage(msg string, expectedProps map[string]interface{}) bool { return true } -// WaitForMessagesWithStandardProperties waits for 'count' messages with standard rabbit properties that are equal to the expected values. -func (s *Session) WaitForMessagesWithStandardProperties(ctx context.Context, timeout time.Duration, count int, props amqp.Delivery) error { +// WaitForMessagesWithStandardProperties waits for 'count' messages with standard rabbit properties +// that are equal to the expected values. +func (s *Session) WaitForMessagesWithStandardProperties( + ctx context.Context, + timeout time.Duration, + count int, + props amqp.Delivery, +) error { return waitUpTo(timeout, func() error { err := fmt.Errorf("no message(s) received match(es) the standard properties") if count < 0 { return err } - for _, msg := range s.Messages { - logrus.Debugf("Checking message: %s", msg.Body) - s.msg = msg - if err := s.ValidateMessageStandardProperties(ctx, props); err == nil { + for i := range s.Messages { + logrus.Debugf("Checking message: %s", s.Messages[i].Body) + s.msg = s.Messages[i] + if err = s.ValidateMessageStandardProperties(ctx, props); err == nil { count-- if count == 0 { return nil @@ -274,8 +250,12 @@ func (s *Session) WaitForMessagesWithStandardProperties(ctx context.Context, tim }) } -// ValidateMessageStandardProperties checks if the message standard rabbit properties are equal the expected values. -func (s *Session) ValidateMessageStandardProperties(ctx context.Context, props amqp.Delivery) error { +// ValidateMessageStandardProperties checks if the message standard rabbit properties are equal +// the expected values. +func (s *Session) ValidateMessageStandardProperties( + ctx context.Context, + props amqp.Delivery, +) error { msg := reflect.ValueOf(s.msg) expectedMsg := reflect.ValueOf(props) t := expectedMsg.Type() @@ -285,7 +265,9 @@ func (s *Session) ValidateMessageStandardProperties(ctx context.Context, props a value := msg.Field(i).Interface() expectedValue := expectedMsg.Field(i).Interface() if value != expectedValue { - return fmt.Errorf("mismatch of standard rabbit property '%s': expected '%s', actual '%s'", key, expectedValue, value) + return fmt.Errorf( + "mismatch of standard rabbit property '%s': expected '%s', actual '%s'", + key, expectedValue, value) } } } @@ -293,7 +275,10 @@ func (s *Session) ValidateMessageStandardProperties(ctx context.Context, props a } // ValidateMessageHeaders checks if the message rabbit headers are equal the expected values. -func (s *Session) ValidateMessageHeaders(ctx context.Context, headers map[string]interface{}) error { +func (s *Session) ValidateMessageHeaders( + ctx context.Context, + headers map[string]interface{}, +) error { h := s.msg.Headers for key, expectedValue := range headers { value, found := h[key] @@ -301,7 +286,9 @@ func (s *Session) ValidateMessageHeaders(ctx context.Context, headers map[string return fmt.Errorf("missing rabbit message header '%s'", key) } if value != expectedValue { - return fmt.Errorf("mismatch of standard rabbit property '%s': expected '%s', actual '%s'", key, expectedValue, value) + return fmt.Errorf( + "mismatch of standard rabbit property '%s': expected '%s', actual '%s'", + key, expectedValue, value) } } return nil @@ -316,21 +303,29 @@ func (s *Session) ValidateMessageTextBody(ctx context.Context, expectedMsg strin return nil } -// ValidateMessageJSONBody checks if the message json body properties of message in position 'pos' are equal the expected values. +// ValidateMessageJSONBody checks if the message json body properties of message in position 'pos' +// are equal the expected values. // if pos == -1 then it means last message stored, that is the one stored in s.msg -func (s *Session) ValidateMessageJSONBody(ctx context.Context, props map[string]interface{}, pos int) error { - m := golium.NewMapFromJSONBytes([]byte(s.msg.Body)) +func (s *Session) ValidateMessageJSONBody(ctx context.Context, + props map[string]interface{}, + pos int, +) error { + m := golium.NewMapFromJSONBytes(s.msg.Body) if pos != -1 { nMessages := len(s.Messages) if pos < 0 || pos >= nMessages { - return fmt.Errorf("trying to validate message in position: '%d', '%d' messages available", pos, nMessages) + return fmt.Errorf( + "trying to validate message in position: '%d', '%d' messages available", + pos, nMessages) } - m = golium.NewMapFromJSONBytes([]byte(s.Messages[pos].Body)) + m = golium.NewMapFromJSONBytes(s.Messages[pos].Body) } for key, expectedValue := range props { value := m.Get(key) if value != expectedValue { - return fmt.Errorf("mismatch of json property '%s': expected '%s', actual '%s'", key, expectedValue, value) + return fmt.Errorf( + "mismatch of json property '%s': expected '%s', actual '%s'", + key, expectedValue, value) } } return nil diff --git a/steps/rabbit/session_test.go b/steps/rabbit/session_test.go new file mode 100644 index 00000000..2a4e32c3 --- /dev/null +++ b/steps/rabbit/session_test.go @@ -0,0 +1,644 @@ +// Copyright 2021 Telefonica Cybersecurity & Cloud Tech SL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rabbit + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/TelefonicaTC2Tech/golium" + "github.com/cucumber/godog" + "github.com/streadway/amqp" +) + +const ( + rabbitmq = "amqp://guest:guest@localhost:5672/" + logsPath = "./logs" +) + +func TestConfigureConnection(t *testing.T) { + tests := []struct { + name string + uri string + connError error + wantErr bool + }{ + { + name: "Dial error", + connError: fmt.Errorf("dial error"), + wantErr: true, + }, + { + name: "Without connection error", + uri: rabbitmq, + connError: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + ctx := InitializeContext(context.Background()) + s.AMQPService = AMQPServiceFuncMock{} + DialError = tt.connError + if err := s.ConfigureConnection(ctx, tt.uri); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigureConnection() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfigureHeaders(t *testing.T) { + var wrongRabbitHeader = make(map[string]interface{}) + wrongRabbitHeader["wrongParam"] = uint(5) + + var rabbitHeader = make(map[string]interface{}) + rabbitHeader["param"] = "value" + rabbitHeader["Header1"] = "value1" + rabbitHeader["Header2"] = "Value2" + tests := []struct { + name string + headers map[string]interface{} + wantErr bool + }{ + { + name: "Validate headers error", + wantErr: true, + headers: wrongRabbitHeader, + }, + { + name: "No error", + headers: rabbitHeader, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + if err := s.ConfigureHeaders(context.Background(), tt.headers); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigureHeaders() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfigureStandardProperties(t *testing.T) { + rabbitHeaders := amqp.Publishing{} + + tests := []struct { + name string + propTable *godog.Table + }{ + { + name: "Configure", + propTable: golium.NewTable([][]string{{"ContentType"}, {"application/json"}}), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + ctx := InitializeContext(context.Background()) + golium.ConvertTableWithoutHeaderToStruct(ctx, tt.propTable, &rabbitHeaders) + s.ConfigureStandardProperties(context.Background(), rabbitHeaders) + }) + } +} + +func TestSubscribeTopic(t *testing.T) { + os.MkdirAll(logsPath, os.ModePerm) + defer os.RemoveAll(logsPath) + tests := []struct { + name string + topic string + connectionChannelError error + channelExchangeDeclareError error + channelQueueDeclareError error + channelQueueBindError error + channelConsumeError error + subCh <-chan amqp.Delivery + wantErr bool + }{ + { + name: "Connection Channel Error", + connectionChannelError: fmt.Errorf("connection channel error"), + wantErr: true, + }, + { + name: "Channel Exchange Declare Error", + connectionChannelError: nil, + channelExchangeDeclareError: fmt.Errorf("channel exchange declare error"), + wantErr: true, + }, + { + name: "Channel Queue Declare Error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelQueueDeclareError: fmt.Errorf("channel queue declare error"), + wantErr: true, + }, + { + name: "Channel Queue Bind Error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelQueueDeclareError: nil, + channelQueueBindError: fmt.Errorf("channel queue bind error"), + wantErr: true, + }, + { + name: "Channel Consume Error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelQueueDeclareError: nil, + channelQueueBindError: nil, + subCh: nil, + channelConsumeError: fmt.Errorf("channel queue bind error"), + wantErr: true, + }, + { + name: "Channel registered without errors", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelQueueDeclareError: nil, + channelQueueBindError: nil, + subCh: nil, + channelConsumeError: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + goliumCtx := golium.InitializeContext(context.Background()) + ctx := InitializeContext(goliumCtx) + s := &Session{} + s.AMQPService = AMQPServiceFuncMock{} + ConnectionChannelError = tt.connectionChannelError + ChannelExchangeDeclareError = tt.channelExchangeDeclareError + ChannelQueueDeclareError = tt.channelQueueDeclareError + ChannelQueueBindError = tt.channelQueueBindError + ChannelConsumeError = tt.channelConsumeError + MockSubCh = tt.subCh + if err := s.SubscribeTopic(ctx, tt.topic); (err != nil) != tt.wantErr { + t.Errorf("Session.SubscribeTopic() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestUnsubscribe(t *testing.T) { + tests := []struct { + name string + channel *amqp.Channel + wantErr bool + }{ + { + name: "Nil channel", + channel: nil, + wantErr: false, + }, + { + name: "Not nil channel", + channel: &amqp.Channel{}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.AMQPService = AMQPServiceFuncMock{} + s.channel = tt.channel + if err := s.Unsubscribe(context.Background()); (err != nil) != tt.wantErr { + t.Errorf("Session.Unsubscribe() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestPublishTextMessage(t *testing.T) { + os.MkdirAll(logsPath, os.ModePerm) + defer os.RemoveAll(logsPath) + type args struct { + topic string + message string + } + tests := []struct { + name string + args args + connectionChannelError error + channelExchangeDeclareError error + channelPublishError error + wantErr bool + }{ + { + name: "Connection channel error", + connectionChannelError: fmt.Errorf("connection channel error"), + wantErr: true, + }, + { + name: "Channel exchange declare error", + connectionChannelError: nil, + channelExchangeDeclareError: fmt.Errorf("channel exchange declare error"), + wantErr: true, + }, + { + name: "Publish error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelPublishError: fmt.Errorf("channel publish error"), + args: args{ + message: "test message", + }, + wantErr: true, + }, + { + name: "Publish without error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelPublishError: nil, + args: args{ + message: "test message", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, s := setPublishTestEnv( + tt.connectionChannelError, + tt.channelExchangeDeclareError, + tt.channelPublishError) + if err := s.PublishTextMessage( + ctx, tt.args.topic, tt.args.message); (err != nil) != tt.wantErr { + t.Errorf("Session.PublishTextMessage() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestPublishJSONMessage(t *testing.T) { + os.MkdirAll(logsPath, os.ModePerm) + defer os.RemoveAll(logsPath) + propsOk := make(map[string]interface{}) + propsOk["id"] = "1" + propsOk["name"] = "test" + type args struct { + topic string + props map[string]interface{} + } + tests := []struct { + name string + connectionChannelError error + channelExchangeDeclareError error + channelPublishError error + args args + wantErr bool + }{ + { + name: "Valid props", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelPublishError: nil, + args: args{ + topic: "test_topic", + props: propsOk, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, s := setPublishTestEnv( + tt.connectionChannelError, + tt.channelExchangeDeclareError, + tt.channelPublishError) + if err := s.PublishJSONMessage(ctx, tt.args.topic, tt.args.props); (err != nil) != tt.wantErr { + t.Errorf("Session.PublishJSONMessage() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func setPublishTestEnv( + conChannelError, channelExchDecError, channelPubError error, +) (context.Context, *Session) { + goliumCtx := golium.InitializeContext(context.Background()) + ctx := InitializeContext(goliumCtx) + s := &Session{} + s.AMQPService = AMQPServiceFuncMock{} + ConnectionChannelError = conChannelError + ChannelExchangeDeclareError = channelExchDecError + ChannelPublishError = channelPubError + return ctx, s +} + +func TestWaitForTextMessage(t *testing.T) { + type args struct { + timeout time.Duration + expectedMsg string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Expected message found", + args: args{ + timeout: time.Duration(5000), + expectedMsg: "expected string", + }, + wantErr: false, + }, + { + name: "Expected message not found", + args: args{ + timeout: time.Duration(5000), + expectedMsg: "error expected string", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.Messages = []amqp.Delivery{ + {Body: []byte("expected string")}, + } + if err := s.WaitForTextMessage( + context.Background(), tt.args.timeout, tt.args.expectedMsg); (err != nil) != tt.wantErr { + t.Errorf("Session.WaitForTextMessage() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestWaitForJSONMessageWithProperties(t *testing.T) { + expectedJSON := make(map[string]interface{}) + expectedJSON["id"] = "1" + wrongExpectedJSON := make(map[string]interface{}) + wrongExpectedJSON["id"] = "5" + + type args struct { + timeout time.Duration + props map[string]interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Expected json found", + args: args{ + timeout: time.Duration(5000), + props: expectedJSON, + }, + wantErr: false, + }, + { + name: "Expected json not found", + args: args{ + timeout: time.Duration(5000), + props: wrongExpectedJSON, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.Messages = []amqp.Delivery{ + { + Body: []byte(`{"id": "1"}`), + }, + } + if err := s.WaitForJSONMessageWithProperties( + context.Background(), tt.args.timeout, tt.args.props); (err != nil) != tt.wantErr { + t.Errorf("Session.WaitForJSONMessageWithProperties() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestWaitForMessagesWithStandardProperties(t *testing.T) { + type args struct { + timeout time.Duration + count int + props amqp.Delivery + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Message count < 0", + args: args{ + count: -1, + timeout: time.Duration(5000), + }, + wantErr: true, + }, + { + name: "Matching properties", + args: args{ + count: 1, + timeout: time.Duration(5000), + props: amqp.Delivery{ + Priority: 5, + }, + }, + wantErr: false, + }, + { + name: "Not matching properties", + args: args{ + count: 1, + timeout: time.Duration(5000), + props: amqp.Delivery{ + Priority: 10, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.Messages = []amqp.Delivery{ + { + Priority: 5, + }, + } + if err := s.WaitForMessagesWithStandardProperties( + context.Background(), + tt.args.timeout, + tt.args.count, tt.args.props); (err != nil) != tt.wantErr { + t.Errorf( + "Session.WaitForMessagesWithStandardProperties() error = %v, wantErr %v", + err, tt.wantErr) + } + }) + } +} + +func TestValidateMessageHeaders(t *testing.T) { + refHeaders := make(amqp.Table) + refHeaders["id"] = "1" + + testHeaders := make(map[string]interface{}) + + tests := []struct { + name string + testKey string + testValue string + wantErr bool + }{ + { + name: "Key found and value matches", + testKey: "id", + testValue: "1", + wantErr: false, + }, + { + name: "Key not found", + testKey: "ids", + testValue: "1", + wantErr: true, + }, + { + name: "Key found wrong value", + testKey: "id", + testValue: "2", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.msg = amqp.Delivery{ + Headers: refHeaders, + } + testHeaders[tt.testKey] = tt.testValue + if err := s.ValidateMessageHeaders( + context.Background(), testHeaders); (err != nil) != tt.wantErr { + t.Errorf("Session.ValidateMessageHeaders() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateMessageTextBody(t *testing.T) { + tests := []struct { + name string + expectedMsg string + wantErr bool + }{ + { + name: "Mismatch of message text", + expectedMsg: "wrong message", + wantErr: true, + }, + { + name: "Right message", + expectedMsg: "message", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.msg = amqp.Delivery{ + Body: []byte("message"), + } + + if err := s.ValidateMessageTextBody( + context.Background(), tt.expectedMsg); (err != nil) != tt.wantErr { + t.Errorf("Session.ValidateMessageTextBody() error = %v, wantErr %v", + err, tt.wantErr) + } + }) + } +} + +func TestValidateMessageJSONBody(t *testing.T) { + expectedJSON := make(map[string]interface{}) + expectedJSON["id"] = "1" + wrongExpectedJSON := make(map[string]interface{}) + wrongExpectedJSON["id"] = "5" + type args struct { + props map[string]interface{} + pos int + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "JSON body match with pos -1", + args: args{ + pos: -1, + props: expectedJSON, + }, + wantErr: false, + }, + { + name: "JSON body match with pos != -1", + args: args{ + pos: 0, + props: expectedJSON, + }, + wantErr: false, + }, + { + name: "pos != -1 without messages", + args: args{ + pos: 1, + props: expectedJSON, + }, + wantErr: true, + }, + { + name: "JSON body mismatch with pos -1", + args: args{ + pos: -1, + props: wrongExpectedJSON, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.Messages = []amqp.Delivery{ + { + Body: []byte(`{"id": "1"}`), + }, + } + s.msg = amqp.Delivery{ + Body: []byte(`{"id": "1"}`), + } + + if err := s.ValidateMessageJSONBody( + context.Background(), tt.args.props, tt.args.pos); (err != nil) != tt.wantErr { + t.Errorf("Session.ValidateMessageJSONBody() error = %v, wantErr %v", + err, tt.wantErr) + } + }) + } +} diff --git a/steps/rabbit/steps.go b/steps/rabbit/steps.go index 5a475db7..f767d974 100755 --- a/steps/rabbit/steps.go +++ b/steps/rabbit/steps.go @@ -147,8 +147,8 @@ func (cs Steps) InitializeSteps(ctx context.Context, scenCtx *godog.ScenarioCont } return session.ValidateMessageJSONBody(ctx, props, pos) }) - scenCtx.AfterScenario(func(sc *godog.Scenario, err error) { - session.Unsubscribe(ctx) + scenCtx.After(func(ctx context.Context, sc *godog.Scenario, err error) (context.Context, error) { + return ctx, session.Unsubscribe(ctx) }) return ctx }