Skip to content

Commit

Permalink
Merge pull request #400 from rargulati/add-secure-tpt-to-set
Browse files Browse the repository at this point in the history
Ensure duplicate transports are filtered
  • Loading branch information
Stebalien authored Aug 21, 2018
2 parents 7d6f952 + 6e72e88 commit c67b87c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
1 change: 1 addition & 0 deletions config/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func makeMuxer(h host.Host, tpts []MsMuxC) (mux.Transport, error) {
if _, ok := transportSet[tptC.ID]; ok {
return nil, fmt.Errorf("duplicate muxer transport: %s", tptC.ID)
}
transportSet[tptC.ID] = struct{}{}
}
for _, tptC := range tpts {
tpt, err := tptC.MuxC(h)
Expand Down
47 changes: 47 additions & 0 deletions config/muxer_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package config

import (
"context"
"testing"

host "github.com/libp2p/go-libp2p-host"
peer "github.com/libp2p/go-libp2p-peer"
swarmt "github.com/libp2p/go-libp2p-swarm/testing"
bhost "github.com/libp2p/go-libp2p/p2p/host/basic"
mux "github.com/libp2p/go-stream-muxer"
yamux "github.com/whyrusleeping/go-smux-yamux"
)
Expand Down Expand Up @@ -52,3 +56,46 @@ func TestMuxerBadTypes(t *testing.T) {
}
}
}

func TestCatchDuplicateTransportsMuxer(t *testing.T) {
ctx := context.Background()
h := bhost.New(swarmt.GenSwarm(t, ctx))
yamuxMuxer, err := MuxerConstructor(yamux.DefaultTransport)
if err != nil {
t.Fatal(err)
}

var tests = map[string]struct {
h host.Host
transports []MsMuxC
expectedError string
}{
"no duplicate transports": {
h: h,
transports: []MsMuxC{MsMuxC{yamuxMuxer, "yamux"}},
expectedError: "",
},
"duplicate transports": {
h: h,
transports: []MsMuxC{
MsMuxC{yamuxMuxer, "yamux"},
MsMuxC{yamuxMuxer, "yamux"},
},
expectedError: "duplicate muxer transport: yamux",
},
}
for testName, test := range tests {
t.Run(testName, func(t *testing.T) {
_, err = makeMuxer(test.h, test.transports)
if err != nil {
if err.Error() != test.expectedError {
t.Errorf(
"\nexpected: [%v]\nactual: [%v]\n",
test.expectedError,
err,
)
}
}
})
}
}
1 change: 1 addition & 0 deletions config/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func makeSecurityTransport(h host.Host, tpts []MsSecC) (security.Transport, erro
if _, ok := transportSet[tptC.ID]; ok {
return nil, fmt.Errorf("duplicate security transport: %s", tptC.ID)
}
transportSet[tptC.ID] = struct{}{}
}
for _, tptC := range tpts {
tpt, err := tptC.SecC(h)
Expand Down

0 comments on commit c67b87c

Please sign in to comment.