Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

add the peer ID to SecureInbound #104

Merged
merged 1 commit into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (b benchenv) connect(stopTimer bool) (*secureSession, *secureSession) {
initSession, initErr = b.initTpt.SecureOutbound(context.TODO(), initConn, b.respTpt.localID)
}()

respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn)
respSession, respErr := b.respTpt.SecureInbound(context.TODO(), respConn, "")
<-done

if initErr != nil {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/flynn/noise v1.0.0
github.com/gogo/protobuf v1.3.2
github.com/libp2p/go-buffer-pool v0.0.2
github.com/libp2p/go-libp2p-core v0.9.0
github.com/libp2p/go-libp2p-core v0.10.0
github.com/multiformats/go-multiaddr v0.3.3 // indirect
github.com/multiformats/go-multihash v0.0.15 // indirect
github.com/stretchr/testify v1.7.0
Expand Down
7 changes: 2 additions & 5 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ipfs/go-cid v0.0.7 h1:ysQJVJA3fNDF1qigJbsSQOdjhVLsOEoPdh0+R97k3jY=
github.com/ipfs/go-cid v0.0.7/go.mod h1:6Ux9z5e+HpkQdckYoX1PG/6xqKspzlEIR5SDmgqgC/I=
github.com/jbenet/go-cienv v0.1.0/go.mod h1:TqNnHUmJgXau0nCzC7kXWeotg3J9W34CUv5Djy1+FlA=
github.com/jbenet/goprocess v0.1.4 h1:DRGOFReOMqqDNXwW70QkacFW0YN9QnwLV0Vqk+3oU0o=
github.com/jbenet/goprocess v0.1.4/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4=
github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ=
Expand All @@ -48,8 +45,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/libp2p/go-buffer-pool v0.0.2 h1:QNK2iAFa8gjAe1SPz6mHSMuCcjs+X1wlHzeOSqcmlfs=
github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoRZd1Vi32+RXyFM=
github.com/libp2p/go-flow-metrics v0.0.3/go.mod h1:HeoSNUrOJVK1jEpDqVEiUOIXqhbnS27omG0uWU5slZs=
github.com/libp2p/go-libp2p-core v0.9.0 h1:t97Mv0LIBZlP2FXVRNKKVzHJCIjbIWGxYptGId4+htU=
github.com/libp2p/go-libp2p-core v0.9.0/go.mod h1:ESsbz31oC3C1AvMJoGx26RTuCkNhmkSRCqZ0kQtJ2/8=
github.com/libp2p/go-libp2p-core v0.10.0 h1:jFy7v5Muq58GTeYkPhGzIH8Qq4BFfziqc0ixPd/pP9k=
github.com/libp2p/go-libp2p-core v0.10.0/go.mod h1:ECdxehoYosLYHgDDFa2N4yE8Y7aQRAMf0sX9mf2sbGg=
github.com/libp2p/go-maddr-filter v0.1.0/go.mod h1:VzZhTXkMucEGGEOSKddrwGiOv0tUhgnKqNEmIAz/bPU=
github.com/libp2p/go-msgio v0.0.6/go.mod h1:4ecVB6d9f4BDSL5fqvPiC4A3KivjWn+Venn/1ALLMWA=
github.com/libp2p/go-openssl v0.0.7 h1:eCAzdLejcNVBzP/iZM9vqHnQm+XyCEbSSIheIPRGNsw=
Expand Down
10 changes: 6 additions & 4 deletions handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
}
}

// We can re-use this buffer for all handshake messages as it's size
// We can re-use this buffer for all handshake messages as its size
// will be the size of the maximum handshake message for the Noise XX pattern.
// Also, since we prefix every noise handshake message with it's length, we need to account for
// Also, since we prefix every noise handshake message with its length, we need to account for
// it when we fetch the buffer from the pool
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*poly1305.TagSize
hbuf := pool.Get(maxMsgSize + LengthPrefixLength)
Expand Down Expand Up @@ -242,8 +242,10 @@ func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStati
return err
}

// if we know who we're trying to reach, make sure we have the right peer
if s.initiator && s.remoteID != id {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to assert that we know the remote peer ID if we're the initiator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to get to the point where we can dial a multiaddr and learn the remote peer ID from the handshake, but this PR is probably not the place to work towards that. Will adjust the logic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. We just need to be careful about that to make sure we can't get tricked into dialing the wrong peer.

Although I guess we enforce this in the swarm so it's probably not an issue.

// check the peer ID for:
// * all outbound connection
// * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID)
if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) {
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
}
Expand Down
5 changes: 3 additions & 2 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ func New(privkey crypto.PrivKey) (*Transport, error) {
}

// SecureInbound runs the Noise handshake as the responder.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, "", false)
// If p is empty, connections from any peer are accepted.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
return newSecureSession(t, ctx, insecure, p, false)
}

// SecureOutbound runs the Noise handshake as the initiator.
Expand Down
67 changes: 56 additions & 11 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ import (
"context"
"encoding/binary"
"errors"
"golang.org/x/crypto/poly1305"
"io"
"math/rand"
"net"
"testing"
"time"

crypto "github.com/libp2p/go-libp2p-core/crypto"
"github.com/stretchr/testify/assert"

"golang.org/x/crypto/poly1305"

"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"

Expand Down Expand Up @@ -79,7 +82,7 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess
initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
}()

respConn, respErr := respTransport.SecureInbound(context.TODO(), resp)
respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "")
<-done

if initErr != nil {
Expand Down Expand Up @@ -161,24 +164,66 @@ func TestKeys(t *testing.T) {
}
}

func TestPeerIDMismatchFailsHandshake(t *testing.T) {
func TestPeerIDMatch(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
init, resp := newConnPair(t)

var initErr error
done := make(chan struct{})
go func() {
defer close(done)
_, initErr = initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id")
conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
assert.NoError(t, err)
assert.Equal(t, conn.RemotePeer(), respTransport.localID)
b := make([]byte, 6)
_, err = conn.Read(b)
assert.NoError(t, err)
assert.Equal(t, b, []byte("foobar"))
}()

_, _ = respTransport.SecureInbound(context.TODO(), resp)
<-done
conn, err := respTransport.SecureInbound(context.TODO(), resp, initTransport.localID)
require.NoError(t, err)
require.Equal(t, conn.RemotePeer(), initTransport.localID)
_, err = conn.Write([]byte("foobar"))
require.NoError(t, err)
}

if initErr == nil {
t.Fatal("expected initiator to fail with peer ID mismatch error")
}
func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
init, resp := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id")
errChan <- err
}()

_, err := respTransport.SecureInbound(context.TODO(), resp, "")
require.Error(t, err)

initErr := <-errChan
require.Error(t, initErr, "expected initiator to fail with peer ID mismatch error")
require.Contains(t, initErr.Error(), "but remote key matches")
}

func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) {
initTransport := newTestTransport(t, crypto.Ed25519, 2048)
respTransport := newTestTransport(t, crypto.Ed25519, 2048)
init, resp := newConnPair(t)

done := make(chan struct{})
go func() {
defer close(done)
conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID)
assert.NoError(t, err)
_, err = conn.Read([]byte{0})
assert.Error(t, err)
}()

_, err := respTransport.SecureInbound(context.TODO(), resp, "a-random-peer-id")
require.Error(t, err, "expected responder to fail with peer ID mismatch error")
<-done
}

func makeLargePlaintext(size int) []byte {
Expand Down