diff --git a/config/config.go b/config/config.go index fb5a2ab1b1..900c06bc30 100644 --- a/config/config.go +++ b/config/config.go @@ -38,6 +38,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/prometheus/client_golang/prometheus" @@ -145,6 +146,8 @@ type Config struct { CustomIPv6BlackHoleSuccessCounter bool UserFxOptions []fx.Option + + ShareTCPListener bool } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -289,6 +292,12 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }), fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), + fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr { + if !cfg.ShareTCPListener { + return nil + } + return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr) + }), fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn { hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool { quicAddrPorts := map[string]struct{}{} diff --git a/libp2p_test.go b/libp2p_test.go index b290227fc1..3de82946d8 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -59,7 +59,7 @@ func TestTransportConstructor(t *testing.T) { _ connmgr.ConnectionGater, upgrader transport.Upgrader, ) transport.Transport { - tpt, err := tcp.NewTCPTransport(upgrader, nil) + tpt, err := tcp.NewTCPTransport(upgrader, nil, nil) require.NoError(t, err) return tpt } @@ -751,3 +751,27 @@ func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { }}, } } + +func TestSharedTCPAddr(t *testing.T) { + h, err := New( + ShareTCPListener(), + Transport(tcp.NewTCPTransport), + Transport(websocket.New), + ListenAddrStrings("/ip4/0.0.0.0/tcp/8888"), + ListenAddrStrings("/ip4/0.0.0.0/tcp/8888/ws"), + ) + require.NoError(t, err) + sawTCP := false + sawWS := false + for _, addr := range h.Addrs() { + if strings.HasSuffix(addr.String(), "/tcp/8888") { + sawTCP = true + } + if strings.HasSuffix(addr.String(), "/tcp/8888/ws") { + sawWS = true + } + } + require.True(t, sawTCP) + require.True(t, sawWS) + h.Close() +} diff --git a/options.go b/options.go index 4fbf8eb2ac..0329b7e60b 100644 --- a/options.go +++ b/options.go @@ -643,3 +643,15 @@ func WithFxOption(opts ...fx.Option) Option { return nil } } + +// ShareTCPListener shares the same listen address between TCP and Websocket +// transports. This lets both transports use the same TCP port. +// +// Currently this behavior is Opt-in. In a future release this will be the +// default, and this option will be removed. +func ShareTCPListener() Option { + return func(cfg *Config) error { + cfg.ShareTCPListener = true + return nil + } +} diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index ed4f00ff58..d264fd1230 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -84,7 +84,7 @@ func makeSwarmWithNoListenAddrs(t *testing.T, opts ...Option) *Swarm { upgrader := makeUpgrader(t, s) var tcpOpts []tcp.Option tcpOpts = append(tcpOpts, tcp.DisableReuseport()) - tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...) require.NoError(t, err) if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) diff --git a/p2p/net/swarm/swarm_addr_test.go b/p2p/net/swarm/swarm_addr_test.go index 435866e920..43e76716e5 100644 --- a/p2p/net/swarm/swarm_addr_test.go +++ b/p2p/net/swarm/swarm_addr_test.go @@ -79,7 +79,7 @@ func TestDialAddressSelection(t *testing.T) { s, err := swarm.NewSwarm("local", nil, eventbus.NewBus()) require.NoError(t, err) - tcpTr, err := tcp.NewTCPTransport(nil, nil) + tcpTr, err := tcp.NewTCPTransport(nil, nil, nil) require.NoError(t, err) require.NoError(t, s.AddTransport(tcpTr)) reuse, err := quicreuse.NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}) diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 0ef43cf62e..add6f5cbba 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -53,7 +53,7 @@ func TestAddrsForDial(t *testing.T) { ps.AddPrivKey(id, priv) t.Cleanup(func() { ps.Close() }) - tpt, err := websocket.New(nil, &network.NullResourceManager{}) + tpt, err := websocket.New(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(ResolverFromMaDNS{resolver})) require.NoError(t, err) @@ -100,7 +100,7 @@ func TestDedupAddrsForDial(t *testing.T) { require.NoError(t, err) defer s.Close() - tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(tpt) require.NoError(t, err) @@ -134,7 +134,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { }) // Add a tcp transport so that we know we can dial a tcp multiaddr and we don't filter it out. - tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}) + tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(tpt) require.NoError(t, err) @@ -151,7 +151,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm { err = s.AddTransport(wtTpt) require.NoError(t, err) - wsTpt, err := websocket.New(nil, &network.NullResourceManager{}) + wsTpt, err := websocket.New(nil, &network.NullResourceManager{}, nil) require.NoError(t, err) err = s.AddTransport(wsTpt) require.NoError(t, err) diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index 2bbe8b27a5..773314a1b8 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -164,7 +164,7 @@ func GenSwarm(t testing.TB, opts ...Option) *swarm.Swarm { if cfg.disableReuseport { tcpOpts = append(tcpOpts, tcp.DisableReuseport()) } - tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) + tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...) require.NoError(t, err) if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 8af2791b36..c2e81d2e93 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -84,23 +84,33 @@ func (l *listener) handleIncoming() { } catcher.Reset() - // gate the connection if applicable - if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { - log.Debugf("gater blocked incoming connection on local addr %s from %s", - maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) - if err := maconn.Close(); err != nil { - log.Warnf("failed to close incoming connection rejected by gater: %s", err) - } - continue + // Check if we already have a connection scope. See the comment in tcpreuse/listener.go for an explanation. + var connScope network.ConnManagementScope + if sc, ok := maconn.(interface { + Scope() network.ConnManagementScope + }); ok { + connScope = sc.Scope() } + if connScope == nil { + // gate the connection if applicable + if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + maconn.LocalMultiaddr(), maconn.RemoteMultiaddr()) + if err := maconn.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } - connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr()) - if err != nil { - log.Debugw("resource manager blocked accept of new connection", "error", err) - if err := maconn.Close(); err != nil { - log.Warnf("failed to incoming connection rejected by resource manager: %s", err) + var err error + connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr()) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := maconn.Close(); err != nil { + log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err) + } + continue } - continue } // The go routine below calls Release when the context is diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index e5d32b0c96..f6b63e32de 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -60,7 +60,7 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u upgrader := swarmt.GenUpgrader(t, netw, nil) upgraders = append(upgraders, upgrader) - tpt, err := tcp.NewTCPTransport(upgrader, nil) + tpt, err := tcp.NewTCPTransport(upgrader, nil, nil) if err != nil { t.Fatal(err) } diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index df53da6eeb..99ce67b521 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -2,6 +2,8 @@ package transport_integration import ( "context" + "encoding/binary" + "net/netip" "strings" "testing" "time" @@ -30,6 +32,23 @@ func stripCertHash(addr ma.Multiaddr) ma.Multiaddr { return addr } +func addrPort(addr ma.Multiaddr) netip.AddrPort { + a := netip.Addr{} + p := uint16(0) + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 { + a, _ = netip.AddrFromSlice(c.RawValue()) + return false + } + if c.Protocol().Code == ma.P_UDP || c.Protocol().Code == ma.P_TCP { + p = binary.BigEndian.Uint16(c.RawValue()) + return true + } + return false + }) + return netip.AddrPortFrom(a, p) +} + func TestInterceptPeerDial(t *testing.T) { if race.WithRace() { t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") @@ -173,10 +192,14 @@ func TestInterceptAccept(t *testing.T) { // remove the certhash component from WebTransport addresses require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) }).AnyTimes() + } else if strings.Contains(tc.Name, "WebSocket-Shared") { + connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { + require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr())) + }) } else { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr()) + require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr()) }) } diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7cfab5f3ca..60f8ca0c06 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -99,6 +99,38 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "TCP-Shared / TLS / Yamux", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) + libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New)) + libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport)) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, + { + Name: "WebSocket-Shared", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, { Name: "WebSocket", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { diff --git a/p2p/transport/tcp/metrics_unix_test.go b/p2p/transport/tcp/metrics_unix_test.go new file mode 100644 index 0000000000..0a09526206 --- /dev/null +++ b/p2p/transport/tcp/metrics_unix_test.go @@ -0,0 +1,54 @@ +// go:build: unix + +package tcp + +import ( + "testing" + + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" + ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" + + "github.com/stretchr/testify/require" +) + +func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) { + + peerA, ia := makeInsecureMuxer(t) + _, ib := makeInsecureMuxer(t) + + sharedTCPSocketA := tcpreuse.NewConnMgr(false, nil, nil) + sharedTCPSocketB := tcpreuse.NewConnMgr(false, nil, nil) + + ua, err := tptu.New(ia, muxers, nil, nil, nil) + require.NoError(t, err) + ta, err := NewTCPTransport(ua, nil, sharedTCPSocketA, WithMetrics()) + require.NoError(t, err) + ub, err := tptu.New(ib, muxers, nil, nil, nil) + require.NoError(t, err) + tb, err := NewTCPTransport(ub, nil, sharedTCPSocketB, WithMetrics()) + require.NoError(t, err) + + zero := "/ip4/127.0.0.1/tcp/0" + + // Not running any test that needs more than 1 conn because the testsuite + // opens multiple conns via multiple listeners, which is not expected to work + // with the shared TCP socket. + subtestsToRun := []ttransport.TransportSubTestFn{ + ttransport.SubtestProtocols, + ttransport.SubtestBasic, + ttransport.SubtestCancel, + ttransport.SubtestPingPong, + + // Stolen from the stream muxer test suite. + ttransport.SubtestStress1Conn1Stream1Msg, + ttransport.SubtestStress1Conn1Stream100Msg, + ttransport.SubtestStress1Conn100Stream100Msg, + ttransport.SubtestStress1Conn1000Stream10Msg, + ttransport.SubtestStress1Conn100Stream100Msg10MB, + ttransport.SubtestStreamOpenStress, + ttransport.SubtestStreamReset, + } + + ttransport.SubtestTransportWithFs(t, ta, tb, zero, peerA, subtestsToRun) +} diff --git a/p2p/transport/tcp/tcp.go b/p2p/transport/tcp/tcp.go index d52bb96019..c80723436e 100644 --- a/p2p/transport/tcp/tcp.go +++ b/p2p/transport/tcp/tcp.go @@ -13,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/net/reuseport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" @@ -33,6 +34,9 @@ type canKeepAlive interface { var _ canKeepAlive = &net.TCPConn{} +// Deprecated: Use tcpreuse.ReuseportIsAvailable +var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable + func tryKeepAlive(conn net.Conn, keepAlive bool) { keepAliveConn, ok := conn.(canKeepAlive) if !ok { @@ -122,6 +126,9 @@ type TcpTransport struct { disableReuseport bool // Explicitly disable reuseport. enableMetrics bool + // share and demultiplex TCP listeners across multiple transports + sharedTcp *tcpreuse.ConnMgr + // TCP connect timeout connectTimeout time.Duration @@ -134,8 +141,8 @@ var _ transport.Transport = &TcpTransport{} var _ transport.DialUpdater = &TcpTransport{} // NewTCPTransport creates a tcp transport object that tracks dialers and listeners -// created. It represents an entire TCP stack (though it might not necessarily be). -func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) { +// created. +func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*TcpTransport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } @@ -143,6 +150,7 @@ func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, upgrader: upgrader, connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option rcmgr: rcmgr, + sharedTcp: sharedTCP, } for _, o := range opts { if err := o(tr); err != nil { @@ -168,6 +176,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co defer cancel() } + if t.sharedTcp != nil { + return t.sharedTcp.DialContext(ctx, raddr) + } + if t.UseReuseport() { return t.reuse.DialContext(ctx, raddr) } @@ -233,10 +245,10 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p // UseReuseport returns true if reuseport is enabled and available. func (t *TcpTransport) UseReuseport() bool { - return !t.disableReuseport && ReuseportIsAvailable() + return !t.disableReuseport && tcpreuse.ReuseportIsAvailable() } -func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) { +func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, error) { if t.UseReuseport() { return t.reuse.Listen(laddr) } @@ -245,10 +257,18 @@ func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) { // Listen listens on the given multiaddr. func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { - list, err := t.maListen(laddr) + var list manet.Listener + var err error + + if t.sharedTcp == nil { + list, err = t.unsharedMAListen(laddr) + } else { + list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect) + } if err != nil { return nil, err } + if t.enableMetrics { list = newTracingListener(&tcpListener{list, 0}) } diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index a57a65e420..1f939d92be 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -14,6 +14,7 @@ import ( "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" ma "github.com/multiformats/go-multiaddr" @@ -31,19 +32,19 @@ func TestTcpTransport(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil) + tb, err := NewTCPTransport(ub, nil, nil) require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" ttransport.SubtestTransport(t, ta, tb, zero, peerA) - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestTcpTransportWithMetrics(t *testing.T) { @@ -52,11 +53,11 @@ func TestTcpTransportWithMetrics(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil, WithMetrics()) + ta, err := NewTCPTransport(ua, nil, nil, WithMetrics()) require.NoError(t, err) ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil, WithMetrics()) + tb, err := NewTCPTransport(ub, nil, nil, WithMetrics()) require.NoError(t, err) zero := "/ip4/127.0.0.1/tcp/0" @@ -72,7 +73,7 @@ func TestResourceManager(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) require.NoError(t, err) @@ -81,7 +82,7 @@ func TestResourceManager(t *testing.T) { ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) rcmgr := mocknetwork.NewMockResourceManager(ctrl) - tb, err := NewTCPTransport(ub, rcmgr) + tb, err := NewTCPTransport(ub, rcmgr, nil) require.NoError(t, err) t.Run("success", func(t *testing.T) { @@ -119,16 +120,16 @@ func TestTcpTransportCantDialDNS(t *testing.T) { require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u, nil) + tpt, err := NewTCPTransport(u, nil, nil) require.NoError(t, err) if tpt.CanDial(dnsa) { t.Fatal("shouldn't be able to dial dns") } - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestTcpTransportCantListenUtp(t *testing.T) { @@ -137,15 +138,15 @@ func TestTcpTransportCantListenUtp(t *testing.T) { require.NoError(t, err) var u transport.Upgrader - tpt, err := NewTCPTransport(u, nil) + tpt, err := NewTCPTransport(u, nil, nil) require.NoError(t, err) _, err = tpt.Listen(utpa) require.Error(t, err, "shouldn't be able to listen on utp addr with tcp transport") - envReuseportVal = false + tcpreuse.EnvReuseportVal = false } - envReuseportVal = true + tcpreuse.EnvReuseportVal = true } func TestDialWithUpdates(t *testing.T) { @@ -154,7 +155,7 @@ func TestDialWithUpdates(t *testing.T) { ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) - ta, err := NewTCPTransport(ua, nil) + ta, err := NewTCPTransport(ua, nil, nil) require.NoError(t, err) ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) require.NoError(t, err) @@ -162,7 +163,7 @@ func TestDialWithUpdates(t *testing.T) { ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) - tb, err := NewTCPTransport(ub, nil) + tb, err := NewTCPTransport(ub, nil, nil) require.NoError(t, err) updCh := make(chan transport.DialUpdate, 1) diff --git a/p2p/transport/tcpreuse/connwithscope.go b/p2p/transport/tcpreuse/connwithscope.go new file mode 100644 index 0000000000..ca66f20325 --- /dev/null +++ b/p2p/transport/tcpreuse/connwithscope.go @@ -0,0 +1,26 @@ +package tcpreuse + +import ( + "fmt" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn" + manet "github.com/multiformats/go-multiaddr/net" +) + +type connWithScope struct { + sampledconn.ManetTCPConnInterface + scope network.ConnManagementScope +} + +func (c connWithScope) Scope() network.ConnManagementScope { + return c.scope +} + +func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) { + if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok { + return &connWithScope{tcpconn, scope}, nil + } + + return nil, fmt.Errorf("manet.Conn is not a TCP Conn") +} diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go new file mode 100644 index 0000000000..fe58243d67 --- /dev/null +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -0,0 +1,97 @@ +package tcpreuse + +import ( + "errors" + "fmt" + "time" + + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn" + manet "github.com/multiformats/go-multiaddr/net" +) + +// This is readiung the first 3 bytes of the packet. It should be instant. +const identifyConnTimeout = 1 * time.Second + +type DemultiplexedConnType int + +const ( + DemultiplexedConnType_Unknown DemultiplexedConnType = iota + DemultiplexedConnType_MultistreamSelect + DemultiplexedConnType_HTTP + DemultiplexedConnType_TLS +) + +func (t DemultiplexedConnType) String() string { + switch t { + case DemultiplexedConnType_MultistreamSelect: + return "MultistreamSelect" + case DemultiplexedConnType_HTTP: + return "HTTP" + case DemultiplexedConnType_TLS: + return "TLS" + default: + return fmt.Sprintf("Unknown(%d)", int(t)) + } +} + +func (t DemultiplexedConnType) IsKnown() bool { + return t >= 1 || t <= 3 +} + +// identifyConnType attempts to identify the connection type by peeking at the +// first few bytes. +// It Callers must not use the passed in Conn after this +// function returns. if an error is returned, the connection will be closed. +func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) { + if err := c.SetReadDeadline(time.Now().Add(identifyConnTimeout)); err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + s, c, err := sampledconn.PeekBytes(c) + if err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + if err := c.SetReadDeadline(time.Time{}); err != nil { + closeErr := c.Close() + return 0, nil, errors.Join(err, closeErr) + } + + if IsMultistreamSelect(s) { + return DemultiplexedConnType_MultistreamSelect, c, nil + } + if IsTLS(s) { + return DemultiplexedConnType_TLS, c, nil + } + if IsHTTP(s) { + return DemultiplexedConnType_HTTP, c, nil + } + return DemultiplexedConnType_Unknown, c, nil +} + +// Matchers are implemented here instead of in the transports so we can easily fuzz them together. +type Prefix = [3]byte + +func IsMultistreamSelect(s Prefix) bool { + return string(s[:]) == "\x13/m" +} + +func IsHTTP(s Prefix) bool { + switch string(s[:]) { + case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT": + return true + default: + return false + } +} + +func IsTLS(s Prefix) bool { + switch string(s[:]) { + case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03": + return true + default: + return false + } +} diff --git a/p2p/transport/tcpreuse/demultiplex_test.go b/p2p/transport/tcpreuse/demultiplex_test.go new file mode 100644 index 0000000000..e201f2ca75 --- /dev/null +++ b/p2p/transport/tcpreuse/demultiplex_test.go @@ -0,0 +1,50 @@ +package tcpreuse + +import "testing" + +func FuzzClash(f *testing.F) { + // make untyped literals type correctly + add := func(a, b, c byte) { f.Add(a, b, c) } + + // multistream-select + add('\x13', '/', 'm') + // http + add('G', 'E', 'T') + add('H', 'E', 'A') + add('P', 'O', 'S') + add('P', 'U', 'T') + add('D', 'E', 'L') + add('C', 'O', 'N') + add('O', 'P', 'T') + add('T', 'R', 'A') + add('P', 'A', 'T') + // tls + add('\x16', '\x03', '\x01') + add('\x16', '\x03', '\x02') + add('\x16', '\x03', '\x03') + add('\x16', '\x03', '\x04') + + f.Fuzz(func(t *testing.T, a, b, c byte) { + s := Prefix{a, b, c} + var total uint + + ms := IsMultistreamSelect(s) + if ms { + total++ + } + + http := IsHTTP(s) + if http { + total++ + } + + tls := IsTLS(s) + if tls { + total++ + } + + if total > 1 { + t.Errorf("clash on: %q; ms: %v; http: %v; tls: %v", s, ms, http, tls) + } + }) +} diff --git a/p2p/transport/tcpreuse/dialer.go b/p2p/transport/tcpreuse/dialer.go new file mode 100644 index 0000000000..ad634583ed --- /dev/null +++ b/p2p/transport/tcpreuse/dialer.go @@ -0,0 +1,16 @@ +package tcpreuse + +import ( + "context" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +// DialContext is like Dial but takes a context. +func (t *ConnMgr) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) { + if t.useReuseport() { + return t.reuse.DialContext(ctx, raddr) + } + var d manet.Dialer + return d.DialContext(ctx, raddr) +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go new file mode 100644 index 0000000000..7324b45849 --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_common.go @@ -0,0 +1,89 @@ +package sampledconn + +import ( + "errors" + "io" + "net" + "syscall" + "time" + + manet "github.com/multiformats/go-multiaddr/net" +) + +const peekSize = 3 + +type PeekedBytes = [peekSize]byte + +var errNotSupported = errors.New("not supported on this platform") + +var ErrNotTCPConn = errors.New("passed conn is not a TCPConn") + +func PeekBytes(conn manet.Conn) (PeekedBytes, manet.Conn, error) { + if c, ok := conn.(syscall.Conn); ok { + b, err := OSPeekConn(c) + if err == nil { + return b, conn, nil + } + if err != errNotSupported { + return PeekedBytes{}, nil, err + } + // Fallback to wrapping the coonn + } + + if c, ok := conn.(ManetTCPConnInterface); ok { + return newFallbackSampledConn(c) + } + + return PeekedBytes{}, nil, ErrNotTCPConn +} + +type fallbackPeekingConn struct { + ManetTCPConnInterface + peekedBytes PeekedBytes + bytesPeeked uint8 +} + +// tcpConnInterface is the interface for TCPConn's functions +// NOTE: `SyscallConn() (syscall.RawConn, error)` is here to make using this as +// a TCP Conn easier, but it's a potential footgun as you could skipped the +// peeked bytes if using the fallback +type tcpConnInterface interface { + net.Conn + syscall.Conn + + CloseRead() error + CloseWrite() error + + SetLinger(sec int) error + SetKeepAlive(keepalive bool) error + SetKeepAlivePeriod(d time.Duration) error + SetNoDelay(noDelay bool) error + MultipathTCP() (bool, error) + + io.ReaderFrom + io.WriterTo +} + +type ManetTCPConnInterface interface { + manet.Conn + tcpConnInterface +} + +func newFallbackSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *fallbackPeekingConn, error) { + s := &fallbackPeekingConn{ManetTCPConnInterface: conn} + _, err := io.ReadFull(conn, s.peekedBytes[:]) + if err != nil { + return s.peekedBytes, nil, err + } + return s.peekedBytes, s, nil +} + +func (sc *fallbackPeekingConn) Read(b []byte) (int, error) { + if int(sc.bytesPeeked) != len(sc.peekedBytes) { + red := copy(b, sc.peekedBytes[sc.bytesPeeked:]) + sc.bytesPeeked += uint8(red) + return red, nil + } + + return sc.ManetTCPConnInterface.Read(b) +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go new file mode 100644 index 0000000000..7386112395 --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go @@ -0,0 +1,11 @@ +//go:build !unix && !windows + +package sampledconn + +import ( + "syscall" +) + +func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { + return PeekedBytes{}, errNotSupported +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go new file mode 100644 index 0000000000..a7c5a65f33 --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_test.go @@ -0,0 +1,78 @@ +package sampledconn + +import ( + "io" + "syscall" + "testing" + "time" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + + "github.com/stretchr/testify/assert" +) + +func TestSampledConn(t *testing.T) { + testCases := []string{ + "platform", + "fallback", + } + + // Start a TCP server + listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")) + assert.NoError(t, err) + defer listener.Close() + + serverAddr := listener.Multiaddr() + + // Server goroutine + go func() { + for i := 0; i < len(testCases); i++ { + conn, err := listener.Accept() + assert.NoError(t, err) + defer conn.Close() + + // Write some data to the connection + _, err = conn.Write([]byte("hello")) + assert.NoError(t, err) + } + }() + + // Give the server a moment to start + time.Sleep(100 * time.Millisecond) + + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + // Create a TCP client + clientConn, err := manet.Dial(serverAddr) + assert.NoError(t, err) + defer clientConn.Close() + + if tc == "platform" { + // Wrap the client connection in SampledConn + peeked, clientConn, err := PeekBytes(clientConn.(interface { + manet.Conn + syscall.Conn + })) + assert.NoError(t, err) + assert.Equal(t, "hel", string(peeked[:])) + + buf := make([]byte, 5) + _, err = clientConn.Read(buf) + assert.NoError(t, err) + assert.Equal(t, "hello", string(buf)) + } else { + // Wrap the client connection in SampledConn + sample, sampledConn, err := newFallbackSampledConn(clientConn.(ManetTCPConnInterface)) + assert.NoError(t, err) + assert.Equal(t, "hel", string(sample[:])) + + buf := make([]byte, 5) + _, err = io.ReadFull(sampledConn, buf) + assert.NoError(t, err) + assert.Equal(t, "hello", string(buf)) + + } + }) + } +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go new file mode 100644 index 0000000000..9847e8d4be --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_unix.go @@ -0,0 +1,42 @@ +//go:build unix + +package sampledconn + +import ( + "errors" + "syscall" +) + +func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { + s := PeekedBytes{} + + rawConn, err := conn.SyscallConn() + if err != nil { + return s, err + } + + readBytes := 0 + var readErr error + err = rawConn.Read(func(fd uintptr) bool { + for readBytes < peekSize { + var n int + n, _, readErr = syscall.Recvfrom(int(fd), s[readBytes:], syscall.MSG_PEEK) + if errors.Is(readErr, syscall.EAGAIN) { + return false + } + if readErr != nil { + return true + } + readBytes += n + } + return true + }) + if readErr != nil { + return s, readErr + } + if err != nil { + return s, err + } + + return s, nil +} diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go new file mode 100644 index 0000000000..46b0617996 --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go @@ -0,0 +1,49 @@ +//go:build windows + +package sampledconn + +import ( + "errors" + "golang.org/x/sys/windows" + "syscall" +) + +func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { + s := PeekedBytes{} + + rawConn, err := conn.SyscallConn() + if err != nil { + return s, err + } + + readBytes := 0 + var readErr error + err = rawConn.Read(func(fd uintptr) bool { + for readBytes < peekSize { + var n uint32 + flags := uint32(windows.MSG_PEEK) + wsabuf := windows.WSABuf{ + Len: uint32(len(s) - readBytes), + Buf: &s[readBytes], + } + + readErr = windows.WSARecv(windows.Handle(fd), &wsabuf, 1, &n, &flags, nil, nil) + if errors.Is(readErr, windows.WSAEWOULDBLOCK) { + return false + } + if readErr != nil { + return true + } + readBytes += int(n) + } + return true + }) + if readErr != nil { + return s, readErr + } + if err != nil { + return s, err + } + + return s, nil +} diff --git a/p2p/transport/tcpreuse/listener.go b/p2p/transport/tcpreuse/listener.go new file mode 100644 index 0000000000..326e1e15b7 --- /dev/null +++ b/p2p/transport/tcpreuse/listener.go @@ -0,0 +1,329 @@ +package tcpreuse + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/connmgr" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/net/reuseport" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +const acceptQueueSize = 64 // It is fine to read 3 bytes from 64 connections in parallel. + +// How long we wait for a connection to be accepted before dropping it. +const acceptTimeout = 30 * time.Second + +var log = logging.Logger("tcp-demultiplex") + +// ConnMgr enables you to share the same listen address between TCP and WebSocket transports. +type ConnMgr struct { + enableReuseport bool + reuse reuseport.Transport + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager + + mx sync.Mutex + listeners map[string]*multiplexedListener +} + +func NewConnMgr(enableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } + return &ConnMgr{ + enableReuseport: enableReuseport, + reuse: reuseport.Transport{}, + connGater: gater, + rcmgr: rcmgr, + listeners: make(map[string]*multiplexedListener), + } +} + +func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) { + if t.useReuseport() { + return t.reuse.Listen(listenAddr) + } else { + return manet.Listen(listenAddr) + } +} + +func (t *ConnMgr) useReuseport() bool { + return t.enableReuseport && ReuseportIsAvailable() +} + +func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) { + haveTCP := false + addr, _ := ma.SplitFunc(listenAddr, func(c ma.Component) bool { + if haveTCP { + return true + } + if c.Protocol().Code == ma.P_TCP { + haveTCP = true + } + return false + }) + if !haveTCP { + return nil, fmt.Errorf("invalid listen addr %s, need tcp address", listenAddr) + } + return addr, nil +} + +// DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections +// accepted from returned listeners need to be upgraded with a `transport.Upgrader`. +// NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port. +func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) { + if !connType.IsKnown() { + return nil, fmt.Errorf("unknown connection type: %s", connType) + } + laddr, err := getTCPAddr(laddr) + if err != nil { + return nil, err + } + + t.mx.Lock() + defer t.mx.Unlock() + ml, ok := t.listeners[laddr.String()] + if ok { + dl, err := ml.DemultiplexedListen(connType) + if err != nil { + return nil, err + } + return dl, nil + } + + l, err := t.maListen(laddr) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(context.Background()) + cancelFunc := func() error { + cancel() + t.mx.Lock() + defer t.mx.Unlock() + delete(t.listeners, laddr.String()) + delete(t.listeners, l.Multiaddr().String()) + return l.Close() + } + ml = &multiplexedListener{ + Listener: l, + listeners: make(map[DemultiplexedConnType]*demultiplexedListener), + ctx: ctx, + closeFn: cancelFunc, + connGater: t.connGater, + rcmgr: t.rcmgr, + } + t.listeners[laddr.String()] = ml + t.listeners[l.Multiaddr().String()] = ml + + dl, err := ml.DemultiplexedListen(connType) + if err != nil { + cerr := ml.Close() + return nil, errors.Join(err, cerr) + } + + ml.wg.Add(1) + go ml.run() + + return dl, nil +} + +var _ manet.Listener = &demultiplexedListener{} + +type multiplexedListener struct { + manet.Listener + listeners map[DemultiplexedConnType]*demultiplexedListener + mx sync.RWMutex + + connGater connmgr.ConnectionGater + rcmgr network.ResourceManager + ctx context.Context + closeFn func() error + wg sync.WaitGroup +} + +var ErrListenerExists = errors.New("listener already exists for this conn type on this address") + +func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { + if !connType.IsKnown() { + return nil, fmt.Errorf("unknown connection type: %s", connType) + } + + m.mx.Lock() + defer m.mx.Unlock() + if _, ok := m.listeners[connType]; ok { + return nil, ErrListenerExists + } + + ctx, cancel := context.WithCancel(m.ctx) + l := &demultiplexedListener{ + buffer: make(chan manet.Conn), + inner: m.Listener, + ctx: ctx, + cancelFunc: cancel, + closeFn: func() error { m.removeDemultiplexedListener(connType); return nil }, + } + + m.listeners[connType] = l + + return l, nil +} + +func (m *multiplexedListener) run() error { + defer m.Close() + defer m.wg.Done() + acceptQueue := make(chan struct{}, acceptQueueSize) + for { + c, err := m.Listener.Accept() + if err != nil { + return err + } + + // Gate and resource limit the connection here. + // If done after sampling the connection, we'll be vulnerable to DOS attacks by a single peer + // which clogs up our entire connection queue. + // This duplicates the responsibility of gating and resource limiting between here and the upgrader. The + // alternative without duplication requires moving the process of upgrading the connection here, which forces + // us to establish the websocket connection here. That is more duplication, or a significant breaking change. + // + // Bugs around multiple calls to OpenConnection or InterceptAccept are prevented by the transport + // integration tests. + if m.connGater != nil && !m.connGater.InterceptAccept(c) { + log.Debugf("gater blocked incoming connection on local addr %s from %s", + c.LocalMultiaddr(), c.RemoteMultiaddr()) + if err := c.Close(); err != nil { + log.Warnf("failed to close incoming connection rejected by gater: %s", err) + } + continue + } + connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr()) + if err != nil { + log.Debugw("resource manager blocked accept of new connection", "error", err) + if err := c.Close(); err != nil { + log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err) + } + continue + } + + select { + case acceptQueue <- struct{}{}: + // NOTE: We can drop the connection, but this is similar to the behaviour in the upgrader. + case <-m.ctx.Done(): + c.Close() + log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr()) + } + + m.wg.Add(1) + go func() { + defer func() { <-acceptQueue }() + defer m.wg.Done() + ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout) + defer cancelCtx() + t, c, err := identifyConnType(c) + if err != nil { + connScope.Done() + closeErr := c.Close() + err = errors.Join(err, closeErr) + log.Debugf("error demultiplexing connection: %s", err.Error()) + return + } + + connWithScope, err := manetConnWithScope(c, connScope) + if err != nil { + connScope.Done() + closeErr := c.Close() + err = errors.Join(err, closeErr) + log.Debugf("error wrapping connection with scope: %s", err.Error()) + return + } + + m.mx.RLock() + demux, ok := m.listeners[t] + m.mx.RUnlock() + if !ok { + closeErr := connWithScope.Close() + if closeErr != nil { + log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error()) + } else { + log.Debugf("no registered listener for demultiplex connection %s", t) + } + return + } + + select { + case demux.buffer <- connWithScope: + case <-ctx.Done(): + connWithScope.Close() + } + }() + } +} + +func (m *multiplexedListener) Close() error { + m.mx.Lock() + for _, l := range m.listeners { + l.cancelFunc() + } + err := m.closeListener() + m.mx.Unlock() + m.wg.Wait() + return err +} + +func (m *multiplexedListener) closeListener() error { + lerr := m.Listener.Close() + cerr := m.closeFn() + return errors.Join(lerr, cerr) +} + +func (m *multiplexedListener) removeDemultiplexedListener(c DemultiplexedConnType) { + m.mx.Lock() + defer m.mx.Unlock() + + delete(m.listeners, c) + if len(m.listeners) == 0 { + m.closeListener() + m.mx.Unlock() + m.wg.Wait() + m.mx.Lock() + } +} + +type demultiplexedListener struct { + buffer chan manet.Conn + inner manet.Listener + ctx context.Context + cancelFunc context.CancelFunc + closeFn func() error +} + +func (m *demultiplexedListener) Accept() (manet.Conn, error) { + select { + case c := <-m.buffer: + return c, nil + case <-m.ctx.Done(): + return nil, transport.ErrListenerClosed + } +} + +func (m *demultiplexedListener) Close() error { + m.cancelFunc() + return m.closeFn() +} + +func (m *demultiplexedListener) Multiaddr() ma.Multiaddr { + return m.inner.Multiaddr() +} + +func (m *demultiplexedListener) Addr() net.Addr { + return m.inner.Addr() +} diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go new file mode 100644 index 0000000000..bdb030a676 --- /dev/null +++ b/p2p/transport/tcpreuse/listener_test.go @@ -0,0 +1,449 @@ +package tcpreuse + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multistream" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func selfSignedTLSConfig(t *testing.T) *tls.Config { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + certTemplate := x509.Certificate{ + SerialNumber: &big.Int{}, + Subject: pkix.Name{ + Organization: []string{"Test"}, + }, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv) + require.NoError(t, err) + + cert := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + return tlsConfig +} + +type wsHandler struct{ conns chan *websocket.Conn } + +func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + u := websocket.Upgrader{} + c, _ := u.Upgrade(w, r, http.Header{}) + wh.conns <- c +} + +func TestListenerSingle(t *testing.T) { + listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") + const N = 64 + for _, enableReuseport := range []bool{true, false} { + t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) { + cm := NewConnMgr(enableReuseport, nil, nil) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + go func() { + d := net.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String()) + if err != nil { + t.Error("failed to dial", err, i) + return + } + lconn := multistream.NewMSSelect(conn, "a") + buf := make([]byte, 10) + _, err = lconn.Write([]byte("hello-multistream")) + if err != nil { + t.Error(err) + } + _, err = lconn.Read(buf) + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + var wg sync.WaitGroup + for i := 0; i < N; i++ { + c, err := l.Accept() + require.NoError(t, err) + wg.Add(1) + go func() { + defer wg.Done() + cc := multistream.NewMSSelect(c, "a") + defer cc.Close() + buf := make([]byte, 30) + n, err := cc.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "hello-multistream", string(buf[:n])) { + return + } + }() + } + wg.Wait() + }) + + t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) { + cm := NewConnMgr(enableReuseport, nil, nil) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + http.Serve(manet.NetListener(l), wh) + }() + go func() { + d := websocket.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", l.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("hello")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + var wg sync.WaitGroup + for i := 0; i < N; i++ { + c := <-wh.conns + wg.Add(1) + go func() { + defer wg.Done() + defer c.Close() + msgType, buf, err := c.ReadMessage() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "hello", string(buf)) { + return + } + }() + } + wg.Wait() + }) + + t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) { + cm := NewConnMgr(enableReuseport, nil, nil) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) + require.NoError(t, err) + defer l.Close() + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)} + s.ServeTLS(manet.NetListener(l), "", "") + }() + go func() { + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", l.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("hello")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + var wg sync.WaitGroup + for i := 0; i < N; i++ { + c := <-wh.conns + wg.Add(1) + go func() { + defer wg.Done() + defer c.Close() + msgType, buf, err := c.ReadMessage() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "hello", string(buf)) { + return + } + }() + } + wg.Wait() + }) + } +} + +func TestListenerMultiplexed(t *testing.T) { + listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") + const N = 20 + for _, enableReuseport := range []bool{true, false} { + cm := NewConnMgr(enableReuseport, nil, nil) + msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + defer msl.Close() + + wsl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + defer wsl.Close() + require.Equal(t, wsl.Multiaddr(), msl.Multiaddr()) + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + http.Serve(manet.NetListener(wsl), wh) + }() + + wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) + require.NoError(t, err) + defer wssl.Close() + require.Equal(t, wssl.Multiaddr(), wsl.Multiaddr()) + whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)} + s.ServeTLS(manet.NetListener(wssl), "", "") + }() + + // multistream connections + go func() { + d := net.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := d.DialContext(ctx, msl.Addr().Network(), msl.Addr().String()) + if err != nil { + t.Error("failed to dial", err, i) + return + } + lconn := multistream.NewMSSelect(conn, "a") + buf := make([]byte, 10) + _, err = lconn.Write([]byte("multistream")) + if err != nil { + t.Error(err) + } + _, err = lconn.Read(buf) + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + // ws connections + go func() { + d := websocket.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", msl.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("websocket")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + // wss connections + go func() { + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", msl.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("websocket-tls")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + c, err := msl.Accept() + if !assert.NoError(t, err) { + return + } + wg.Add(1) + go func() { + defer wg.Done() + cc := multistream.NewMSSelect(c, "a") + defer cc.Close() + buf := make([]byte, 20) + n, err := cc.Read(buf) + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, "multistream", string(buf[:n])) { + return + } + }() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + c := <-wh.conns + wg.Add(1) + go func() { + defer wg.Done() + defer c.Close() + msgType, buf, err := c.ReadMessage() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "websocket", string(buf)) { + return + } + }() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + c := <-whs.conns + wg.Add(1) + go func() { + defer wg.Done() + defer c.Close() + msgType, buf, err := c.ReadMessage() + if !assert.NoError(t, err) { + return + } + if !assert.Equal(t, msgType, websocket.TextMessage) { + return + } + if !assert.Equal(t, "websocket-tls", string(buf)) { + return + } + }() + } + }() + wg.Wait() + } +} + +func TestListenerClose(t *testing.T) { + testClose := func(listenAddr ma.Multiaddr) { + // listen on port 0 + cm := NewConnMgr(false, nil, nil) + ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + wl.Close() + + wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + + ml.Close() + + mll, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + + mll.Close() + wl.Close() + + ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) + require.NoError(t, err) + + // Now listen on the specific port previously used + listenAddr = ml.Multiaddr() + wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + wl.Close() + + wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + require.Equal(t, wl.Multiaddr(), ml.Multiaddr()) + + ml.Close() + wl.Close() + } + listenAddrs := []ma.Multiaddr{ma.StringCast("/ip4/0.0.0.0/tcp/0"), ma.StringCast("/ip6/::/tcp/0")} + for _, listenAddr := range listenAddrs { + testClose(listenAddr) + } +} diff --git a/p2p/transport/tcp/reuseport.go b/p2p/transport/tcpreuse/reuseport.go similarity index 81% rename from p2p/transport/tcp/reuseport.go rename to p2p/transport/tcpreuse/reuseport.go index ba09304622..a2529c0bda 100644 --- a/p2p/transport/tcp/reuseport.go +++ b/p2p/transport/tcpreuse/reuseport.go @@ -1,4 +1,4 @@ -package tcp +package tcpreuse import ( "os" @@ -11,13 +11,13 @@ import ( // It default to true. const envReuseport = "LIBP2P_TCP_REUSEPORT" -// envReuseportVal stores the value of envReuseport. defaults to true. -var envReuseportVal = true +// EnvReuseportVal stores the value of envReuseport. defaults to true. +var EnvReuseportVal = true func init() { v := strings.ToLower(os.Getenv(envReuseport)) if v == "false" || v == "f" || v == "0" { - envReuseportVal = false + EnvReuseportVal = false log.Infof("REUSEPORT disabled (LIBP2P_TCP_REUSEPORT=%s)", v) } } @@ -31,5 +31,5 @@ func init() { // If this becomes a sought after feature, we could add this to the config. // In the end, reuseport is a stop-gap. func ReuseportIsAvailable() bool { - return envReuseportVal && reuseport.Available() + return EnvReuseportVal && reuseport.Available() } diff --git a/p2p/transport/testsuite/utils_suite.go b/p2p/transport/testsuite/utils_suite.go index 5e488397a5..8b002f8900 100644 --- a/p2p/transport/testsuite/utils_suite.go +++ b/p2p/transport/testsuite/utils_suite.go @@ -11,7 +11,9 @@ import ( ma "github.com/multiformats/go-multiaddr" ) -var Subtests = []func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID){ +type TransportSubTestFn func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID) + +var Subtests = []TransportSubTestFn{ SubtestProtocols, SubtestBasic, SubtestCancel, @@ -33,12 +35,17 @@ func getFunctionName(i interface{}) string { } func SubtestTransport(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID) { + t.Helper() + SubtestTransportWithFs(t, ta, tb, addr, peerA, Subtests) +} + +func SubtestTransportWithFs(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID, tests []TransportSubTestFn) { maddr, err := ma.NewMultiaddr(addr) if err != nil { t.Fatal(err) } - for _, f := range Subtests { + for _, f := range tests { t.Run(getFunctionName(f), func(t *testing.T) { f(t, ta, tb, maddr, peerA) }) diff --git a/p2p/transport/websocket/addrs_test.go b/p2p/transport/websocket/addrs_test.go index 3c5ba502a9..50a8b9e823 100644 --- a/p2p/transport/websocket/addrs_test.go +++ b/p2p/transport/websocket/addrs_test.go @@ -69,7 +69,7 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) { } func TestListeningOnDNSAddr(t *testing.T) { - ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil) + ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil) require.NoError(t, err) addr := ln.Multiaddr() first, rest := ma.SplitFirst(addr) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 30b70055d0..ce51611703 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -1,6 +1,7 @@ package websocket import ( + "errors" "io" "net" "sync" @@ -8,6 +9,8 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" ws "github.com/gorilla/websocket" ) @@ -22,20 +25,53 @@ type Conn struct { secure bool DefaultMessageType int reader io.Reader - closeOnce sync.Once + closeOnceVal func() error + laddr ma.Multiaddr + raddr ma.Multiaddr readLock, writeLock sync.Mutex } var _ net.Conn = (*Conn)(nil) +var _ manet.Conn = (*Conn)(nil) // NewConn creates a Conn given a regular gorilla/websocket Conn. +// +// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release. func NewConn(raw *ws.Conn, secure bool) *Conn { - return &Conn{ + lna := NewAddrWithScheme(raw.LocalAddr().String(), secure) + laddr, err := manet.FromNetAddr(lna) + if err != nil { + log.Errorf("BUG: invalid localaddr on websocket conn", raw.LocalAddr()) + return nil + } + + rna := NewAddrWithScheme(raw.RemoteAddr().String(), secure) + raddr, err := manet.FromNetAddr(rna) + if err != nil { + log.Errorf("BUG: invalid remoteaddr on websocket conn", raw.RemoteAddr()) + return nil + } + + c := &Conn{ Conn: raw, secure: secure, DefaultMessageType: ws.BinaryMessage, + laddr: laddr, + raddr: raddr, } + c.closeOnceVal = sync.OnceValue(c.closeOnceFn) + return c +} + +// LocalMultiaddr implements manet.Conn. +func (c *Conn) LocalMultiaddr() ma.Multiaddr { + return c.laddr +} + +// RemoteMultiaddr implements manet.Conn. +func (c *Conn) RemoteMultiaddr() ma.Multiaddr { + return c.raddr } func (c *Conn) Read(b []byte) (int, error) { @@ -99,26 +135,31 @@ func (c *Conn) Write(b []byte) (n int, err error) { return len(b), nil } -// Close closes the connection. Only the first call to Close will receive the -// close error, subsequent and concurrent calls will return nil. +func (c *Conn) Scope() network.ConnManagementScope { + nc := c.NetConn() + if sc, ok := nc.(interface { + Scope() network.ConnManagementScope + }); ok { + return sc.Scope() + } + return nil +} + +// Close closes the connection. +// subsequent and concurrent calls will return the same error value. // This method is thread-safe. func (c *Conn) Close() error { - var err error - c.closeOnce.Do(func() { - err1 := c.Conn.WriteControl( - ws.CloseMessage, - ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"), - time.Now().Add(GracefulCloseTimeout), - ) - err2 := c.Conn.Close() - switch { - case err1 != nil: - err = err1 - case err2 != nil: - err = err2 - } - }) - return err + return c.closeOnceVal() +} + +func (c *Conn) closeOnceFn() error { + err1 := c.Conn.WriteControl( + ws.CloseMessage, + ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"), + time.Now().Add(GracefulCloseTimeout), + ) + err2 := c.Conn.Close() + return errors.Join(err1, err2) } func (c *Conn) LocalAddr() net.Addr { diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index 8071ddb814..dd399aa079 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -4,14 +4,16 @@ import ( "crypto/tls" "errors" "fmt" - "go.uber.org/zap" "net" "net/http" "sync" + "go.uber.org/zap" + logging "github.com/ipfs/go-log/v2" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -50,7 +52,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). -func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { +func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) { parsed, err := parseWebsocketMultiaddr(a) if err != nil { return nil, err @@ -60,19 +62,36 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) } - lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) - if err != nil { - return nil, err - } - nl, err := net.Listen(lnet, lnaddr) - if err != nil { - return nil, err + var nl net.Listener + + if sharedTcp == nil { + lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) + if err != nil { + return nil, err + } + nl, err = net.Listen(lnet, lnaddr) + if err != nil { + return nil, err + } + } else { + var connType tcpreuse.DemultiplexedConnType + if parsed.isWSS { + connType = tcpreuse.DemultiplexedConnType_TLS + } else { + connType = tcpreuse.DemultiplexedConnType_HTTP + } + mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType) + if err != nil { + return nil, err + } + nl = manet.NetListener(mal) } laddr, err := manet.FromNetAddr(nl.Addr()) if err != nil { return nil, err } + first, _ := ma.SplitFirst(a) // Don't resolve dns addresses. // We want to be able to announce domain names, so the peer can validate the TLS certificate. @@ -111,7 +130,12 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { // The upgrader writes a response for us. return } - + nc := NewConn(c, l.isWss) + if nc == nil { + c.Close() + w.WriteHeader(500) + return + } select { case l.incoming <- NewConn(c, l.isWss): case <-l.closed: @@ -126,13 +150,7 @@ func (l *listener) Accept() (manet.Conn, error) { if !ok { return nil, transport.ErrListenerClosed } - - mnc, err := manet.WrapNetConn(c) - if err != nil { - c.Close() - return nil, err - } - return mnc, nil + return c, nil case <-l.closed: return nil, transport.ErrListenerClosed } diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 0f07617dc7..e24cb88c6d 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -87,11 +88,13 @@ type WebsocketTransport struct { tlsClientConf *tls.Config tlsConf *tls.Config + + sharedTcp *tcpreuse.ConnMgr } var _ transport.Transport = (*WebsocketTransport)(nil) -func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) { +func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*WebsocketTransport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } @@ -99,6 +102,7 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (* upgrader: u, rcmgr: rcmgr, tlsClientConf: &tls.Config{}, + sharedTcp: sharedTCP, } for _, opt := range opts { if err := opt(t); err != nil { @@ -233,7 +237,7 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { if t.tlsConf != nil { tlsConf = t.tlsConf.Clone() } - l, err := newListener(a, tlsConf) + l, err := newListener(a, tlsConf, t.sharedTcp) if err != nil { return nil, err } diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 8f912c4138..9ca03775a2 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -154,7 +154,7 @@ func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID } id, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSConfig(tlsConf)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSConfig(tlsConf)) if err != nil { t.Fatal(err) } @@ -237,7 +237,7 @@ func TestHostHeaderWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -256,7 +256,7 @@ func TestDialWss(t *testing.T) { tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig)) + tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig)) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -279,7 +279,7 @@ func TestDialWssNoClientCert(t *testing.T) { require.Contains(t, serverMA.String(), "tls") _, u := newSecureUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) masToDial, err := tpt.Resolve(context.Background(), serverMA) @@ -294,12 +294,12 @@ func TestDialWssNoClientCert(t *testing.T) { func TestWebsocketTransport(t *testing.T) { peerA, ua := newUpgrader(t) - ta, err := New(ua, nil) + ta, err := New(ua, nil, nil) if err != nil { t.Fatal(err) } _, ub := newUpgrader(t) - tb, err := New(ub, nil) + tb, err := New(ub, nil, nil) if err != nil { t.Fatal(err) } @@ -325,7 +325,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSConfig(tlsConf)) } server, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, opts...) + tpt, err := New(u, &network.NullResourceManager{}, nil, opts...) require.NoError(t, err) l, err := tpt.Listen(laddr) require.NoError(t, err) @@ -344,7 +344,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) } _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}, opts...) + tpt, err := New(u, &network.NullResourceManager{}, nil, opts...) require.NoError(t, err) c, err := tpt.Dial(context.Background(), l.Multiaddr(), server) require.NoError(t, err) @@ -382,7 +382,7 @@ func TestWebsocketConnection(t *testing.T) { func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss") _, err = tpt.Listen(addr) @@ -391,7 +391,7 @@ func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) { func TestWebsocketListenSecureAndInsecure(t *testing.T) { serverID, serverUpgrader := newUpgrader(t) - server, err := New(serverUpgrader, &network.NullResourceManager{}, WithTLSConfig(generateTLSConfig(t))) + server, err := New(serverUpgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t))) require.NoError(t, err) lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) @@ -401,7 +401,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("insecure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -418,7 +418,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { t.Run("secure", func(t *testing.T) { _, clientUpgrader := newUpgrader(t) - client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) + client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})) require.NoError(t, err) // dialing the insecure address should succeed @@ -436,7 +436,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { func TestConcurrentClose(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) require.NoError(t, err) l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) if err != nil { @@ -474,7 +474,7 @@ func TestConcurrentClose(t *testing.T) { func TestWriteZero(t *testing.T) { _, u := newUpgrader(t) - tpt, err := New(u, &network.NullResourceManager{}) + tpt, err := New(u, &network.NullResourceManager{}, nil) if err != nil { t.Fatal(err) }