Skip to content

Commit

Permalink
holepunch: pass address function in constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Sep 25, 2024
1 parent bc8f18f commit 89523cb
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 76 deletions.
11 changes: 10 additions & 1 deletion p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,16 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
opts.HolePunchingOptions = append(hpOpts, opts.HolePunchingOptions...)

}
h.hps, err = holepunch.NewService(h, h.ids, opts.HolePunchingOptions...)
h.hps, err = holepunch.NewService(h, h.ids, func() []ma.Multiaddr {
addrs := h.AllAddrs()
if opts.AddrsFactory != nil {
addrs = opts.AddrsFactory(addrs)
}
// AllAddrs may ignore observed addresses in favour of NAT mappings. Use both for hole punching.
addrs = append(addrs, h.ids.OwnObservedAddrs()...)
addrs = ma.Unique(addrs)
return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) })
}, opts.HolePunchingOptions...)
if err != nil {
return nil, fmt.Errorf("failed to create hole punch service: %w", err)
}
Expand Down
12 changes: 6 additions & 6 deletions p2p/protocol/holepunch/holepunch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ func TestNoHolePunchIfDirectConnExists(t *testing.T) {
require.GreaterOrEqual(t, nc1, 1)
nc2 := len(h2.Network().ConnsToPeer(h1.ID()))
require.GreaterOrEqual(t, nc2, 1)

require.NoError(t, hps.DirectConnect(h2.ID()))
require.Len(t, h1.Network().ConnsToPeer(h2.ID()), nc1)
require.Len(t, h2.Network().ConnsToPeer(h1.ID()), nc2)
Expand Down Expand Up @@ -473,8 +472,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc
hps = addHolePunchService(t, h2, h2opt...)
}

// h1 has a relay addr
// h2 should connect to the relay addr
// h2 has a relay addr
var raddr ma.Multiaddr
for _, a := range h2.Addrs() {
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil {
Expand All @@ -483,6 +481,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc
}
}
require.NotEmpty(t, raddr)
// h1 should connect to the relay addr
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
ID: h2.ID(),
Addrs: []ma.Multiaddr{raddr},
Expand All @@ -492,7 +491,9 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePunc

