diff --git a/benchmark_test.go b/benchmark_test.go index 5697856..12009a5 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -85,7 +85,7 @@ func (b benchenv) connect(stopTimer bool) (*secureSession, *secureSession) { initSession, initErr = b.initTpt.SecureOutbound(context.TODO(), initConn, b.respTpt.localID) }() - respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn) + respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn, "") <-done if initErr != nil { diff --git a/go.mod b/go.mod index 57cc970..97cbc7f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/flynn/noise v1.0.0 github.com/gogo/protobuf v1.3.2 github.com/libp2p/go-buffer-pool v0.0.2 - github.com/libp2p/go-libp2p-core v0.9.0 + github.com/libp2p/go-libp2p-core v0.10.0 github.com/multiformats/go-multiaddr v0.3.3 // indirect github.com/multiformats/go-multihash v0.0.15 // indirect github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index e7b1790..54a11bf 100644 --- a/go.sum +++ b/go.sum @@ -28,9 +28,6 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ipfs/go-cid v0.0.7 h1:ysQJVJA3fNDF1qigJbsSQOdjhVLsOEoPdh0+R97k3jY= github.com/ipfs/go-cid v0.0.7/go.mod h1:6Ux9z5e+HpkQdckYoX1PG/6xqKspzlEIR5SDmgqgC/I= -github.com/jbenet/go-cienv v0.1.0/go.mod h1:TqNnHUmJgXau0nCzC7kXWeotg3J9W34CUv5Djy1+FlA= -github.com/jbenet/goprocess v0.1.4 h1:DRGOFReOMqqDNXwW70QkacFW0YN9QnwLV0Vqk+3oU0o= -github.com/jbenet/goprocess v0.1.4/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= @@ -48,8 +45,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/libp2p/go-buffer-pool v0.0.2 h1:QNK2iAFa8gjAe1SPz6mHSMuCcjs+X1wlHzeOSqcmlfs= github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoRZd1Vi32+RXyFM= github.com/libp2p/go-flow-metrics v0.0.3/go.mod h1:HeoSNUrOJVK1jEpDqVEiUOIXqhbnS27omG0uWU5slZs= -github.com/libp2p/go-libp2p-core v0.9.0 h1:t97Mv0LIBZlP2FXVRNKKVzHJCIjbIWGxYptGId4+htU= -github.com/libp2p/go-libp2p-core v0.9.0/go.mod h1:ESsbz31oC3C1AvMJoGx26RTuCkNhmkSRCqZ0kQtJ2/8= +github.com/libp2p/go-libp2p-core v0.10.0 h1:jFy7v5Muq58GTeYkPhGzIH8Qq4BFfziqc0ixPd/pP9k= +github.com/libp2p/go-libp2p-core v0.10.0/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg= github.com/libp2p/go-maddr-filter v0.1.0/go.mod h1:VzZhTXkMucEGGEOSKddrwGiOv0tUhgnKqNEmIAz/bPU= github.com/libp2p/go-msgio v0.0.6/go.mod h1:4ecVB6d9f4BDSL5fqvPiC4A3KivjWn+Venn/1ALLMWA= github.com/libp2p/go-openssl v0.0.7 h1:eCAzdLejcNVBzP/iZM9vqHnQm+XyCEbSSIheIPRGNsw= diff --git a/handshake.go b/handshake.go index f872af9..d70a529 100644 --- a/handshake.go +++ b/handshake.go @@ -60,9 +60,9 @@ func (s *secureSession) runHandshake(ctx context.Context) error { } } - // We can re-use this buffer for all handshake messages as it's size + // We can re-use this buffer for all handshake messages as its size // will be the size of the maximum handshake message for the Noise XX pattern. - // Also, since we prefix every noise handshake message with it's length, we need to account for + // Also, since we prefix every noise handshake message with its length, we need to account for // it when we fetch the buffer from the pool maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*poly1305.TagSize hbuf := pool.Get(maxMsgSize + LengthPrefixLength) @@ -242,8 +242,10 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati return err } - // if we know who we're trying to reach, make sure we have the right peer - if s.initiator && s.remoteID != id { + // check the peer ID for: + // * all outbound connection + // * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID) + if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) { // use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms. return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty()) } diff --git a/transport.go b/transport.go index a1daa7d..c8d7a44 100644 --- a/transport.go +++ b/transport.go @@ -36,8 +36,9 @@ func New(privkey crypto.PrivKey) (*Transport, error) { } // SecureInbound runs the Noise handshake as the responder. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, "", false) +// If p is empty, connections from any peer are accepted. +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { + return newSecureSession(t, ctx, insecure, p, false) } // SecureOutbound runs the Noise handshake as the initiator. diff --git a/transport_test.go b/transport_test.go index 818a229..b65b9cb 100644 --- a/transport_test.go +++ b/transport_test.go @@ -5,14 +5,17 @@ import ( "context" "encoding/binary" "errors" - "golang.org/x/crypto/poly1305" "io" "math/rand" "net" "testing" "time" - crypto "github.com/libp2p/go-libp2p-core/crypto" + "github.com/stretchr/testify/assert" + + "golang.org/x/crypto/poly1305" + + "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/sec" @@ -79,7 +82,7 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) }() - respConn, respErr := respTransport.SecureInbound(context.TODO(), resp) + respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "") <-done if initErr != nil { @@ -161,24 +164,66 @@ func TestKeys(t *testing.T) { } } -func TestPeerIDMismatchFailsHandshake(t *testing.T) { +func TestPeerIDMatch(t *testing.T) { initTransport := newTestTransport(t, crypto.Ed25519, 2048) respTransport := newTestTransport(t, crypto.Ed25519, 2048) init, resp := newConnPair(t) - var initErr error done := make(chan struct{}) go func() { defer close(done) - _, initErr = initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id") + conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) + assert.NoError(t, err) + assert.Equal(t, conn.RemotePeer(), respTransport.localID) + b := make([]byte, 6) + _, err = conn.Read(b) + assert.NoError(t, err) + assert.Equal(t, b, []byte("foobar")) }() - _, _ = respTransport.SecureInbound(context.TODO(), resp) - <-done + conn, err := respTransport.SecureInbound(context.TODO(), resp, initTransport.localID) + require.NoError(t, err) + require.Equal(t, conn.RemotePeer(), initTransport.localID) + _, err = conn.Write([]byte("foobar")) + require.NoError(t, err) +} - if initErr == nil { - t.Fatal("expected initiator to fail with peer ID mismatch error") - } +func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) { + initTransport := newTestTransport(t, crypto.Ed25519, 2048) + respTransport := newTestTransport(t, crypto.Ed25519, 2048) + init, resp := newConnPair(t) + + errChan := make(chan error) + go func() { + _, err := initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id") + errChan <- err + }() + + _, err := respTransport.SecureInbound(context.TODO(), resp, "") + require.Error(t, err) + + initErr := <-errChan + require.Error(t, initErr, "expected initiator to fail with peer ID mismatch error") + require.Contains(t, initErr.Error(), "but remote key matches") +} + +func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) { + initTransport := newTestTransport(t, crypto.Ed25519, 2048) + respTransport := newTestTransport(t, crypto.Ed25519, 2048) + init, resp := newConnPair(t) + + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) + assert.NoError(t, err) + _, err = conn.Read([]byte{0}) + assert.Error(t, err) + }() + + _, err := respTransport.SecureInbound(context.TODO(), resp, "a-random-peer-id") + require.Error(t, err, "expected responder to fail with peer ID mismatch error") + <-done } func makeLargePlaintext(size int) []byte {