From 1f3f4fe3f2e14402fda853c5518a622633f5897a Mon Sep 17 00:00:00 2001 From: Jordi Puig Bou Date: Tue, 5 Apr 2022 16:26:25 +0200 Subject: [PATCH 1/5] fix(MAGNETO-7807): Fixed rabbitmq sonar issues --- go.mod | 2 +- steps/dns/session.go | 2 +- steps/jwt/session_test.go | 882 +++++++++++++++++++++++++++++++++++ steps/rabbit/amqp.go | 110 +++++ steps/rabbit/amqp_mock.go | 85 ++++ steps/rabbit/context.go | 4 +- steps/rabbit/session.go | 129 +++-- steps/rabbit/session_test.go | 644 +++++++++++++++++++++++++ steps/rabbit/steps.go | 4 +- 9 files changed, 1815 insertions(+), 47 deletions(-) create mode 100644 steps/jwt/session_test.go create mode 100644 steps/rabbit/amqp.go create mode 100644 steps/rabbit/amqp_mock.go create mode 100644 steps/rabbit/session_test.go diff --git a/go.mod b/go.mod index 71cda6d5..63c6b785 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/AdguardTeam/dnsproxy v0.41.2 github.com/aws/aws-sdk-go v1.38.71 github.com/cucumber/godog v0.12.0 - github.com/cucumber/messages-go/v16 v16.0.1 // indirect + github.com/cucumber/messages-go/v16 v16.0.1 github.com/elastic/go-elasticsearch/v7 v7.12.0 github.com/go-redis/redis/v8 v8.7.1 github.com/google/uuid v1.2.0 diff --git a/steps/dns/session.go b/steps/dns/session.go index b01c916c..fd84e244 100644 --- a/steps/dns/session.go +++ b/steps/dns/session.go @@ -129,7 +129,7 @@ func (s *Session) SendDoHQuery(ctx context.Context, method string, qtype uint16, switch method { case "GET": dq := base64.RawURLEncoding.EncodeToString(data) - request, err = http.NewRequest("GET", fmt.Sprintf("%s?dns=%s", s.Server, dq), nil) + request, err = http.NewRequest("GET", fmt.Sprintf("%s?dns=%s", s.Server, dq), http.NewRequest("GET", fmt.Sprintf("%s?dns=%s", s.Server, dq), http.NoBody)) if err != nil { return err } diff --git a/steps/jwt/session_test.go b/steps/jwt/session_test.go new file mode 100644 index 00000000..c4325621 --- /dev/null +++ b/steps/jwt/session_test.go @@ -0,0 +1,882 @@ +// Copyright 2021 Telefonica Cybersecurity & Cloud Tech SL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jwt + +import ( + "context" + "testing" + + "github.com/TelefonicaTC2Tech/golium" + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jws" +) + +var bytePayload = []byte("payload") + +const ( + symmetricKey = "sign_symmetric_key_that_is_long_enough_for_algorithm_" + + "HS512_(with_more_than 256 bits!)" + + signPublicKey = ` +-----BEGIN RSA PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAoNNaEB/t0c4kZNhoz9G5 +t5fq+1OQVl4szeB+MwicYdzf2Ho23w7/0TFBa8azF++o0CYJYAgqRh/MCKXD7gsF +iXjm/TSj2M0GgJXGtPn/ZS7ULOLKm/Fp+mB6209qtcBPON9L4c8++tUCn6wwIqaO +x6OzdjUMMC9jhSkOJChBxRJtkJjwf4usRi9nYCCxsVfVeJGOxwEOghMxjdA5vJCx +XcXzFnhGQigT6EHUoxLg7JRvKgAdvN9+lpAvn8lDnCJsrCjDsFrz0BAewiKBcM8C +tR5AcCFf3pG+oQ7Uq62idvjmKsCB96jkTfVidr19Rj03E1Lg+2RCuS+4Qtylnl4b +RQIDAQAB +-----END RSA PUBLIC KEY----- +` + encryptPublicKey = ` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu+5LoHNwgUG8+bn5Mn51 +uTQ6EfP6XlblPv23i1LmTHpoGsgGirnIwYkrjHUqxNOP1XOq32p+SJwZdRZOLyoQ +Gigf6gSoYHuCG2AfsPdkVjd7JBAiejbNwdODn9coqGbgVJFi3iR4qfI9GSMmE5Qj +pvLvv/XKSVTkAkobgxmeKs6RzdWepyWgOXUgWdyYJXj5B7yCotMeYDrhbwtfmX0j +yoyMN0hyLJRDG7UPbuvl+PrHzYiC4TVo+cQm89qOJnnvAsGTg7QYcK1854pU8evh +CDRvfCc1KNcyPbJj8ZjrvamK16KbxBxwsMXBuRznKQufNx60+Ej63vwRBxcH+y7T +MQIDAQAB +-----END PUBLIC KEY----- +` + wrongTypeKey = ` +-----BEGIN EC PRIVATE KEY----- +MHQCAQEEIFL3sLnioGcDvHWM/BPlNw96BOx1KKco2qsq4UwhQUosoAcGBSuBBAAK +oUQDQgAEXs1Fmq4QdPAbn3NycdEU+HOjc3kW9efbso2kI/vdDTWcSCMk310s53G3 +tRClDBPPuuJAsKghbPfaTaUpmXFCNA== +-----END EC PRIVATE KEY----- +` + signPrivateKey = ` +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAoNNaEB/t0c4kZNhoz9G5t5fq+1OQVl4szeB+MwicYdzf2Ho2 +3w7/0TFBa8azF++o0CYJYAgqRh/MCKXD7gsFiXjm/TSj2M0GgJXGtPn/ZS7ULOLK +m/Fp+mB6209qtcBPON9L4c8++tUCn6wwIqaOx6OzdjUMMC9jhSkOJChBxRJtkJjw +f4usRi9nYCCxsVfVeJGOxwEOghMxjdA5vJCxXcXzFnhGQigT6EHUoxLg7JRvKgAd +vN9+lpAvn8lDnCJsrCjDsFrz0BAewiKBcM8CtR5AcCFf3pG+oQ7Uq62idvjmKsCB +96jkTfVidr19Rj03E1Lg+2RCuS+4Qtylnl4bRQIDAQABAoIBABKWu0c32Y4xjIVX +ei3jKNsupQttqjZBZl5Zf1y4txKcrAbigWsg2bK9RbmGWvb+TX3Zl6XQ68n1LOkm +99GQ1pAAOHq11eZeNE0ygqgyaTGxyvZxNEf4DG7TLgAhWs0tMDr7nFK6WKY3brkz +9tBafhBXPIwCL6l2IMOobikBujBj9Pe5mpTwCzbqE8TEzfnVB5DzqroFR2W/O8/6 +af6f7Co5cY+kta38G9hCq78wW4iU5qesynVw7K50mzUMCHRVF5SsKZwHmgNxjv7O ++B7jt1mzIKjPhBhN8ZjEIsFsKgZgOnZFSn8CKhxNGU9FYAo/r1Ih5TEKsPufAVbd +ChWkSoECgYEA0G8xd8lVM0WylukoI+A9u7uvpiiRkiksPmNuwhEtFSxvSLFKhqeM +gG/iDSe6DJGKSCgyCnLv6QVAwLcIhKmzJKUYrJb9OzxsBlaMpXucisWpKqRCKB7/ +RQhK+rkxrKtExWGDvk/0+nbTvqXVPnd5ifp0QtfSHtp6OT5GX0xttWECgYEAxYbW +roeN1B2b+/+WFTafguA5e72g4WUDa6OQNRdePmfUdZHyzGOP7CMF/BJ59VeLRsn+ +ZfHI2cD8YCb+8jxtmDK8M1h0dr3mUcrbExU09hRYhBKiL76xMsV8QU7zerHXcyDp +Mjk0D743lhUmllm02sLJFEGEfGc+Idk2+TW1DGUCgYEAyAGa303zkrKTr1nmKZ7Y +vhdYckHFhhI6IVe6hUCEGSg9VOzDDbkjCm/R4zu2vK6/mYPwmLQ34EspGoPICbzp +aQV/SsXMExZktiRA695UlZkcPg3Gacdsvio6AKLKttzVre1nxKvm8JwrjWqF2F4+ +4xbQjv+X4gFVfS5zyqiFMaECgYBt8+YTFw/jEGxg9WAVBOf8EVbOQ7uHXBRwWYcP +lqd2c5O3snuGPLHDz6coLxzGbmnwCMbc9p9IX33dBDgMnYigHTXYGxgRdRn9U79p +OvfVN3QiaMDxdOPskDPfotQz60U0KBDHTUJmtQr6N2HYda0PzTfjV6kpGstiSiio +xrW2ZQKBgHlbw5kRWNKe/UCmyqL/EgPH8mE3Q/s+zFaMfu9bV3D/n5iffAS9POHv +pFUZU0CeuJcHF9D8VjnzjeisJW5uK/EEOVKaerTL01qjYE9nzOl18xvevKEgfD2l +jXW8Z4nFjhcqsv1DlylqVga/B47YHfg0qTcfLJ78gMaiGMNaLQhF +-----END RSA PRIVATE KEY----- +` + encryptPrivateKey = ` +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAu+5LoHNwgUG8+bn5Mn51uTQ6EfP6XlblPv23i1LmTHpoGsgG +irnIwYkrjHUqxNOP1XOq32p+SJwZdRZOLyoQGigf6gSoYHuCG2AfsPdkVjd7JBAi +ejbNwdODn9coqGbgVJFi3iR4qfI9GSMmE5QjpvLvv/XKSVTkAkobgxmeKs6RzdWe +pyWgOXUgWdyYJXj5B7yCotMeYDrhbwtfmX0jyoyMN0hyLJRDG7UPbuvl+PrHzYiC +4TVo+cQm89qOJnnvAsGTg7QYcK1854pU8evhCDRvfCc1KNcyPbJj8ZjrvamK16Kb +xBxwsMXBuRznKQufNx60+Ej63vwRBxcH+y7TMQIDAQABAoIBAQCon6SUD4C/OfEK +UehbekS/LTF9smDQDUAdSSJLjNK/hIWsohXcm96aaS3+FZOOkBXa/LIxTSiKliXx +fVYh06gnECGypQM/rxKK5bEJ5LDO+3EuZpvI9Suh4tuTrEb683QN7XW8xRTPyF8y +EuuzXZSv1ANzRmN/cQA1XbFZ7L6SwMKNiYn3FVkORZvABnCd80Nc4tqqMz22nRdl +kWFw/gI+h6igIv1DeSe2gJbHe0HrbqatW2IbAgHiBptr7u4BxhxB2ppDaXoxzksU +vyoc0QmVRPFy01EXt6kxiCG2wF+87VtY/cCiOqUxy3xV3SHriVQx8ewOR2W2GHgG +md7b6FlRAoGBAMLrAnD2jqUWVvt7b/51z+dwcbLiAm0VV75szRhnOuY52SIfB+3h +JdTW6ujnsPvHaWWVxv8tX+QaxzS4UFsakrgFcaBvKBjaicyxxoJ+/Nu0wok8ZFmy +qzfLBDV2919TfaDMwML+N4pHIpTAxFFxKlqahf6cIoJ86RWcL3kCkV83AoGBAPbS +wGiZERX5NWZoPqj01Beq8VdrW4P9v9kBixg4lRTHog57Uhk9zaKrE67uGhVwEEBU +nnZ+D0+90AwUu1Cy+HITGhaUBFIUJLY9RyT0nih944MzUwsybcxmlucR6bNLRlRp +v2mjawGPDoFDq/IPhttZOxZPHBd4CnUgZm3OvQTXAoGAK2hw471U6Rj/iAPmXhHY +mh8lgwPoLGjbYJIUXsHmkQ0C+SFV/7jrVuoB6JpohLnVFAV2CrANMdxwzqHZa2CQ +miDEPElk8ZwBoi9ZGQi0wS0RQcTMSFmM3eD9b/atgnIygRP4PbSlo8rRvbTsQ4Lj +Psg43QnieZLdya09uUJEI6MCgYEAkN2XYozcU1pGNkne5Ql1ZkLFjbqMJwcKz9Ix +ElE7ZsvY2MkWoYv9oojob5Z+JrD0SN2heAh68iGE92I/opi4azO87x2G/6mk9nU2 +yYDtRvTEUOAR0JOTkBFyZkLEOKBoseizGMx6ZJrTN5lBVTw5uYpAvNJHuZqSALa4 +h6B8nlcCgYEAkcS2h37xeCNfdLZh2GDEAnEcy8w/op6F7NzKWSnpFskOmIEUsCvM +5vEiKbAHyDTDPyoy0Zx8waMV6eAhZGxD65uqifsH8dj6ltMAGvOfHR9Cs1qD3TQt +HippkcLN9e1ETEx1zNxmWAXFHOX2Ia2kSccmuuCKEzglN47DrJE5/Gs= +-----END RSA PRIVATE KEY----- +` + token = "eyJhbGciOiJSU0ExXzUiLCJjdHkiOiJKV1QiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2IiwidHlwIj" + + "oiSldFIn0.UxTx-Yh41FVC1V9FJdjKDbBcoX2-0As-S1IBr1KZibxwTDT_n1wlw0x3MzckR78cRWMBajU80a" + + "4rsSRCAY0Gzi0aehuQgbrYufg1EURCE7X_IQAtDH-9nvgcVzfWxonfUCVvWlH95Zk9FrWPcQhhxBbR1aDuyr" + + "Ozl4iHt2p0Grrt37e8EPyqJFE7edC2LVI-H3CfYaHlrILJFukJR1-swpEzl9r-uXGifAgEtSrT9DL56wLLCm" + + "OxIwusMQvRvMl53uakmVYfQtWF7-Ibr4aMUOBYr9H31UhZP5ZbaVAzONRngFpnTPWbvhJg12kycT1snd8-8m" + + "f8uRnfqcPkQWmrMw.2wNQ3hktIGHz9dpt3nOZvw.Ki9j5uT7JeSNGpO-JMWvLMpJRHgBHqSasV3dBoDH4pHI" + + "aTT7n2A19_vLiLo4df0xtQGqaHFSgS49vcuV2N9yOuW49fZv6nJuGvXkk8HJcRDHtrSS3_AhNSX_zBrJw2do" + + "FwHjKixXxyS1nboX3Q-p7AaTRIx9l6mRetOa_xwXogEusM9GEMqKP6GkxNE669j9MwR3DIDFO83S3Ntj4GiK" + + "XXEhJgRQSQgsgL4qvuyfqK4VSY22m_Z6ndCtL5hsvC1chhF0PvB5M-6U36ynRc9_tx8iwv-Zwy_Ja5q0gPHY" + + "R7fcxJ-u9eHz7ZZMLB0x6qyjDit3OsVF30ehiJvC6WfFq8v1MuDPdtHuLScTyCdG8jE7UtP0Djt9JxvvjvHY" + + "dt9-RwgOLgVMAlQCBt4PA1NvUcjzKkfQaWrKR65Uc_MVtXAuYVqcR3lcZH5DbjS84UyMFAYf-5QqW7xctRg8" + + "0VW_H7rvtI_NsPGtI6Qah3XKu5dkKmCokLW9MhSReQpiJW5oebpAL7vJ6fh_S9xPNF-msw.cUoLbK2mGgvaO" + + "Od3K_Jh9Q" +) + +func TestConfigureSignatureAlgorithm(t *testing.T) { + tests := []struct { + name string + alg string + wantErr bool + }{ + { + name: "Valid signature algorithm", + alg: "HS512", + wantErr: false, + }, + { + name: "Invalid signature algorithm", + alg: "invalid_signature", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + if err := s.ConfigureSignatureAlgorithm( + context.Background(), tt.alg); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigureSignatureAlgorithm() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfigureKeyEncryptionAlgorithm(t *testing.T) { + type args struct { + ctx context.Context + alg string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Valid encryption algorithm", + args: args{ + alg: string(jwa.RSA1_5), + ctx: context.Background(), + }, + + wantErr: false, + }, + { + name: "Invalid encryption algorithm", + args: args{ + alg: "invalid_encryption", + ctx: context.Background(), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + if err := s.ConfigureKeyEncryptionAlgorithm( + tt.args.ctx, tt.args.alg); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigureKeyEncryptionAlgorithm() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfigureContentEncryptionAlgorithm(t *testing.T) { + tests := []struct { + name string + alg string + wantErr bool + }{ + { + name: "Valid content encryption algorithm", + alg: string(jwa.A128CBC_HS256), + + wantErr: false, + }, + { + name: "Invalid content encryption algorithm", + alg: "invalid_content_encryption", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + if err := s.ConfigureContentEncryptionAlgorithm( + context.Background(), tt.alg); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigureContentEncryptionAlgorithm() error = %v, wantErr %v", + err, tt.wantErr) + } + }) + } +} + +func TestConfigureSymmetricKey(t *testing.T) { + type args struct { + ctx context.Context + symmetricKey string + } + tests := []struct { + name string + args args + }{ + { + name: "Configure Symmetric Key", + args: args{ + ctx: context.Background(), + symmetricKey: symmetricKey, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.ConfigureSymmetricKey(tt.args.ctx, tt.args.symmetricKey) + }) + } +} + +func TestConfigurePublicKey(t *testing.T) { + type args struct { + ctx context.Context + publicKeyPEM string + } + tests := []struct { + name string + + args args + wantErr bool + }{ + { + name: "Not valid pem", + wantErr: true, + args: args{ + ctx: context.Background(), + publicKeyPEM: "WRONG", + }, + }, + { + name: "Valid Public Key", + wantErr: false, + args: args{ + ctx: context.Background(), + publicKeyPEM: encryptPublicKey, + }, + }, + { + name: "Valid RSA Public Key", + wantErr: false, + args: args{ + ctx: context.Background(), + publicKeyPEM: signPublicKey, + }, + }, + { + name: "Default", + wantErr: true, + args: args{ + ctx: context.Background(), + publicKeyPEM: wrongTypeKey, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + if err := s.ConfigurePublicKey(tt.args.ctx, tt.args.publicKeyPEM); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigurePublicKey() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfigurePrivateKey(t *testing.T) { + type args struct { + ctx context.Context + privateKeyPEM string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Not valid pem", + wantErr: true, + args: args{ + ctx: context.Background(), + privateKeyPEM: "WRONG", + }, + }, + { + name: "Valid Private Key", + wantErr: false, + args: args{ + ctx: context.Background(), + privateKeyPEM: signPrivateKey, + }, + }, + { + name: "Default", + wantErr: true, + args: args{ + ctx: context.Background(), + privateKeyPEM: wrongTypeKey, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + if err := s.ConfigurePrivateKey(tt.args.ctx, tt.args.privateKeyPEM); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigurePrivateKey() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfigureJSONPayload(t *testing.T) { + paramsInput := make(map[string]interface{}) + paramsInput["title"] = "foo1" + type args struct { + ctx context.Context + props map[string]interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Valid payload", + args: args{ + ctx: context.Background(), + props: paramsInput, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + if err := s.ConfigureJSONPayload(tt.args.ctx, tt.args.props); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigureJSONPayload() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestGenerateSignedJWTInContext(t *testing.T) { + type args struct { + ctx context.Context + ctxtKey string + } + tests := []struct { + name string + args args + payload []byte + signatureAlgorithm jwa.SignatureAlgorithm + privateKey interface{} + wantErr bool + }{ + { + name: "Nil payload", + payload: nil, + wantErr: true, + args: args{ + ctx: context.Background(), + }, + }, + { + name: "Empty signature", + payload: bytePayload, + signatureAlgorithm: "", + wantErr: true, + args: args{ + ctx: context.Background(), + }, + }, + { + name: "Nil private key", + payload: bytePayload, + signatureAlgorithm: jwa.RS256, + privateKey: nil, + wantErr: true, + args: args{ + ctx: context.Background(), + }, + }, + { + name: "Valid token", + payload: bytePayload, + signatureAlgorithm: jwa.RS256, + privateKey: "valid", + wantErr: false, + args: args{ + ctx: context.Background(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxGolium := golium.InitializeContext(tt.args.ctx) + ctx := InitializeContext(ctxGolium) + s := &Session{} + s.Payload = tt.payload + s.SignatureAlgorithm = tt.signatureAlgorithm + if tt.privateKey == nil { + s.PrivateKey = tt.privateKey + } else { + s.ConfigurePrivateKey(ctx, signPrivateKey) + } + if err := s.GenerateSignedJWTInContext( + ctx, tt.args.ctxtKey); (err != nil) != tt.wantErr { + t.Errorf("Session.GenerateSignedJWTInContext() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestGenerateEncryptedJWTInContext(t *testing.T) { + type args struct { + ctx context.Context + ctxtKey string + } + tests := []struct { + name string + args args + payload []byte + keyEncryptionAlgorithm jwa.KeyEncryptionAlgorithm + contentEncryptionAlgorithm jwa.ContentEncryptionAlgorithm + publicKey interface{} + wantErr bool + }{ + { + name: "Nil payload", + payload: nil, + wantErr: true, + args: args{ + ctx: context.Background(), + ctxtKey: "jwt.jwse", + }, + }, + { + name: "Empty Encryption Algorithm", + payload: bytePayload, + keyEncryptionAlgorithm: "", + wantErr: true, + args: args{ + ctx: context.Background(), + ctxtKey: "jwt.jwse", + }, + }, + { + name: "Empty Content Encryption Algorithn", + payload: bytePayload, + keyEncryptionAlgorithm: jwa.RSA1_5, + contentEncryptionAlgorithm: "", + wantErr: true, + args: args{ + ctx: context.Background(), + ctxtKey: "jwt.jwse", + }, + }, + { + name: "Nil Public Key", + payload: bytePayload, + keyEncryptionAlgorithm: jwa.RSA1_5, + contentEncryptionAlgorithm: jwa.A128CBC_HS256, + publicKey: nil, + wantErr: true, + args: args{ + ctx: context.Background(), + ctxtKey: "jwt.jwse", + }, + }, + { + name: "Valid token", + payload: bytePayload, + keyEncryptionAlgorithm: jwa.RSA1_5, + contentEncryptionAlgorithm: jwa.A128CBC_HS256, + publicKey: "valid", + wantErr: false, + args: args{ + ctx: context.Background(), + ctxtKey: "jwt.jwse", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxGolium := golium.InitializeContext(tt.args.ctx) + ctx := InitializeContext(ctxGolium) + s := &Session{} + s.Payload = tt.payload + s.KeyEncryptionAlgorithm = tt.keyEncryptionAlgorithm + s.ContentEncryptionAlgorithm = tt.contentEncryptionAlgorithm + if tt.publicKey == nil { + s.PublicKey = tt.publicKey + } else { + s.ConfigurePublicKey(ctx, encryptPublicKey) + } + if err := s.GenerateEncryptedJWTInContext(ctx, tt.args.ctxtKey); (err != nil) != tt.wantErr { + t.Errorf("Session.GenerateEncryptedJWTInContext() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestGenerateSignedEncryptedJWTInContext(t *testing.T) { + type args struct { + ctx context.Context + ctxtKey string + } + tests := []struct { + name string + args args + signedError bool + wantErr bool + }{ + { + name: "Error generating signed JWT in Context", + signedError: true, + args: args{ + ctx: context.Background(), + ctxtKey: "jwt.jwse", + }, + wantErr: true, + }, + { + name: "Valid generation", + signedError: false, + args: args{ + ctx: context.Background(), + ctxtKey: "jwt.jwse", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxGolium := golium.InitializeContext(tt.args.ctx) + ctx := InitializeContext(ctxGolium) + s := &Session{} + + s.Payload = bytePayload + s.SignatureAlgorithm = jwa.RS256 + if !tt.signedError { + s.ConfigurePrivateKey(ctx, signPrivateKey) + } + s.KeyEncryptionAlgorithm = jwa.RSA1_5 + s.ContentEncryptionAlgorithm = jwa.A128CBC_HS256 + s.ConfigurePublicKey(ctx, encryptPublicKey) + + if err := s.GenerateSignedEncryptedJWTInContext( + ctx, tt.args.ctxtKey); (err != nil) != tt.wantErr { + t.Errorf( + "Session.GenerateSignedEncryptedJWTInContext() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestProcessSignedEncryptedJWT(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + tokenError bool + args args + wantErr bool + }{ + { + name: "Valid token", + args: args{ + ctx: context.Background(), + }, + tokenError: false, + wantErr: false, + }, + { + name: "Wrong token", + args: args{ + ctx: context.Background(), + }, + tokenError: true, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxGolium := golium.InitializeContext(tt.args.ctx) + ctx := InitializeContext(ctxGolium) + s := &Session{} + if tt.tokenError { + s.Token = "fakeToken1" + } else { + s.Token = token + } + s.ConfigurePrivateKey(ctx, encryptPrivateKey) + s.KeyEncryptionAlgorithm = jwa.RSA1_5 + if err := s.ProcessSignedEncryptedJWT(ctx, s.Token); (err != nil) != tt.wantErr { + t.Errorf("Session.ProcessSignedEncryptedJWT() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateJWTRequirements(t *testing.T) { + type args struct { + ctx context.Context + } + tests := []struct { + name string + token string + signatureAlgorithm jwa.SignatureAlgorithm + publicKey interface{} + signedMessage *jws.Message + args args + wantErr bool + }{ + { + name: "Empty token", + args: args{ + ctx: context.Background(), + }, + token: "", + wantErr: true, + }, + { + name: "Empty signature", + args: args{ + ctx: context.Background(), + }, + token: token, + signatureAlgorithm: "", + wantErr: true, + }, + { + name: "Empty Public Key", + args: args{ + ctx: context.Background(), + }, + token: token, + signatureAlgorithm: jwa.RS256, + publicKey: nil, + wantErr: true, + }, + { + name: "Nil Signed Message", + args: args{ + ctx: context.Background(), + }, + token: token, + signatureAlgorithm: jwa.RS256, + publicKey: "valid", + signedMessage: nil, + wantErr: false, + }, + { + name: "Signed Message", + args: args{ + ctx: context.Background(), + }, + token: token, + signatureAlgorithm: jwa.RS256, + publicKey: "valid", + signedMessage: &jws.Message{}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxGolium := golium.InitializeContext(tt.args.ctx) + ctx := InitializeContext(ctxGolium) + s := &Session{} + s.Token = tt.token + s.SignatureAlgorithm = tt.signatureAlgorithm + if tt.publicKey == nil { + s.PublicKey = tt.publicKey + } else { + s.ConfigurePublicKey(ctx, encryptPublicKey) + } + + s.SignedMessage = tt.signedMessage + + if err := s.ValidateJWTRequirements(); (err != nil) != tt.wantErr { + t.Errorf("Session.validateJWTRequirements() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidatePayloadJSONProperties(t *testing.T) { + testPayload := make(map[string]interface{}) + testPayload["jsonKey"] = "jsonValue1" + missmatchPayload := make(map[string]interface{}) + missmatchPayload["jsonKey"] = "jsonValueError" + tests := []struct { + name string + payload []byte + expectedPayload map[string]interface{} + wantErr bool + }{ + { + name: "Nil payload", + payload: nil, + wantErr: true, + }, + { + name: "Valid payload", + wantErr: false, + payload: []byte("nonEmptyPayload"), + expectedPayload: testPayload, + }, + { + name: "Payload missmatch", + wantErr: true, + payload: []byte("nonEmptyPayload"), + expectedPayload: missmatchPayload, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxGolium := golium.InitializeContext(context.Background()) + ctx := InitializeContext(ctxGolium) + s := &Session{} + if tt.payload == nil { + s.Payload = tt.payload + } else { + s.ConfigureJSONPayload(ctx, testPayload) + } + if err := s.ValidatePayloadJSONProperties(ctx, tt.expectedPayload); (err != nil) != tt.wantErr { + t.Errorf("Session.ValidatePayloadJSONProperties() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} +func TestValidateJWT(t *testing.T) { + testPayload := make(map[string]interface{}) + testPayload["jsonKey"] = "jsonValue2" + tests := []struct { + name string + invalidReq bool + wrongToken bool + wantErr bool + }{ + { + name: "Invalid JWT Requirements", + invalidReq: true, + wrongToken: false, + wantErr: true, + }, + { + name: "Valid JWT", + invalidReq: false, + wrongToken: false, + wantErr: false, + }, + { + name: "Wrong Token", + invalidReq: false, + wrongToken: true, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxGolium := golium.InitializeContext(context.Background()) + ctx := InitializeContext(ctxGolium) + s := &Session{} + if !tt.invalidReq { + s.SignatureAlgorithm = jwa.HS512 + s.ConfigureSymmetricKey(ctx, symmetricKey) + s.ConfigureJSONPayload(ctx, testPayload) + s.GenerateSignedJWTInContext(ctx, "jwt.jws") + if tt.wrongToken { + s.Token = "fakeToken2" + } + s.ProcessSignedJWT(ctx, s.Token) + } + if err := s.ValidateJWT(ctx); (err != nil) != tt.wantErr { + t.Errorf("Session.ValidateJWT() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} +func TestValidateInvalidJWT(t *testing.T) { + testPayload := make(map[string]interface{}) + testPayload["jsonKey"] = "jsonValue3" + tests := []struct { + name string + invalidReq bool + wrongToken bool + expectedError string + wantErr bool + }{ + { + name: "Invalid JWT Requirements", + invalidReq: true, + wrongToken: false, + wantErr: true, + }, + { + name: "Error when validation ok", + invalidReq: false, + wrongToken: false, + wantErr: true, + }, + { + name: "Validation error", + invalidReq: false, + wrongToken: true, + wantErr: false, + expectedError: "", + }, + { + name: "Validation error", + invalidReq: false, + wrongToken: true, + wantErr: true, + expectedError: "expectedError", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxGolium := golium.InitializeContext(context.Background()) + ctx := InitializeContext(ctxGolium) + s := &Session{} + if !tt.invalidReq { + s.SignatureAlgorithm = jwa.HS512 + s.ConfigureSymmetricKey(ctx, symmetricKey) + s.ConfigureJSONPayload(ctx, testPayload) + s.GenerateSignedJWTInContext(ctx, "jwt.jws") + if tt.wrongToken { + s.Token = "fakeToken3" + } + s.ProcessSignedJWT(ctx, s.Token) + } + if err := s.ValidateInvalidJWT(ctx, tt.expectedError); (err != nil) != tt.wantErr { + t.Errorf("Session.ValidateInvalidJWT() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/steps/rabbit/amqp.go b/steps/rabbit/amqp.go new file mode 100644 index 00000000..3f29ba96 --- /dev/null +++ b/steps/rabbit/amqp.go @@ -0,0 +1,110 @@ +// Copyright 2021 Telefonica Cybersecurity & Cloud Tech SL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rabbit + +import ( + "github.com/streadway/amqp" +) + +type AMQPServiceFunctions interface { + Dial(url string) (*amqp.Connection, error) + ConnectionChannel(c *amqp.Connection) (*amqp.Channel, error) + ChannelExchangeDeclare( + channel *amqp.Channel, + name, kind string, + durable, autoDelete, internal, noWait bool, + args amqp.Table, + ) error + ChannelQueueDeclare( + channel *amqp.Channel, + name string, + durable, autoDelete, exclusive, noWait bool, + args amqp.Table, + ) (amqp.Queue, error) + ChannelQueueBind(channel *amqp.Channel, + name, key, exchange string, + noWait bool, args amqp.Table, + ) error + ChannelConsume(channel *amqp.Channel, + queue, consumer string, + autoAck, exclusive, noLocal, noWait bool, + args amqp.Table, + ) (<-chan amqp.Delivery, error) + ChannelClose(channel *amqp.Channel) error + ChannelPublish(channel *amqp.Channel, + exchange, key string, + mandatory, immediate bool, + msg amqp.Publishing, + ) error +} + +type AMQPService struct { + Connection *amqp.Connection +} + +func NewAMQPService() *AMQPService { + return &AMQPService{} +} + +func (a AMQPService) Dial(url string) (*amqp.Connection, error) { + return amqp.Dial(url) +} + +func (a AMQPService) ConnectionChannel(connection *amqp.Connection) (*amqp.Channel, error) { + return connection.Channel() +} +func (a AMQPService) ChannelExchangeDeclare( + channel *amqp.Channel, + name, kind string, + durable, autoDelete, internal, noWait bool, + args amqp.Table, +) error { + return channel.ExchangeDeclare(name, kind, durable, autoDelete, internal, noWait, args) +} + +func (a AMQPService) ChannelQueueDeclare( + channel *amqp.Channel, + name string, + durable, autoDelete, exclusive, noWait bool, + args amqp.Table, +) (amqp.Queue, error) { + return channel.QueueDeclare(name, durable, autoDelete, exclusive, noWait, args) +} +func (a AMQPService) ChannelQueueBind(channel *amqp.Channel, + name, key, exchange string, + noWait bool, args amqp.Table, +) error { + return channel.QueueBind(name, key, exchange, noWait, args) +} + +func (a AMQPService) ChannelConsume(channel *amqp.Channel, + queue, consumer string, + autoAck, exclusive, noLocal, noWait bool, + args amqp.Table, +) (<-chan amqp.Delivery, error) { + return channel.Consume(queue, consumer, autoAck, exclusive, noLocal, noWait, args) +} + +func (a AMQPService) ChannelClose(channel *amqp.Channel) error { + return channel.Close() +} + +func (a AMQPService) ChannelPublish(channel *amqp.Channel, + exchange, key string, + mandatory, immediate bool, + msg amqp.Publishing, +) error { + return channel.Publish(exchange, key, mandatory, immediate, msg) +} diff --git a/steps/rabbit/amqp_mock.go b/steps/rabbit/amqp_mock.go new file mode 100644 index 00000000..c2c1a166 --- /dev/null +++ b/steps/rabbit/amqp_mock.go @@ -0,0 +1,85 @@ +// Copyright 2021 Telefonica Cybersecurity & Cloud Tech SL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rabbit + +import ( + "github.com/streadway/amqp" +) + +var ( + DialError error + ConnectionChannelError error + ChannelExchangeDeclareError error + ChannelQueueDeclareError error + ChannelQueueBindError error + ChannelConsumeError error + MockSubCh <-chan amqp.Delivery + ChannelPublishError error +) + +type AMQPServiceFuncMock struct{} + +func (a AMQPServiceFuncMock) Dial(url string) (*amqp.Connection, error) { + return nil, DialError +} + +func (a AMQPServiceFuncMock) ConnectionChannel(c *amqp.Connection) (*amqp.Channel, error) { + return nil, ConnectionChannelError +} + +func (a AMQPServiceFuncMock) ChannelExchangeDeclare( + channel *amqp.Channel, + name, kind string, + durable, autoDelete, internal, noWait bool, + args amqp.Table, +) error { + return ChannelExchangeDeclareError +} + +func (a AMQPServiceFuncMock) ChannelQueueDeclare( + channel *amqp.Channel, + name string, + durable, autoDelete, exclusive, noWait bool, + args amqp.Table, +) (amqp.Queue, error) { + amqpQueue := amqp.Queue{} + return amqpQueue, ChannelQueueDeclareError +} + +func (a AMQPServiceFuncMock) ChannelQueueBind(channel *amqp.Channel, + name, key, exchange string, + noWait bool, args amqp.Table, +) error { + return ChannelQueueBindError +} + +func (a AMQPServiceFuncMock) ChannelConsume(channel *amqp.Channel, + queue, consumer string, + autoAck, exclusive, noLocal, noWait bool, + args amqp.Table, +) (<-chan amqp.Delivery, error) { + return MockSubCh, ChannelConsumeError +} +func (a AMQPServiceFuncMock) ChannelClose(channel *amqp.Channel) error { + return nil +} + +func (a AMQPServiceFuncMock) ChannelPublish(channel *amqp.Channel, + exchange, key string, + mandatory, immediate bool, + msg amqp.Publishing, +) error { + return ChannelPublishError +} diff --git a/steps/rabbit/context.go b/steps/rabbit/context.go index c7ae3d54..ef37a067 100755 --- a/steps/rabbit/context.go +++ b/steps/rabbit/context.go @@ -26,7 +26,9 @@ const contextKey ContextKey = "rabbitSession" // InitializeContext adds the rabbit session to the context. // The new context is returned because context is immutable. func InitializeContext(ctx context.Context) context.Context { - return context.WithValue(ctx, contextKey, &Session{}) + var session Session + session.AMQPService = *NewAMQPService() + return context.WithValue(ctx, contextKey, session) } // GetSession returns the rabbit session stored in context. diff --git a/steps/rabbit/session.go b/steps/rabbit/session.go index 445eb4e3..919c0d30 100644 --- a/steps/rabbit/session.go +++ b/steps/rabbit/session.go @@ -25,6 +25,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/streadway/amqp" + "github.com/tidwall/sjson" ) @@ -46,12 +47,14 @@ type Session struct { publishing amqp.Publishing // rabbit received delivery message msg amqp.Delivery + // ampq service + AMQPService AMQPServiceFunctions } // ConfigureConnection creates a rabbit connection based on the URI. func (s *Session) ConfigureConnection(ctx context.Context, uri string) error { var err error - s.Connection, err = amqp.Dial(uri) + s.Connection, err = s.AMQPService.Dial(uri) if err != nil { return fmt.Errorf("failed configuring connection '%s': %w", uri, err) } @@ -79,23 +82,25 @@ func (s *Session) ConfigureStandardProperties(ctx context.Context, props amqp.Pu func (s *Session) SubscribeTopic(ctx context.Context, topic string) error { GetLogger().LogSubscribedTopic(topic) var err error - s.channel, err = s.Connection.Channel() + s.channel, err = s.AMQPService.ConnectionChannel(s.Connection) if err != nil { return errors.Wrap(err, "failed to open a channel") } - err = s.channel.ExchangeDeclare( + err = s.AMQPService.ChannelExchangeDeclare( + s.channel, topic, // name "fanout", // type true, // durable false, // auto-deleted false, // internal false, // no-wait - nil, // arguments + nil, // arguments) ) if err != nil { return errors.Wrap(err, "failed to declare an exchange") } - q, err := s.channel.QueueDeclare( + q, err := s.AMQPService.ChannelQueueDeclare( + s.channel, "", // name false, // durable true, // delete when unused @@ -106,7 +111,8 @@ func (s *Session) SubscribeTopic(ctx context.Context, topic string) error { if err != nil { return errors.Wrap(err, "failed to declare a queue") } - err = s.channel.QueueBind( + err = s.AMQPService.ChannelQueueBind( + s.channel, q.Name, // queue name "", // routing key topic, // exchange @@ -116,7 +122,8 @@ func (s *Session) SubscribeTopic(ctx context.Context, topic string) error { if err != nil { return errors.Wrap(err, "failed to bind a queue") } - s.subCh, err = s.channel.Consume( + s.subCh, err = s.AMQPService.ChannelConsume( + s.channel, q.Name, // queue "", // consumer true, // auto-ack @@ -146,18 +153,19 @@ func (s *Session) Unsubscribe(ctx context.Context) error { if s.channel == nil { return nil } - return s.channel.Close() + return s.AMQPService.ChannelClose(s.channel) } // PublishTextMessage publishes a text message in a rabbit topic. func (s *Session) PublishTextMessage(ctx context.Context, topic, message string) error { GetLogger().LogPublishedMessage(message, topic, s.Correlator) var err error - s.channel, err = s.Connection.Channel() + s.channel, err = s.AMQPService.ConnectionChannel(s.Connection) if err != nil { return errors.Wrap(err, "failed to open a channel") } - err = s.channel.ExchangeDeclare( + err = s.AMQPService.ChannelExchangeDeclare( + s.channel, topic, // name "fanout", // type true, // durable @@ -170,7 +178,8 @@ func (s *Session) PublishTextMessage(ctx context.Context, topic, message string) return fmt.Errorf("failed to declare an exchange") } publishing := s.buildPublishingMessage([]byte(message)) - err = s.channel.Publish( + err = s.AMQPService.ChannelPublish( + s.channel, topic, // exchange "", // routing key false, // mandatory @@ -195,30 +204,38 @@ func (s *Session) buildPublishingMessage(body []byte) amqp.Publishing { publishing.ContentType = "text/plain" } publishing.Headers = s.headers - publishing.Body = []byte(body) + publishing.Body = body return publishing } // PublishJSONMessage publishes a JSON message in a rabbit topic. -func (s *Session) PublishJSONMessage(ctx context.Context, topic string, props map[string]interface{}) error { +func (s *Session) PublishJSONMessage( + ctx context.Context, + topic string, + props map[string]interface{}, +) error { var json string var err error for key, value := range props { if json, err = sjson.Set(json, key, value); err != nil { - return fmt.Errorf("failed setting property '%s' with value '%s' in the message: %w", key, value, err) + return fmt.Errorf("failed setting property '%s' with value '%s' in the message: %w", + key, value, err) } } s.publishing.ContentType = "application/json" return s.PublishTextMessage(ctx, topic, json) } -// WaitForTextMessage waits up to timeout until the expected message is found in the received messages -// for this session. -func (s *Session) WaitForTextMessage(ctx context.Context, timeout time.Duration, expectedMsg string) error { +// WaitForTextMessage waits up to timeout until the expected message is found in +// the received messages for this session. +func (s *Session) WaitForTextMessage(ctx context.Context, + timeout time.Duration, + expectedMsg string, +) error { return waitUpTo(timeout, func() error { - for _, msg := range s.Messages { - if string(msg.Body) == expectedMsg { - s.msg = msg + for i := range s.Messages { + if string(s.Messages[i].Body) == expectedMsg { + s.msg = s.Messages[i] return nil } } @@ -228,12 +245,15 @@ func (s *Session) WaitForTextMessage(ctx context.Context, timeout time.Duration, // WaitForJSONMessageWithProperties waits up to timeout and verifies if there is a message received // in the topic with the requested properties. -func (s *Session) WaitForJSONMessageWithProperties(ctx context.Context, timeout time.Duration, props map[string]interface{}) error { +func (s *Session) WaitForJSONMessageWithProperties(ctx context.Context, + timeout time.Duration, + props map[string]interface{}, +) error { return waitUpTo(timeout, func() error { - for _, msg := range s.Messages { - logrus.Debugf("Checking message: %s", msg.Body) - if matchMessage(string(msg.Body), props) { - s.msg = msg + for i := range s.Messages { + logrus.Debugf("Checking message: %s", s.Messages[i].Body) + if matchMessage(string(s.Messages[i].Body), props) { + s.msg = s.Messages[i] return nil } } @@ -253,17 +273,23 @@ func matchMessage(msg string, expectedProps map[string]interface{}) bool { return true } -// WaitForMessagesWithStandardProperties waits for 'count' messages with standard rabbit properties that are equal to the expected values. -func (s *Session) WaitForMessagesWithStandardProperties(ctx context.Context, timeout time.Duration, count int, props amqp.Delivery) error { +// WaitForMessagesWithStandardProperties waits for 'count' messages with standard rabbit properties +// that are equal to the expected values. +func (s *Session) WaitForMessagesWithStandardProperties( + ctx context.Context, + timeout time.Duration, + count int, + props amqp.Delivery, +) error { return waitUpTo(timeout, func() error { err := fmt.Errorf("no message(s) received match(es) the standard properties") if count < 0 { return err } - for _, msg := range s.Messages { - logrus.Debugf("Checking message: %s", msg.Body) - s.msg = msg - if err := s.ValidateMessageStandardProperties(ctx, props); err == nil { + for i := range s.Messages { + logrus.Debugf("Checking message: %s", s.Messages[i].Body) + s.msg = s.Messages[i] + if err = s.ValidateMessageStandardProperties(ctx, props); err == nil { count-- if count == 0 { return nil @@ -274,8 +300,12 @@ func (s *Session) WaitForMessagesWithStandardProperties(ctx context.Context, tim }) } -// ValidateMessageStandardProperties checks if the message standard rabbit properties are equal the expected values. -func (s *Session) ValidateMessageStandardProperties(ctx context.Context, props amqp.Delivery) error { +// ValidateMessageStandardProperties checks if the message standard rabbit properties are equal +// the expected values. +func (s *Session) ValidateMessageStandardProperties( + ctx context.Context, + props amqp.Delivery, +) error { msg := reflect.ValueOf(s.msg) expectedMsg := reflect.ValueOf(props) t := expectedMsg.Type() @@ -285,7 +315,9 @@ func (s *Session) ValidateMessageStandardProperties(ctx context.Context, props a value := msg.Field(i).Interface() expectedValue := expectedMsg.Field(i).Interface() if value != expectedValue { - return fmt.Errorf("mismatch of standard rabbit property '%s': expected '%s', actual '%s'", key, expectedValue, value) + return fmt.Errorf( + "mismatch of standard rabbit property '%s': expected '%s', actual '%s'", + key, expectedValue, value) } } } @@ -293,7 +325,10 @@ func (s *Session) ValidateMessageStandardProperties(ctx context.Context, props a } // ValidateMessageHeaders checks if the message rabbit headers are equal the expected values. -func (s *Session) ValidateMessageHeaders(ctx context.Context, headers map[string]interface{}) error { +func (s *Session) ValidateMessageHeaders( + ctx context.Context, + headers map[string]interface{}, +) error { h := s.msg.Headers for key, expectedValue := range headers { value, found := h[key] @@ -301,7 +336,9 @@ func (s *Session) ValidateMessageHeaders(ctx context.Context, headers map[string return fmt.Errorf("missing rabbit message header '%s'", key) } if value != expectedValue { - return fmt.Errorf("mismatch of standard rabbit property '%s': expected '%s', actual '%s'", key, expectedValue, value) + return fmt.Errorf( + "mismatch of standard rabbit property '%s': expected '%s', actual '%s'", + key, expectedValue, value) } } return nil @@ -316,21 +353,29 @@ func (s *Session) ValidateMessageTextBody(ctx context.Context, expectedMsg strin return nil } -// ValidateMessageJSONBody checks if the message json body properties of message in position 'pos' are equal the expected values. +// ValidateMessageJSONBody checks if the message json body properties of message in position 'pos' +// are equal the expected values. // if pos == -1 then it means last message stored, that is the one stored in s.msg -func (s *Session) ValidateMessageJSONBody(ctx context.Context, props map[string]interface{}, pos int) error { - m := golium.NewMapFromJSONBytes([]byte(s.msg.Body)) +func (s *Session) ValidateMessageJSONBody(ctx context.Context, + props map[string]interface{}, + pos int, +) error { + m := golium.NewMapFromJSONBytes(s.msg.Body) if pos != -1 { nMessages := len(s.Messages) if pos < 0 || pos >= nMessages { - return fmt.Errorf("trying to validate message in position: '%d', '%d' messages available", pos, nMessages) + return fmt.Errorf( + "trying to validate message in position: '%d', '%d' messages available", + pos, nMessages) } - m = golium.NewMapFromJSONBytes([]byte(s.Messages[pos].Body)) + m = golium.NewMapFromJSONBytes(s.Messages[pos].Body) } for key, expectedValue := range props { value := m.Get(key) if value != expectedValue { - return fmt.Errorf("mismatch of json property '%s': expected '%s', actual '%s'", key, expectedValue, value) + return fmt.Errorf( + "mismatch of json property '%s': expected '%s', actual '%s'", + key, expectedValue, value) } } return nil diff --git a/steps/rabbit/session_test.go b/steps/rabbit/session_test.go new file mode 100644 index 00000000..2a4e32c3 --- /dev/null +++ b/steps/rabbit/session_test.go @@ -0,0 +1,644 @@ +// Copyright 2021 Telefonica Cybersecurity & Cloud Tech SL +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rabbit + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/TelefonicaTC2Tech/golium" + "github.com/cucumber/godog" + "github.com/streadway/amqp" +) + +const ( + rabbitmq = "amqp://guest:guest@localhost:5672/" + logsPath = "./logs" +) + +func TestConfigureConnection(t *testing.T) { + tests := []struct { + name string + uri string + connError error + wantErr bool + }{ + { + name: "Dial error", + connError: fmt.Errorf("dial error"), + wantErr: true, + }, + { + name: "Without connection error", + uri: rabbitmq, + connError: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + ctx := InitializeContext(context.Background()) + s.AMQPService = AMQPServiceFuncMock{} + DialError = tt.connError + if err := s.ConfigureConnection(ctx, tt.uri); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigureConnection() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfigureHeaders(t *testing.T) { + var wrongRabbitHeader = make(map[string]interface{}) + wrongRabbitHeader["wrongParam"] = uint(5) + + var rabbitHeader = make(map[string]interface{}) + rabbitHeader["param"] = "value" + rabbitHeader["Header1"] = "value1" + rabbitHeader["Header2"] = "Value2" + tests := []struct { + name string + headers map[string]interface{} + wantErr bool + }{ + { + name: "Validate headers error", + wantErr: true, + headers: wrongRabbitHeader, + }, + { + name: "No error", + headers: rabbitHeader, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + if err := s.ConfigureHeaders(context.Background(), tt.headers); (err != nil) != tt.wantErr { + t.Errorf("Session.ConfigureHeaders() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfigureStandardProperties(t *testing.T) { + rabbitHeaders := amqp.Publishing{} + + tests := []struct { + name string + propTable *godog.Table + }{ + { + name: "Configure", + propTable: golium.NewTable([][]string{{"ContentType"}, {"application/json"}}), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + ctx := InitializeContext(context.Background()) + golium.ConvertTableWithoutHeaderToStruct(ctx, tt.propTable, &rabbitHeaders) + s.ConfigureStandardProperties(context.Background(), rabbitHeaders) + }) + } +} + +func TestSubscribeTopic(t *testing.T) { + os.MkdirAll(logsPath, os.ModePerm) + defer os.RemoveAll(logsPath) + tests := []struct { + name string + topic string + connectionChannelError error + channelExchangeDeclareError error + channelQueueDeclareError error + channelQueueBindError error + channelConsumeError error + subCh <-chan amqp.Delivery + wantErr bool + }{ + { + name: "Connection Channel Error", + connectionChannelError: fmt.Errorf("connection channel error"), + wantErr: true, + }, + { + name: "Channel Exchange Declare Error", + connectionChannelError: nil, + channelExchangeDeclareError: fmt.Errorf("channel exchange declare error"), + wantErr: true, + }, + { + name: "Channel Queue Declare Error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelQueueDeclareError: fmt.Errorf("channel queue declare error"), + wantErr: true, + }, + { + name: "Channel Queue Bind Error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelQueueDeclareError: nil, + channelQueueBindError: fmt.Errorf("channel queue bind error"), + wantErr: true, + }, + { + name: "Channel Consume Error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelQueueDeclareError: nil, + channelQueueBindError: nil, + subCh: nil, + channelConsumeError: fmt.Errorf("channel queue bind error"), + wantErr: true, + }, + { + name: "Channel registered without errors", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelQueueDeclareError: nil, + channelQueueBindError: nil, + subCh: nil, + channelConsumeError: nil, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + goliumCtx := golium.InitializeContext(context.Background()) + ctx := InitializeContext(goliumCtx) + s := &Session{} + s.AMQPService = AMQPServiceFuncMock{} + ConnectionChannelError = tt.connectionChannelError + ChannelExchangeDeclareError = tt.channelExchangeDeclareError + ChannelQueueDeclareError = tt.channelQueueDeclareError + ChannelQueueBindError = tt.channelQueueBindError + ChannelConsumeError = tt.channelConsumeError + MockSubCh = tt.subCh + if err := s.SubscribeTopic(ctx, tt.topic); (err != nil) != tt.wantErr { + t.Errorf("Session.SubscribeTopic() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestUnsubscribe(t *testing.T) { + tests := []struct { + name string + channel *amqp.Channel + wantErr bool + }{ + { + name: "Nil channel", + channel: nil, + wantErr: false, + }, + { + name: "Not nil channel", + channel: &amqp.Channel{}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.AMQPService = AMQPServiceFuncMock{} + s.channel = tt.channel + if err := s.Unsubscribe(context.Background()); (err != nil) != tt.wantErr { + t.Errorf("Session.Unsubscribe() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestPublishTextMessage(t *testing.T) { + os.MkdirAll(logsPath, os.ModePerm) + defer os.RemoveAll(logsPath) + type args struct { + topic string + message string + } + tests := []struct { + name string + args args + connectionChannelError error + channelExchangeDeclareError error + channelPublishError error + wantErr bool + }{ + { + name: "Connection channel error", + connectionChannelError: fmt.Errorf("connection channel error"), + wantErr: true, + }, + { + name: "Channel exchange declare error", + connectionChannelError: nil, + channelExchangeDeclareError: fmt.Errorf("channel exchange declare error"), + wantErr: true, + }, + { + name: "Publish error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelPublishError: fmt.Errorf("channel publish error"), + args: args{ + message: "test message", + }, + wantErr: true, + }, + { + name: "Publish without error", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelPublishError: nil, + args: args{ + message: "test message", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, s := setPublishTestEnv( + tt.connectionChannelError, + tt.channelExchangeDeclareError, + tt.channelPublishError) + if err := s.PublishTextMessage( + ctx, tt.args.topic, tt.args.message); (err != nil) != tt.wantErr { + t.Errorf("Session.PublishTextMessage() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestPublishJSONMessage(t *testing.T) { + os.MkdirAll(logsPath, os.ModePerm) + defer os.RemoveAll(logsPath) + propsOk := make(map[string]interface{}) + propsOk["id"] = "1" + propsOk["name"] = "test" + type args struct { + topic string + props map[string]interface{} + } + tests := []struct { + name string + connectionChannelError error + channelExchangeDeclareError error + channelPublishError error + args args + wantErr bool + }{ + { + name: "Valid props", + connectionChannelError: nil, + channelExchangeDeclareError: nil, + channelPublishError: nil, + args: args{ + topic: "test_topic", + props: propsOk, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, s := setPublishTestEnv( + tt.connectionChannelError, + tt.channelExchangeDeclareError, + tt.channelPublishError) + if err := s.PublishJSONMessage(ctx, tt.args.topic, tt.args.props); (err != nil) != tt.wantErr { + t.Errorf("Session.PublishJSONMessage() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func setPublishTestEnv( + conChannelError, channelExchDecError, channelPubError error, +) (context.Context, *Session) { + goliumCtx := golium.InitializeContext(context.Background()) + ctx := InitializeContext(goliumCtx) + s := &Session{} + s.AMQPService = AMQPServiceFuncMock{} + ConnectionChannelError = conChannelError + ChannelExchangeDeclareError = channelExchDecError + ChannelPublishError = channelPubError + return ctx, s +} + +func TestWaitForTextMessage(t *testing.T) { + type args struct { + timeout time.Duration + expectedMsg string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Expected message found", + args: args{ + timeout: time.Duration(5000), + expectedMsg: "expected string", + }, + wantErr: false, + }, + { + name: "Expected message not found", + args: args{ + timeout: time.Duration(5000), + expectedMsg: "error expected string", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.Messages = []amqp.Delivery{ + {Body: []byte("expected string")}, + } + if err := s.WaitForTextMessage( + context.Background(), tt.args.timeout, tt.args.expectedMsg); (err != nil) != tt.wantErr { + t.Errorf("Session.WaitForTextMessage() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestWaitForJSONMessageWithProperties(t *testing.T) { + expectedJSON := make(map[string]interface{}) + expectedJSON["id"] = "1" + wrongExpectedJSON := make(map[string]interface{}) + wrongExpectedJSON["id"] = "5" + + type args struct { + timeout time.Duration + props map[string]interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Expected json found", + args: args{ + timeout: time.Duration(5000), + props: expectedJSON, + }, + wantErr: false, + }, + { + name: "Expected json not found", + args: args{ + timeout: time.Duration(5000), + props: wrongExpectedJSON, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.Messages = []amqp.Delivery{ + { + Body: []byte(`{"id": "1"}`), + }, + } + if err := s.WaitForJSONMessageWithProperties( + context.Background(), tt.args.timeout, tt.args.props); (err != nil) != tt.wantErr { + t.Errorf("Session.WaitForJSONMessageWithProperties() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestWaitForMessagesWithStandardProperties(t *testing.T) { + type args struct { + timeout time.Duration + count int + props amqp.Delivery + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Message count < 0", + args: args{ + count: -1, + timeout: time.Duration(5000), + }, + wantErr: true, + }, + { + name: "Matching properties", + args: args{ + count: 1, + timeout: time.Duration(5000), + props: amqp.Delivery{ + Priority: 5, + }, + }, + wantErr: false, + }, + { + name: "Not matching properties", + args: args{ + count: 1, + timeout: time.Duration(5000), + props: amqp.Delivery{ + Priority: 10, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.Messages = []amqp.Delivery{ + { + Priority: 5, + }, + } + if err := s.WaitForMessagesWithStandardProperties( + context.Background(), + tt.args.timeout, + tt.args.count, tt.args.props); (err != nil) != tt.wantErr { + t.Errorf( + "Session.WaitForMessagesWithStandardProperties() error = %v, wantErr %v", + err, tt.wantErr) + } + }) + } +} + +func TestValidateMessageHeaders(t *testing.T) { + refHeaders := make(amqp.Table) + refHeaders["id"] = "1" + + testHeaders := make(map[string]interface{}) + + tests := []struct { + name string + testKey string + testValue string + wantErr bool + }{ + { + name: "Key found and value matches", + testKey: "id", + testValue: "1", + wantErr: false, + }, + { + name: "Key not found", + testKey: "ids", + testValue: "1", + wantErr: true, + }, + { + name: "Key found wrong value", + testKey: "id", + testValue: "2", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.msg = amqp.Delivery{ + Headers: refHeaders, + } + testHeaders[tt.testKey] = tt.testValue + if err := s.ValidateMessageHeaders( + context.Background(), testHeaders); (err != nil) != tt.wantErr { + t.Errorf("Session.ValidateMessageHeaders() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestValidateMessageTextBody(t *testing.T) { + tests := []struct { + name string + expectedMsg string + wantErr bool + }{ + { + name: "Mismatch of message text", + expectedMsg: "wrong message", + wantErr: true, + }, + { + name: "Right message", + expectedMsg: "message", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.msg = amqp.Delivery{ + Body: []byte("message"), + } + + if err := s.ValidateMessageTextBody( + context.Background(), tt.expectedMsg); (err != nil) != tt.wantErr { + t.Errorf("Session.ValidateMessageTextBody() error = %v, wantErr %v", + err, tt.wantErr) + } + }) + } +} + +func TestValidateMessageJSONBody(t *testing.T) { + expectedJSON := make(map[string]interface{}) + expectedJSON["id"] = "1" + wrongExpectedJSON := make(map[string]interface{}) + wrongExpectedJSON["id"] = "5" + type args struct { + props map[string]interface{} + pos int + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "JSON body match with pos -1", + args: args{ + pos: -1, + props: expectedJSON, + }, + wantErr: false, + }, + { + name: "JSON body match with pos != -1", + args: args{ + pos: 0, + props: expectedJSON, + }, + wantErr: false, + }, + { + name: "pos != -1 without messages", + args: args{ + pos: 1, + props: expectedJSON, + }, + wantErr: true, + }, + { + name: "JSON body mismatch with pos -1", + args: args{ + pos: -1, + props: wrongExpectedJSON, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Session{} + s.Messages = []amqp.Delivery{ + { + Body: []byte(`{"id": "1"}`), + }, + } + s.msg = amqp.Delivery{ + Body: []byte(`{"id": "1"}`), + } + + if err := s.ValidateMessageJSONBody( + context.Background(), tt.args.props, tt.args.pos); (err != nil) != tt.wantErr { + t.Errorf("Session.ValidateMessageJSONBody() error = %v, wantErr %v", + err, tt.wantErr) + } + }) + } +} diff --git a/steps/rabbit/steps.go b/steps/rabbit/steps.go index 5a475db7..f767d974 100755 --- a/steps/rabbit/steps.go +++ b/steps/rabbit/steps.go @@ -147,8 +147,8 @@ func (cs Steps) InitializeSteps(ctx context.Context, scenCtx *godog.ScenarioCont } return session.ValidateMessageJSONBody(ctx, props, pos) }) - scenCtx.AfterScenario(func(sc *godog.Scenario, err error) { - session.Unsubscribe(ctx) + scenCtx.After(func(ctx context.Context, sc *godog.Scenario, err error) (context.Context, error) { + return ctx, session.Unsubscribe(ctx) }) return ctx } From 94aa5875601d2648b6634b43d5cde4c88145a4fa Mon Sep 17 00:00:00 2001 From: Jordi Puig Bou Date: Tue, 5 Apr 2022 16:35:02 +0200 Subject: [PATCH 2/5] fix(MAGNETO-7807): Session init fix --- steps/rabbit/context.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/steps/rabbit/context.go b/steps/rabbit/context.go index ef37a067..dfc430be 100755 --- a/steps/rabbit/context.go +++ b/steps/rabbit/context.go @@ -26,9 +26,7 @@ const contextKey ContextKey = "rabbitSession" // InitializeContext adds the rabbit session to the context. // The new context is returned because context is immutable. func InitializeContext(ctx context.Context) context.Context { - var session Session - session.AMQPService = *NewAMQPService() - return context.WithValue(ctx, contextKey, session) + return context.WithValue(ctx, contextKey, &Session{AMQPService: *NewAMQPService()}) } // GetSession returns the rabbit session stored in context. From 57ce912d36d69e3759bd57034fc748263ebaade6 Mon Sep 17 00:00:00 2001 From: Jordi Puig Bou Date: Tue, 5 Apr 2022 16:54:42 +0200 Subject: [PATCH 3/5] fix(MAGNETO-7807): Remeved wrong struct field --- steps/rabbit/amqp.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/steps/rabbit/amqp.go b/steps/rabbit/amqp.go index 3f29ba96..e440a4c3 100644 --- a/steps/rabbit/amqp.go +++ b/steps/rabbit/amqp.go @@ -50,9 +50,7 @@ type AMQPServiceFunctions interface { ) error } -type AMQPService struct { - Connection *amqp.Connection -} +type AMQPService struct{} func NewAMQPService() *AMQPService { return &AMQPService{} From fa58c7cf873661386073a8083dee0a9afd50e09a Mon Sep 17 00:00:00 2001 From: Jordi Puig Bou Date: Tue, 5 Apr 2022 16:56:30 +0200 Subject: [PATCH 4/5] fix(MAGNETO-7807): Removed wrong empty line --- steps/rabbit/session.go | 1 - 1 file changed, 1 deletion(-) diff --git a/steps/rabbit/session.go b/steps/rabbit/session.go index 919c0d30..efc64761 100644 --- a/steps/rabbit/session.go +++ b/steps/rabbit/session.go @@ -25,7 +25,6 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/streadway/amqp" - "github.com/tidwall/sjson" ) From 5c53300d97e95fe2b483565e95108de695e21ca8 Mon Sep 17 00:00:00 2001 From: Jordi Puig Bou Date: Wed, 6 Apr 2022 09:45:07 +0200 Subject: [PATCH 5/5] fix(MAGNETO-7807): Fixed code smells --- steps/rabbit/amqp.go | 64 +++++++++------------------------------ steps/rabbit/amqp_mock.go | 27 +++-------------- steps/rabbit/session.go | 61 ++++--------------------------------- 3 files changed, 26 insertions(+), 126 deletions(-) diff --git a/steps/rabbit/amqp.go b/steps/rabbit/amqp.go index e440a4c3..04b7064a 100644 --- a/steps/rabbit/amqp.go +++ b/steps/rabbit/amqp.go @@ -21,31 +21,14 @@ import ( type AMQPServiceFunctions interface { Dial(url string) (*amqp.Connection, error) ConnectionChannel(c *amqp.Connection) (*amqp.Channel, error) - ChannelExchangeDeclare( - channel *amqp.Channel, - name, kind string, - durable, autoDelete, internal, noWait bool, - args amqp.Table, - ) error - ChannelQueueDeclare( - channel *amqp.Channel, - name string, - durable, autoDelete, exclusive, noWait bool, - args amqp.Table, - ) (amqp.Queue, error) - ChannelQueueBind(channel *amqp.Channel, - name, key, exchange string, - noWait bool, args amqp.Table, - ) error - ChannelConsume(channel *amqp.Channel, - queue, consumer string, - autoAck, exclusive, noLocal, noWait bool, - args amqp.Table, + ChannelExchangeDeclare(channel *amqp.Channel, name string) error + ChannelQueueDeclare(channel *amqp.Channel) (amqp.Queue, error) + ChannelQueueBind(channel *amqp.Channel, name, exchange string) error + ChannelConsume(channel *amqp.Channel, queue string, ) (<-chan amqp.Delivery, error) ChannelClose(channel *amqp.Channel) error ChannelPublish(channel *amqp.Channel, - exchange, key string, - mandatory, immediate bool, + exchange string, msg amqp.Publishing, ) error } @@ -63,36 +46,20 @@ func (a AMQPService) Dial(url string) (*amqp.Connection, error) { func (a AMQPService) ConnectionChannel(connection *amqp.Connection) (*amqp.Channel, error) { return connection.Channel() } -func (a AMQPService) ChannelExchangeDeclare( - channel *amqp.Channel, - name, kind string, - durable, autoDelete, internal, noWait bool, - args amqp.Table, -) error { - return channel.ExchangeDeclare(name, kind, durable, autoDelete, internal, noWait, args) +func (a AMQPService) ChannelExchangeDeclare(channel *amqp.Channel, name string) error { + return channel.ExchangeDeclare(name, "fanout", true, false, false, false, nil) } -func (a AMQPService) ChannelQueueDeclare( - channel *amqp.Channel, - name string, - durable, autoDelete, exclusive, noWait bool, - args amqp.Table, -) (amqp.Queue, error) { - return channel.QueueDeclare(name, durable, autoDelete, exclusive, noWait, args) +func (a AMQPService) ChannelQueueDeclare(channel *amqp.Channel) (amqp.Queue, error) { + return channel.QueueDeclare("", false, true, true, false, nil) } -func (a AMQPService) ChannelQueueBind(channel *amqp.Channel, - name, key, exchange string, - noWait bool, args amqp.Table, -) error { - return channel.QueueBind(name, key, exchange, noWait, args) +func (a AMQPService) ChannelQueueBind(channel *amqp.Channel, name, exchange string) error { + return channel.QueueBind(name, "", exchange, false, nil) } -func (a AMQPService) ChannelConsume(channel *amqp.Channel, - queue, consumer string, - autoAck, exclusive, noLocal, noWait bool, - args amqp.Table, +func (a AMQPService) ChannelConsume(channel *amqp.Channel, queue string, ) (<-chan amqp.Delivery, error) { - return channel.Consume(queue, consumer, autoAck, exclusive, noLocal, noWait, args) + return channel.Consume(queue, "", true, false, false, false, nil) } func (a AMQPService) ChannelClose(channel *amqp.Channel) error { @@ -100,9 +67,8 @@ func (a AMQPService) ChannelClose(channel *amqp.Channel) error { } func (a AMQPService) ChannelPublish(channel *amqp.Channel, - exchange, key string, - mandatory, immediate bool, + exchange string, msg amqp.Publishing, ) error { - return channel.Publish(exchange, key, mandatory, immediate, msg) + return channel.Publish(exchange, "", false, false, msg) } diff --git a/steps/rabbit/amqp_mock.go b/steps/rabbit/amqp_mock.go index c2c1a166..d7921777 100644 --- a/steps/rabbit/amqp_mock.go +++ b/steps/rabbit/amqp_mock.go @@ -39,36 +39,20 @@ func (a AMQPServiceFuncMock) ConnectionChannel(c *amqp.Connection) (*amqp.Channe return nil, ConnectionChannelError } -func (a AMQPServiceFuncMock) ChannelExchangeDeclare( - channel *amqp.Channel, - name, kind string, - durable, autoDelete, internal, noWait bool, - args amqp.Table, -) error { +func (a AMQPServiceFuncMock) ChannelExchangeDeclare(channel *amqp.Channel, name string) error { return ChannelExchangeDeclareError } -func (a AMQPServiceFuncMock) ChannelQueueDeclare( - channel *amqp.Channel, - name string, - durable, autoDelete, exclusive, noWait bool, - args amqp.Table, -) (amqp.Queue, error) { +func (a AMQPServiceFuncMock) ChannelQueueDeclare(channel *amqp.Channel) (amqp.Queue, error) { amqpQueue := amqp.Queue{} return amqpQueue, ChannelQueueDeclareError } -func (a AMQPServiceFuncMock) ChannelQueueBind(channel *amqp.Channel, - name, key, exchange string, - noWait bool, args amqp.Table, -) error { +func (a AMQPServiceFuncMock) ChannelQueueBind(channel *amqp.Channel, name, exchange string) error { return ChannelQueueBindError } -func (a AMQPServiceFuncMock) ChannelConsume(channel *amqp.Channel, - queue, consumer string, - autoAck, exclusive, noLocal, noWait bool, - args amqp.Table, +func (a AMQPServiceFuncMock) ChannelConsume(channel *amqp.Channel, queue string, ) (<-chan amqp.Delivery, error) { return MockSubCh, ChannelConsumeError } @@ -77,8 +61,7 @@ func (a AMQPServiceFuncMock) ChannelClose(channel *amqp.Channel) error { } func (a AMQPServiceFuncMock) ChannelPublish(channel *amqp.Channel, - exchange, key string, - mandatory, immediate bool, + exchange string, msg amqp.Publishing, ) error { return ChannelPublishError diff --git a/steps/rabbit/session.go b/steps/rabbit/session.go index efc64761..d55eb550 100644 --- a/steps/rabbit/session.go +++ b/steps/rabbit/session.go @@ -85,52 +85,19 @@ func (s *Session) SubscribeTopic(ctx context.Context, topic string) error { if err != nil { return errors.Wrap(err, "failed to open a channel") } - err = s.AMQPService.ChannelExchangeDeclare( - s.channel, - topic, // name - "fanout", // type - true, // durable - false, // auto-deleted - false, // internal - false, // no-wait - nil, // arguments) - ) + err = s.AMQPService.ChannelExchangeDeclare(s.channel, topic) if err != nil { return errors.Wrap(err, "failed to declare an exchange") } - q, err := s.AMQPService.ChannelQueueDeclare( - s.channel, - "", // name - false, // durable - true, // delete when unused - true, // exclusive - false, // no-wait - nil, // arguments - ) + q, err := s.AMQPService.ChannelQueueDeclare(s.channel) if err != nil { return errors.Wrap(err, "failed to declare a queue") } - err = s.AMQPService.ChannelQueueBind( - s.channel, - q.Name, // queue name - "", // routing key - topic, // exchange - false, - nil, - ) + err = s.AMQPService.ChannelQueueBind(s.channel, q.Name, topic) if err != nil { return errors.Wrap(err, "failed to bind a queue") } - s.subCh, err = s.AMQPService.ChannelConsume( - s.channel, - q.Name, // queue - "", // consumer - true, // auto-ack - false, // exclusive - false, // no-local - false, // no-wait - nil, // args - ) + s.subCh, err = s.AMQPService.ChannelConsume(s.channel, q.Name) go func() { logrus.Debugf("Receiving messages from topic %s...", topic) for msg := range s.subCh { @@ -163,28 +130,12 @@ func (s *Session) PublishTextMessage(ctx context.Context, topic, message string) if err != nil { return errors.Wrap(err, "failed to open a channel") } - err = s.AMQPService.ChannelExchangeDeclare( - s.channel, - topic, // name - "fanout", // type - true, // durable - false, // auto-deleted - false, // internal - false, // no-wait - nil, // arguments - ) + err = s.AMQPService.ChannelExchangeDeclare(s.channel, topic) if err != nil { return fmt.Errorf("failed to declare an exchange") } publishing := s.buildPublishingMessage([]byte(message)) - err = s.AMQPService.ChannelPublish( - s.channel, - topic, // exchange - "", // routing key - false, // mandatory - false, // immediate - publishing, // publishing - ) + err = s.AMQPService.ChannelPublish(s.channel, topic, publishing) if err != nil { return fmt.Errorf("failed publishing the message '%s' to topic '%s': %w", message, topic, err) }