Skip to content

Commit

Permalink
fix: apply addrFilters in the dht (#872)
Browse files Browse the repository at this point in the history
* fix: correctly apply addrFilters in the dht

This still does not do the fullrt client but it wasn't doing it before either.

* Add address filter tests

* use channel instead of waitgroup

to catch timeouts

* filter multiaddresses when serving provider records

---------

Co-authored-by: Dennis Trautwein <git@dtrautwein.eu>
  • Loading branch information
Jorropo and dennis-tra committed Sep 4, 2023
1 parent 2cbe38a commit 0c90569
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 23 deletions.
10 changes: 7 additions & 3 deletions dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,12 @@ func (dht *IpfsDHT) maybeAddAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Dura
if p == dht.self || dht.host.Network().Connectedness(p) == network.Connected {
return
}
if dht.addrFilter != nil {
addrs = dht.addrFilter(addrs)
dht.peerstore.AddAddrs(p, dht.filterAddrs(addrs), ttl)
}

func (dht *IpfsDHT) filterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
if f := dht.addrFilter; f != nil {
return f(addrs)
}
dht.peerstore.AddAddrs(p, addrs, ttl)
return addrs
}
179 changes: 173 additions & 6 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"testing"
"time"

"github.com/libp2p/go-libp2p-kad-dht/internal/net"
"github.com/libp2p/go-libp2p-kad-dht/providers"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network"
Expand Down Expand Up @@ -561,12 +563,136 @@ func TestProvides(t *testing.T) {
if prov.ID != dhts[3].self {
t.Fatal("Got back wrong provider")
}
if len(prov.Addrs) == 0 {
t.Fatal("Got no addresses back")
}
case <-ctxT.Done():
t.Fatal("Did not get a provider back.")
}
}
}

type testMessageSender struct {
sendRequest func(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error)
sendMessage func(ctx context.Context, p peer.ID, pmes *pb.Message) error
}

var _ pb.MessageSender = (*testMessageSender)(nil)

func (t testMessageSender) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
return t.sendRequest(ctx, p, pmes)
}

func (t testMessageSender) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
return t.sendMessage(ctx, p, pmes)
}

func TestProvideAddressFilter(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

dhts := setupDHTS(t, ctx, 2)

connect(t, ctx, dhts[0], dhts[1])
testMaddr := ma.StringCast("/ip4/99.99.99.99/tcp/9999")

done := make(chan struct{})
impl := net.NewMessageSenderImpl(dhts[0].host, dhts[0].protocols)
tms := &testMessageSender{
sendMessage: func(ctx context.Context, p peer.ID, pmes *pb.Message) error {
defer close(done)
assert.Equal(t, pmes.Type, pb.Message_ADD_PROVIDER)
assert.Len(t, pmes.ProviderPeers[0].Addrs, 1)
assert.True(t, pmes.ProviderPeers[0].Addresses()[0].Equal(testMaddr))
return impl.SendMessage(ctx, p, pmes)
},
sendRequest: func(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
return impl.SendRequest(ctx, p, pmes)
},
}
pm, err := pb.NewProtocolMessenger(tms)
require.NoError(t, err)

dhts[0].protoMessenger = pm
dhts[0].addrFilter = func(multiaddrs []ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{testMaddr}
}

if err := dhts[0].Provide(ctx, testCaseCids[0], true); err != nil {
t.Fatal(err)
}

select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
}

type testProviderManager struct {
addProvider func(ctx context.Context, key []byte, prov peer.AddrInfo) error
getProviders func(ctx context.Context, key []byte) ([]peer.AddrInfo, error)
close func() error
}

var _ providers.ProviderStore = (*testProviderManager)(nil)

func (t *testProviderManager) AddProvider(ctx context.Context, key []byte, prov peer.AddrInfo) error {
return t.addProvider(ctx, key, prov)
}

func (t *testProviderManager) GetProviders(ctx context.Context, key []byte) ([]peer.AddrInfo, error) {
return t.getProviders(ctx, key)
}

func (t *testProviderManager) Close() error {
return t.close()
}

