diff --git a/config/muxer.go b/config/muxer.go index 6a48cb3bd4..ee2adeb08a 100644 --- a/config/muxer.go +++ b/config/muxer.go @@ -50,6 +50,7 @@ func makeMuxer(h host.Host, tpts []MsMuxC) (mux.Transport, error) { if _, ok := transportSet[tptC.ID]; ok { return nil, fmt.Errorf("duplicate muxer transport: %s", tptC.ID) } + transportSet[tptC.ID] = struct{}{} } for _, tptC := range tpts { tpt, err := tptC.MuxC(h) diff --git a/config/muxer_test.go b/config/muxer_test.go index 5e7bab41ff..d7f7baeb5d 100644 --- a/config/muxer_test.go +++ b/config/muxer_test.go @@ -1,9 +1,13 @@ package config import ( + "context" "testing" + host "github.com/libp2p/go-libp2p-host" peer "github.com/libp2p/go-libp2p-peer" + swarmt "github.com/libp2p/go-libp2p-swarm/testing" + bhost "github.com/libp2p/go-libp2p/p2p/host/basic" mux "github.com/libp2p/go-stream-muxer" yamux "github.com/whyrusleeping/go-smux-yamux" ) @@ -52,3 +56,46 @@ func TestMuxerBadTypes(t *testing.T) { } } } + +func TestCatchDuplicateTransportsMuxer(t *testing.T) { + ctx := context.Background() + h := bhost.New(swarmt.GenSwarm(t, ctx)) + yamuxMuxer, err := MuxerConstructor(yamux.DefaultTransport) + if err != nil { + t.Fatal(err) + } + + var tests = map[string]struct { + h host.Host + transports []MsMuxC + expectedError string + }{ + "no duplicate transports": { + h: h, + transports: []MsMuxC{MsMuxC{yamuxMuxer, "yamux"}}, + expectedError: "", + }, + "duplicate transports": { + h: h, + transports: []MsMuxC{ + MsMuxC{yamuxMuxer, "yamux"}, + MsMuxC{yamuxMuxer, "yamux"}, + }, + expectedError: "duplicate muxer transport: yamux", + }, + } + for testName, test := range tests { + t.Run(testName, func(t *testing.T) { + _, err = makeMuxer(test.h, test.transports) + if err != nil { + if err.Error() != test.expectedError { + t.Errorf( + "\nexpected: [%v]\nactual: [%v]\n", + test.expectedError, + err, + ) + } + } + }) + } +} diff --git a/config/security.go b/config/security.go index c3639728ff..2798fda85f 100644 --- a/config/security.go +++ b/config/security.go @@ -61,6 +61,7 @@ func makeSecurityTransport(h host.Host, tpts []MsSecC) (security.Transport, erro if _, ok := transportSet[tptC.ID]; ok { return nil, fmt.Errorf("duplicate security transport: %s", tptC.ID) } + transportSet[tptC.ID] = struct{}{} } for _, tptC := range tpts { tpt, err := tptC.SecC(h)