func addHolePunchService(t *testing.T, h host.Host, opts ...holepunch.Option) *holepunch.Service {
t.Helper()
hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...)
hps, err := holepunch.NewService(h, newMockIDService(t, h), func() []ma.Multiaddr {
return append(h.Addrs(), ma.StringCast("/ip4/1.2.3.4/tcp/1234"))
}, opts...)
require.NoError(t, err)
return hps
}
Expand All @@ -505,7 +506,6 @@ func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host,
libp2p.ResourceManager(&network.NullResourceManager{}),
)
require.NoError(t, err)
hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...)
require.NoError(t, err)
hps := addHolePunchService(t, h, opts...)
return h, hps
}
23 changes: 12 additions & 11 deletions p2p/protocol/holepunch/holepuncher.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ type holePuncher struct {
host host.Host
refCount sync.WaitGroup

ids identify.IDService
ids identify.IDService
listenAddrs func() []ma.Multiaddr

// active hole punches for deduplicating
activeMx sync.Mutex
Expand All @@ -50,13 +51,14 @@ type holePuncher struct {
filter AddrFilter
}

func newHolePuncher(h host.Host, ids identify.IDService, tracer *tracer, filter AddrFilter) *holePuncher {
func newHolePuncher(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, tracer *tracer, filter AddrFilter) *holePuncher {
hp := &holePuncher{
host: h,
ids: ids,
active: make(map[peer.ID]struct{}),
tracer: tracer,
filter: filter,
host: h,
ids: ids,
active: make(map[peer.ID]struct{}),
tracer: tracer,
filter: filter,
listenAddrs: listenAddrs,
}
hp.ctx, hp.ctxCancel = context.WithCancel(context.Background())
h.Network().Notify((*netNotifiee)(hp))
Expand Down Expand Up @@ -102,16 +104,15 @@ func (hp *holePuncher) directConnect(rp peer.ID) error {
if getDirectConnection(hp.host, rp) != nil {
return nil
}

// short-circuit hole punching if a direct dial works.
// attempt a direct connection ONLY if we have a public address for the remote peer
for _, a := range hp.host.Peerstore().Addrs(rp) {
if manet.IsPublicAddr(a) && !isRelayAddress(a) {
if !isRelayAddress(a) && manet.IsPublicAddr(a) {
forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching")
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout)

tstart := time.Now()
// This dials *all* public addresses from the peerstore.
// This dials *all* addresses, public and private, from the peerstore.
err := hp.host.Connect(dialCtx, peer.AddrInfo{ID: rp})
dt := time.Since(tstart)
cancel()
Expand Down Expand Up @@ -206,7 +207,7 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr
str.SetDeadline(time.Now().Add(StreamTimeout))

// send a CONNECT and start RTT measurement.
obsAddrs := removeRelayAddrs(hp.ids.OwnObservedAddrs())
obsAddrs := removeRelayAddrs(hp.listenAddrs())
if hp.filter != nil {
obsAddrs = hp.filter.FilterLocal(str.Conn().RemotePeer(), obsAddrs)
}
Expand Down
73 changes: 22 additions & 51 deletions p2p/protocol/holepunch/svc.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ import (
"context"
"errors"
"fmt"
"slices"
"sync"
"time"

logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb"
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
"github.com/libp2p/go-msgio/pbio"
Expand Down Expand Up @@ -47,7 +46,13 @@ type Service struct {
ctxCancel context.CancelFunc

host host.Host
ids identify.IDService
// ids helps with connection reversal. We wait for identify to complete and attempt
// a direct connection to the peer if it's publicly reachable.
ids identify.IDService
// listenAddrs provides the addresses for the host to be used for hole punching. We use this
// and not host.Addrs because host.Addrs might remove public unreachable address and only advertise
// publicly reachable relay addresses.
listenAddrs func() []ma.Multiaddr

holePuncherMx sync.Mutex
holePuncher *holePuncher
Expand All @@ -65,7 +70,7 @@ type Service struct {
// no matter if they are behind a NAT / firewall or not.
// The Service handles DCUtR streams (which are initiated from the node behind
// a NAT / Firewall once we establish a connection to them through a relay.
func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service, error) {
func NewService(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, opts ...Option) (*Service, error) {
if ids == nil {
return nil, errors.New("identify service can't be nil")
}
Expand All @@ -76,6 +81,7 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service,
ctxCancel: cancel,
host: h,
ids: ids,
listenAddrs: listenAddrs,
hasPublicAddrsChan: make(chan struct{}),
}

Expand All @@ -88,18 +94,18 @@ func NewService(h host.Host, ids identify.IDService, opts ...Option) (*Service,
s.tracer.Start()

s.refCount.Add(1)
go s.watchForPublicAddr()
go s.waitForPublicAddr()

return s, nil
}

func (s *Service) watchForPublicAddr() {
func (s *Service) waitForPublicAddr() {
defer s.refCount.Done()

log.Debug("waiting until we have at least one public address", "peer", s.host.ID())

// TODO: We should have an event here that fires when identify discovers a new
// address (and when autonat confirms that address).
// address.
// As we currently don't have an event like this, just check our observed addresses
// regularly (exponential backoff starting at 250 ms, capped at 5s).
duration := 250 * time.Millisecond
Expand All @@ -125,44 +131,27 @@ func (s *Service) watchForPublicAddr() {
}
}

// Only start the holePuncher if we're behind a NAT / firewall.
sub, err := s.host.EventBus().Subscribe(&event.EvtLocalReachabilityChanged{}, eventbus.Name("holepunch"))
if err != nil {
log.Debugf("failed to subscripe to Reachability event: %s", err)
s.holePuncherMx.Lock()
if s.ctx.Err() != nil {
// service is closed
return
}
defer sub.Close()
for {
select {
case <-s.ctx.Done():
return
case e, ok := <-sub.Out():
if !ok {
return
}
if e.(event.EvtLocalReachabilityChanged).Reachability != network.ReachabilityPrivate {
continue
}
s.holePuncherMx.Lock()
s.holePuncher = newHolePuncher(s.host, s.ids, s.tracer, s.filter)
s.holePuncherMx.Unlock()
close(s.hasPublicAddrsChan)
return
}
}
s.holePuncher = newHolePuncher(s.host, s.ids, s.listenAddrs, s.tracer, s.filter)
s.holePuncherMx.Unlock()
close(s.hasPublicAddrsChan)
}

// Close closes the Hole Punch Service.
func (s *Service) Close() error {
var err error
s.ctxCancel()
s.holePuncherMx.Lock()
if s.holePuncher != nil {
err = s.holePuncher.Close()
}
s.holePuncherMx.Unlock()
s.tracer.Close()
s.host.RemoveStreamHandler(Protocol)
s.ctxCancel()
s.refCount.Wait()
return err
}
Expand All @@ -172,7 +161,7 @@ func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, remo
if !isRelayAddress(str.Conn().RemoteMultiaddr()) {
return 0, nil, nil, fmt.Errorf("received hole punch stream: %s", str.Conn().RemoteMultiaddr())
}
ownAddrs = s.getPublicAddrs()
ownAddrs = s.listenAddrs()
if s.filter != nil {
ownAddrs = s.filter.FilterLocal(str.Conn().RemotePeer(), ownAddrs)
}
Expand Down Expand Up @@ -277,25 +266,7 @@ func (s *Service) handleNewStream(str network.Stream) {

// getPublicAddrs returns public observed and interface addresses
func (s *Service) getPublicAddrs() []ma.Multiaddr {
addrs := removeRelayAddrs(s.ids.OwnObservedAddrs())

interfaceListenAddrs, err := s.host.Network().InterfaceListenAddresses()
if err != nil {
log.Debugf("failed to get to get InterfaceListenAddresses: %s", err)
} else {
addrs = append(addrs, interfaceListenAddrs...)
}

addrs = ma.Unique(addrs)

publicAddrs := make([]ma.Multiaddr, 0, len(addrs))

for _, addr := range addrs {
if manet.IsPublicAddr(addr) {
publicAddrs = append(publicAddrs, addr)
}
}
return publicAddrs
return slices.DeleteFunc(s.listenAddrs(), func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) })
}

// DirectConnect is only exposed for testing purposes.
Expand Down
9 changes: 2 additions & 7 deletions p2p/protocol/holepunch/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package holepunch

import (
"context"
"slices"

"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
Expand All @@ -11,13 +12,7 @@ import (
)

func removeRelayAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
result := make([]ma.Multiaddr, 0, len(addrs))
for _, addr := range addrs {
if !isRelayAddress(addr) {
result = append(result, addr)
}
}
return result
return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return isRelayAddress(a) })
}

func isRelayAddress(a ma.Multiaddr) bool {
Expand Down

0 comments on commit 89523cb

Please sign in to comment.