Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: add comprehensive end-to-end tests for connection gating #2200

Merged
merged 3 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion core/connmgr/gater.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ import (
// DisconnectReasons is that we require stream multiplexing capability to open a
// control protocol stream to transmit the message.
type ConnectionGater interface {

// InterceptPeerDial tests whether we're permitted to Dial the specified peer.
//
// This is called by the network.Network implementation when dialling a peer.
Expand Down
237 changes: 237 additions & 0 deletions p2p/test/transport/gating_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
package transport_integration

import (
"context"
"testing"
"time"

"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/net/swarm"

"github.com/libp2p/go-libp2p-testing/race"

"github.com/golang/mock/gomock"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)

//go:generate go run github.com/golang/mock/mockgen -package transport_integration -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater

func stripCertHash(addr ma.Multiaddr) ma.Multiaddr {
for {
if _, err := addr.ValueForProtocol(ma.P_CERTHASH); err != nil {
break
}
addr, _ = ma.SplitLast(addr)
}
return addr
}

func TestInterceptPeerDial(t *testing.T) {
if race.WithRace() {
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
}
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)

h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, ConnGater: connGater})
h2 := tc.HostGenerator(t, TransportTestCaseOpts{})
require.Len(t, h2.Addrs(), 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
connGater.EXPECT().InterceptPeerDial(h2.ID())
require.ErrorIs(t, h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}), swarm.ErrGaterDisallowedConnection)
})
}
}

func TestInterceptAddrDial(t *testing.T) {
if race.WithRace() {
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
}
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)

h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, ConnGater: connGater})
h2 := tc.HostGenerator(t, TransportTestCaseOpts{})
require.Len(t, h2.Addrs(), 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
gomock.InOrder(
connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true),
connGater.EXPECT().InterceptAddrDial(h2.ID(), h2.Addrs()[0]),
)
require.ErrorIs(t, h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}), swarm.ErrNoGoodAddresses)
})
}
}

func TestInterceptSecuredOutgoing(t *testing.T) {
if race.WithRace() {
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
}
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)

h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, ConnGater: connGater})
h2 := tc.HostGenerator(t, TransportTestCaseOpts{})
require.Len(t, h2.Addrs(), 1)
require.Len(t, h2.Addrs(), 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
gomock.InOrder(
connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true),
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String())
}),
)
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
}
}

func TestInterceptUpgradedOutgoing(t *testing.T) {
if race.WithRace() {
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
}
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)

h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, ConnGater: connGater})
h2 := tc.HostGenerator(t, TransportTestCaseOpts{})
require.Len(t, h2.Addrs(), 1)
require.Len(t, h2.Addrs(), 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
gomock.InOrder(
connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true),
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr())
require.Equal(t, h1.ID(), c.LocalPeer())
require.Equal(t, h2.ID(), c.RemotePeer())
}))
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
}
}

func TestInterceptAccept(t *testing.T) {
if race.WithRace() {
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
}
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)

h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
h2 := tc.HostGenerator(t, TransportTestCaseOpts{ConnGater: connGater})
require.Len(t, h2.Addrs(), 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// The basic host dials the first connection.
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
})
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
}
}

func TestInterceptSecuredIncoming(t *testing.T) {
if race.WithRace() {
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
}
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)

h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
h2 := tc.HostGenerator(t, TransportTestCaseOpts{ConnGater: connGater})
require.Len(t, h2.Addrs(), 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
gomock.InOrder(
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
}
}

func TestInterceptUpgradedIncoming(t *testing.T) {
if race.WithRace() {
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
}
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)

h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
h2 := tc.HostGenerator(t, TransportTestCaseOpts{ConnGater: connGater})
require.Len(t, h2.Addrs(), 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
gomock.InOrder(
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr())
require.Equal(t, h1.ID(), c.RemotePeer())
require.Equal(t, h2.ID(), c.LocalPeer())
}),
)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
require.Error(t, err)
require.NotErrorIs(t, err, context.DeadlineExceeded)
})
}
}
109 changes: 109 additions & 0 deletions p2p/test/transport/mock_connection_gater_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p/config"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
Expand All @@ -30,8 +31,9 @@ type TransportTestCase struct {
}

type TransportTestCaseOpts struct {
NoListen bool
NoRcmgr bool
NoListen bool
NoRcmgr bool
ConnGater connmgr.ConnectionGater
}

func transformOpts(opts TransportTestCaseOpts) []config.Option {
Expand All @@ -40,6 +42,9 @@ func transformOpts(opts TransportTestCaseOpts) []config.Option {
if opts.NoRcmgr {
libp2pOpts = append(libp2pOpts, libp2p.ResourceManager(&network.NullResourceManager{}))
}
if opts.ConnGater != nil {
libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater(opts.ConnGater))
}
return libp2pOpts
}

Expand Down