Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: QUIC/Webtransport Transports now will prefer their owned listeners for dialing out #2936

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions libp2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"net/netip"
"regexp"
Expand Down Expand Up @@ -587,3 +588,70 @@ func TestWebRTCReuseAddrWithQUIC(t *testing.T) {
require.Contains(t, h1.Addrs()[0].String(), "quic-v1")
})
}

func TestUseCorrectTransportForDialOut(t *testing.T) {
listAddrOrder := [][]string{
{"/ip4/127.0.0.1/udp/0/quic-v1", "/ip4/127.0.0.1/udp/0/quic-v1/webtransport"},
{"/ip4/127.0.0.1/udp/0/quic-v1/webtransport", "/ip4/127.0.0.1/udp/0/quic-v1"},
{"/ip4/0.0.0.0/udp/0/quic-v1", "/ip4/0.0.0.0/udp/0/quic-v1/webtransport"},
{"/ip4/0.0.0.0/udp/0/quic-v1/webtransport", "/ip4/0.0.0.0/udp/0/quic-v1"},
}
for _, order := range listAddrOrder {
h1, err := New(ListenAddrStrings(order...), Transport(quic.NewTransport), Transport(webtransport.New))
require.NoError(t, err)
t.Cleanup(func() {
h1.Close()
})

go func() {
h1.SetStreamHandler("/echo-port", func(s network.Stream) {
m := s.Conn().RemoteMultiaddr()
v, err := m.ValueForProtocol(ma.P_UDP)
if err != nil {
s.Reset()
return
}
s.Write([]byte(v))
s.Close()
})
}()

for _, addr := range h1.Addrs() {
t.Run("order "+strings.Join(order, ",")+" Dial to "+addr.String(), func(t *testing.T) {
h2, err := New(ListenAddrStrings(
"/ip4/0.0.0.0/udp/0/quic-v1",
"/ip4/0.0.0.0/udp/0/quic-v1/webtransport",
), Transport(quic.NewTransport), Transport(webtransport.New))
require.NoError(t, err)
defer h2.Close()
t.Log("H2 Addrs", h2.Addrs())
var myExpectedDialOutAddr ma.Multiaddr
addrIsWT, _ := webtransport.IsWebtransportMultiaddr(addr)
isLocal := func(a ma.Multiaddr) bool {
return strings.Contains(a.String(), "127.0.0.1")
}
addrIsLocal := isLocal(addr)
for _, a := range h2.Addrs() {
aIsWT, _ := webtransport.IsWebtransportMultiaddr(a)
if addrIsWT == aIsWT && isLocal(a) == addrIsLocal {
myExpectedDialOutAddr = a
break
}
}

err = h2.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: []ma.Multiaddr{addr}})
require.NoError(t, err)

s, err := h2.NewStream(context.Background(), h1.ID(), "/echo-port")
require.NoError(t, err)

port, err := io.ReadAll(s)
require.NoError(t, err)

myExpectedPort, err := myExpectedDialOutAddr.ValueForProtocol(ma.P_UDP)
require.NoError(t, err)
require.Equal(t, myExpectedPort, string(port))
})
}
}
}
5 changes: 3 additions & 2 deletions p2p/transport/quic/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
}

