From 9c7dc65093603cf7aa396c5ab7e81e55a3c95222 Mon Sep 17 00:00:00 2001 From: Ian Davis <18375+iand@users.noreply.github.com> Date: Tue, 12 Sep 2023 12:32:50 +0100 Subject: [PATCH] Ensure routing tables are safe for concurrent use (#121) Fixes #115 --- routing/simplert/table.go | 60 ++++++++++++++++------------------ routing/simplert/table_test.go | 39 ++++++++++++++++++++++ routing/triert/table.go | 40 ++++++++++++++++++----- routing/triert/table_test.go | 41 +++++++++++++++++++++++ 4 files changed, 140 insertions(+), 40 deletions(-) diff --git a/routing/simplert/table.go b/routing/simplert/table.go index 16a7b04..cd4e367 100644 --- a/routing/simplert/table.go +++ b/routing/simplert/table.go @@ -2,6 +2,7 @@ package simplert import ( "sort" + "sync" "github.com/plprobelab/go-kademlia/internal/kadtest" @@ -16,8 +17,10 @@ type peerInfo[K kad.Key[K], N kad.NodeID[K]] struct { type SimpleRT[K kad.Key[K], N kad.NodeID[K]] struct { self K - buckets [][]peerInfo[K, N] bucketSize int + + mu sync.RWMutex // guards access to buckets + buckets [][]peerInfo[K, N] } var _ kad.RoutingTable[key.Key256, kadtest.ID[key.Key256]] = (*SimpleRT[key.Key256, kadtest.ID[key.Key256]])(nil) @@ -46,6 +49,13 @@ func (rt *SimpleRT[K, N]) BucketSize() int { } func (rt *SimpleRT[K, N]) BucketIdForKey(kadId K) (int, error) { + rt.mu.RLock() + defer rt.mu.RUnlock() + return rt.bucketIdForKey(kadId) +} + +// bucketIdForKey must only be called while rt.mu is held +func (rt *SimpleRT[K, N]) bucketIdForKey(kadId K) (int, error) { bid := rt.self.CommonPrefixLength(kadId) nBuckets := len(rt.buckets) if bid >= nBuckets { @@ -55,6 +65,8 @@ func (rt *SimpleRT[K, N]) BucketIdForKey(kadId K) (int, error) { } func (rt *SimpleRT[K, N]) SizeOfBucket(bucketId int) int { + rt.mu.RLock() + defer rt.mu.RUnlock() return len(rt.buckets[bucketId]) } @@ -63,19 +75,15 @@ func (rt *SimpleRT[K, N]) AddNode(id N) bool { } func (rt *SimpleRT[K, N]) addPeer(kadId K, id N) bool { - //_, span := util.StartSpan(ctx, "routing.simple.addPeer", trace.WithAttributes( - // attribute.String("KadID", key.HexString(kadId)), - // attribute.Stringer("PeerID", id), - //)) - //defer span.End() + rt.mu.Lock() + defer rt.mu.Unlock() // no need to check the error here, it's already been checked in keyError - bid, _ := rt.BucketIdForKey(kadId) + bid, _ := rt.bucketIdForKey(kadId) lastBucketId := len(rt.buckets) - 1 if rt.alreadyInBucket(kadId, bid) { - // span.AddEvent("peer not added, already in bucket " + strconv.Itoa(bid)) // discard new peer return false } @@ -83,20 +91,17 @@ func (rt *SimpleRT[K, N]) addPeer(kadId K, id N) bool { if bid < lastBucketId { // new peer doesn't belong in last bucket if len(rt.buckets[bid]) >= rt.bucketSize { - // span.AddEvent("peer not added, bucket " + strconv.Itoa(bid) + " full") // bucket is full, discard new peer return false } // add new peer to bucket rt.buckets[bid] = append(rt.buckets[bid], peerInfo[K, N]{id, kadId}) - // span.AddEvent("peer added to bucket " + strconv.Itoa(bid)) return true } if len(rt.buckets[lastBucketId]) < rt.bucketSize { // last bucket is not full, add new peer rt.buckets[lastBucketId] = append(rt.buckets[lastBucketId], peerInfo[K, N]{id, kadId}) - // span.AddEvent("peer added to bucket " + strconv.Itoa(lastBucketId)) return true } // last bucket is full, try to split it @@ -106,15 +111,11 @@ func (rt *SimpleRT[K, N]) addPeer(kadId K, id N) bool { // closeBucket contains peers with a CPL higher than lastBucketId closeBucket := make([]peerInfo[K, N], 0) - // span.AddEvent("splitting last bucket (" + strconv.Itoa(lastBucketId) + ")") - for _, p := range rt.buckets[lastBucketId] { if p.kadId.CommonPrefixLength(rt.self) == lastBucketId { farBucket = append(farBucket, p) } else { closeBucket = append(closeBucket, p) - // span.AddEvent(p.id.String() + " moved to new bucket (" + - // strconv.Itoa(lastBucketId+1) + ")") } } if len(farBucket) == rt.bucketSize && @@ -131,7 +132,7 @@ func (rt *SimpleRT[K, N]) addPeer(kadId K, id N) bool { lastBucketId++ } - newBid, _ := rt.BucketIdForKey(kadId) + newBid, _ := rt.bucketIdForKey(kadId) // add new peer to appropraite bucket rt.buckets[newBid] = append(rt.buckets[newBid], peerInfo[K, N]{id, kadId}) // span.AddEvent("peer added to bucket " + strconv.Itoa(newBid)) @@ -149,29 +150,25 @@ func (rt *SimpleRT[K, N]) alreadyInBucket(kadId K, bucketId int) bool { } func (rt *SimpleRT[K, N]) RemoveKey(kadId K) bool { - //_, span := util.StartSpan(ctx, "routing.simple.removeKey", trace.WithAttributes( - // attribute.String("KadID", key.HexString(kadId)), - //)) - //defer span.End() + rt.mu.Lock() + defer rt.mu.Unlock() - bid, _ := rt.BucketIdForKey(kadId) + bid, _ := rt.bucketIdForKey(kadId) for i, p := range rt.buckets[bid] { if key.Equal(kadId, p.kadId) { // remove peer from bucket rt.buckets[bid][i] = rt.buckets[bid][len(rt.buckets[bid])-1] rt.buckets[bid] = rt.buckets[bid][:len(rt.buckets[bid])-1] - - // span.AddEvent(fmt.Sprint(p.id.String(), "removed from bucket", bid)) return true } } - // peer not found in the routing table - // span.AddEvent(fmt.Sprint("peer not found in bucket", bid)) return false } func (rt *SimpleRT[K, N]) GetNode(kadId K) (N, bool) { - bid, _ := rt.BucketIdForKey(kadId) + rt.mu.RLock() + defer rt.mu.RUnlock() + bid, _ := rt.bucketIdForKey(kadId) for _, p := range rt.buckets[bid] { if key.Equal(kadId, p.kadId) { return p.id, true @@ -184,13 +181,10 @@ func (rt *SimpleRT[K, N]) GetNode(kadId K) (N, bool) { // TODO: not exactly working as expected // returns min(n, bucketSize) peers from the bucket matching the given key func (rt *SimpleRT[K, N]) NearestNodes(kadId K, n int) []N { - //_, span := util.StartSpan(ctx, "routing.simple.nearestPeers", trace.WithAttributes( - // attribute.String("KadID", key.HexString(kadId)), - // attribute.Int("n", int(n)), - //)) - //defer span.End() + rt.mu.RLock() + defer rt.mu.RUnlock() - bid, _ := rt.BucketIdForKey(kadId) + bid, _ := rt.bucketIdForKey(kadId) var peers []peerInfo[K, N] // TODO: optimize this @@ -240,6 +234,8 @@ func (rt *SimpleRT[K, N]) Cpl(kk K) int { // CplSize returns the number of nodes in the table whose longest common prefix with the table's key is of length cpl. func (rt *SimpleRT[K, N]) CplSize(cpl int) int { + rt.mu.RLock() + defer rt.mu.RUnlock() bid := cpl // cpl is simply the bucket id nBuckets := len(rt.buckets) if bid >= nBuckets { diff --git a/routing/simplert/table_test.go b/routing/simplert/table_test.go index 9eb0f2e..8fb87cf 100644 --- a/routing/simplert/table_test.go +++ b/routing/simplert/table_test.go @@ -2,6 +2,8 @@ package simplert import ( "fmt" + "math/rand" + "sync" "testing" "github.com/libp2p/go-libp2p/core/peer" @@ -201,3 +203,40 @@ func TestNearestPeers(t *testing.T) { peers = rt2.NearestNodes(key0, 10) require.Equal(t, peers[0], peers[1]) } + +func TestTableConcurrentReadWrite(t *testing.T) { + nodes := make([]*kt.ID[key.Key32], 5000) + for i := range nodes { + nodes[i] = kt.NewID(kt.RandomKey()) + } + + rt := New[key.Key32](kt.NewID(key.Key32(0)), 2) + + workers := 3 + var wg sync.WaitGroup + wg.Add(workers) + + // start workers to concurrently read and write the routing table + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + work := make([]*kt.ID[key.Key32], len(nodes)) + copy(work, nodes) + rand.Shuffle(len(work), func(i, j int) { work[i], work[j] = work[j], work[i] }) + + for i := range work { + node := work[i] + _, found := rt.GetNode(node.Key()) + if !found { + // add new peer + rt.AddNode(work[i]) + } else { + // remove it + rt.RemoveKey(node.Key()) + } + + } + }() + } + wg.Wait() +} diff --git a/routing/triert/table.go b/routing/triert/table.go index 15e92fc..dc58e39 100644 --- a/routing/triert/table.go +++ b/routing/triert/table.go @@ -2,6 +2,8 @@ package triert import ( "fmt" + "sync" + "sync/atomic" "github.com/plprobelab/go-kademlia/internal/kadtest" "github.com/plprobelab/go-kademlia/kad" @@ -15,7 +17,8 @@ type TrieRT[K kad.Key[K], N kad.NodeID[K]] struct { self K keyFilter KeyFilterFunc[K, N] - keys *trie.Trie[K, N] + mu sync.Mutex // held to synchronise mutations to the trie + trie atomic.Value // holds a *trie.Trie[K, N] } var _ kad.RoutingTable[key.Key256, kadtest.ID[key.Key256]] = (*TrieRT[key.Key256, kadtest.ID[key.Key256]])(nil) @@ -25,8 +28,9 @@ var _ kad.RoutingTable[key.Key256, kadtest.ID[key.Key256]] = (*TrieRT[key.Key256 func New[K kad.Key[K], N kad.NodeID[K]](self N, cfg *Config[K, N]) (*TrieRT[K, N], error) { rt := &TrieRT[K, N]{ self: self.Key(), - keys: &trie.Trie[K, N]{}, } + rt.trie.Store(&trie.Trie[K, N]{}) + if err := rt.apply(cfg); err != nil { return nil, fmt.Errorf("apply config: %w", err) } @@ -55,18 +59,35 @@ func (rt *TrieRT[K, N]) AddNode(node N) bool { return false } - return rt.keys.Add(kk, node) + rt.mu.Lock() + defer rt.mu.Unlock() + this := rt.trie.Load().(*trie.Trie[K, N]) + next, _ := trie.Add(this, kk, node) + if next == this { + return false + } + rt.trie.Store(next) + return true } // RemoveKey tries to remove a node identified by its Kademlia key from the // routing table. It returns true if the key was found to be present in the table and was removed. func (rt *TrieRT[K, N]) RemoveKey(kk K) bool { - return rt.keys.Remove(kk) + rt.mu.Lock() + defer rt.mu.Unlock() + this := rt.trie.Load().(*trie.Trie[K, N]) + next, _ := trie.Remove(this, kk) + if next == this { + return false + } + rt.trie.Store(next) + return true } // NearestNodes returns the n closest nodes to a given key. func (rt *TrieRT[K, N]) NearestNodes(target K, n int) []N { - closestEntries := trie.Closest(rt.keys, target, n) + this := rt.trie.Load().(*trie.Trie[K, N]) + closestEntries := trie.Closest(this, target, n) if len(closestEntries) == 0 { return []N{} } @@ -80,7 +101,8 @@ func (rt *TrieRT[K, N]) NearestNodes(target K, n int) []N { } func (rt *TrieRT[K, N]) GetNode(kk K) (N, bool) { - found, node := trie.Find(rt.keys, kk) + this := rt.trie.Load().(*trie.Trie[K, N]) + found, node := trie.Find(this, kk) if !found { var zero N return zero, false @@ -90,7 +112,8 @@ func (rt *TrieRT[K, N]) GetNode(kk K) (N, bool) { // Size returns the number of peers contained in the table. func (rt *TrieRT[K, N]) Size() int { - return rt.keys.Size() + this := rt.trie.Load().(*trie.Trie[K, N]) + return this.Size() } // Cpl returns the longest common prefix length the supplied key shares with the table's key. @@ -100,7 +123,8 @@ func (rt *TrieRT[K, N]) Cpl(kk K) int { // CplSize returns the number of peers in the table whose longest common prefix with the table's key is of length cpl. func (rt *TrieRT[K, N]) CplSize(cpl int) int { - n, err := countCpl(rt.keys, rt.self, cpl, 0) + this := rt.trie.Load().(*trie.Trie[K, N]) + n, err := countCpl(this, rt.self, cpl, 0) if err != nil { return 0 } diff --git a/routing/triert/table_test.go b/routing/triert/table_test.go index 2931e50..a29159e 100644 --- a/routing/triert/table_test.go +++ b/routing/triert/table_test.go @@ -2,6 +2,7 @@ package triert import ( "math/rand" + "sync" "testing" "github.com/plprobelab/go-kademlia/internal/kadtest" @@ -313,6 +314,46 @@ func TestKeyFilter(t *testing.T) { require.Equal(t, want, got) } +func TestTableConcurrentReadWrite(t *testing.T) { + nodes := make([]*kadtest.ID[key.Key32], 5000) + for i := range nodes { + nodes[i] = kadtest.NewID(kadtest.RandomKey()) + } + + rt, err := New[key.Key32](kadtest.NewID(key0), nil) + if err != nil { + t.Fatalf("unexpected error creating table: %v", err) + } + + workers := 3 + var wg sync.WaitGroup + wg.Add(workers) + + // start workers to concurrently read and write the routing table + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + work := make([]*kadtest.ID[key.Key32], len(nodes)) + copy(work, nodes) + rand.Shuffle(len(work), func(i, j int) { work[i], work[j] = work[j], work[i] }) + + for i := range work { + node := work[i] + _, found := rt.GetNode(node.Key()) + if !found { + // add new peer + rt.AddNode(work[i]) + } else { + // remove it + rt.RemoveKey(node.Key()) + } + + } + }() + } + wg.Wait() +} + func BenchmarkBuildTable(b *testing.B) { b.Run("1000", benchmarkBuildTable(1000)) b.Run("10000", benchmarkBuildTable(10000))