func TestHandleAddProviderAddressFilter(t *testing.T) {
ctx := context.Background()

d := setupDHT(ctx, t, false)
provider := setupDHT(ctx, t, false)

testMaddr := ma.StringCast("/ip4/99.99.99.99/tcp/9999")

d.addrFilter = func(multiaddrs []ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{testMaddr}
}

done := make(chan struct{})
d.providerStore = &testProviderManager{
addProvider: func(ctx context.Context, key []byte, prov peer.AddrInfo) error {
defer close(done)
assert.True(t, prov.Addrs[0].Equal(testMaddr))
return nil
},
close: func() error { return nil },
}

pmes := &pb.Message{
Type: pb.Message_ADD_PROVIDER,
Key: []byte("test-key"),
ProviderPeers: pb.RawPeerInfosToPBPeers([]peer.AddrInfo{{
ID: provider.self,
Addrs: []ma.Multiaddr{
ma.StringCast("/ip4/55.55.55.55/tcp/5555"),
ma.StringCast("/ip4/66.66.66.66/tcp/6666"),
},
}}),
}

_, err := d.handleAddProvider(ctx, provider.self, pmes)
require.NoError(t, err)

select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
}

func TestLocalProvides(t *testing.T) {
// t.Skip("skipping test to debug another")
ctx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -603,6 +729,47 @@ func TestLocalProvides(t *testing.T) {
}
}

func TestAddressFilterProvide(t *testing.T) {
// t.Skip("skipping test to debug another")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

testMaddr := ma.StringCast("/ip4/99.99.99.99/tcp/9999")

d := setupDHT(ctx, t, false)
provider := setupDHT(ctx, t, false)

d.addrFilter = func(maddrs []ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{
testMaddr,
}
}

_, err := d.handleAddProvider(ctx, provider.self, &pb.Message{
Type: pb.Message_ADD_PROVIDER,
Key: []byte("random-key"),
ProviderPeers: pb.PeerInfosToPBPeers(provider.host.Network(), []peer.AddrInfo{{
ID: provider.self,
Addrs: provider.host.Addrs(),
}}),
})
require.NoError(t, err)

// because of the identify protocol we add all
// addresses to the peerstore, although the addresses
// will be filtered in the above handleAddProvider call
d.peerstore.AddAddrs(provider.self, provider.host.Addrs(), time.Hour)

resp, err := d.handleGetProviders(ctx, d.self, &pb.Message{
Type: pb.Message_GET_PROVIDERS,
Key: []byte("random-key"),
})
require.NoError(t, err)

assert.True(t, resp.ProviderPeers[0].Addresses()[0].Equal(testMaddr))
assert.Len(t, resp.ProviderPeers[0].Addresses(), 1)
}

