Skip to content
This repository has been archived by the owner on Sep 6, 2022. It is now read-only.

pass the peer ID to SecureInbound in the SecureTransport and SecureMuxer #211

Merged
merged 1 commit into from
Sep 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sec/insecure/insecure.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (t *Transport) LocalPrivateKey() ci.PrivKey {
//
// SecureInbound may fail if the remote peer sends an ID and public key that are inconsistent
// with each other, or if a network error occurs during the ID exchange.
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.SecureConn, error) {
func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) {
conn := &Conn{
Conn: insecure,
local: t.id,
Expand All @@ -72,6 +72,10 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn) (sec.S
return nil, err
}

if t.key != nil && p != "" && p != conn.remote {
return nil, fmt.Errorf("remote peer sent unexpected peer ID. expected=%s received=%s", p, conn.remote)
}

return conn, nil
}

Expand Down
155 changes: 63 additions & 92 deletions sec/insecure/insecure_test.go
Original file line number Diff line number Diff line change
@@ -1,157 +1,128 @@
package insecure

import (
"bytes"
"context"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"
"io"
"net"
"testing"

ci "github.com/libp2p/go-libp2p-core/crypto"
"github.com/stretchr/testify/require"

"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"
)

// Run a set of sessions through the session setup and verification.
func TestConnections(t *testing.T) {
clientTpt := newTestTransport(t, ci.RSA, 2048)
serverTpt := newTestTransport(t, ci.Ed25519, 1024)
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)

clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "")
require.NoError(t, clientErr)
require.NoError(t, serverErr)
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
testReadWrite(t, clientConn, serverConn)
}

func TestPeerIdMatchInbound(t *testing.T) {
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)

testConnection(t, clientTpt, serverTpt)
clientConn, serverConn, clientErr, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), clientTpt.LocalPeer())
require.NoError(t, clientErr)
require.NoError(t, serverErr)
testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
testReadWrite(t, clientConn, serverConn)
}

func TestPeerIDMismatchInbound(t *testing.T) {
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)

_, _, _, serverErr := connect(t, clientTpt, serverTpt, serverTpt.LocalPeer(), "a-random-peer")
require.Error(t, serverErr)
require.Contains(t, serverErr.Error(), "remote peer sent unexpected peer ID")
}

func TestPeerIDMismatchOutbound(t *testing.T) {
clientTpt := newTestTransport(t, crypto.RSA, 2048)
serverTpt := newTestTransport(t, crypto.Ed25519, 1024)

_, _, clientErr, _ := connect(t, clientTpt, serverTpt, "a random peer", "")
require.Error(t, clientErr)
require.Contains(t, clientErr.Error(), "remote peer sent unexpected peer ID")
}

func newTestTransport(t *testing.T, typ, bits int) *Transport {
priv, pub, err := ci.GenerateKeyPair(typ, bits)
if err != nil {
t.Fatal(err)
}
priv, pub, err := crypto.GenerateKeyPair(typ, bits)
require.NoError(t, err)
id, err := peer.IDFromPublicKey(pub)
if err != nil {
t.Fatal(err)
}

require.NoError(t, err)
return NewWithIdentity(id, priv)
}

// Create a new pair of connected TCP sockets.
func newConnPair(t *testing.T) (net.Conn, net.Conn) {
lstnr, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
return nil, nil
}
require.NoError(t, err, "failed to listen")

var clientErr error
var client net.Conn
addr := lstnr.Addr()
done := make(chan struct{})

go func() {
defer close(done)
addr := lstnr.Addr()
client, clientErr = net.Dial(addr.Network(), addr.String())
}()

server, err := lstnr.Accept()
<-done
require.NoError(t, err, "failed to accept")

<-done
lstnr.Close()

if err != nil {
t.Fatalf("Failed to accept: %v", err)
}

if clientErr != nil {
t.Fatalf("Failed to connect: %v", clientErr)
}

require.NoError(t, clientErr, "failed to connect")
return client, server
}

