From 60a3d1748ef53e10e92cc089eb7e057890be5a43 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 8 Sep 2021 11:34:32 +0100 Subject: [PATCH] pass the peer ID to SecureInbound in the SecureTransport and SecureMuxer (#211) The peer ID may be empty. This will be the common case. In that case, connections from any peer are accepted. --- sec/insecure/insecure.go | 6 +- sec/insecure/insecure_test.go | 155 ++++++++++++++-------------------- sec/security.go | 8 +- 3 files changed, 73 insertions(+), 96 deletions(-) diff --git a/sec/insecure/insecure.go b/sec/insecure/insecure.go index 487bb675..1fd52457 100644 --- a/sec/insecure/insecure.go +++ b/sec/insecure/insecure.go @@ -60,7 +60,7 @@ func (t *Transport) LocalPrivateKey() ci.PrivKey { // // SecureInbound may fail if the remote peer sends an ID and public key that are inconsistent // with each other, or if a network error occurs during the ID exchange. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) { +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { conn := &Conn{ Conn: insecure, local: t.id, @@ -72,6 +72,10 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.S return nil, err } + if t.key != nil && p != "" && p != conn.remote { + return nil, fmt.Errorf("remote peer sent unexpected peer ID. expected=%s received=%s", p, conn.remote) + } + return conn, nil } diff --git a/sec/insecure/insecure_test.go b/sec/insecure/insecure_test.go index 311cf4ef..433df841 100644 --- a/sec/insecure/insecure_test.go +++ b/sec/insecure/insecure_test.go @@ -1,157 +1,128 @@ package insecure import ( - "bytes" "context" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/sec" "io" "net" "testing" - ci "github.com/libp2p/go-libp2p-core/crypto" + "github.com/stretchr/testify/require" + + "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/sec" ) // Run a set of sessions through the session setup and verification. func TestConnections(t *testing.T) { - clientTpt := newTestTransport(t, ci.RSA, 2048) - serverTpt := newTestTransport(t, ci.Ed25519, 1024) + clientTpt := newTestTransport(t, crypto.RSA, 2048) + serverTpt := newTestTransport(t, crypto.Ed25519, 1024) + + clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "") + require.NoError(t, clientErr) + require.NoError(t, serverErr) + testIDs(t, clientTpt, serverTpt, clientConn, serverConn) + testKeys(t, clientTpt, serverTpt, clientConn, serverConn) + testReadWrite(t, clientConn, serverConn) +} + +func TestPeerIdMatchInbound(t *testing.T) { + clientTpt := newTestTransport(t, crypto.RSA, 2048) + serverTpt := newTestTransport(t, crypto.Ed25519, 1024) - testConnection(t, clientTpt, serverTpt) + clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), clientTpt.LocalPeer()) + require.NoError(t, clientErr) + require.NoError(t, serverErr) + testIDs(t, clientTpt, serverTpt, clientConn, serverConn) + testKeys(t, clientTpt, serverTpt, clientConn, serverConn) + testReadWrite(t, clientConn, serverConn) +} + +func TestPeerIDMismatchInbound(t *testing.T) { + clientTpt := newTestTransport(t, crypto.RSA, 2048) + serverTpt := newTestTransport(t, crypto.Ed25519, 1024) + + _, _, _, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "a-random-peer") + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "remote peer sent unexpected peer ID") +} + +func TestPeerIDMismatchOutbound(t *testing.T) { + clientTpt := newTestTransport(t, crypto.RSA, 2048) + serverTpt := newTestTransport(t, crypto.Ed25519, 1024) + + _, _, clientErr, _ := connect(t, clientTpt, serverTpt, "a random peer", "") + require.Error(t, clientErr) + require.Contains(t, clientErr.Error(), "remote peer sent unexpected peer ID") } func newTestTransport(t *testing.T, typ, bits int) *Transport { - priv, pub, err := ci.GenerateKeyPair(typ, bits) - if err != nil { - t.Fatal(err) - } + priv, pub, err := crypto.GenerateKeyPair(typ, bits) + require.NoError(t, err) id, err := peer.IDFromPublicKey(pub) - if err != nil { - t.Fatal(err) - } - + require.NoError(t, err) return NewWithIdentity(id, priv) } // Create a new pair of connected TCP sockets. func newConnPair(t *testing.T) (net.Conn, net.Conn) { lstnr, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to listen: %v", err) - return nil, nil - } + require.NoError(t, err, "failed to listen") var clientErr error var client net.Conn - addr := lstnr.Addr() done := make(chan struct{}) go func() { defer close(done) + addr := lstnr.Addr() client, clientErr = net.Dial(addr.Network(), addr.String()) }() server, err := lstnr.Accept() - <-done + require.NoError(t, err, "failed to accept") + <-done lstnr.Close() - - if err != nil { - t.Fatalf("Failed to accept: %v", err) - } - - if clientErr != nil { - t.Fatalf("Failed to connect: %v", clientErr) - } - + require.NoError(t, clientErr, "failed to connect") return client, server } -// Create a new pair of connected sessions based off of the provided -// session generators. -func connect(t *testing.T, clientTpt, serverTpt *Transport) (sec.SecureConn, sec.SecureConn) { +func connect(t *testing.T, clientTpt, serverTpt *Transport, clientExpectsID, serverExpectsID peer.ID) (clientConn sec.SecureConn, serverConn sec.SecureConn, clientErr, serverErr error) { client, server := newConnPair(t) - // Connect the client and server sessions done := make(chan struct{}) - - var clientConn sec.SecureConn - var clientErr error go func() { defer close(done) - clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, serverTpt.LocalPeer()) + clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID) }() - - serverConn, serverErr := serverTpt.SecureInbound(context.TODO(), server) + serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID) <-done - - if serverErr != nil { - t.Fatal(serverErr) - } - - if clientErr != nil { - t.Fatal(clientErr) - } - - return clientConn, serverConn + return } // Check the peer IDs func testIDs(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) { - if clientConn.LocalPeer() != clientTpt.LocalPeer() { - t.Fatal("Client Local Peer ID mismatch.") - } - - if clientConn.RemotePeer() != serverTpt.LocalPeer() { - t.Fatal("Client Remote Peer ID mismatch.") - } - - if clientConn.LocalPeer() != serverConn.RemotePeer() { - t.Fatal("Server Local Peer ID mismatch.") - } + require.Equal(t, clientConn.LocalPeer(), clientTpt.LocalPeer(), "Client Local Peer ID mismatch.") + require.Equal(t, clientConn.RemotePeer(), serverTpt.LocalPeer(), "Client Remote Peer ID mismatch.") + require.Equal(t, clientConn.LocalPeer(), serverConn.RemotePeer(), "Server Local Peer ID mismatch.") } // Check the keys func testKeys(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) { sk := serverConn.LocalPrivateKey() - pk := sk.GetPublic() - - if !sk.Equals(serverTpt.LocalPrivateKey()) { - t.Error("Private key Mismatch.") - } - - if !pk.Equals(clientConn.RemotePublicKey()) { - t.Error("Public key mismatch.") - } + require.True(t, sk.Equals(serverTpt.LocalPrivateKey()), "private key mismatch") + require.True(t, sk.GetPublic().Equals(clientConn.RemotePublicKey()), "public key mismatch") } // Check sending and receiving messages func testReadWrite(t *testing.T, clientConn, serverConn sec.SecureConn) { before := []byte("hello world") _, err := clientConn.Write(before) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) after := make([]byte, len(before)) _, err = io.ReadFull(serverConn, after) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(before, after) { - t.Errorf("Message mismatch. %v != %v", before, after) - } -} - -// Setup a new session with a pair of locally connected sockets -func testConnection(t *testing.T, clientTpt, serverTpt *Transport) { - clientConn, serverConn := connect(t, clientTpt, serverTpt) - - testIDs(t, clientTpt, serverTpt, clientConn, serverConn) - testKeys(t, clientTpt, serverTpt, clientConn, serverConn) - testReadWrite(t, clientConn, serverConn) - - clientConn.Close() - serverConn.Close() + require.NoError(t, err) + require.Equal(t, before, after, "message mismatch") } diff --git a/sec/security.go b/sec/security.go index 42321d18..a4cd7a2e 100644 --- a/sec/security.go +++ b/sec/security.go @@ -19,7 +19,8 @@ type SecureConn interface { // plain-text, native connections into authenticated, encrypted connections. type SecureTransport interface { // SecureInbound secures an inbound connection. - SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, error) + // If p is empty, connections from any peer are accepted. + SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) // SecureOutbound secures an outbound connection. SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) @@ -29,9 +30,10 @@ type SecureTransport interface { // and open outbound connections with simultaneous open. type SecureMuxer interface { // SecureInbound secures an inbound connection. - // The returned boolean indicates whether the connection should be trated as a server + // The returned boolean indicates whether the connection should be treated as a server // connection; in the case of SecureInbound it should always be true. - SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, bool, error) + // If p is empty, connections from any peer are accepted. + SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error) // SecureOutbound secures an outbound connection. // The returned boolean indicates whether the connection should be treated as a server