Skip to content

Commit

Permalink
add a Request struct
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 21, 2023
1 parent 25a0fe9 commit e1c362a
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 132 deletions.
100 changes: 58 additions & 42 deletions p2p/protocol/autonatv2/autonat.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"golang.org/x/exp/slices"
)

//go:generate protoc --go_out=. --go_opt=Mpb/autonatv2.proto=./pb pb/autonatv2.proto

const (
ServiceName = "libp2p.autonatv2"
DialBackProtocol = "/libp2p/autonat/2/dial-back"
Expand All @@ -28,33 +30,62 @@ const (
streamTimeout = time.Minute
dialBackStreamTimeout = 5 * time.Second
dialBackDialTimeout = 30 * time.Second
minHandshakeSizeBytes = 30_000 // for amplification attack prevention
maxHandshakeSizeBytes = 100_000
minHandshakeSizeBytes = 30_000
maxPeerAddresses = 50
// maxPeerAddresses is the number of addresses in a dial request the server
// will inspect, rest are ignored.
maxPeerAddresses = 50
)

var (
ErrNoValidPeers = errors.New("no valid peers for autonat v2")
ErrDialRefused = errors.New("dial refused")
)

var (
log = logging.Logger("autonatv2")
)

// Request is the request to verify reachability of a single address
type Request struct {
// Addr is the multiaddr to verify
Addr ma.Multiaddr
// SendDialData indicates whether to send dial data if the server requests it for Addr
SendDialData bool
}

// Result is the result of the CheckReachability call
type Result struct {
// Idx is the index of the dialed address
Idx int
// Addr is the dialed address
Addr ma.Multiaddr
// Reachability of the dialed address
Reachability network.Reachability
// Status is the outcome of the dialback
Status pb.DialStatus
}

// AutoNAT implements the AutoNAT v2 client and server. Users can check reachability
// for their addresses using the CheckReachability method.
type AutoNAT struct {
host host.Host
sub event.Subscription
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
srv *Server
cli *Client
mx sync.Mutex
peers *peersMap
host host.Host
sub event.Subscription

// for cleanly closing
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup

srv *Server
cli *client

mx sync.Mutex
peers *peersMap

allowAllAddrs bool // for testing
}

// New returns a new AutoNAT instance. The returned instance runs the server when the provided host
// is publicly reachable.
func New(h host.Host, dialer host.Host, opts ...AutoNATOption) (*AutoNAT, error) {
s := defaultSettings()
for _, o := range opts {
Expand Down Expand Up @@ -83,21 +114,21 @@ func New(h host.Host, dialer host.Host, opts ...AutoNATOption) (*AutoNAT, error)
new(event.EvtPeerIdentificationCompleted),
})
if err != nil {
return nil, fmt.Errorf("event subscription failed: %w", err)
return nil, fmt.Errorf("event subscription: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())

ctx, cancel := context.WithCancel(context.Background())
an := &AutoNAT{
host: h,
ctx: ctx,
cancel: cancel,
sub: sub,
srv: NewServer(h, dialer, s),
cli: NewClient(h),
cli: newClient(h),
allowAllAddrs: s.allowAllAddrs,
peers: newPeersMap(),
}
an.cli.Register()
an.cli.RegisterDialBack()

an.wg.Add(1)
go an.background()
Expand Down Expand Up @@ -136,30 +167,13 @@ func (an *AutoNAT) Close() {
an.wg.Wait()
}

// Result is the result of the CheckReachability call
type Result struct {
// Idx is the index of the dialed address
Idx int
// Addr is the dialed address
Addr ma.Multiaddr
// Reachability of the dialed address
Reachability network.Reachability
// Status is the outcome of the dialback
Status pb.DialStatus
}

// CheckReachability makes a single dial request for checking reachability. For highPriorityAddrs dial charge is paid
// if the server asks for it. For lowPriorityAddrs dial charge is rejected.
func (an *AutoNAT) CheckReachability(ctx context.Context, highPriorityAddrs []ma.Multiaddr, lowPriorityAddrs []ma.Multiaddr) (Result, error) {
func (an *AutoNAT) CheckReachability(ctx context.Context, reqs []Request) (Result, error) {
if !an.allowAllAddrs {
for _, a := range highPriorityAddrs {
if !manet.IsPublicAddr(a) {
return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", a)
}
}
for _, a := range lowPriorityAddrs {
if !manet.IsPublicAddr(a) {
return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", a)
for _, r := range reqs {
if !manet.IsPublicAddr(r.Addr) {
return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", r.Addr)
}
}
}
Expand All @@ -168,7 +182,7 @@ func (an *AutoNAT) CheckReachability(ctx context.Context, highPriorityAddrs []ma
return Result{}, ErrNoValidPeers
}

res, err := an.cli.CheckReachability(ctx, p, highPriorityAddrs, lowPriorityAddrs)
res, err := an.cli.CheckReachability(ctx, p, reqs)
if err != nil {
log.Debugf("reachability check with %s failed, err: %s", p, err)
return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err)
Expand All @@ -181,12 +195,14 @@ func (an *AutoNAT) updatePeer(p peer.ID) {
an.mx.Lock()
defer an.mx.Unlock()

// There are no ordering gurantees between identify and swarm events. Check peerstore
// and swarm for the current state
protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol)
connectedness := an.host.Network().Connectedness(p)
if err == nil && slices.Contains(protos, DialProtocol) && connectedness == network.Connected {
an.peers.Put(p)
} else {
an.peers.Remove(p)
an.peers.Delete(p)
}
}

Expand Down Expand Up @@ -219,13 +235,13 @@ func (p *peersMap) Put(pid peer.ID) {
p.peerIdx[pid] = len(p.peers) - 1
}

func (p *peersMap) Remove(pid peer.ID) {
func (p *peersMap) Delete(pid peer.ID) {
idx, ok := p.peerIdx[pid]
if !ok {
return
}
delete(p.peerIdx, pid)
p.peers[idx] = p.peers[len(p.peers)-1]
p.peerIdx[p.peers[idx]] = idx
p.peers = p.peers[:len(p.peers)-1]
delete(p.peerIdx, pid)
}
87 changes: 60 additions & 27 deletions p2p/protocol/autonatv2/autonat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func newAutoNAT(t *testing.T, dialer host.Host, opts ...AutoNATOption) *AutoNAT
t.Error(err)
}
an.srv.Enable()
an.cli.Register()
an.cli.RegisterDialBack()
return an
}

Expand Down Expand Up @@ -78,7 +78,7 @@ func identify(t *testing.T, cli *AutoNAT, srv *AutoNAT) {

func TestAutoNATPrivateAddr(t *testing.T) {
an := newAutoNAT(t, nil)
res, err := an.CheckReachability(context.Background(), []ma.Multiaddr{ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}, nil)
res, err := an.CheckReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}})
require.Equal(t, res, Result{})
require.NotNil(t, err)
}
Expand Down Expand Up @@ -115,7 +115,12 @@ func TestClientRequest(t *testing.T) {
idAndConnect(t, an.host, p)
waitForPeer(t, an)

res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:])
res, err := an.CheckReachability(
context.Background(),
[]Request{
{Addr: addrs[0], SendDialData: true},
{Addr: addrs[1]},
})
require.Equal(t, res, Result{})
require.NotNil(t, err)
require.True(t, gotReq.Load())
Expand Down Expand Up @@ -160,7 +165,12 @@ func TestClientServerError(t *testing.T) {
for i, tc := range tests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
p.SetStreamHandler(DialProtocol, tc.handler)
res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:])
res, err := an.CheckReachability(
context.Background(),
[]Request{
{Addr: addrs[0], SendDialData: true},
{Addr: addrs[1]},
})
require.Equal(t, res, Result{})
require.NotNil(t, err)
<-done
Expand Down Expand Up @@ -251,7 +261,12 @@ func TestClientDataRequest(t *testing.T) {
for i, tc := range tests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
p.SetStreamHandler(DialProtocol, tc.handler)
res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:])
res, err := an.CheckReachability(
context.Background(),
[]Request{
{Addr: addrs[0], SendDialData: true},
{Addr: addrs[1]},
})
require.Equal(t, res, Result{})
require.NotNil(t, err)
<-done
Expand Down Expand Up @@ -472,7 +487,12 @@ func TestClientDialBacks(t *testing.T) {
for i, tc := range tests {
t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) {
p.SetStreamHandler(DialProtocol, tc.handler)
res, err := an.CheckReachability(context.Background(), addrs[:1], addrs[1:])
res, err := an.CheckReachability(
context.Background(),
[]Request{
{Addr: addrs[0], SendDialData: true},
{Addr: addrs[1]},
})
if !tc.success {
if tc.isError {
require.Equal(t, res, Result{})
Expand Down Expand Up @@ -530,28 +550,41 @@ func TestEventSubscription(t *testing.T) {
}

func TestPeersMap(t *testing.T) {
p := newPeersMap()
emptyPeerID := peer.ID("")
require.Equal(t, emptyPeerID, p.GetRand())

allPeers := make(map[peer.ID]bool)
for i := 0; i < 20; i++ {
pid := peer.ID(fmt.Sprintf("peer-%d", i))
allPeers[pid] = true
p.Put(pid)
}
foundPeers := make(map[peer.ID]bool)
for i := 0; i < 1000; i++ {
pid := p.GetRand()
require.NotEqual(t, emptyPeerID, p)
require.True(t, allPeers[pid])
foundPeers[pid] = true
if len(foundPeers) == len(allPeers) {
break
t.Run("single_item", func(t *testing.T) {
p := newPeersMap()
p.Put("peer1")
p.Delete("peer1")
p.Put("peer1")
require.Equal(t, peer.ID("peer1"), p.GetRand())
p.Delete("peer1")
require.Equal(t, emptyPeerID, p.GetRand())
})

t.Run("multiple_items", func(t *testing.T) {
p := newPeersMap()
require.Equal(t, emptyPeerID, p.GetRand())

allPeers := make(map[peer.ID]bool)
for i := 0; i < 20; i++ {
pid := peer.ID(fmt.Sprintf("peer-%d", i))
allPeers[pid] = true
p.Put(pid)
}
}
for pid := range allPeers {
p.Remove(pid)
}
require.Equal(t, emptyPeerID, p.GetRand())
foundPeers := make(map[peer.ID]bool)
for i := 0; i < 1000; i++ {
pid := p.GetRand()
require.NotEqual(t, emptyPeerID, p)
require.True(t, allPeers[pid])
foundPeers[pid] = true
if len(foundPeers) == len(allPeers) {
break
}
}
for pid := range allPeers {
p.Delete(pid)
}
require.Equal(t, emptyPeerID, p.GetRand())
})
}
Loading

0 comments on commit e1c362a

Please sign in to comment.