// Create a new pair of connected sessions based off of the provided
// session generators.
func connect(t *testing.T, clientTpt, serverTpt *Transport) (sec.SecureConn, sec.SecureConn) {
func connect(t *testing.T, clientTpt, serverTpt *Transport, clientExpectsID, serverExpectsID peer.ID) (clientConn sec.SecureConn, serverConn sec.SecureConn, clientErr, serverErr error) {
client, server := newConnPair(t)

// Connect the client and server sessions
done := make(chan struct{})

var clientConn sec.SecureConn
var clientErr error
go func() {
defer close(done)
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, serverTpt.LocalPeer())
clientConn, clientErr = clientTpt.SecureOutbound(context.TODO(), client, clientExpectsID)
}()

serverConn, serverErr := serverTpt.SecureInbound(context.TODO(), server)
serverConn, serverErr = serverTpt.SecureInbound(context.TODO(), server, serverExpectsID)
<-done

if serverErr != nil {
t.Fatal(serverErr)
}

if clientErr != nil {
t.Fatal(clientErr)
}

return clientConn, serverConn
return
}

// Check the peer IDs
func testIDs(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
if clientConn.LocalPeer() != clientTpt.LocalPeer() {
t.Fatal("Client Local Peer ID mismatch.")
}

if clientConn.RemotePeer() != serverTpt.LocalPeer() {
t.Fatal("Client Remote Peer ID mismatch.")
}

if clientConn.LocalPeer() != serverConn.RemotePeer() {
t.Fatal("Server Local Peer ID mismatch.")
}
require.Equal(t, clientConn.LocalPeer(), clientTpt.LocalPeer(), "Client Local Peer ID mismatch.")
require.Equal(t, clientConn.RemotePeer(), serverTpt.LocalPeer(), "Client Remote Peer ID mismatch.")
require.Equal(t, clientConn.LocalPeer(), serverConn.RemotePeer(), "Server Local Peer ID mismatch.")
}

// Check the keys
func testKeys(t *testing.T, clientTpt, serverTpt *Transport, clientConn, serverConn sec.SecureConn) {
sk := serverConn.LocalPrivateKey()
pk := sk.GetPublic()

if !sk.Equals(serverTpt.LocalPrivateKey()) {
t.Error("Private key Mismatch.")
}

if !pk.Equals(clientConn.RemotePublicKey()) {
t.Error("Public key mismatch.")
}
require.True(t, sk.Equals(serverTpt.LocalPrivateKey()), "private key mismatch")
require.True(t, sk.GetPublic().Equals(clientConn.RemotePublicKey()), "public key mismatch")
}

// Check sending and receiving messages
func testReadWrite(t *testing.T, clientConn, serverConn sec.SecureConn) {
before := []byte("hello world")
_, err := clientConn.Write(before)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

after := make([]byte, len(before))
_, err = io.ReadFull(serverConn, after)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(before, after) {
t.Errorf("Message mismatch. %v != %v", before, after)
}
}

// Setup a new session with a pair of locally connected sockets
func testConnection(t *testing.T, clientTpt, serverTpt *Transport) {
clientConn, serverConn := connect(t, clientTpt, serverTpt)

testIDs(t, clientTpt, serverTpt, clientConn, serverConn)
testKeys(t, clientTpt, serverTpt, clientConn, serverConn)
testReadWrite(t, clientConn, serverConn)

clientConn.Close()
serverConn.Close()
require.NoError(t, err)
require.Equal(t, before, after, "message mismatch")
}
8 changes: 5 additions & 3 deletions sec/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ type SecureConn interface {
// plain-text, native connections into authenticated, encrypted connections.
type SecureTransport interface {
// SecureInbound secures an inbound connection.
SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, error)
// If p is empty, connections from any peer are accepted.
SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error)

// SecureOutbound secures an outbound connection.
SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, error)
Expand All @@ -29,9 +30,10 @@ type SecureTransport interface {
// and open outbound connections with simultaneous open.
type SecureMuxer interface {
// SecureInbound secures an inbound connection.
// The returned boolean indicates whether the connection should be trated as a server
// The returned boolean indicates whether the connection should be treated as a server
// connection; in the case of SecureInbound it should always be true.
SecureInbound(ctx context.Context, insecure net.Conn) (SecureConn, bool, error)
// If p is empty, connections from any peer are accepted.
SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (SecureConn, bool, error)

// SecureOutbound secures an outbound connection.
// The returned boolean indicates whether the connection should be treated as a server
Expand Down