Skip to content

Commit

Permalink
identify: avoid spuriously triggering pushes (#2299)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed May 30, 2023
1 parent b4faaf8 commit 6926113
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 6 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ require (
go.uber.org/fx v1.19.2
go.uber.org/goleak v1.1.12
golang.org/x/crypto v0.7.0
golang.org/x/exp v0.0.0-20230321023759-10a507213a29
golang.org/x/sync v0.1.0
golang.org/x/sys v0.7.0
golang.org/x/tools v0.7.0
Expand Down Expand Up @@ -109,7 +110,6 @@ require (
go.uber.org/dig v1.16.1 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.24.0 // indirect
golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect
golang.org/x/mod v0.10.0 // indirect
golang.org/x/net v0.8.0 // indirect
golang.org/x/text v0.8.0 // indirect
Expand Down
52 changes: 47 additions & 5 deletions p2p/protocol/identify/id.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package identify

import (
"bytes"
"context"
"fmt"
"io"
"sort"
"sync"
"time"

"golang.org/x/exp/slices"

"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
Expand Down Expand Up @@ -62,6 +66,31 @@ type identifySnapshot struct {
record *record.Envelope
}

// Equal says if two snapshots are identical.
// It does NOT compare the sequence number.
func (s identifySnapshot) Equal(other *identifySnapshot) bool {
hasRecord := s.record != nil
otherHasRecord := other.record != nil
if hasRecord != otherHasRecord {
return false
}
if hasRecord && !s.record.Equal(other.record) {
return false
}
if !slices.Equal(s.protocols, other.protocols) {
return false
}
if len(s.addrs) != len(other.addrs) {
return false
}
for i, a := range s.addrs {
if !a.Equal(other.addrs[i]) {
return false
}
}
return true
}

type IDService interface {
// IdentifyConn synchronously triggers an identify request on the connection and
// waits for it to complete. If the connection is being identified by another
Expand Down Expand Up @@ -249,10 +278,12 @@ func (ids *idService) loop(ctx context.Context) {
if !ok {
return
}
if updated := ids.updateSnapshot(); !updated {
continue
}
if ids.metricsTracer != nil {
ids.metricsTracer.TriggeredPushes(e)
}
ids.updateSnapshot()
select {
case triggerPush <- struct{}{}:
default: // we already have one more push queued, no need to queue another one
Expand Down Expand Up @@ -529,23 +560,34 @@ func readAllIDMessages(r pbio.Reader, finalMsg proto.Message) error {
return fmt.Errorf("too many parts")
}

func (ids *idService) updateSnapshot() {
func (ids *idService) updateSnapshot() (updated bool) {
addrs := ids.Host.Addrs()
sort.Slice(addrs, func(i, j int) bool { return bytes.Compare(addrs[i].Bytes(), addrs[j].Bytes()) == -1 })
protos := ids.Host.Mux().Protocols()
sort.Slice(protos, func(i, j int) bool { return protos[i] < protos[j] })
snapshot := identifySnapshot{
addrs: ids.Host.Addrs(),
protocols: ids.Host.Mux().Protocols(),
addrs: addrs,
protocols: protos,
}

if !ids.disableSignedPeerRecord {
if cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore()); ok {
snapshot.record = cab.GetPeerRecord(ids.Host.ID())
}
}

ids.currentSnapshot.Lock()
defer ids.currentSnapshot.Unlock()

if ids.currentSnapshot.snapshot.Equal(&snapshot) {
return false
}

snapshot.seq = ids.currentSnapshot.snapshot.seq + 1
ids.currentSnapshot.snapshot = snapshot
ids.currentSnapshot.Unlock()

log.Debugw("updating snapshot", "seq", snapshot.seq, "addrs", snapshot.addrs)
return true
}

func (ids *idService) writeChunkedIdentifyMsg(s network.Stream, mes *pb.Identify) error {
Expand Down
47 changes: 47 additions & 0 deletions p2p/protocol/identify/snapshot_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package identify

import (
"crypto/rand"
"testing"

"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/record"

ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)

func TestSnapshotEquality(t *testing.T) {
addr1 := ma.StringCast("/ip4/127.0.0.1/tcp/1234")
addr2 := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1")

_, pubKey1, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)
_, pubKey2, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)
record1 := &record.Envelope{PublicKey: pubKey1}
record2 := &record.Envelope{PublicKey: pubKey2}

for _, tc := range []struct {
s1, s2 *identifySnapshot
result bool
}{
{s1: &identifySnapshot{record: record1}, s2: &identifySnapshot{record: record1}, result: true},
{s1: &identifySnapshot{record: record1}, s2: &identifySnapshot{record: record2}, result: false},
{s1: &identifySnapshot{addrs: []ma.Multiaddr{addr1}}, s2: &identifySnapshot{addrs: []ma.Multiaddr{addr1}}, result: true},
{s1: &identifySnapshot{addrs: []ma.Multiaddr{addr1}}, s2: &identifySnapshot{addrs: []ma.Multiaddr{addr2}}, result: false},
{s1: &identifySnapshot{addrs: []ma.Multiaddr{addr1, addr2}}, s2: &identifySnapshot{addrs: []ma.Multiaddr{addr2}}, result: false},
{s1: &identifySnapshot{addrs: []ma.Multiaddr{addr1}}, s2: &identifySnapshot{addrs: []ma.Multiaddr{addr1, addr2}}, result: false},
{s1: &identifySnapshot{protocols: []protocol.ID{"/foo"}}, s2: &identifySnapshot{protocols: []protocol.ID{"/foo"}}, result: true},
{s1: &identifySnapshot{protocols: []protocol.ID{"/foo"}}, s2: &identifySnapshot{protocols: []protocol.ID{"/bar"}}, result: false},
{s1: &identifySnapshot{protocols: []protocol.ID{"/foo", "/bar"}}, s2: &identifySnapshot{protocols: []protocol.ID{"/bar"}}, result: false},
{s1: &identifySnapshot{protocols: []protocol.ID{"/foo"}}, s2: &identifySnapshot{protocols: []protocol.ID{"/foo", "/bar"}}, result: false},
} {
if tc.result {
require.Truef(t, tc.s1.Equal(tc.s2), "expected equal: %+v and %+v", tc.s1, tc.s2)
} else {
require.Falsef(t, tc.s1.Equal(tc.s2), "expected unequal: %+v and %+v", tc.s1, tc.s2)
}
}
}

0 comments on commit 6926113

Please sign in to comment.