From b0e6ce9862eb4626b6199cc9d2604b2aa8de5ed2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 5 Sep 2021 13:24:32 +0100 Subject: [PATCH] pass the peer ID to SecureInbound in the SecureTransport and SecureMuxer The peer ID may be empty. This will be the common case. In that case, connections from any peer are accepted. --- sec/insecure/insecure.go | 6 +- sec/insecure/insecure_test.go | 155 ++++++++++++++-------------------- sec/security.go | 8 +- 3 files changed, 73 insertions(+), 96 deletions(-) diff --git a/sec/insecure/insecure.go b/sec/insecure/insecure.go index 487bb675..1fd52457 100644 --- a/sec/insecure/insecure.go +++ b/sec/insecure/insecure.go @@ -60,7 +60,7 @@ func (t *Transport) LocalPrivateKey() ci.PrivKey { // // SecureInbound may fail if the remote peer sends an ID and public key that are inconsistent // with each other, or if a network error occurs during the ID exchange. -func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) { +func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { conn := &Conn{ Conn: insecure, local: t.id, @@ -72,6 +72,10 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.S return nil, err } + if t.key != nil && p != "" && p != conn.remote { + return nil, fmt.Errorf("remote peer sent unexpected peer ID. expected=%s received=%s", p, conn.remote) + } + return conn, nil } diff --git a/sec/insecure/insecure_test.go b/sec/insecure/insecure_test.go index 311cf4ef..433df841 100644 --- a/sec/insecure/insecure_test.go +++ b/sec/insecure/insecure_test.go @@ -1,157 +1,128 @@ package insecure import ( - "bytes" "context" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/sec" "io" "net" "testing" - ci "github.com/libp2p/go-libp2p-core/crypto" + "github.com/stretchr/testify/require" + + "github.com/libp2p/go-libp2p-core/crypto" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/sec" ) // Run a set of sessions through the session setup and verification. func TestConnections(t *testing.T) { - clientTpt := newTestTransport(t, ci.RSA, 2048) - serverTpt := newTestTransport(t, ci.Ed25519, 1024) + clientTpt := newTestTransport(t, crypto.RSA, 2048) + serverTpt := newTestTransport(t, crypto.Ed25519, 1024) + + clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "") + require.NoError(t, clientErr) + require.NoError(t, serverErr) + testIDs(t, clientTpt, serverTpt, clientConn, serverConn) + testKeys(t, clientTpt, serverTpt, clientConn, serverConn) + testReadWrite(t, clientConn, serverConn) +} + +func TestPeerIdMatchInbound(t *testing.T) { + clientTpt := newTestTransport(t, crypto.RSA, 2048) + serverTpt := newTestTransport(t, crypto.Ed25519, 1024) - testConnection(t, clientTpt, serverTpt) + clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), clientTpt.LocalPeer()) + require.NoError(t, clientErr) + require.NoError(t, serverErr) + testIDs(t, clientTpt, serverTpt, clientConn, serverConn) + testKeys(t, clientTpt, serverTpt, clientConn, serverConn) + testReadWrite(t, clientConn, serverConn) +} + +func TestPeerIDMismatchInbound(t *testing.T) { + clientTpt := newTestTransport(t, crypto.RSA, 2048) + serverTpt := newTestTransport(t, crypto.Ed25519, 1024) + + _, _, _, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "a-random-peer") + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "remote peer sent unexpected peer ID") +} + +func TestPeerIDMismatchOutbound(t *testing.T) { + clientTpt := newTestTransport(t, crypto.RSA, 2048) + serverTpt := newTestTransport(t, crypto.Ed25519, 1024) + + _, _, clientErr, _ := connect(t, clientTpt, serverTpt, "a random peer", "") + require.Error(t, clientErr) + require.Contains(t, clientErr.Error(), "remote peer sent unexpected peer ID") } func newTestTransport(t *testing.T, typ, bits int) *Transport { - priv, pub, err := ci.GenerateKeyPair(typ, bits) - if err != nil { - t.Fatal(err) - } + priv, pub, err := crypto.GenerateKeyPair(typ, bits) + require.NoError(t, err) id, err := peer.IDFromPublicKey(pub) - if err != nil { - t.Fatal(err) - } - + require.NoError(t, err) return NewWithIdentity(id, priv) } // Create a new pair of connected TCP sockets. func newConnPair(t *testing.T) (net.Conn, net.Conn) { lstnr, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Failed to listen: %v", err) - return nil, nil - } + require.NoError(t, err, "failed to listen") var clientErr error var client net.Conn - addr := lstnr.Addr() done := make(chan struct{}) go func() { defer close(done) + addr := lstnr.Addr() client, clientErr = net.Dial(addr.Network(), addr.String()) }() server, err := lstnr.Accept() - <-done + require.NoError(t, err, "failed to accept") + <-done lstnr.Close() - - if err != nil { - t.Fatalf("Failed to accept: %v", err) - } - - if clientErr != nil { - t.Fatalf("Failed to connect: %v", clientErr) - } - + require.NoError(t, clientErr, "failed to connect") return client, server } -// Create a new pair of connected sessions based off of the provided -// session generators. -func connect(t *testing.T, clientTpt, serverTpt *Transport) (sec.SecureConn, sec.SecureConn) { +func connect(t *testing.T, clientTpt, serverTpt *Transport, clientExpectsID, serverExpectsID peer.ID) (clientConn sec.SecureConn, serverConn sec.SecureConn, clientErr, serverErr error) { client, server := newConnPair(t) - // Connect the client and server sessions done := make(chan struct{}) - - var clientConn sec.SecureConn - var clientErr error go func() { defer close(done) - clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, serverTpt.LocalPeer()) + clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID) }() - - serverConn, serverErr := serverTpt.SecureInbound(context.TODO(), server) + serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID) <-done - - if serverErr != nil { - t.Fatal(serverErr) - } - - if clientErr != nil { - t.Fatal(clientErr) - } - - return clientConn, serverConn + return } // Check the peer IDs func testIDs(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) { - if clientConn.LocalPeer() != clientTpt.LocalPeer() { - t.Fatal("Client Local Peer ID mismatch.") - } - - if clientConn.RemotePeer() != serverTpt.LocalPeer() { - t.Fatal("Client Remote Peer ID mismatch.") - } - - if clientConn.LocalPeer() != serverConn.RemotePeer() { - t.Fatal("Server Local Peer ID mismatch.") - } + require.Equal(t, clientConn.LocalPeer(), clientTpt.LocalPeer(), "Client Local Peer ID mismatch.") + require.Equal(t, clientConn.RemotePeer(), serverTpt.LocalPeer(), "Client Remote Peer ID mismatch.") + require.Equal(t, clientConn.LocalPeer(), serverConn.RemotePeer(), "Server Local Peer ID mismatch.") } // Check the keys func testKeys(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) { sk := serverConn.LocalPrivateKey() - pk := sk.GetPublic() - - if !sk.Equals(serverTpt.LocalPrivateKey()) { - t.Error("Private key Mismatch.") - } - - if !pk.Equals(clientConn.RemotePublicKey()) { - t.Error("Public key mismatch.") - } + require.True(t, sk.Equals(serverTpt.LocalPrivateKey()), "private key mismatch") + require.True(t, sk.GetPublic().Equals(clientConn.RemotePublicKey()), "public key mismatch") } // Check sending and receiving messages func testReadWrite(t *testing.T, clientConn, serverConn sec.SecureConn) { before := []byte("hello world") _, err := clientConn.Write(before) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) after := make([]byte, len(before)) _, err = io.ReadFull(serverConn, after) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(before, after) { - t.Errorf("Message mismatch. %v != %v", before, after) - } -} - -// Setup a new session with a pair of locally connected sockets -func testConnection(t *testing.T, clientTpt, serverTpt *Transport) { - clientConn, serverConn := connect(t, clientTpt, serverTpt) - - testIDs(t, clientTpt, serverTpt, clientConn, serverConn) - testKeys(t, clientTpt, serverTpt, clientConn, serverConn) - testReadWrite(t, clientConn, serverConn) - - clientConn.Close() - serverConn.Close() + require.NoError(t, err) + require.Equal(t, before, after, "message mismatch") } diff --git a/sec/security.go b/sec/security.go index 42321d18..a4cd7a2e 100644 --- a/sec/security.go +++ b/sec/security.go @@ -19,7 +19,8 @@ type SecureConn interface { // plain-text, native connections into authenticated, encrypted connections. type SecureTransport interface { // SecureInbound secures an inbound connection. - SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, error) + // If p is empty, connections from any peer are accepted. + SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) // SecureOutbound secures an outbound connection. SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error) @@ -29,9 +30,10 @@ type SecureTransport interface { // and open outbound connections with simultaneous open. type SecureMuxer interface { // SecureInbound secures an inbound connection. - // The returned boolean indicates whether the connection should be trated as a server + // The returned boolean indicates whether the connection should be treated as a server // connection; in the case of SecureInbound it should always be true. - SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, bool, error) + // If p is empty, connections from any peer are accepted. + SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error) // SecureOutbound secures an outbound connection. // The returned boolean indicates whether the connection should be treated as a server