tlsConf, keyCh := t.identity.ConfigForPeer(p)
ctx = quicreuse.WithAssociation(ctx, t)
pconn, err := t.connManager.DialQUIC(ctx, raddr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
Expand Down Expand Up @@ -196,7 +197,7 @@ func (t *transport) holePunch(ctx context.Context, raddr ma.Multiaddr, p peer.ID
if err != nil {
return nil, err
}
tr, err := t.connManager.TransportForDial(network, addr)
tr, err := t.connManager.TransportWithAssociationForDial(t, network, addr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -313,7 +314,7 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) {
return nil, fmt.Errorf("can't listen on quic version %v, underlying listener doesn't support it", version)
}
} else {
ln, err := t.connManager.ListenQUIC(addr, &tlsConf, t.allowWindowIncrease)
ln, err := t.connManager.ListenQUICAndAssociate(t, addr, &tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
}
Expand Down
38 changes: 33 additions & 5 deletions p2p/transport/quicreuse/connmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ func (c *ConnManager) getReuse(network string) (*reuse, error) {
}

func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) {
return c.ListenQUICAndAssociate(nil, addr, tlsConf, allowWindowIncrease)
}

// ListenQUICAndAssociate returns a QUIC listener and associates the underlying transport with the given association.
func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) {
netw, host, err := manet.DialArgs(addr)
if err != nil {
return nil, err
Expand All @@ -117,7 +122,7 @@ func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWi
key := laddr.String()
entry, ok := c.quicListeners[key]
if !ok {
tr, err := c.transportForListen(netw, laddr)
tr, err := c.transportForListen(association, netw, laddr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -176,13 +181,18 @@ func (c *ConnManager) SharedNonQUICPacketConn(network string, laddr *net.UDPAddr
return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set")
}

func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) {
func (c *ConnManager) transportForListen(association any, network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) {
if c.enableReuseport {
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return reuse.TransportForListen(network, laddr)
tr, err := reuse.TransportForListen(network, laddr)
if err != nil {
return nil, err
}
tr.associate(association)
return tr, nil
}

conn, err := net.ListenUDP(network, laddr)
Expand All @@ -199,6 +209,14 @@ func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (re
}, nil
}

type associationKey struct{}

// WithAssociation returns a new context with the given association. Used in
// DialQUIC to prefer a transport that has the given association.
func WithAssociation(ctx context.Context, association any) context.Context {
return context.WithValue(ctx, associationKey{}, association)
}

func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (quic.Connection, error) {
naddr, v, err := FromQuicMultiaddr(raddr)
if err != nil {
Expand All @@ -219,7 +237,12 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf
return nil, errors.New("unknown QUIC version")
}

tr, err := c.TransportForDial(netw, naddr)
var tr refCountedQuicTransport
if association := ctx.Value(associationKey{}); association != nil {
tr, err = c.TransportWithAssociationForDial(association, netw, naddr)
} else {
tr, err = c.TransportForDial(netw, naddr)
}
if err != nil {
return nil, err
}
Expand All @@ -232,12 +255,17 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf
}

func (c *ConnManager) TransportForDial(network string, raddr *net.UDPAddr) (refCountedQuicTransport, error) {
return c.TransportWithAssociationForDial(nil, network, raddr)
}

// TransportWithAssociationForDial returns a QUIC transport for dialing, preferring a transport with the given association.
func (c *ConnManager) TransportWithAssociationForDial(association any, network string, raddr *net.UDPAddr) (refCountedQuicTransport, error) {
if c.enableReuseport {
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return reuse.TransportForDial(network, raddr)
return reuse.transportWithAssociationForDial(association, network, raddr)
}

var laddr *net.UDPAddr
Expand Down
4 changes: 1 addition & 3 deletions p2p/transport/quicreuse/connmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ func testListenOnSameProto(t *testing.T, enableReuseport bool) {

const alpn = "proto"

var tlsConf tls.Config
tlsConf.NextProtos = []string{alpn}
ln1, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{alpn}}, nil)
require.NoError(t, err)
defer ln1.Close()
Expand Down Expand Up @@ -96,7 +94,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) {

_, err = cm.ListenQUIC(raddr, &tls.Config{NextProtos: []string{"proto"}}, nil)
require.NoError(t, err)
quicTr, err := cm.transportForListen(netw, naddr)
quicTr, err := cm.transportForListen(nil, netw, naddr)
require.NoError(t, err)
defer quicTr.Close()
if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok {
Expand Down
49 changes: 42 additions & 7 deletions p2p/transport/quicreuse/reuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,36 @@ type refcountedTransport struct {
mutex sync.Mutex
refCount int
unusedSince time.Time

assocations map[any]struct{}
}

// associate an arbitrary value with this transport.
// This lets us "tag" the refcountedTransport when listening so we can use it
// later for dialing. Necessary for holepunching and learning about our own
// observed listening address.
func (c *refcountedTransport) associate(a any) {
if a == nil {
return
}
c.mutex.Lock()
defer c.mutex.Unlock()
if c.assocations == nil {
c.assocations = make(map[any]struct{})
}
c.assocations[a] = struct{}{}
}

// hasAssociation returns true if the transport has the given association.
// If it is a nil association, it will always return true.
func (c *refcountedTransport) hasAssociation(a any) bool {
if a == nil {
return true
}
c.mutex.Lock()
defer c.mutex.Unlock()
_, ok := c.assocations[a]
return ok
}

func (c *refcountedTransport) IncreaseCount() {
Expand Down Expand Up @@ -204,7 +234,7 @@ func (r *reuse) gc() {
}
}

func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcountedTransport, error) {
func (r *reuse) transportWithAssociationForDial(association any, network string, raddr *net.UDPAddr) (*refcountedTransport, error) {
var ip *net.IP

// Only bother looking up the source address if we actually _have_ non 0.0.0.0 listeners.
Expand All @@ -224,29 +254,34 @@ func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcounte
r.mutex.Lock()
defer r.mutex.Unlock()

tr, err := r.transportForDialLocked(network, ip)
tr, err := r.transportForDialLocked(association, network, ip)
if err != nil {
return nil, err
}
tr.IncreaseCount()
return tr, nil
}

func (r *reuse) transportForDialLocked(network string, source *net.IP) (*refcountedTransport, error) {
func (r *reuse) transportForDialLocked(association any, network string, source *net.IP) (*refcountedTransport, error) {
if source != nil {
// We already have at least one suitable transport...
if trs, ok := r.unicast[source.String()]; ok {
// ... we don't care which port we're dialing from. Just use the first.
// Prefer a transport that has the given association. We want to
// reuse the transport the association used for listening.
for _, tr := range trs {
return tr, nil
if tr.hasAssociation(association) {
return tr, nil
}
}
}
}

// Use a transport listening on 0.0.0.0 (or ::).
// Again, we don't care about the port number.
// Again, prefer a transport that has the given association.
for _, tr := range r.globalListeners {
return tr, nil
if tr.hasAssociation(association) {
return tr, nil
}
}

// Use a transport we've previously dialed from
Expand Down
14 changes: 7 additions & 7 deletions p2p/transport/quicreuse/reuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {

addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
conn, err := reuse.TransportForDial("udp4", addr)
conn, err := reuse.transportWithAssociationForDial(nil, "udp4", addr)
require.NoError(t, err)
require.Equal(t, 1, conn.GetCount())
laddr := conn.LocalAddr().(*net.UDPAddr)
Expand All @@ -111,7 +111,7 @@ func TestReuseConnectionWhenDialing(t *testing.T) {
// dial
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
conn, err := reuse.TransportForDial("udp4", raddr)
conn, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
require.Equal(t, 2, conn.GetCount())
}
Expand All @@ -122,7 +122,7 @@ func TestReuseConnectionWhenListening(t *testing.T) {

raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
tr, err := reuse.TransportForDial("udp4", raddr)
tr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
laddr := &net.UDPAddr{IP: net.IPv4zero, Port: tr.LocalAddr().(*net.UDPAddr).Port}
lconn, err := reuse.TransportForListen("udp4", laddr)
Expand All @@ -138,7 +138,7 @@ func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
// dial any address
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
require.NoError(t, err)
rTr, err := reuse.TransportForDial("udp4", raddr)
rTr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)

// open a listener
Expand All @@ -149,7 +149,7 @@ func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
// new dials should go via the listener connection
raddr, err = net.ResolveUDPAddr("udp4", "1.1.1.1:1235")
require.NoError(t, err)
tr, err := reuse.TransportForDial("udp4", raddr)
tr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
require.Equal(t, lTr, tr)
require.Equal(t, 2, tr.GetCount())
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestReuseListenOnSpecificInterface(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 1, lconn.GetCount())
// dial
conn, err := reuse.TransportForDial("udp4", raddr)
conn, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
require.Equal(t, 1, conn.GetCount())
}
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestReuseGarbageCollect(t *testing.T) {

raddr, err := net.ResolveUDPAddr("udp4", "1.2.3.4:1234")
require.NoError(t, err)
dTr, err := reuse.TransportForDial("udp4", raddr)
dTr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
require.NoError(t, err)
require.Equal(t, 1, dTr.GetCount())

Expand Down
3 changes: 2 additions & 1 deletion p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string
return verifyRawCerts(rawCerts, certHashes)
}
}
ctx = quicreuse.WithAssociation(ctx, t)
conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -331,7 +332,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
}
tlsConf.NextProtos = append(tlsConf.NextProtos, http3.NextProtoH3)

ln, err := t.connManager.ListenQUIC(laddr, tlsConf, t.allowWindowIncrease)
ln, err := t.connManager.ListenQUICAndAssociate(t, laddr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
}
Expand Down
Loading