// if minPeers or avgPeers is 0, dont test for it.
func waitForWellFormedTables(t *testing.T, dhts []*IpfsDHT, minPeers, avgPeers int, timeout time.Duration) {
// test "well-formed-ness" (>= minPeers peers in every routing table)
Expand Down Expand Up @@ -630,7 +797,7 @@ func checkForWellFormedTablesOnce(t *testing.T, dhts []*IpfsDHT, minPeers, avgPe
rtlen := dht.routingTable.Size()
totalPeers += rtlen
if minPeers > 0 && rtlen < minPeers {
//t.Logf("routing table for %s only has %d peers (should have >%d)", dht.self, rtlen, minPeers)
// t.Logf("routing table for %s only has %d peers (should have >%d)", dht.self, rtlen, minPeers)
return false
}
}
Expand Down Expand Up @@ -1568,9 +1735,7 @@ func TestProvideDisabled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var (
optsA, optsB []Option
)
var optsA, optsB []Option
optsA = append(optsA, ProtocolPrefix("/provMaybeDisabled"))
optsB = append(optsB, ProtocolPrefix("/provMaybeDisabled"))

Expand Down Expand Up @@ -1995,8 +2160,10 @@ func TestBootStrapWhenRTIsEmpty(t *testing.T) {
// convert the bootstrap addresses to a p2p address
bootstrapAddrs := make([]peer.AddrInfo, nBootStraps)
for i := 0; i < nBootStraps; i++ {
b := peer.AddrInfo{ID: bootstrappers[i].self,
Addrs: bootstrappers[i].host.Addrs()}
b := peer.AddrInfo{
ID: bootstrappers[i].self,
Addrs: bootstrappers[i].host.Addrs(),
}
bootstrapAddrs[i] = b
}

Expand Down
5 changes: 4 additions & 1 deletion fullrt/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,10 @@ func (dht *FullRT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err e
}

successes := dht.execOnMany(ctx, func(ctx context.Context, p peer.ID) error {
err := dht.protoMessenger.PutProvider(ctx, p, keyMH, dht.h)
err := dht.protoMessenger.PutProviderAddrs(ctx, p, keyMH, peer.AddrInfo{
ID: dht.self,
Addrs: dht.h.Addrs(),
})
return err
}, peers, true)

Expand Down
16 changes: 14 additions & 2 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,16 @@ func (dht *IpfsDHT) handleGetProviders(ctx context.Context, p peer.ID, pmes *pb.
if err != nil {
return nil, err
}
resp.ProviderPeers = pb.PeerInfosToPBPeers(dht.host.Network(), providers)

filtered := make([]peer.AddrInfo, len(providers))
for i, provider := range providers {
filtered[i] = peer.AddrInfo{
ID: provider.ID,
Addrs: dht.filterAddrs(provider.Addrs),
}
}

resp.ProviderPeers = pb.PeerInfosToPBPeers(dht.host.Network(), filtered)

// Also send closer peers.
closer := dht.betterPeersToQuery(pmes, p, dht.bucketSize)
Expand Down Expand Up @@ -359,7 +368,10 @@ func (dht *IpfsDHT) handleAddProvider(ctx context.Context, p peer.ID, pmes *pb.M
continue
}

dht.providerStore.AddProvider(ctx, key, peer.AddrInfo{ID: pi.ID, Addrs: pi.Addrs})
// We run the addrs filter after checking for the length,
// this allows transient nodes with varying /p2p-circuit addresses to still have their anouncement go through.
addrs := dht.filterAddrs(pi.Addrs)
dht.providerStore.AddProvider(ctx, key, peer.AddrInfo{ID: pi.ID, Addrs: addrs})
}

return nil, nil
Expand Down
5 changes: 4 additions & 1 deletion lookup_optim.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ func (os *optimisticState) stopFn(qps *qpeerset.QueryPeerset) bool {
}

func (os *optimisticState) putProviderRecord(pid peer.ID) {
err := os.dht.protoMessenger.PutProvider(os.putCtx, pid, []byte(os.key), os.dht.host)
err := os.dht.protoMessenger.PutProviderAddrs(os.putCtx, pid, []byte(os.key), peer.AddrInfo{
ID: os.dht.self,
Addrs: os.dht.filterAddrs(os.dht.host.Addrs()),
})
os.peerStatesLk.Lock()
if err != nil {
os.peerStates[pid] = failure
Expand Down
21 changes: 12 additions & 9 deletions pb/protocol_messenger.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,16 @@ func (pm *ProtocolMessenger) GetClosestPeers(ctx context.Context, p peer.ID, id
return peers, nil
}

// PutProvider asks a peer to store that we are a provider for the given key.
func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key multihash.Multihash, host host.Host) (err error) {
// PutProvider is deprecated please use [ProtocolMessenger.PutProviderAddrs].
func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key multihash.Multihash, h host.Host) error {
return pm.PutProviderAddrs(ctx, p, key, peer.AddrInfo{
ID: h.ID(),
Addrs: h.Addrs(),
})
}

// PutProviderAddrs asks a peer to store that we are a provider for the given key.
func (pm *ProtocolMessenger) PutProviderAddrs(ctx context.Context, p peer.ID, key multihash.Multihash, self peer.AddrInfo) (err error) {
ctx, span := internal.StartSpan(ctx, "ProtocolMessenger.PutProvider")
defer span.End()
if span.IsRecording() {
Expand All @@ -182,19 +190,14 @@ func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key mul
}()
}

pi := peer.AddrInfo{
ID: host.ID(),
Addrs: host.Addrs(),
}

// TODO: We may want to limit the type of addresses in our provider records
// For example, in a WAN-only DHT prohibit sharing non-WAN addresses (e.g. 192.168.0.100)
if len(pi.Addrs) < 1 {
if len(self.Addrs) < 1 {
return fmt.Errorf("no known addresses for self, cannot put provider")
}

pmes := NewMessage(Message_ADD_PROVIDER, key, 0)
pmes.ProviderPeers = RawPeerInfosToPBPeers([]peer.AddrInfo{pi})
pmes.ProviderPeers = RawPeerInfosToPBPeers([]peer.AddrInfo{self})

return pm.m.SendMessage(ctx, p, pmes)
}
Expand Down
5 changes: 4 additions & 1 deletion routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,10 @@ func (dht *IpfsDHT) classicProvide(ctx context.Context, keyMH multihash.Multihas
go func(p peer.ID) {
defer wg.Done()
logger.Debugf("putProvider(%s, %s)", internal.LoggableProviderRecordBytes(keyMH), p)
err := dht.protoMessenger.PutProvider(ctx, p, keyMH, dht.host)
err := dht.protoMessenger.PutProviderAddrs(ctx, p, keyMH, peer.AddrInfo{
ID: dht.self,
Addrs: dht.filterAddrs(dht.host.Addrs()),
})
if err != nil {
logger.Debug(err)
}
Expand Down

0 comments on commit 0c90569

Please sign in to comment.