Skip to content

Commit

Permalink
Ensure routing tables are safe for concurrent use (#121)
Browse files Browse the repository at this point in the history
Fixes #115
  • Loading branch information
iand committed Sep 12, 2023
1 parent 18d9578 commit 9c7dc65
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 40 deletions.
60 changes: 28 additions & 32 deletions routing/simplert/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package simplert

import (
"sort"
"sync"

"github.com/plprobelab/go-kademlia/internal/kadtest"

Expand All @@ -16,8 +17,10 @@ type peerInfo[K kad.Key[K], N kad.NodeID[K]] struct {

type SimpleRT[K kad.Key[K], N kad.NodeID[K]] struct {
self K
buckets [][]peerInfo[K, N]
bucketSize int

mu sync.RWMutex // guards access to buckets
buckets [][]peerInfo[K, N]
}

var _ kad.RoutingTable[key.Key256, kadtest.ID[key.Key256]] = (*SimpleRT[key.Key256, kadtest.ID[key.Key256]])(nil)
Expand Down Expand Up @@ -46,6 +49,13 @@ func (rt *SimpleRT[K, N]) BucketSize() int {
}

func (rt *SimpleRT[K, N]) BucketIdForKey(kadId K) (int, error) {
rt.mu.RLock()
defer rt.mu.RUnlock()
return rt.bucketIdForKey(kadId)
}

// bucketIdForKey must only be called while rt.mu is held
func (rt *SimpleRT[K, N]) bucketIdForKey(kadId K) (int, error) {
bid := rt.self.CommonPrefixLength(kadId)
nBuckets := len(rt.buckets)
if bid >= nBuckets {
Expand All @@ -55,6 +65,8 @@ func (rt *SimpleRT[K, N]) BucketIdForKey(kadId K) (int, error) {
}

func (rt *SimpleRT[K, N]) SizeOfBucket(bucketId int) int {
rt.mu.RLock()
defer rt.mu.RUnlock()
return len(rt.buckets[bucketId])
}

Expand All @@ -63,40 +75,33 @@ func (rt *SimpleRT[K, N]) AddNode(id N) bool {
}

func (rt *SimpleRT[K, N]) addPeer(kadId K, id N) bool {
//_, span := util.StartSpan(ctx, "routing.simple.addPeer", trace.WithAttributes(
// attribute.String("KadID", key.HexString(kadId)),
// attribute.Stringer("PeerID", id),
//))
//defer span.End()
rt.mu.Lock()
defer rt.mu.Unlock()

// no need to check the error here, it's already been checked in keyError
bid, _ := rt.BucketIdForKey(kadId)
bid, _ := rt.bucketIdForKey(kadId)

lastBucketId := len(rt.buckets) - 1

if rt.alreadyInBucket(kadId, bid) {
// span.AddEvent("peer not added, already in bucket " + strconv.Itoa(bid))
// discard new peer
return false
}

if bid < lastBucketId {
// new peer doesn't belong in last bucket
if len(rt.buckets[bid]) >= rt.bucketSize {
// span.AddEvent("peer not added, bucket " + strconv.Itoa(bid) + " full")
// bucket is full, discard new peer
return false
}

// add new peer to bucket
rt.buckets[bid] = append(rt.buckets[bid], peerInfo[K, N]{id, kadId})
// span.AddEvent("peer added to bucket " + strconv.Itoa(bid))
return true
}
if len(rt.buckets[lastBucketId]) < rt.bucketSize {
// last bucket is not full, add new peer
rt.buckets[lastBucketId] = append(rt.buckets[lastBucketId], peerInfo[K, N]{id, kadId})
// span.AddEvent("peer added to bucket " + strconv.Itoa(lastBucketId))
return true
}
// last bucket is full, try to split it
Expand All @@ -106,15 +111,11 @@ func (rt *SimpleRT[K, N]) addPeer(kadId K, id N) bool {
// closeBucket contains peers with a CPL higher than lastBucketId
closeBucket := make([]peerInfo[K, N], 0)

// span.AddEvent("splitting last bucket (" + strconv.Itoa(lastBucketId) + ")")

for _, p := range rt.buckets[lastBucketId] {
if p.kadId.CommonPrefixLength(rt.self) == lastBucketId {
farBucket = append(farBucket, p)
} else {
closeBucket = append(closeBucket, p)
// span.AddEvent(p.id.String() + " moved to new bucket (" +
// strconv.Itoa(lastBucketId+1) + ")")
}
}
if len(farBucket) == rt.bucketSize &&
Expand All @@ -131,7 +132,7 @@ func (rt *SimpleRT[K, N]) addPeer(kadId K, id N) bool {
lastBucketId++
}

newBid, _ := rt.BucketIdForKey(kadId)
newBid, _ := rt.bucketIdForKey(kadId)
// add new peer to appropraite bucket
rt.buckets[newBid] = append(rt.buckets[newBid], peerInfo[K, N]{id, kadId})
// span.AddEvent("peer added to bucket " + strconv.Itoa(newBid))
Expand All @@ -149,29 +150,25 @@ func (rt *SimpleRT[K, N]) alreadyInBucket(kadId K, bucketId int) bool {
}

func (rt *SimpleRT[K, N]) RemoveKey(kadId K) bool {
//_, span := util.StartSpan(ctx, "routing.simple.removeKey", trace.WithAttributes(
// attribute.String("KadID", key.HexString(kadId)),
//))
//defer span.End()
rt.mu.Lock()
defer rt.mu.Unlock()

bid, _ := rt.BucketIdForKey(kadId)
bid, _ := rt.bucketIdForKey(kadId)
for i, p := range rt.buckets[bid] {
if key.Equal(kadId, p.kadId) {
// remove peer from bucket
rt.buckets[bid][i] = rt.buckets[bid][len(rt.buckets[bid])-1]
rt.buckets[bid] = rt.buckets[bid][:len(rt.buckets[bid])-1]

// span.AddEvent(fmt.Sprint(p.id.String(), "removed from bucket", bid))
return true
}
}
// peer not found in the routing table
// span.AddEvent(fmt.Sprint("peer not found in bucket", bid))
return false
}

func (rt *SimpleRT[K, N]) GetNode(kadId K) (N, bool) {
bid, _ := rt.BucketIdForKey(kadId)
rt.mu.RLock()
defer rt.mu.RUnlock()
bid, _ := rt.bucketIdForKey(kadId)
for _, p := range rt.buckets[bid] {
if key.Equal(kadId, p.kadId) {
return p.id, true
Expand All @@ -184,13 +181,10 @@ func (rt *SimpleRT[K, N]) GetNode(kadId K) (N, bool) {
// TODO: not exactly working as expected
// returns min(n, bucketSize) peers from the bucket matching the given key
func (rt *SimpleRT[K, N]) NearestNodes(kadId K, n int) []N {
//_, span := util.StartSpan(ctx, "routing.simple.nearestPeers", trace.WithAttributes(
// attribute.String("KadID", key.HexString(kadId)),
// attribute.Int("n", int(n)),
//))
//defer span.End()
rt.mu.RLock()
defer rt.mu.RUnlock()

bid, _ := rt.BucketIdForKey(kadId)
bid, _ := rt.bucketIdForKey(kadId)

var peers []peerInfo[K, N]
// TODO: optimize this
Expand Down Expand Up @@ -240,6 +234,8 @@ func (rt *SimpleRT[K, N]) Cpl(kk K) int {

// CplSize returns the number of nodes in the table whose longest common prefix with the table's key is of length cpl.
func (rt *SimpleRT[K, N]) CplSize(cpl int) int {
rt.mu.RLock()
defer rt.mu.RUnlock()
bid := cpl // cpl is simply the bucket id
nBuckets := len(rt.buckets)
if bid >= nBuckets {
Expand Down
39 changes: 39 additions & 0 deletions routing/simplert/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package simplert

import (
"fmt"
"math/rand"
"sync"
"testing"

"github.com/libp2p/go-libp2p/core/peer"
Expand Down Expand Up @@ -201,3 +203,40 @@ func TestNearestPeers(t *testing.T) {
peers = rt2.NearestNodes(key0, 10)
require.Equal(t, peers[0], peers[1])
}

func TestTableConcurrentReadWrite(t *testing.T) {
nodes := make([]*kt.ID[key.Key32], 5000)
for i := range nodes {
nodes[i] = kt.NewID(kt.RandomKey())
}

rt := New[key.Key32](kt.NewID(key.Key32(0)), 2)

workers := 3
var wg sync.WaitGroup
wg.Add(workers)

// start workers to concurrently read and write the routing table
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
work := make([]*kt.ID[key.Key32], len(nodes))
copy(work, nodes)
rand.Shuffle(len(work), func(i, j int) { work[i], work[j] = work[j], work[i] })

for i := range work {
node := work[i]
_, found := rt.GetNode(node.Key())
if !found {
// add new peer
rt.AddNode(work[i])
} else {
// remove it
rt.RemoveKey(node.Key())
}

}
}()
}
wg.Wait()
}
40 changes: 32 additions & 8 deletions routing/triert/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package triert

import (
"fmt"
"sync"
"sync/atomic"

"github.com/plprobelab/go-kademlia/internal/kadtest"
"github.com/plprobelab/go-kademlia/kad"
Expand All @@ -15,7 +17,8 @@ type TrieRT[K kad.Key[K], N kad.NodeID[K]] struct {
self K
keyFilter KeyFilterFunc[K, N]

keys *trie.Trie[K, N]
mu sync.Mutex // held to synchronise mutations to the trie
trie atomic.Value // holds a *trie.Trie[K, N]
}

var _ kad.RoutingTable[key.Key256, kadtest.ID[key.Key256]] = (*TrieRT[key.Key256, kadtest.ID[key.Key256]])(nil)
Expand All @@ -25,8 +28,9 @@ var _ kad.RoutingTable[key.Key256, kadtest.ID[key.Key256]] = (*TrieRT[key.Key256
func New[K kad.Key[K], N kad.NodeID[K]](self N, cfg *Config[K, N]) (*TrieRT[K, N], error) {
rt := &TrieRT[K, N]{
self: self.Key(),
keys: &trie.Trie[K, N]{},
}
rt.trie.Store(&trie.Trie[K, N]{})

if err := rt.apply(cfg); err != nil {
return nil, fmt.Errorf("apply config: %w", err)
}
Expand Down Expand Up @@ -55,18 +59,35 @@ func (rt *TrieRT[K, N]) AddNode(node N) bool {
return false
}

return rt.keys.Add(kk, node)
rt.mu.Lock()
defer rt.mu.Unlock()
this := rt.trie.Load().(*trie.Trie[K, N])
next, _ := trie.Add(this, kk, node)
if next == this {
return false
}
rt.trie.Store(next)
return true
}

// RemoveKey tries to remove a node identified by its Kademlia key from the
// routing table. It returns true if the key was found to be present in the table and was removed.
func (rt *TrieRT[K, N]) RemoveKey(kk K) bool {
return rt.keys.Remove(kk)
rt.mu.Lock()
defer rt.mu.Unlock()
this := rt.trie.Load().(*trie.Trie[K, N])
next, _ := trie.Remove(this, kk)
if next == this {
return false
}
rt.trie.Store(next)
return true
}

// NearestNodes returns the n closest nodes to a given key.
func (rt *TrieRT[K, N]) NearestNodes(target K, n int) []N {
closestEntries := trie.Closest(rt.keys, target, n)
this := rt.trie.Load().(*trie.Trie[K, N])
closestEntries := trie.Closest(this, target, n)
if len(closestEntries) == 0 {
return []N{}
}
Expand All @@ -80,7 +101,8 @@ func (rt *TrieRT[K, N]) NearestNodes(target K, n int) []N {
}

func (rt *TrieRT[K, N]) GetNode(kk K) (N, bool) {
found, node := trie.Find(rt.keys, kk)
this := rt.trie.Load().(*trie.Trie[K, N])
found, node := trie.Find(this, kk)
if !found {
var zero N
return zero, false
Expand All @@ -90,7 +112,8 @@ func (rt *TrieRT[K, N]) GetNode(kk K) (N, bool) {

// Size returns the number of peers contained in the table.
func (rt *TrieRT[K, N]) Size() int {
return rt.keys.Size()
this := rt.trie.Load().(*trie.Trie[K, N])
return this.Size()
}

// Cpl returns the longest common prefix length the supplied key shares with the table's key.
Expand All @@ -100,7 +123,8 @@ func (rt *TrieRT[K, N]) Cpl(kk K) int {

// CplSize returns the number of peers in the table whose longest common prefix with the table's key is of length cpl.
func (rt *TrieRT[K, N]) CplSize(cpl int) int {
n, err := countCpl(rt.keys, rt.self, cpl, 0)
this := rt.trie.Load().(*trie.Trie[K, N])
n, err := countCpl(this, rt.self, cpl, 0)
if err != nil {
return 0
}
Expand Down
41 changes: 41 additions & 0 deletions routing/triert/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package triert

import (
"math/rand"
"sync"
"testing"

"github.com/plprobelab/go-kademlia/internal/kadtest"
Expand Down Expand Up @@ -313,6 +314,46 @@ func TestKeyFilter(t *testing.T) {
require.Equal(t, want, got)
}

func TestTableConcurrentReadWrite(t *testing.T) {
nodes := make([]*kadtest.ID[key.Key32], 5000)
for i := range nodes {
nodes[i] = kadtest.NewID(kadtest.RandomKey())
}

rt, err := New[key.Key32](kadtest.NewID(key0), nil)
if err != nil {
t.Fatalf("unexpected error creating table: %v", err)
}

workers := 3
var wg sync.WaitGroup
wg.Add(workers)

// start workers to concurrently read and write the routing table
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
work := make([]*kadtest.ID[key.Key32], len(nodes))
copy(work, nodes)
rand.Shuffle(len(work), func(i, j int) { work[i], work[j] = work[j], work[i] })

for i := range work {
node := work[i]
_, found := rt.GetNode(node.Key())
if !found {
// add new peer
rt.AddNode(work[i])
} else {
// remove it
rt.RemoveKey(node.Key())
}

}
}()
}
wg.Wait()
}

func BenchmarkBuildTable(b *testing.B) {
b.Run("1000", benchmarkBuildTable(1000))
b.Run("10000", benchmarkBuildTable(10000))
Expand Down

0 comments on commit 9c7dc65

Please sign in to comment.