diff --git a/v2/backend.go b/v2/backend.go index a8c7775a..4e5d313f 100644 --- a/v2/backend.go +++ b/v2/backend.go @@ -102,7 +102,7 @@ func NewBackendPublicKey(ds ds.TxnDatastore, cfg *RecordBackendConfig) (be *Reco // The values returned from [ProvidersBackend.Fetch] will be of type // [*providerSet] (unexported). The cfg parameter can be nil, in which case the // [DefaultProviderBackendConfig] will be used. -func NewBackendProvider(pstore peerstore.Peerstore, dstore ds.Batching, cfg *ProvidersBackendConfig) (be *ProvidersBackend, err error) { +func NewBackendProvider(pstore peerstore.Peerstore, dstore ds.Datastore, cfg *ProvidersBackendConfig) (be *ProvidersBackend, err error) { if cfg == nil { if cfg, err = DefaultProviderBackendConfig(); err != nil { return nil, fmt.Errorf("default provider backend config: %w", err) diff --git a/v2/backend_provider.go b/v2/backend_provider.go index 1ddd764f..3be9d88a 100644 --- a/v2/backend_provider.go +++ b/v2/backend_provider.go @@ -257,14 +257,15 @@ func (p *ProvidersBackend) StartGarbageCollection() { p.gcCancel = cancel p.gcDone = make(chan struct{}) - p.log.Info("Provider backend's started garbage collection schedule") + // init ticker outside the goroutine to prevent race condition with + // clock mock in garbage collection test. + ticker := p.cfg.clk.Ticker(p.cfg.GCInterval) go func() { defer close(p.gcDone) - - ticker := p.cfg.clk.Ticker(p.cfg.GCInterval) defer ticker.Stop() + p.log.Info("Provider backend started garbage collection schedule") for { select { case <-ctx.Done(): diff --git a/v2/backend_provider_test.go b/v2/backend_provider_test.go index 10407e54..d3ab465d 100644 --- a/v2/backend_provider_test.go +++ b/v2/backend_provider_test.go @@ -9,7 +9,6 @@ import ( "github.com/benbjohnson/clock" ds "github.com/ipfs/go-datastore" - syncds "github.com/ipfs/go-datastore/sync" "github.com/libp2p/go-libp2p" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,7 +21,9 @@ func newBackendProvider(t testing.TB, cfg *ProvidersBackendConfig) *ProvidersBac h, err := libp2p.New(libp2p.NoListenAddrs) require.NoError(t, err) - dstore := syncds.MutexWrap(ds.NewMapDatastore()) + dstore, err := InMemoryDatastore() + require.NoError(t, err) + t.Cleanup(func() { if err = dstore.Close(); err != nil { t.Logf("closing datastore: %s", err) diff --git a/v2/coord/conversion.go b/v2/coord/conversion.go deleted file mode 100644 index d605507b..00000000 --- a/v2/coord/conversion.go +++ /dev/null @@ -1,39 +0,0 @@ -package coord - -import ( - "github.com/libp2p/go-libp2p/core/peer" - "github.com/plprobelab/go-kademlia/kad" - - "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" -) - -// kadPeerIDToAddrInfo converts a kad.NodeID to a peer.AddrInfo with no addresses. -// This function will panic if id's underlying type is not kadt.PeerID -func kadPeerIDToAddrInfo(id kad.NodeID[kadt.Key]) peer.AddrInfo { - peerID := id.(kadt.PeerID) - return peer.AddrInfo{ - ID: peer.ID(peerID), - } -} - -// addrInfoToKadPeerID converts a peer.AddrInfo to a kad.NodeID. -func addrInfoToKadPeerID(addrInfo peer.AddrInfo) kadt.PeerID { - return kadt.PeerID(addrInfo.ID) -} - -// sliceOfPeerIDToSliceOfKadPeerID converts a slice of peer.ID to a slice of kadt.PeerID -func sliceOfPeerIDToSliceOfKadPeerID(peers []peer.ID) []kadt.PeerID { - nodes := make([]kadt.PeerID, len(peers)) - for i := range peers { - nodes[i] = kadt.PeerID(peers[i]) - } - return nodes -} - -func sliceOfAddrInfoToSliceOfKadPeerID(addrInfos []peer.AddrInfo) []kadt.PeerID { - peers := make([]kadt.PeerID, len(addrInfos)) - for i := range addrInfos { - peers[i] = kadt.PeerID(addrInfos[i].ID) - } - return peers -} diff --git a/v2/coord/coordinator.go b/v2/coord/coordinator.go index 9c1e8a70..4a4f3875 100644 --- a/v2/coord/coordinator.go +++ b/v2/coord/coordinator.go @@ -4,15 +4,14 @@ import ( "context" "errors" "fmt" + "reflect" + "sync" "time" "github.com/benbjohnson/clock" logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-libp2p/core/peer" - ma "github.com/multiformats/go-multiaddr" "github.com/plprobelab/go-kademlia/kad" "github.com/plprobelab/go-kademlia/kaderr" - "github.com/plprobelab/go-kademlia/network/address" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -23,6 +22,7 @@ import ( "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" "github.com/libp2p/go-libp2p-kad-dht/v2/coord/routing" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" + "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) // A Coordinator coordinates the state machines that comprise a Kademlia DHT @@ -40,9 +40,7 @@ type Coordinator struct { rt kad.RoutingTable[kadt.Key, kadt.PeerID] // rtr is the message router used to send messages - rtr Router - - routingNotifications chan RoutingNotification + rtr Router[kadt.Key, kadt.PeerID, *pb.Message] // networkBehaviour is the behaviour responsible for communicating with the network networkBehaviour *NetworkBehaviour @@ -57,6 +55,10 @@ type Coordinator struct { tele *Telemetry } +type RoutingNotifier interface { + Notify(context.Context, RoutingNotification) +} + type CoordinatorConfig struct { PeerstoreTTL time.Duration // duration for which a peer is kept in the peerstore @@ -72,6 +74,8 @@ type CoordinatorConfig struct { MeterProvider metric.MeterProvider // the meter provider to use when initialising metric instruments TracerProvider trace.TracerProvider // the tracer provider to use when initialising tracing + + RoutingNotifier RoutingNotifier // receives notifications of routing events } // Validate checks the configuration options and returns an error if any have invalid values. @@ -131,6 +135,13 @@ func (cfg *CoordinatorConfig) Validate() error { } } + if cfg.RoutingNotifier == nil { + return &kaderr.ConfigurationError{ + Component: "CoordinatorConfig", + Err: fmt.Errorf("routing notifier must not be nil"), + } + } + return nil } @@ -145,10 +156,11 @@ func DefaultCoordinatorConfig() *CoordinatorConfig { Logger: slog.New(zapslog.NewHandler(logging.Logger("coord").Desugar().Core())), MeterProvider: otel.GetMeterProvider(), TracerProvider: otel.GetTracerProvider(), + RoutingNotifier: nullRoutingNotifier{}, } } -func NewCoordinator(self kadt.PeerID, rtr Router, rt routing.RoutingTableCpl[kadt.Key, kadt.PeerID], cfg *CoordinatorConfig) (*Coordinator, error) { +func NewCoordinator(self kadt.PeerID, rtr Router[kadt.Key, kadt.PeerID, *pb.Message], rt routing.RoutingTableCpl[kadt.Key, kadt.PeerID], cfg *CoordinatorConfig) (*Coordinator, error) { if cfg == nil { cfg = DefaultCoordinatorConfig() } else if err := cfg.Validate(); err != nil { @@ -227,8 +239,6 @@ func NewCoordinator(self kadt.PeerID, rtr Router, rt routing.RoutingTableCpl[kad networkBehaviour: networkBehaviour, routingBehaviour: routingBehaviour, queryBehaviour: queryBehaviour, - - routingNotifications: make(chan RoutingNotification, 20), // buffered mainly to allow tests to read the channel after running an operation } go d.eventLoop(ctx) @@ -245,20 +255,6 @@ func (c *Coordinator) ID() kadt.PeerID { return c.self } -func (c *Coordinator) Addresses() []ma.Multiaddr { - // TODO: return configured listen addresses - info, err := c.rtr.GetNodeInfo(context.TODO(), peer.ID(c.self)) - if err != nil { - return nil - } - return info.Addrs -} - -// RoutingNotifications returns a channel that may be read to be notified of routing updates -func (c *Coordinator) RoutingNotifications() <-chan RoutingNotification { - return c.routingNotifications -} - func (c *Coordinator) eventLoop(ctx context.Context) { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.eventLoop") defer span.End() @@ -295,11 +291,7 @@ func (c *Coordinator) dispatchEvent(ctx context.Context, ev BehaviourEvent) { case RoutingCommand: c.routingBehaviour.Notify(ctx, ev) case RoutingNotification: - select { - case <-ctx.Done(): - case c.routingNotifications <- ev: - default: - } + c.cfg.RoutingNotifier.Notify(ctx, ev) default: panic(fmt.Sprintf("unexpected event: %T", ev)) } @@ -307,14 +299,14 @@ func (c *Coordinator) dispatchEvent(ctx context.Context, ev BehaviourEvent) { // GetNode retrieves the node associated with the given node id from the DHT's local routing table. // If the node isn't found in the table, it returns ErrNodeNotFound. -func (c *Coordinator) GetNode(ctx context.Context, id peer.ID) (Node, error) { +func (c *Coordinator) GetNode(ctx context.Context, id kadt.PeerID) (Node, error) { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.GetNode") defer span.End() - if _, exists := c.rt.GetNode(kadt.PeerID(id).Key()); !exists { + if _, exists := c.rt.GetNode(id.Key()); !exists { return nil, ErrNodeNotFound } - nh, err := c.networkBehaviour.getNodeHandler(ctx, kadt.PeerID(id)) + nh, err := c.networkBehaviour.getNodeHandler(ctx, id) if err != nil { return nil, err } @@ -362,9 +354,9 @@ func (c *Coordinator) Query(ctx context.Context, target kadt.Key, fn QueryFunc) return QueryStats{}, err } - seedIDs := make([]peer.ID, 0, len(seeds)) + seedIDs := make([]kadt.PeerID, 0, len(seeds)) for _, s := range seeds { - seedIDs = append(seedIDs, s.ID()) + seedIDs = append(seedIDs, kadt.PeerID(s.ID())) } waiter := NewWaiter[BehaviourEvent]() @@ -373,8 +365,6 @@ func (c *Coordinator) Query(ctx context.Context, target kadt.Key, fn QueryFunc) cmd := &EventStartQuery{ QueryID: queryID, Target: target, - ProtocolID: address.ProtocolID("TODO"), - Message: &fakeMessage{key: target}, KnownClosestNodes: seedIDs, Notify: waiter, } @@ -431,22 +421,20 @@ func (c *Coordinator) Query(ctx context.Context, target kadt.Key, fn QueryFunc) } } -// AddNodes suggests new DHT nodes and their associated addresses to be added to the routing table. +// AddNodes suggests new DHT nodes to be added to the routing table. // If the routing table is updated as a result of this operation an EventRoutingUpdated notification // is emitted on the routing notification channel. -func (c *Coordinator) AddNodes(ctx context.Context, ais []peer.AddrInfo) error { +func (c *Coordinator) AddNodes(ctx context.Context, ids []kadt.PeerID) error { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.AddNodes") defer span.End() - for _, ai := range ais { - if ai.ID == peer.ID(c.self) { + for _, id := range ids { + if id.Equal(c.self) { // skip self continue } - // TODO: apply address filter - - c.routingBehaviour.Notify(ctx, &EventAddAddrInfo{ - NodeInfo: ai, + c.routingBehaviour.Notify(ctx, &EventAddNode{ + NodeID: id, }) } @@ -455,12 +443,10 @@ func (c *Coordinator) AddNodes(ctx context.Context, ais []peer.AddrInfo) error { } // Bootstrap instructs the dht to begin bootstrapping the routing table. -func (c *Coordinator) Bootstrap(ctx context.Context, seeds []peer.ID) error { +func (c *Coordinator) Bootstrap(ctx context.Context, seeds []kadt.PeerID) error { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.Bootstrap") defer span.End() c.routingBehaviour.Notify(ctx, &EventStartBootstrap{ - // Bootstrap state machine uses the message - Message: &fakeMessage{key: kadt.PeerID(c.self).Key()}, SeedNodes: seeds, }) @@ -469,15 +455,12 @@ func (c *Coordinator) Bootstrap(ctx context.Context, seeds []peer.ID) error { // NotifyConnectivity notifies the coordinator that a peer has passed a connectivity check // which means it is connected and supports finding closer nodes -func (c *Coordinator) NotifyConnectivity(ctx context.Context, id peer.ID) error { +func (c *Coordinator) NotifyConnectivity(ctx context.Context, id kadt.PeerID) error { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.NotifyConnectivity") defer span.End() - ai := peer.AddrInfo{ - ID: id, - } c.routingBehaviour.Notify(ctx, &EventNotifyConnectivity{ - NodeInfo: ai, + NodeID: id, }) return nil @@ -485,7 +468,7 @@ func (c *Coordinator) NotifyConnectivity(ctx context.Context, id peer.ID) error // NotifyNonConnectivity notifies the coordinator that a peer has failed a connectivity check // which means it is not connected and/or it doesn't support finding closer nodes -func (c *Coordinator) NotifyNonConnectivity(ctx context.Context, id peer.ID) error { +func (c *Coordinator) NotifyNonConnectivity(ctx context.Context, id kadt.PeerID) error { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.NotifyNonConnectivity") defer span.End() @@ -495,3 +478,106 @@ func (c *Coordinator) NotifyNonConnectivity(ctx context.Context, id peer.ID) err return nil } + +// A BufferedRoutingNotifier is a [RoutingNotifier] that buffers [RoutingNotification] events and provides methods +// to expect occurrences of specific events. It is designed for use in a test environment. +type BufferedRoutingNotifier struct { + mu sync.Mutex + buffered []RoutingNotification + signal chan struct{} +} + +func NewBufferedRoutingNotifier() *BufferedRoutingNotifier { + return &BufferedRoutingNotifier{ + signal: make(chan struct{}, 1), + } +} + +func (w *BufferedRoutingNotifier) Notify(ctx context.Context, ev RoutingNotification) { + w.mu.Lock() + w.buffered = append(w.buffered, ev) + select { + case w.signal <- struct{}{}: + default: + } + w.mu.Unlock() +} + +func (w *BufferedRoutingNotifier) Expect(ctx context.Context, expected RoutingNotification) (RoutingNotification, error) { + for { + // look in buffered events + w.mu.Lock() + for i, ev := range w.buffered { + if reflect.TypeOf(ev) == reflect.TypeOf(expected) { + // remove first from buffer and return it + w.buffered = w.buffered[:i+copy(w.buffered[i:], w.buffered[i+1:])] + w.mu.Unlock() + return ev, nil + } + } + w.mu.Unlock() + + // wait to be signaled that there is a new event + select { + case <-ctx.Done(): + return nil, fmt.Errorf("test deadline exceeded while waiting for event %T", expected) + case <-w.signal: + } + } +} + +// ExpectRoutingUpdated blocks until an [EventRoutingUpdated] event is seen for the specified peer id +func (w *BufferedRoutingNotifier) ExpectRoutingUpdated(ctx context.Context, id kadt.PeerID) (*EventRoutingUpdated, error) { + for { + // look in buffered events + w.mu.Lock() + for i, ev := range w.buffered { + if tev, ok := ev.(*EventRoutingUpdated); ok { + if id.Equal(tev.NodeID) { + // remove first from buffer and return it + w.buffered = w.buffered[:i+copy(w.buffered[i:], w.buffered[i+1:])] + w.mu.Unlock() + return tev, nil + } + } + } + w.mu.Unlock() + + // wait to be signaled that there is a new event + select { + case <-ctx.Done(): + return nil, fmt.Errorf("test deadline exceeded while waiting for routing updated event") + case <-w.signal: + } + } +} + +// ExpectRoutingRemoved blocks until an [EventRoutingRemoved] event is seen for the specified peer id +func (w *BufferedRoutingNotifier) ExpectRoutingRemoved(ctx context.Context, id kadt.PeerID) (*EventRoutingRemoved, error) { + for { + // look in buffered events + w.mu.Lock() + for i, ev := range w.buffered { + if tev, ok := ev.(*EventRoutingRemoved); ok { + if id.Equal(tev.NodeID) { + // remove first from buffer and return it + w.buffered = w.buffered[:i+copy(w.buffered[i:], w.buffered[i+1:])] + w.mu.Unlock() + return tev, nil + } + } + } + w.mu.Unlock() + + // wait to be signaled that there is a new event + select { + case <-ctx.Done(): + return nil, fmt.Errorf("test deadline exceeded while waiting for routing removed event") + case <-w.signal: + } + } +} + +type nullRoutingNotifier struct{} + +func (nullRoutingNotifier) Notify(context.Context, RoutingNotification) {} diff --git a/v2/coord/coordinator_test.go b/v2/coord/coordinator_test.go index f9b0e484..ba32444e 100644 --- a/v2/coord/coordinator_test.go +++ b/v2/coord/coordinator_test.go @@ -2,15 +2,11 @@ package coord import ( "context" - "fmt" "log" - "reflect" - "sync" "testing" "time" "github.com/benbjohnson/clock" - "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/nettest" @@ -20,58 +16,6 @@ import ( const peerstoreTTL = 10 * time.Minute -type notificationWatcher struct { - mu sync.Mutex - buffered []RoutingNotification - signal chan struct{} -} - -func (w *notificationWatcher) Watch(t *testing.T, ctx context.Context, ch <-chan RoutingNotification) { - t.Helper() - w.signal = make(chan struct{}, 1) - go func() { - for { - select { - case <-ctx.Done(): - return - case ev := <-ch: - w.mu.Lock() - t.Logf("buffered routing notification: %T\n", ev) - w.buffered = append(w.buffered, ev) - select { - case w.signal <- struct{}{}: - default: - } - w.mu.Unlock() - - } - } - }() -} - -func (w *notificationWatcher) Expect(ctx context.Context, expected RoutingNotification) (RoutingNotification, error) { - for { - // look in buffered events - w.mu.Lock() - for i, ev := range w.buffered { - if reflect.TypeOf(ev) == reflect.TypeOf(expected) { - // remove first from buffer and return it - w.buffered = w.buffered[:i+copy(w.buffered[i:], w.buffered[i+1:])] - w.mu.Unlock() - return ev, nil - } - } - w.mu.Unlock() - - // wait to be signaled that there is a new event - select { - case <-ctx.Done(): - return nil, fmt.Errorf("test deadline exceeded while waiting for event %T", expected) - case <-w.signal: - } - } -} - func TestConfigValidate(t *testing.T) { t.Run("default is valid", func(t *testing.T) { cfg := DefaultCoordinatorConfig() @@ -140,6 +84,12 @@ func TestConfigValidate(t *testing.T) { cfg.TracerProvider = nil require.Error(t, cfg.Validate()) }) + + t.Run("routing notifier not nil", func(t *testing.T) { + cfg := DefaultCoordinatorConfig() + cfg.RoutingNotifier = nil + require.Error(t, cfg.Validate()) + }) } func TestExhaustiveQuery(t *testing.T) { @@ -156,11 +106,11 @@ func TestExhaustiveQuery(t *testing.T) { // A (ids[0]) is looking for D (ids[3]) // A will first ask B, B will reply with C's address (and A's address) // A will then ask C, C will reply with D's address (and B's address) - self := kadt.PeerID(nodes[0].NodeInfo.ID) + self := kadt.PeerID(nodes[0].NodeID) c, err := NewCoordinator(self, nodes[0].Router, nodes[0].RoutingTable, ccfg) require.NoError(t, err) - target := kadt.PeerID(nodes[3].NodeInfo.ID).Key() + target := kadt.PeerID(nodes[3].NodeID).Key() visited := make(map[string]int) @@ -175,9 +125,9 @@ func TestExhaustiveQuery(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, len(visited)) - require.Contains(t, visited, nodes[1].NodeInfo.ID.String()) - require.Contains(t, visited, nodes[2].NodeInfo.ID.String()) - require.Contains(t, visited, nodes[3].NodeInfo.ID.String()) + require.Contains(t, visited, nodes[1].NodeID.String()) + require.Contains(t, visited, nodes[2].NodeID.String()) + require.Contains(t, visited, nodes[3].NodeID.String()) } func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { @@ -192,24 +142,24 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { ccfg.Clock = clk ccfg.PeerstoreTTL = peerstoreTTL + rn := NewBufferedRoutingNotifier() + ccfg.RoutingNotifier = rn + // A (ids[0]) is looking for D (ids[3]) // A will first ask B, B will reply with C's address (and A's address) // A will then ask C, C will reply with D's address (and B's address) - self := kadt.PeerID(nodes[0].NodeInfo.ID) + self := kadt.PeerID(nodes[0].NodeID) c, err := NewCoordinator(self, nodes[0].Router, nodes[0].RoutingTable, ccfg) if err != nil { log.Fatalf("unexpected error creating coordinator: %v", err) } - w := new(notificationWatcher) - w.Watch(t, ctx, c.RoutingNotifications()) - qfn := func(ctx context.Context, node Node, stats QueryStats) error { return nil } // Run a query to find the value - target := kadt.PeerID(nodes[3].NodeInfo.ID).Key() + target := nodes[3].NodeID.Key() _, err = c.Query(ctx, target, qfn) require.NoError(t, err) @@ -224,20 +174,20 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { // However the order in which these events are emitted may vary depending on timing. - ev1, err := w.Expect(ctx, &EventRoutingUpdated{}) + ev1, err := rn.Expect(ctx, &EventRoutingUpdated{}) require.NoError(t, err) tev1 := ev1.(*EventRoutingUpdated) - ev2, err := w.Expect(ctx, &EventRoutingUpdated{}) + ev2, err := rn.Expect(ctx, &EventRoutingUpdated{}) require.NoError(t, err) tev2 := ev2.(*EventRoutingUpdated) - if tev1.NodeInfo.ID == nodes[2].NodeInfo.ID { - require.Equal(t, nodes[3].NodeInfo.ID, tev2.NodeInfo.ID) - } else if tev2.NodeInfo.ID == nodes[2].NodeInfo.ID { - require.Equal(t, nodes[3].NodeInfo.ID, tev1.NodeInfo.ID) + if tev1.NodeID.Equal(nodes[2].NodeID) { + require.Equal(t, nodes[3].NodeID, tev2.NodeID) + } else if tev2.NodeID.Equal(nodes[2].NodeID) { + require.Equal(t, nodes[3].NodeID, tev1.NodeID) } else { - require.Failf(t, "did not see routing updated event for %s", nodes[2].NodeInfo.ID.String()) + require.Failf(t, "did not see routing updated event for %s", nodes[2].NodeID.String()) } } @@ -253,19 +203,19 @@ func TestBootstrap(t *testing.T) { ccfg.Clock = clk ccfg.PeerstoreTTL = peerstoreTTL - self := kadt.PeerID(nodes[0].NodeInfo.ID) + rn := NewBufferedRoutingNotifier() + ccfg.RoutingNotifier = rn + + self := kadt.PeerID(nodes[0].NodeID) d, err := NewCoordinator(self, nodes[0].Router, nodes[0].RoutingTable, ccfg) require.NoError(t, err) - w := new(notificationWatcher) - w.Watch(t, ctx, d.RoutingNotifications()) - - seeds := []peer.ID{nodes[1].NodeInfo.ID} + seeds := []kadt.PeerID{nodes[1].NodeID} err = d.Bootstrap(ctx, seeds) require.NoError(t, err) // the query run by the dht should have completed - ev, err := w.Expect(ctx, &EventBootstrapFinished{}) + ev, err := rn.Expect(ctx, &EventBootstrapFinished{}) require.NoError(t, err) require.IsType(t, &EventBootstrapFinished{}, ev) @@ -274,22 +224,22 @@ func TestBootstrap(t *testing.T) { require.Equal(t, 3, tevf.Stats.Success) require.Equal(t, 0, tevf.Stats.Failure) - _, err = w.Expect(ctx, &EventRoutingUpdated{}) + _, err = rn.Expect(ctx, &EventRoutingUpdated{}) require.NoError(t, err) - _, err = w.Expect(ctx, &EventRoutingUpdated{}) + _, err = rn.Expect(ctx, &EventRoutingUpdated{}) require.NoError(t, err) // coordinator will have node1 in its routing table - _, err = d.GetNode(ctx, nodes[1].NodeInfo.ID) + _, err = d.GetNode(ctx, nodes[1].NodeID) require.NoError(t, err) // coordinator should now have node2 in its routing table - _, err = d.GetNode(ctx, nodes[2].NodeInfo.ID) + _, err = d.GetNode(ctx, nodes[2].NodeID) require.NoError(t, err) // coordinator should now have node3 in its routing table - _, err = d.GetNode(ctx, nodes[3].NodeInfo.ID) + _, err = d.GetNode(ctx, nodes[3].NodeID) require.NoError(t, err) } @@ -305,33 +255,33 @@ func TestIncludeNode(t *testing.T) { ccfg.Clock = clk ccfg.PeerstoreTTL = peerstoreTTL - candidate := nodes[len(nodes)-1].NodeInfo // not in nodes[0] routing table + rn := NewBufferedRoutingNotifier() + ccfg.RoutingNotifier = rn - self := kadt.PeerID(nodes[0].NodeInfo.ID) + candidate := nodes[len(nodes)-1].NodeID // not in nodes[0] routing table + + self := nodes[0].NodeID d, err := NewCoordinator(self, nodes[0].Router, nodes[0].RoutingTable, ccfg) if err != nil { log.Fatalf("unexpected error creating dht: %v", err) } // the routing table should not contain the node yet - _, err = d.GetNode(ctx, candidate.ID) + _, err = d.GetNode(ctx, candidate) require.ErrorIs(t, err, ErrNodeNotFound) - w := new(notificationWatcher) - w.Watch(t, ctx, d.RoutingNotifications()) - // inject a new node - err = d.AddNodes(ctx, []peer.AddrInfo{candidate}) + err = d.AddNodes(ctx, []kadt.PeerID{candidate}) require.NoError(t, err) // the include state machine runs in the background and eventually should add the node to routing table - ev, err := w.Expect(ctx, &EventRoutingUpdated{}) + ev, err := rn.Expect(ctx, &EventRoutingUpdated{}) require.NoError(t, err) tev := ev.(*EventRoutingUpdated) - require.Equal(t, candidate.ID, tev.NodeInfo.ID) + require.Equal(t, candidate, tev.NodeID) // the routing table should now contain the node - _, err = d.GetNode(ctx, candidate.ID) + _, err = d.GetNode(ctx, candidate) require.NoError(t, err) } diff --git a/v2/coord/coretypes.go b/v2/coord/coretypes.go index 8da79942..0f72cebf 100644 --- a/v2/coord/coretypes.go +++ b/v2/coord/coretypes.go @@ -5,12 +5,9 @@ import ( "errors" "time" - "github.com/libp2p/go-libp2p/core/peer" - ma "github.com/multiformats/go-multiaddr" - "github.com/plprobelab/go-kademlia/network/address" + "github.com/plprobelab/go-kademlia/kad" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" - "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) // Value is a value that may be stored in the DHT. @@ -22,10 +19,7 @@ type Value interface { // Node represents the local or a remote node participating in the DHT. type Node interface { // ID returns the peer ID identifying this node. - ID() peer.ID - - // Addresses returns the network addresses associated with the given node. - Addresses() []ma.Multiaddr + ID() kadt.PeerID // GetClosestNodes requests the n closest nodes to the key from the node's // local routing table. The node may return fewer nodes than requested. @@ -74,17 +68,13 @@ var ( ErrSkipRemaining = errors.New("skip remaining nodes") ) -// Router its a work in progress -// TODO figure out the role of protocol identifiers -type Router interface { - // SendMessage attempts to send a request to another node. The Router will absorb the addresses in to into its - // internal nodestore. This method blocks until a response is received or an error is encountered. - SendMessage(ctx context.Context, to peer.AddrInfo, protoID address.ProtocolID, req *pb.Message) (*pb.Message, error) +type Message interface{} - AddNodeInfo(ctx context.Context, info peer.AddrInfo, ttl time.Duration) error - GetNodeInfo(ctx context.Context, id peer.ID) (peer.AddrInfo, error) +type Router[K kad.Key[K], N kad.NodeID[K], M Message] interface { + // SendMessage attempts to send a request to another node. This method blocks until a response is received or an error is encountered. + SendMessage(ctx context.Context, to N, req M) (M, error) // GetClosestNodes attempts to send a request to another node asking it for nodes that it considers to be // closest to the target key. - GetClosestNodes(ctx context.Context, to peer.AddrInfo, target kadt.Key) ([]peer.AddrInfo, error) + GetClosestNodes(ctx context.Context, to N, target K) ([]N, error) } diff --git a/v2/coord/event.go b/v2/coord/event.go index 69a9d5d7..663cfee9 100644 --- a/v2/coord/event.go +++ b/v2/coord/event.go @@ -1,11 +1,6 @@ package coord import ( - "github.com/libp2p/go-libp2p/core/peer" - ma "github.com/multiformats/go-multiaddr" - "github.com/plprobelab/go-kademlia/kad" - "github.com/plprobelab/go-kademlia/network/address" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) @@ -48,9 +43,7 @@ type RoutingNotification interface { } type EventStartBootstrap struct { - ProtocolID address.ProtocolID - Message kad.Request[kadt.Key, ma.Multiaddr] - SeedNodes []peer.ID // TODO: peer.AddrInfo + SeedNodes []kadt.PeerID } func (*EventStartBootstrap) behaviourEvent() {} @@ -58,7 +51,7 @@ func (*EventStartBootstrap) routingCommand() {} type EventOutboundGetCloserNodes struct { QueryID query.QueryID - To peer.AddrInfo + To kadt.PeerID Target kadt.Key Notify Notify[BehaviourEvent] } @@ -70,9 +63,7 @@ func (*EventOutboundGetCloserNodes) networkCommand() {} type EventStartQuery struct { QueryID query.QueryID Target kadt.Key - ProtocolID address.ProtocolID - Message kad.Request[kadt.Key, ma.Multiaddr] - KnownClosestNodes []peer.ID + KnownClosestNodes []kadt.PeerID Notify NotifyCloser[BehaviourEvent] } @@ -86,22 +77,21 @@ type EventStopQuery struct { func (*EventStopQuery) behaviourEvent() {} func (*EventStopQuery) queryCommand() {} -// EventAddAddrInfo notifies the routing behaviour of a potential new peer or of additional addresses for -// an existing peer. -type EventAddAddrInfo struct { - NodeInfo peer.AddrInfo +// EventAddNode notifies the routing behaviour of a potential new peer. +type EventAddNode struct { + NodeID kadt.PeerID } -func (*EventAddAddrInfo) behaviourEvent() {} -func (*EventAddAddrInfo) routingCommand() {} +func (*EventAddNode) behaviourEvent() {} +func (*EventAddNode) routingCommand() {} // EventGetCloserNodesSuccess notifies a behaviour that a GetCloserNodes request, initiated by an // [EventOutboundGetCloserNodes] event has produced a successful response. type EventGetCloserNodesSuccess struct { QueryID query.QueryID - To peer.AddrInfo // To is the peer address that the GetCloserNodes request was sent to. + To kadt.PeerID // To is the peer that the GetCloserNodes request was sent to. Target kadt.Key - CloserNodes []peer.AddrInfo + CloserNodes []kadt.PeerID } func (*EventGetCloserNodesSuccess) behaviourEvent() {} @@ -111,7 +101,7 @@ func (*EventGetCloserNodesSuccess) nodeHandlerResponse() {} // [EventOutboundGetCloserNodes] event has failed to produce a valid response. type EventGetCloserNodesFailure struct { QueryID query.QueryID - To peer.AddrInfo // To is the peer address that the GetCloserNodes request was sent to. + To kadt.PeerID // To is the peer that the GetCloserNodes request was sent to. Target kadt.Key Err error } @@ -123,8 +113,8 @@ func (*EventGetCloserNodesFailure) nodeHandlerResponse() {} // response from a node. type EventQueryProgressed struct { QueryID query.QueryID - NodeID peer.ID - Response kad.Response[kadt.Key, ma.Multiaddr] + NodeID kadt.PeerID + Response Message Stats query.QueryStats } @@ -141,7 +131,7 @@ func (*EventQueryFinished) behaviourEvent() {} // EventRoutingUpdated is emitted by the coordinator when a new node has been verified and added to the routing table. type EventRoutingUpdated struct { - NodeInfo peer.AddrInfo + NodeID kadt.PeerID } func (*EventRoutingUpdated) behaviourEvent() {} @@ -149,7 +139,7 @@ func (*EventRoutingUpdated) routingNotification() {} // EventRoutingRemoved is emitted by the coordinator when new node has been removed from the routing table. type EventRoutingRemoved struct { - NodeID peer.ID + NodeID kadt.PeerID } func (*EventRoutingRemoved) behaviourEvent() {} @@ -169,7 +159,7 @@ func (*EventBootstrapFinished) routingNotification() {} // general connections to the host but only when it is confirmed that the peer responds to requests for closer // nodes. type EventNotifyConnectivity struct { - NodeInfo peer.AddrInfo + NodeID kadt.PeerID } func (*EventNotifyConnectivity) behaviourEvent() {} @@ -178,7 +168,7 @@ func (*EventNotifyConnectivity) routingNotification() {} // EventNotifyNonConnectivity notifies a behaviour that a peer does not have connectivity and/or does not support // finding closer nodes is known. type EventNotifyNonConnectivity struct { - NodeID peer.ID + NodeID kadt.PeerID } func (*EventNotifyNonConnectivity) behaviourEvent() {} diff --git a/v2/coord/event_test.go b/v2/coord/event_test.go index b6afdd4a..2944be13 100644 --- a/v2/coord/event_test.go +++ b/v2/coord/event_test.go @@ -3,7 +3,7 @@ package coord var _ NetworkCommand = (*EventOutboundGetCloserNodes)(nil) var ( - _ RoutingCommand = (*EventAddAddrInfo)(nil) + _ RoutingCommand = (*EventAddNode)(nil) _ RoutingCommand = (*EventStartBootstrap)(nil) ) diff --git a/v2/coord/internal/nettest/layouts.go b/v2/coord/internal/nettest/layouts.go index c90e544b..7fce42f0 100644 --- a/v2/coord/internal/nettest/layouts.go +++ b/v2/coord/internal/nettest/layouts.go @@ -2,10 +2,8 @@ package nettest import ( "context" - "fmt" "github.com/benbjohnson/clock" - ma "github.com/multiformats/go-multiaddr" "github.com/plprobelab/go-kademlia/routing/simplert" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" @@ -17,45 +15,40 @@ import ( // The topology is not a ring: nodes[0] only has nodes[1] in its table and nodes[n-1] only has nodes[n-2] in its table. // nodes[1] has nodes[0] and nodes[2] in its routing table. // If n > 2 then the first and last nodes will not have one another in their routing tables. -func LinearTopology(n int, clk clock.Clock) (*Topology, []*Node, error) { - nodes := make([]*Node, n) +func LinearTopology(n int, clk clock.Clock) (*Topology, []*Peer, error) { + nodes := make([]*Peer, n) top := NewTopology(clk) for i := range nodes { - a, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 2000+i)) + id, err := NewPeerID() if err != nil { return nil, nil, err } - ai, err := NewAddrInfo([]ma.Multiaddr{a}) - if err != nil { - return nil, nil, err - } - - nodes[i] = &Node{ - NodeInfo: ai, - Router: NewRouter(ai.ID, top), - RoutingTable: simplert.New[kadt.Key, kadt.PeerID](kadt.PeerID(ai.ID), 20), + nodes[i] = &Peer{ + NodeID: id, + Router: NewRouter(id, top), + RoutingTable: simplert.New[kadt.Key, kadt.PeerID](id, 20), } } // Define the network topology, with default network links between every node for i := 0; i < len(nodes); i++ { for j := i + 1; j < len(nodes); j++ { - top.ConnectNodes(nodes[i], nodes[j]) + top.ConnectPeers(nodes[i], nodes[j]) } } // Connect nodes in a chain for i := 0; i < len(nodes); i++ { if i > 0 { - nodes[i].Router.AddNodeInfo(context.Background(), nodes[i-1].NodeInfo, 0) - nodes[i].RoutingTable.AddNode(kadt.PeerID(nodes[i-1].NodeInfo.ID)) + nodes[i].Router.AddToPeerStore(context.Background(), nodes[i-1].NodeID) + nodes[i].RoutingTable.AddNode(kadt.PeerID(nodes[i-1].NodeID)) } if i < len(nodes)-1 { - nodes[i].Router.AddNodeInfo(context.Background(), nodes[i+1].NodeInfo, 0) - nodes[i].RoutingTable.AddNode(kadt.PeerID(nodes[i+1].NodeInfo.ID)) + nodes[i].Router.AddToPeerStore(context.Background(), nodes[i+1].NodeID) + nodes[i].RoutingTable.AddNode(kadt.PeerID(nodes[i+1].NodeID)) } } diff --git a/v2/coord/internal/nettest/routing.go b/v2/coord/internal/nettest/routing.go index 7553674f..880e27e4 100644 --- a/v2/coord/internal/nettest/routing.go +++ b/v2/coord/internal/nettest/routing.go @@ -9,7 +9,6 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" - ma "github.com/multiformats/go-multiaddr" "github.com/plprobelab/go-kademlia/kad" "github.com/plprobelab/go-kademlia/key" "github.com/plprobelab/go-kademlia/network/address" @@ -21,20 +20,17 @@ import ( var rng = rand.New(rand.NewSource(6283185)) -func NewAddrInfo(addrs []ma.Multiaddr) (peer.AddrInfo, error) { +func NewPeerID() (kadt.PeerID, error) { _, pub, err := crypto.GenerateEd25519Key(rng) if err != nil { - return peer.AddrInfo{}, err + return kadt.PeerID(""), err } pid, err := peer.IDFromPublicKey(pub) if err != nil { - return peer.AddrInfo{}, err + return kadt.PeerID(""), err } - return peer.AddrInfo{ - ID: pid, - Addrs: addrs, - }, nil + return kadt.PeerID(pid), nil } // Link represents the route between two nodes. It allows latency and transport failures to be simulated. @@ -53,22 +49,22 @@ func (l *DefaultLink) ConnLatency() time.Duration { return 0 } func (l *DefaultLink) DialLatency() time.Duration { return 0 } type Router struct { - self peer.ID + self kadt.PeerID top *Topology mu sync.Mutex // guards nodes - nodes map[peer.ID]*nodeStatus + nodes map[string]*nodeStatus } type nodeStatus struct { - NodeInfo peer.AddrInfo + NodeID kadt.PeerID Connectedness endpoint.Connectedness } -func NewRouter(self peer.ID, top *Topology) *Router { +func NewRouter(self kadt.PeerID, top *Topology) *Router { return &Router{ self: self, top: top, - nodes: make(map[peer.ID]*nodeStatus), + nodes: make(map[string]*nodeStatus), } } @@ -76,28 +72,16 @@ func (r *Router) NodeID() kad.NodeID[kadt.Key] { return kadt.PeerID(r.self) } -func (r *Router) SendMessage(ctx context.Context, to peer.AddrInfo, protoID address.ProtocolID, req *pb.Message) (*pb.Message, error) { - if err := r.AddNodeInfo(ctx, to, 0); err != nil { - return nil, fmt.Errorf("add node info: %w", err) - } - - if err := r.Dial(ctx, to); err != nil { - return nil, fmt.Errorf("dial: %w", err) - } - - return r.top.RouteMessage(ctx, r.self, to.ID, protoID, req) -} - -func (r *Router) HandleMessage(ctx context.Context, n peer.ID, protoID address.ProtocolID, req *pb.Message) (*pb.Message, error) { +func (r *Router) handleMessage(ctx context.Context, n kadt.PeerID, protoID address.ProtocolID, req *pb.Message) (*pb.Message, error) { closer := make([]*pb.Message_Peer, 0) r.mu.Lock() for _, n := range r.nodes { // only include self if it was the target of the request - if n.NodeInfo.ID == r.self && !key.Equal(kadt.PeerID(n.NodeInfo.ID).Key(), req.Target()) { + if n.NodeID.Equal(r.self) && !key.Equal(n.NodeID.Key(), req.Target()) { continue } - closer = append(closer, pb.FromAddrInfo(n.NodeInfo)) + closer = append(closer, pb.FromAddrInfo(peer.AddrInfo{ID: peer.ID(n.NodeID)})) } r.mu.Unlock() @@ -110,65 +94,68 @@ func (r *Router) HandleMessage(ctx context.Context, n peer.ID, protoID address.P return resp, nil } -func (r *Router) Dial(ctx context.Context, to peer.AddrInfo) error { +func (r *Router) dial(ctx context.Context, to kadt.PeerID) error { r.mu.Lock() - status, ok := r.nodes[to.ID] + status, ok := r.nodes[to.String()] r.mu.Unlock() - if ok { - switch status.Connectedness { - case endpoint.Connected: - return nil - case endpoint.CanConnect: - if _, err := r.top.Dial(ctx, r.self, to.ID); err != nil { - return err - } - - status.Connectedness = endpoint.Connected - r.mu.Lock() - r.nodes[to.ID] = status - r.mu.Unlock() - return nil + if !ok { + status = &nodeStatus{ + NodeID: to, + Connectedness: endpoint.CanConnect, } } - return endpoint.ErrUnknownPeer + + if status.Connectedness == endpoint.Connected { + return nil + } + if err := r.top.Dial(ctx, r.self, to); err != nil { + return err + } + + status.Connectedness = endpoint.Connected + r.mu.Lock() + r.nodes[to.String()] = status + r.mu.Unlock() + return nil } -func (r *Router) AddNodeInfo(ctx context.Context, info peer.AddrInfo, ttl time.Duration) error { +func (r *Router) AddToPeerStore(ctx context.Context, id kadt.PeerID) error { r.mu.Lock() defer r.mu.Unlock() - if _, ok := r.nodes[info.ID]; !ok { - r.nodes[info.ID] = &nodeStatus{ - NodeInfo: info, + if _, ok := r.nodes[id.String()]; !ok { + r.nodes[id.String()] = &nodeStatus{ + NodeID: id, Connectedness: endpoint.CanConnect, } } return nil } -func (r *Router) GetNodeInfo(ctx context.Context, id peer.ID) (peer.AddrInfo, error) { - r.mu.Lock() - defer r.mu.Unlock() - - status, ok := r.nodes[id] - if !ok { - return peer.AddrInfo{}, fmt.Errorf("unknown node") +func (r *Router) SendMessage(ctx context.Context, to kadt.PeerID, req *pb.Message) (*pb.Message, error) { + if err := r.dial(ctx, to); err != nil { + return nil, fmt.Errorf("dial: %w", err) } - return status.NodeInfo, nil -} -func (r *Router) GetClosestNodes(ctx context.Context, to peer.AddrInfo, target kadt.Key) ([]peer.AddrInfo, error) { - protoID := address.ProtocolID("/test/1.0.0") + return r.top.RouteMessage(ctx, r.self, to, "", req) +} +func (r *Router) GetClosestNodes(ctx context.Context, to kadt.PeerID, target kadt.Key) ([]kadt.PeerID, error) { req := &pb.Message{ Type: pb.Message_FIND_NODE, Key: []byte("random-key"), } - resp, err := r.SendMessage(ctx, to, protoID, req) + resp, err := r.SendMessage(ctx, to, req) if err != nil { return nil, err } - return resp.CloserPeersAddrInfos(), nil + + // possibly learned about some new nodes + for _, id := range resp.CloserNodes() { + r.AddToPeerStore(ctx, id) + } + + return resp.CloserNodes(), nil } diff --git a/v2/coord/internal/nettest/topology.go b/v2/coord/internal/nettest/topology.go index 61653f23..96d6380a 100644 --- a/v2/coord/internal/nettest/topology.go +++ b/v2/coord/internal/nettest/topology.go @@ -13,8 +13,8 @@ import ( "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) -type Node struct { - NodeInfo peer.AddrInfo +type Peer struct { + NodeID kadt.PeerID Router *Router RoutingTable routing.RoutingTableCpl[kadt.Key, kadt.PeerID] } @@ -22,37 +22,37 @@ type Node struct { type Topology struct { clk clock.Clock links map[string]Link - nodes []*Node - nodeIndex map[peer.ID]*Node - routers map[peer.ID]*Router + nodes []*Peer + nodeIndex map[string]*Peer + routers map[string]*Router } func NewTopology(clk clock.Clock) *Topology { return &Topology{ clk: clk, links: make(map[string]Link), - nodeIndex: make(map[peer.ID]*Node), - routers: make(map[peer.ID]*Router), + nodeIndex: make(map[string]*Peer), + routers: make(map[string]*Router), } } -func (t *Topology) Nodes() []*Node { +func (t *Topology) Peers() []*Peer { return t.nodes } -func (t *Topology) ConnectNodes(a *Node, b *Node) { - t.ConnectNodesWithRoute(a, b, &DefaultLink{}) +func (t *Topology) ConnectPeers(a *Peer, b *Peer) { + t.ConnectPeersWithRoute(a, b, &DefaultLink{}) } -func (t *Topology) ConnectNodesWithRoute(a *Node, b *Node, l Link) { - akey := a.NodeInfo.ID +func (t *Topology) ConnectPeersWithRoute(a *Peer, b *Peer, l Link) { + akey := a.NodeID.String() if _, exists := t.nodeIndex[akey]; !exists { t.nodeIndex[akey] = a t.nodes = append(t.nodes, a) t.routers[akey] = a.Router } - bkey := b.NodeInfo.ID + bkey := b.NodeID.String() if _, exists := t.nodeIndex[bkey]; !exists { t.nodeIndex[bkey] = b t.nodes = append(t.nodes, b) @@ -67,8 +67,8 @@ func (t *Topology) ConnectNodesWithRoute(a *Node, b *Node, l Link) { t.links[btoa] = l } -func (t *Topology) findRoute(ctx context.Context, from peer.ID, to peer.ID) (Link, error) { - key := fmt.Sprintf("%s->%s", from, to) +func (t *Topology) findRoute(ctx context.Context, from kadt.PeerID, to kadt.PeerID) (Link, error) { + key := fmt.Sprintf("%s->%s", peer.ID(from), peer.ID(to)) route, ok := t.links[key] if !ok { @@ -78,19 +78,19 @@ func (t *Topology) findRoute(ctx context.Context, from peer.ID, to peer.ID) (Lin return route, nil } -func (t *Topology) Dial(ctx context.Context, from peer.ID, to peer.ID) (peer.AddrInfo, error) { +func (t *Topology) Dial(ctx context.Context, from kadt.PeerID, to kadt.PeerID) error { if from == to { - node, ok := t.nodeIndex[to] + _, ok := t.nodeIndex[to.String()] if !ok { - return peer.AddrInfo{}, fmt.Errorf("unknown node") + return fmt.Errorf("unknown node") } - return node.NodeInfo, nil + return nil } route, err := t.findRoute(ctx, from, to) if err != nil { - return peer.AddrInfo{}, fmt.Errorf("find route: %w", err) + return fmt.Errorf("find route: %w", err) } latency := route.DialLatency() @@ -99,25 +99,25 @@ func (t *Topology) Dial(ctx context.Context, from peer.ID, to peer.ID) (peer.Add } if err := route.DialErr(); err != nil { - return peer.AddrInfo{}, err + return err } - node, ok := t.nodeIndex[to] + _, ok := t.nodeIndex[to.String()] if !ok { - return peer.AddrInfo{}, fmt.Errorf("unknown node") + return fmt.Errorf("unknown node") } - return node.NodeInfo, nil + return nil } -func (t *Topology) RouteMessage(ctx context.Context, from peer.ID, to peer.ID, protoID address.ProtocolID, req *pb.Message) (*pb.Message, error) { +func (t *Topology) RouteMessage(ctx context.Context, from kadt.PeerID, to kadt.PeerID, protoID address.ProtocolID, req *pb.Message) (*pb.Message, error) { if from == to { - node, ok := t.nodeIndex[to] + node, ok := t.nodeIndex[to.String()] if !ok { return nil, fmt.Errorf("unknown node") } - return node.Router.HandleMessage(ctx, from, protoID, req) + return node.Router.handleMessage(ctx, from, protoID, req) } route, err := t.findRoute(ctx, from, to) @@ -130,10 +130,10 @@ func (t *Topology) RouteMessage(ctx context.Context, from peer.ID, to peer.ID, p t.clk.Sleep(latency) } - node, ok := t.nodeIndex[to] + node, ok := t.nodeIndex[to.String()] if !ok { return nil, fmt.Errorf("no route to node") } - return node.Router.HandleMessage(ctx, from, protoID, req) + return node.Router.handleMessage(ctx, from, protoID, req) } diff --git a/v2/coord/network.go b/v2/coord/network.go index d2da896c..72369b6f 100644 --- a/v2/coord/network.go +++ b/v2/coord/network.go @@ -5,20 +5,18 @@ import ( "fmt" "sync" - "github.com/libp2p/go-libp2p/core/peer" - ma "github.com/multiformats/go-multiaddr" - "github.com/plprobelab/go-kademlia/kad" "github.com/plprobelab/go-kademlia/key" "go.opentelemetry.io/otel/trace" "golang.org/x/exp/slog" "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" + "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) type NetworkBehaviour struct { // rtr is the message router used to send messages - rtr Router + rtr Router[kadt.Key, kadt.PeerID, *pb.Message] nodeHandlersMu sync.Mutex nodeHandlers map[kadt.PeerID]*NodeHandler // TODO: garbage collect node handlers @@ -31,7 +29,7 @@ type NetworkBehaviour struct { tracer trace.Tracer } -func NewNetworkBehaviour(rtr Router, logger *slog.Logger, tracer trace.Tracer) *NetworkBehaviour { +func NewNetworkBehaviour(rtr Router[kadt.Key, kadt.PeerID, *pb.Message], logger *slog.Logger, tracer trace.Tracer) *NetworkBehaviour { b := &NetworkBehaviour{ rtr: rtr, nodeHandlers: make(map[kadt.PeerID]*NodeHandler), @@ -53,10 +51,11 @@ func (b *NetworkBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { switch ev := ev.(type) { case *EventOutboundGetCloserNodes: b.nodeHandlersMu.Lock() - nh, ok := b.nodeHandlers[kadt.PeerID(ev.To.ID)] + p := kadt.PeerID(ev.To) + nh, ok := b.nodeHandlers[p] if !ok { - nh = NewNodeHandler(ev.To, b.rtr, b.logger, b.tracer) - b.nodeHandlers[kadt.PeerID(ev.To.ID)] = nh + nh = NewNodeHandler(p, b.rtr, b.logger, b.tracer) + b.nodeHandlers[p] = nh } b.nodeHandlersMu.Unlock() nh.Notify(ctx, ev) @@ -103,12 +102,8 @@ func (b *NetworkBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) { func (b *NetworkBehaviour) getNodeHandler(ctx context.Context, id kadt.PeerID) (*NodeHandler, error) { b.nodeHandlersMu.Lock() nh, ok := b.nodeHandlers[id] - if !ok || len(nh.Addresses()) == 0 { - info, err := b.rtr.GetNodeInfo(ctx, peer.ID(id)) - if err != nil { - return nil, err - } - nh = NewNodeHandler(info, b.rtr, b.logger, b.tracer) + if !ok { + nh = NewNodeHandler(id, b.rtr, b.logger, b.tracer) b.nodeHandlers[id] = nh } b.nodeHandlersMu.Unlock() @@ -116,14 +111,14 @@ func (b *NetworkBehaviour) getNodeHandler(ctx context.Context, id kadt.PeerID) ( } type NodeHandler struct { - self peer.AddrInfo - rtr Router + self kadt.PeerID + rtr Router[kadt.Key, kadt.PeerID, *pb.Message] queue *WorkQueue[NodeHandlerRequest] logger *slog.Logger tracer trace.Tracer } -func NewNodeHandler(self peer.AddrInfo, rtr Router, logger *slog.Logger, tracer trace.Tracer) *NodeHandler { +func NewNodeHandler(self kadt.PeerID, rtr Router[kadt.Key, kadt.PeerID, *pb.Message], logger *slog.Logger, tracer trace.Tracer) *NodeHandler { h := &NodeHandler{ self: self, rtr: rtr, @@ -172,12 +167,8 @@ func (h *NodeHandler) send(ctx context.Context, ev NodeHandlerRequest) bool { return false } -func (h *NodeHandler) ID() peer.ID { - return h.self.ID -} - -func (h *NodeHandler) Addresses() []ma.Multiaddr { - return h.self.Addrs +func (h *NodeHandler) ID() kadt.PeerID { + return h.self } // GetClosestNodes requests the n closest nodes to the key from the node's local routing table. @@ -233,20 +224,3 @@ func (h *NodeHandler) GetValue(ctx context.Context, key kadt.Key) (Value, error) func (h *NodeHandler) PutValue(ctx context.Context, r Value, q int) error { panic("not implemented") } - -type fakeMessage struct { - key kadt.Key - infos []kad.NodeInfo[kadt.Key, ma.Multiaddr] -} - -func (r fakeMessage) Target() kadt.Key { - return r.key -} - -func (r fakeMessage) CloserNodes() []kad.NodeInfo[kadt.Key, ma.Multiaddr] { - return r.infos -} - -func (r fakeMessage) EmptyResponse() kad.Response[kadt.Key, ma.Multiaddr] { - return &fakeMessage{} -} diff --git a/v2/coord/network_test.go b/v2/coord/network_test.go index 4d2ca5b5..6baacf53 100644 --- a/v2/coord/network_test.go +++ b/v2/coord/network_test.go @@ -10,7 +10,6 @@ import ( "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/nettest" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" - "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) // TODO: this is just a basic is-it-working test that needs to be improved @@ -21,10 +20,10 @@ func TestGetClosestNodes(t *testing.T) { _, nodes, err := nettest.LinearTopology(4, clk) require.NoError(t, err) - h := NewNodeHandler(nodes[1].NodeInfo, nodes[1].Router, slog.Default(), trace.NewNoopTracerProvider().Tracer("")) + h := NewNodeHandler(nodes[1].NodeID, nodes[1].Router, slog.Default(), trace.NewNoopTracerProvider().Tracer("")) // node 1 has node 2 in its routing table so it will return it along with node 0 - found, err := h.GetClosestNodes(ctx, kadt.PeerID(nodes[2].NodeInfo.ID).Key(), 2) + found, err := h.GetClosestNodes(ctx, nodes[2].NodeID.Key(), 2) require.NoError(t, err) for _, f := range found { t.Logf("found node %v", f.ID()) diff --git a/v2/coord/query.go b/v2/coord/query.go index 6857fc6d..b8ebb982 100644 --- a/v2/coord/query.go +++ b/v2/coord/query.go @@ -48,7 +48,7 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { cmd = &query.EventPoolAddQuery[kadt.Key, kadt.PeerID]{ QueryID: ev.QueryID, Target: ev.Target, - KnownClosestNodes: sliceOfPeerIDToSliceOfKadPeerID(ev.KnownClosestNodes), + KnownClosestNodes: ev.KnownClosestNodes, } if ev.Notify != nil { p.waiters[ev.QueryID] = ev.Notify @@ -60,34 +60,36 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { } case *EventGetCloserNodesSuccess: + // TODO: add addresses for discovered nodes in DHT + for _, info := range ev.CloserNodes { // TODO: do this after advancing pool - p.pending = append(p.pending, &EventAddAddrInfo{ - NodeInfo: info, + p.pending = append(p.pending, &EventAddNode{ + NodeID: info, }) } waiter, ok := p.waiters[ev.QueryID] if ok { waiter.Notify(ctx, &EventQueryProgressed{ - NodeID: ev.To.ID, + NodeID: ev.To, QueryID: ev.QueryID, // CloserNodes: CloserNodeIDs(ev.CloserNodes), // Stats: stats, }) } cmd = &query.EventPoolFindCloserResponse[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), + NodeID: ev.To, QueryID: ev.QueryID, - CloserNodes: sliceOfAddrInfoToSliceOfKadPeerID(ev.CloserNodes), + CloserNodes: ev.CloserNodes, } case *EventGetCloserNodesFailure: // queue an event that will notify the routing behaviour of a failed node p.pending = append(p.pending, &EventNotifyNonConnectivity{ - ev.To.ID, + ev.To, }) cmd = &query.EventPoolFindCloserFailure[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), + NodeID: ev.To, QueryID: ev.QueryID, Error: ev.Err, } @@ -156,7 +158,7 @@ func (p *PooledQueryBehaviour) advancePool(ctx context.Context, ev query.PoolEve case *query.StatePoolFindCloser[kadt.Key, kadt.PeerID]: return &EventOutboundGetCloserNodes{ QueryID: st.QueryID, - To: kadPeerIDToAddrInfo(st.NodeID), + To: st.NodeID, Target: st.Target, Notify: p, }, true diff --git a/v2/coord/routing.go b/v2/coord/routing.go index f9edbe3f..ead1b107 100644 --- a/v2/coord/routing.go +++ b/v2/coord/routing.go @@ -5,7 +5,6 @@ import ( "fmt" "sync" - "github.com/libp2p/go-libp2p/core/peer" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "golang.org/x/exp/slog" @@ -66,7 +65,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { case *EventStartBootstrap: span.SetAttributes(attribute.String("event", "EventStartBootstrap")) cmd := &routing.EventBootstrapStart[kadt.Key, kadt.PeerID]{ - KnownClosestNodes: sliceOfPeerIDToSliceOfKadPeerID(ev.SeedNodes), + KnownClosestNodes: ev.SeedNodes, } // attempt to advance the bootstrap next, ok := r.advanceBootstrap(ctx, cmd) @@ -74,15 +73,15 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { r.pending = append(r.pending, next) } - case *EventAddAddrInfo: + case *EventAddNode: span.SetAttributes(attribute.String("event", "EventAddAddrInfo")) // Ignore self - if ev.NodeInfo.ID == peer.ID(r.self) { + if r.self.Equal(ev.NodeID) { break } // TODO: apply ttl cmd := &routing.EventIncludeAddCandidate[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.NodeInfo.ID), + NodeID: ev.NodeID, } // attempt to advance the include next, ok := r.advanceInclude(ctx, cmd) @@ -91,9 +90,9 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { } case *EventRoutingUpdated: - span.SetAttributes(attribute.String("event", "EventRoutingUpdated"), attribute.String("nodeid", ev.NodeInfo.ID.String())) + span.SetAttributes(attribute.String("event", "EventRoutingUpdated"), attribute.String("nodeid", ev.NodeID.String())) cmd := &routing.EventProbeAdd[kadt.Key, kadt.PeerID]{ - NodeID: addrInfoToKadPeerID(ev.NodeInfo), + NodeID: ev.NodeID, } // attempt to advance the probe state machine next, ok := r.advanceProbe(ctx, cmd) @@ -107,13 +106,13 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { case "bootstrap": for _, info := range ev.CloserNodes { // TODO: do this after advancing bootstrap - r.pending = append(r.pending, &EventAddAddrInfo{ - NodeInfo: info, + r.pending = append(r.pending, &EventAddNode{ + NodeID: info, }) } cmd := &routing.EventBootstrapFindCloserResponse[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), - CloserNodes: sliceOfAddrInfoToSliceOfKadPeerID(ev.CloserNodes), + NodeID: ev.To, + CloserNodes: ev.CloserNodes, } // attempt to advance the bootstrap next, ok := r.advanceBootstrap(ctx, cmd) @@ -123,15 +122,14 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { case "include": var cmd routing.IncludeEvent - // require that the node responded with at least one closer node if len(ev.CloserNodes) > 0 { cmd = &routing.EventIncludeConnectivityCheckSuccess[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), + NodeID: ev.To, } } else { cmd = &routing.EventIncludeConnectivityCheckFailure[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), + NodeID: ev.To, Error: fmt.Errorf("response did not include any closer nodes"), } } @@ -146,11 +144,11 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // require that the node responded with at least one closer node if len(ev.CloserNodes) > 0 { cmd = &routing.EventProbeConnectivityCheckSuccess[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), + NodeID: ev.To, } } else { cmd = &routing.EventProbeConnectivityCheckFailure[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), + NodeID: ev.To, Error: fmt.Errorf("response did not include any closer nodes"), } } @@ -169,7 +167,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { switch ev.QueryID { case "bootstrap": cmd := &routing.EventBootstrapFindCloserFailure[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), + NodeID: ev.To, Error: ev.Err, } // attempt to advance the bootstrap @@ -179,7 +177,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { } case "include": cmd := &routing.EventIncludeConnectivityCheckFailure[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), + NodeID: ev.To, Error: ev.Err, } // attempt to advance the include state machine @@ -189,7 +187,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { } case "probe": cmd := &routing.EventProbeConnectivityCheckFailure[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.To.ID), + NodeID: ev.To, Error: ev.Err, } // attempt to advance the probe state machine @@ -202,14 +200,14 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { panic(fmt.Sprintf("unexpected query id: %s", ev.QueryID)) } case *EventNotifyConnectivity: - span.SetAttributes(attribute.String("event", "EventNotifyConnectivity"), attribute.String("nodeid", ev.NodeInfo.ID.String())) + span.SetAttributes(attribute.String("event", "EventNotifyConnectivity"), attribute.String("nodeid", ev.NodeID.String())) // ignore self - if ev.NodeInfo.ID == peer.ID(r.self) { + if r.self.Equal(ev.NodeID) { break } // tell the include state machine in case this is a new peer that could be added to the routing table cmd := &routing.EventIncludeAddCandidate[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.NodeInfo.ID), + NodeID: ev.NodeID, } next, ok := r.advanceInclude(ctx, cmd) if ok { @@ -218,7 +216,7 @@ func (r *RoutingBehaviour) notify(ctx context.Context, ev BehaviourEvent) { // tell the probe state machine in case there is are connectivity checks that could satisfied cmdProbe := &routing.EventProbeNotifyConnectivity[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.NodeInfo.ID), + NodeID: ev.NodeID, } nextProbe, ok := r.advanceProbe(ctx, cmdProbe) if ok { @@ -308,7 +306,7 @@ func (r *RoutingBehaviour) advanceBootstrap(ctx context.Context, ev routing.Boot case *routing.StateBootstrapFindCloser[kadt.Key, kadt.PeerID]: return &EventOutboundGetCloserNodes{ QueryID: "bootstrap", - To: kadPeerIDToAddrInfo(st.NodeID), + To: st.NodeID, Target: st.Target, Notify: r, }, true @@ -339,7 +337,7 @@ func (r *RoutingBehaviour) advanceInclude(ctx context.Context, ev routing.Includ // include wants to send a find node message to a node return &EventOutboundGetCloserNodes{ QueryID: "include", - To: kadPeerIDToAddrInfo(st.NodeID), + To: st.NodeID, Target: st.NodeID.Key(), Notify: r, }, true @@ -349,13 +347,13 @@ func (r *RoutingBehaviour) advanceInclude(ctx context.Context, ev routing.Includ // notify other routing state machines that there is a new node in the routing table r.notify(ctx, &EventRoutingUpdated{ - NodeInfo: kadPeerIDToAddrInfo(st.NodeID), + NodeID: st.NodeID, }) // return the event to notify outwards too span.SetAttributes(attribute.String("out_event", "EventRoutingUpdated")) return &EventRoutingUpdated{ - NodeInfo: kadPeerIDToAddrInfo(st.NodeID), + NodeID: st.NodeID, }, true case *routing.StateIncludeWaitingAtCapacity: // nothing to do except wait for message response or timeout @@ -381,7 +379,7 @@ func (r *RoutingBehaviour) advanceProbe(ctx context.Context, ev routing.ProbeEve // include wants to send a find node message to a node return &EventOutboundGetCloserNodes{ QueryID: "probe", - To: kadPeerIDToAddrInfo(st.NodeID), + To: st.NodeID, Target: st.NodeID.Key(), Notify: r, }, true @@ -390,12 +388,12 @@ func (r *RoutingBehaviour) advanceProbe(ctx context.Context, ev routing.ProbeEve // emit an EventRoutingRemoved event to notify clients that the node has been removed r.pending = append(r.pending, &EventRoutingRemoved{ - NodeID: peer.ID(st.NodeID), + NodeID: st.NodeID, }) // add the node to the inclusion list for a second chance - r.notify(ctx, &EventAddAddrInfo{ - NodeInfo: kadPeerIDToAddrInfo(st.NodeID), + r.notify(ctx, &EventAddNode{ + NodeID: st.NodeID, }) case *routing.StateProbeWaitingAtCapacity: // the probe state machine is waiting for responses for checks and the maximum number of concurrent checks has been reached. diff --git a/v2/coord/routing_test.go b/v2/coord/routing_test.go index ded02c3b..2b07d6d1 100644 --- a/v2/coord/routing_test.go +++ b/v2/coord/routing_test.go @@ -9,7 +9,6 @@ import ( "github.com/benbjohnson/clock" "github.com/libp2p/go-libp2p/core/peer" - "github.com/plprobelab/go-kademlia/network/address" "github.com/stretchr/testify/require" "golang.org/x/exp/slog" @@ -18,7 +17,6 @@ import ( "github.com/libp2p/go-libp2p-kad-dht/v2/coord/routing" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" - "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) func TestRoutingStartBootstrapSendsEvent(t *testing.T) { @@ -28,7 +26,7 @@ func TestRoutingStartBootstrapSendsEvent(t *testing.T) { _, nodes, err := nettest.LinearTopology(4, clk) require.NoError(t, err) - self := kadt.PeerID(nodes[0].NodeInfo.ID) + self := nodes[0].NodeID // records the event passed to bootstrap bootstrap := NewRecordingSM[routing.BootstrapEvent, routing.BootstrapState](&routing.StateBootstrapIdle{}) @@ -37,22 +35,15 @@ func TestRoutingStartBootstrapSendsEvent(t *testing.T) { routingBehaviour := NewRoutingBehaviour(self, bootstrap, include, probe, slog.Default(), otel.Tracer("test")) - req := &pb.Message{ - Type: pb.Message_FIND_NODE, - Key: []byte(self), - } - ev := &EventStartBootstrap{ - ProtocolID: address.ProtocolID("test"), - Message: req, - SeedNodes: []peer.ID{nodes[1].NodeInfo.ID}, + SeedNodes: []kadt.PeerID{nodes[1].NodeID}, } routingBehaviour.Notify(ctx, ev) // the event that should be passed to the bootstrap state machine expected := &routing.EventBootstrapStart[kadt.Key, kadt.PeerID]{ - KnownClosestNodes: sliceOfPeerIDToSliceOfKadPeerID(ev.SeedNodes), + KnownClosestNodes: ev.SeedNodes, } require.Equal(t, expected, bootstrap.Received) } @@ -64,7 +55,7 @@ func TestRoutingBootstrapGetClosestNodesSuccess(t *testing.T) { _, nodes, err := nettest.LinearTopology(4, clk) require.NoError(t, err) - self := kadt.PeerID(nodes[0].NodeInfo.ID) + self := nodes[0].NodeID // records the event passed to bootstrap bootstrap := NewRecordingSM[routing.BootstrapEvent, routing.BootstrapState](&routing.StateBootstrapIdle{}) @@ -75,9 +66,9 @@ func TestRoutingBootstrapGetClosestNodesSuccess(t *testing.T) { ev := &EventGetCloserNodesSuccess{ QueryID: query.QueryID("bootstrap"), - To: nodes[1].NodeInfo, - Target: kadt.PeerID(nodes[0].NodeInfo.ID).Key(), - CloserNodes: []peer.AddrInfo{nodes[2].NodeInfo}, + To: nodes[1].NodeID, + Target: nodes[0].NodeID.Key(), + CloserNodes: []kadt.PeerID{nodes[2].NodeID}, } routingBehaviour.Notify(ctx, ev) @@ -86,8 +77,8 @@ func TestRoutingBootstrapGetClosestNodesSuccess(t *testing.T) { require.IsType(t, &routing.EventBootstrapFindCloserResponse[kadt.Key, kadt.PeerID]{}, bootstrap.Received) rev := bootstrap.Received.(*routing.EventBootstrapFindCloserResponse[kadt.Key, kadt.PeerID]) - require.Equal(t, nodes[1].NodeInfo.ID, peer.ID(rev.NodeID)) - require.Equal(t, sliceOfAddrInfoToSliceOfKadPeerID(ev.CloserNodes), rev.CloserNodes) + require.True(t, nodes[1].NodeID.Equal(rev.NodeID)) + require.Equal(t, ev.CloserNodes, rev.CloserNodes) } func TestRoutingBootstrapGetClosestNodesFailure(t *testing.T) { @@ -97,7 +88,7 @@ func TestRoutingBootstrapGetClosestNodesFailure(t *testing.T) { _, nodes, err := nettest.LinearTopology(4, clk) require.NoError(t, err) - self := kadt.PeerID(nodes[0].NodeInfo.ID) + self := nodes[0].NodeID // records the event passed to bootstrap bootstrap := NewRecordingSM[routing.BootstrapEvent, routing.BootstrapState](&routing.StateBootstrapIdle{}) @@ -109,8 +100,8 @@ func TestRoutingBootstrapGetClosestNodesFailure(t *testing.T) { failure := errors.New("failed") ev := &EventGetCloserNodesFailure{ QueryID: query.QueryID("bootstrap"), - To: nodes[1].NodeInfo, - Target: kadt.PeerID(nodes[0].NodeInfo.ID).Key(), + To: nodes[1].NodeID, + Target: nodes[0].NodeID.Key(), Err: failure, } @@ -120,7 +111,7 @@ func TestRoutingBootstrapGetClosestNodesFailure(t *testing.T) { require.IsType(t, &routing.EventBootstrapFindCloserFailure[kadt.Key, kadt.PeerID]{}, bootstrap.Received) rev := bootstrap.Received.(*routing.EventBootstrapFindCloserFailure[kadt.Key, kadt.PeerID]) - require.Equal(t, nodes[1].NodeInfo.ID, peer.ID(rev.NodeID)) + require.Equal(t, peer.ID(nodes[1].NodeID), peer.ID(rev.NodeID)) require.Equal(t, failure, rev.Error) } @@ -131,7 +122,7 @@ func TestRoutingAddNodeInfoSendsEvent(t *testing.T) { _, nodes, err := nettest.LinearTopology(4, clk) require.NoError(t, err) - self := kadt.PeerID(nodes[0].NodeInfo.ID) + self := nodes[0].NodeID // records the event passed to include include := NewRecordingSM[routing.IncludeEvent, routing.IncludeState](&routing.StateIncludeIdle{}) @@ -141,15 +132,15 @@ func TestRoutingAddNodeInfoSendsEvent(t *testing.T) { routingBehaviour := NewRoutingBehaviour(self, bootstrap, include, probe, slog.Default(), otel.Tracer("test")) - ev := &EventAddAddrInfo{ - NodeInfo: nodes[2].NodeInfo, + ev := &EventAddNode{ + NodeID: nodes[2].NodeID, } routingBehaviour.Notify(ctx, ev) // the event that should be passed to the include state machine expected := &routing.EventIncludeAddCandidate[kadt.Key, kadt.PeerID]{ - NodeID: kadt.PeerID(ev.NodeInfo.ID), + NodeID: ev.NodeID, } require.Equal(t, expected, include.Received) } @@ -161,7 +152,7 @@ func TestRoutingIncludeGetClosestNodesSuccess(t *testing.T) { _, nodes, err := nettest.LinearTopology(4, clk) require.NoError(t, err) - self := kadt.PeerID(nodes[0].NodeInfo.ID) + self := nodes[0].NodeID // records the event passed to include include := NewRecordingSM[routing.IncludeEvent, routing.IncludeState](&routing.StateIncludeIdle{}) @@ -173,9 +164,9 @@ func TestRoutingIncludeGetClosestNodesSuccess(t *testing.T) { ev := &EventGetCloserNodesSuccess{ QueryID: query.QueryID("include"), - To: nodes[1].NodeInfo, - Target: kadt.PeerID(nodes[0].NodeInfo.ID).Key(), - CloserNodes: []peer.AddrInfo{nodes[2].NodeInfo}, + To: nodes[1].NodeID, + Target: nodes[0].NodeID.Key(), + CloserNodes: []kadt.PeerID{nodes[2].NodeID}, } routingBehaviour.Notify(ctx, ev) @@ -184,7 +175,7 @@ func TestRoutingIncludeGetClosestNodesSuccess(t *testing.T) { require.IsType(t, &routing.EventIncludeConnectivityCheckSuccess[kadt.Key, kadt.PeerID]{}, include.Received) rev := include.Received.(*routing.EventIncludeConnectivityCheckSuccess[kadt.Key, kadt.PeerID]) - require.Equal(t, nodes[1].NodeInfo.ID, peer.ID(rev.NodeID)) + require.Equal(t, peer.ID(nodes[1].NodeID), peer.ID(rev.NodeID)) } func TestRoutingIncludeGetClosestNodesFailure(t *testing.T) { @@ -194,7 +185,7 @@ func TestRoutingIncludeGetClosestNodesFailure(t *testing.T) { _, nodes, err := nettest.LinearTopology(4, clk) require.NoError(t, err) - self := kadt.PeerID(nodes[0].NodeInfo.ID) + self := nodes[0].NodeID // records the event passed to include include := NewRecordingSM[routing.IncludeEvent, routing.IncludeState](&routing.StateIncludeIdle{}) @@ -207,8 +198,8 @@ func TestRoutingIncludeGetClosestNodesFailure(t *testing.T) { failure := errors.New("failed") ev := &EventGetCloserNodesFailure{ QueryID: query.QueryID("include"), - To: nodes[1].NodeInfo, - Target: kadt.PeerID(nodes[0].NodeInfo.ID).Key(), + To: nodes[1].NodeID, + Target: nodes[0].NodeID.Key(), Err: failure, } @@ -218,7 +209,7 @@ func TestRoutingIncludeGetClosestNodesFailure(t *testing.T) { require.IsType(t, &routing.EventIncludeConnectivityCheckFailure[kadt.Key, kadt.PeerID]{}, include.Received) rev := include.Received.(*routing.EventIncludeConnectivityCheckFailure[kadt.Key, kadt.PeerID]) - require.Equal(t, nodes[1].NodeInfo.ID, peer.ID(rev.NodeID)) + require.Equal(t, peer.ID(nodes[1].NodeID), peer.ID(rev.NodeID)) require.Equal(t, failure, rev.Error) } @@ -229,7 +220,7 @@ func TestRoutingIncludedNodeAddToProbeList(t *testing.T) { _, nodes, err := nettest.LinearTopology(4, clk) require.NoError(t, err) - self := kadt.PeerID(nodes[0].NodeInfo.ID) + self := nodes[0].NodeID rt := nodes[0].RoutingTable includeCfg := routing.DefaultIncludeConfig() @@ -249,15 +240,15 @@ func TestRoutingIncludedNodeAddToProbeList(t *testing.T) { routingBehaviour := NewRoutingBehaviour(self, bootstrap, include, probe, slog.Default(), otel.Tracer("test")) // a new node to be included - candidate := nodes[len(nodes)-1].NodeInfo + candidate := nodes[len(nodes)-1].NodeID // the routing table should not contain the node yet - _, intable := rt.GetNode(kadt.PeerID(candidate.ID).Key()) + _, intable := rt.GetNode(candidate.Key()) require.False(t, intable) // notify that there is a new node to be included - routingBehaviour.Notify(ctx, &EventAddAddrInfo{ - NodeInfo: candidate, + routingBehaviour.Notify(ctx, &EventAddNode{ + NodeID: candidate, }) // collect the result of the notify @@ -277,11 +268,11 @@ func TestRoutingIncludedNodeAddToProbeList(t *testing.T) { QueryID: oev.QueryID, To: oev.To, Target: oev.Target, - CloserNodes: []peer.AddrInfo{nodes[1].NodeInfo}, // must include one for include check to pass + CloserNodes: []kadt.PeerID{nodes[1].NodeID}, // must include one for include check to pass }) // the routing table should now contain the node - _, intable = rt.GetNode(kadt.PeerID(candidate.ID).Key()) + _, intable = rt.GetNode(candidate.Key()) require.True(t, intable) // routing update event should be emitted from the include state machine @@ -300,5 +291,5 @@ func TestRoutingIncludedNodeAddToProbeList(t *testing.T) { // confirm that the message is for the correct node oev = dev.(*EventOutboundGetCloserNodes) require.Equal(t, query.QueryID("probe"), oev.QueryID) - require.Equal(t, candidate.ID, oev.To.ID) + require.Equal(t, candidate, oev.To) } diff --git a/v2/dht.go b/v2/dht.go index 06086d81..559a5288 100644 --- a/v2/dht.go +++ b/v2/dht.go @@ -150,12 +150,12 @@ func New(h host.Host, cfg *Config) (*DHT, error) { } // instantiate a new Kademlia DHT coordinator. - coordCfg := coord.DefaultCoordinatorConfig() + coordCfg := cfg.Kademlia coordCfg.Clock = cfg.Clock coordCfg.MeterProvider = cfg.MeterProvider coordCfg.TracerProvider = cfg.TracerProvider - d.kad, err = coord.NewCoordinator(kadt.PeerID(d.host.ID()), &Router{host: h}, d.rt, coordCfg) + d.kad, err = coord.NewCoordinator(kadt.PeerID(d.host.ID()), &Router{host: h, ProtocolID: cfg.ProtocolID}, d.rt, coordCfg) if err != nil { return nil, fmt.Errorf("new coordinator: %w", err) } @@ -309,12 +309,16 @@ func (d *DHT) AddAddresses(ctx context.Context, ais []peer.AddrInfo, ttl time.Du ctx, span := d.tele.Tracer.Start(ctx, "DHT.AddAddresses") defer span.End() + ids := make([]kadt.PeerID, 0, len(ais)) + ps := d.host.Peerstore() for _, ai := range ais { + // TODO: apply address filter ps.AddAddrs(ai.ID, ai.Addrs, ttl) + ids = append(ids, kadt.PeerID(ai.ID)) } - return d.kad.AddNodes(ctx, ais) + return d.kad.AddNodes(ctx, ids) } // newSHA256Key returns a [kadt.KadKey] that conforms to the [kad.Key] interface by diff --git a/v2/dht_test.go b/v2/dht_test.go index 29993a58..6296dbf3 100644 --- a/v2/dht_test.go +++ b/v2/dht_test.go @@ -1,9 +1,6 @@ package dht import ( - "context" - "fmt" - "reflect" "testing" "time" @@ -14,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p-kad-dht/v2/coord" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" + "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) func TestNew(t *testing.T) { @@ -75,26 +73,12 @@ func TestNew(t *testing.T) { } } -// expectEventType selects on the event channel until an event of the expected type is sent. -func expectEventType(t *testing.T, ctx context.Context, events <-chan coord.RoutingNotification, expected coord.RoutingNotification) (coord.RoutingNotification, error) { - t.Helper() - for { - select { - case ev := <-events: - t.Logf("saw event: %T\n", ev) - if reflect.TypeOf(ev) == reflect.TypeOf(expected) { - return ev, nil - } - case <-ctx.Done(): - return nil, fmt.Errorf("test deadline exceeded while waiting for event %T", expected) - } - } -} - func TestAddAddresses(t *testing.T) { ctx := kadtest.CtxShort(t) localCfg := DefaultConfig() + rn := coord.NewBufferedRoutingNotifier() + localCfg.Kademlia.RoutingNotifier = rn local := newClientDht(t, localCfg) @@ -104,7 +88,7 @@ func TestAddAddresses(t *testing.T) { fillRoutingTable(t, remote, 1) // local routing table should not contain the node - _, err := local.kad.GetNode(ctx, remote.host.ID()) + _, err := local.kad.GetNode(ctx, kadt.PeerID(remote.host.ID())) require.ErrorIs(t, err, coord.ErrNodeNotFound) remoteAddrInfo := peer.AddrInfo{ @@ -119,10 +103,10 @@ func TestAddAddresses(t *testing.T) { require.NoError(t, err) // the include state machine runs in the background and eventually should add the node to routing table - _, err = expectEventType(t, ctx, local.kad.RoutingNotifications(), &coord.EventRoutingUpdated{}) + _, err = rn.Expect(ctx, &coord.EventRoutingUpdated{}) require.NoError(t, err) // the routing table should now contain the node - _, err = local.kad.GetNode(ctx, remote.host.ID()) + _, err = local.kad.GetNode(ctx, kadt.PeerID(remote.host.ID())) require.NoError(t, err) } diff --git a/v2/handlers.go b/v2/handlers.go index bcd89f9a..5b8536f3 100644 --- a/v2/handlers.go +++ b/v2/handlers.go @@ -25,7 +25,7 @@ func (d *DHT) handleFindPeer(ctx context.Context, remote peer.ID, req *pb.Messag } // tell the coordinator that this peer supports finding closer nodes - d.kad.NotifyConnectivity(ctx, remote) + d.kad.NotifyConnectivity(ctx, kadt.PeerID(remote)) // "parse" requested peer ID from the key field target := peer.ID(req.GetKey()) diff --git a/v2/internal/kadtest/context.go b/v2/internal/kadtest/context.go index 41623c08..8a69328c 100644 --- a/v2/internal/kadtest/context.go +++ b/v2/internal/kadtest/context.go @@ -2,6 +2,7 @@ package kadtest import ( "context" + "runtime" "testing" "time" ) @@ -13,7 +14,13 @@ import ( func CtxShort(t *testing.T) context.Context { t.Helper() - timeout := 10 * time.Second + var timeout time.Duration + // Increase the timeout for 32-bit Windows + if runtime.GOOS == "windows" && runtime.GOARCH == "386" { + timeout = 60 * time.Second + } else { + timeout = 10 * time.Second + } goal := time.Now().Add(timeout) deadline, ok := t.Deadline() diff --git a/v2/kadt/kadt.go b/v2/kadt/kadt.go index 9de3e6e9..f87057a8 100644 --- a/v2/kadt/kadt.go +++ b/v2/kadt/kadt.go @@ -39,6 +39,11 @@ func (p PeerID) String() string { return peer.ID(p).String() } +// Equal compares the [PeerID] with another by comparing the underlying [peer.ID]. +func (p PeerID) Equal(o PeerID) bool { + return peer.ID(p) == peer.ID(o) +} + // AddrInfo is a type that wraps peer.AddrInfo and implements the kad.NodeInfo // interface. This means we can use AddrInfo for any operation that interfaces // with go-kademlia. diff --git a/v2/notifee.go b/v2/notifee.go index d1889428..0666e836 100644 --- a/v2/notifee.go +++ b/v2/notifee.go @@ -4,9 +4,9 @@ import ( "context" "fmt" + "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" ) // networkEventsSubscription registers a subscription on the libp2p event bus @@ -90,7 +90,5 @@ func (d *DHT) onEvtLocalReachabilityChanged(evt event.EvtLocalReachabilityChange func (d *DHT) onEvtPeerIdentificationCompleted(evt event.EvtPeerIdentificationCompleted) { // tell the coordinator about a new candidate for inclusion in the routing table - d.kad.AddNodes(context.Background(), []peer.AddrInfo{ - {ID: evt.Peer}, - }) + d.kad.AddNodes(context.Background(), []kadt.PeerID{kadt.PeerID(evt.Peer)}) } diff --git a/v2/notifee_test.go b/v2/notifee_test.go index a42f82bf..b7079ac6 100644 --- a/v2/notifee_test.go +++ b/v2/notifee_test.go @@ -4,7 +4,9 @@ import ( "testing" "time" + "github.com/libp2p/go-libp2p-kad-dht/v2/coord" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" + "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/event" @@ -72,7 +74,11 @@ func TestDHT_consumeNetworkEvents_onEvtLocalReachabilityChanged(t *testing.T) { func TestDHT_consumeNetworkEvents_onEvtPeerIdentificationCompleted(t *testing.T) { ctx := kadtest.CtxShort(t) - d1 := newServerDht(t, nil) + cfg1 := DefaultConfig() + rn1 := coord.NewBufferedRoutingNotifier() + cfg1.Kademlia.RoutingNotifier = rn1 + d1 := newServerDht(t, cfg1) + d2 := newServerDht(t, nil) // make sure d1 has the address of d2 in its peerstore @@ -83,6 +89,6 @@ func TestDHT_consumeNetworkEvents_onEvtPeerIdentificationCompleted(t *testing.T) Peer: d2.host.ID(), }) - _, err := expectRoutingUpdated(t, ctx, d1.kad.RoutingNotifications(), d2.host.ID()) + _, err := rn1.ExpectRoutingUpdated(ctx, kadt.PeerID(d2.host.ID())) require.NoError(t, err) } diff --git a/v2/pb/msg.aux.go b/v2/pb/msg.aux.go index b0bf4ef0..68ac067a 100644 --- a/v2/pb/msg.aux.go +++ b/v2/pb/msg.aux.go @@ -7,7 +7,6 @@ import ( "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" - "github.com/plprobelab/go-kademlia/kad" "github.com/plprobelab/go-kademlia/key" "golang.org/x/exp/slog" ) @@ -15,20 +14,11 @@ import ( // this file contains auxiliary methods to augment the protobuf generated types. // It is used to let these types conform to interfaces or add convenience methods. -var _ kad.Request[kadt.Key, ma.Multiaddr] = (*Message)(nil) - func (m *Message) Target() kadt.Key { b := sha256.Sum256(m.Key) return key.NewKey256(b[:]) } -func (m *Message) EmptyResponse() kad.Response[kadt.Key, ma.Multiaddr] { - return &Message{ - Type: m.Type, - Key: m.Key, - } -} - // FromAddrInfo constructs a [Message_Peer] from the given [peer.AddrInfo]. func FromAddrInfo(p peer.AddrInfo) *Message_Peer { mp := &Message_Peer{ @@ -90,20 +80,17 @@ func (m *Message) CloserPeersAddrInfos() []peer.AddrInfo { return addrInfos } -func (m *Message) CloserNodes() []kad.NodeInfo[kadt.Key, ma.Multiaddr] { +func (m *Message) CloserNodes() []kadt.PeerID { if m == nil { return nil } - infos := make([]kad.NodeInfo[kadt.Key, ma.Multiaddr], 0, len(m.CloserPeers)) + ids := make([]kadt.PeerID, 0, len(m.CloserPeers)) for _, p := range m.CloserPeers { - infos = append(infos, &kadt.AddrInfo{Info: peer.AddrInfo{ - ID: peer.ID(p.Id), - Addrs: p.Addresses(), - }}) + ids = append(ids, kadt.PeerID(peer.ID(p.Id))) } - return infos + return ids } // Addresses returns the Multiaddresses associated with the Message_Peer entry diff --git a/v2/query_test.go b/v2/query_test.go index b96c0b33..bf8e3ee1 100644 --- a/v2/query_test.go +++ b/v2/query_test.go @@ -2,7 +2,6 @@ package dht import ( "context" - "fmt" "testing" "time" @@ -87,43 +86,7 @@ func newClientDht(t testing.TB, cfg *Config) *DHT { return d } -// expectRoutingUpdated selects on the event channel until an EventRoutingUpdated event is seen for the specified peer id -func expectRoutingUpdated(t *testing.T, ctx context.Context, events <-chan coord.RoutingNotification, id peer.ID) (*coord.EventRoutingUpdated, error) { - t.Helper() - for { - select { - case ev := <-events: - if tev, ok := ev.(*coord.EventRoutingUpdated); ok { - if tev.NodeInfo.ID == id { - return tev, nil - } - t.Logf("saw routing update for %s", tev.NodeInfo.ID) - } - case <-ctx.Done(): - return nil, fmt.Errorf("test deadline exceeded while waiting for routing update event") - } - } -} - -// expectRoutingUpdated selects on the event channel until an EventRoutingUpdated event is seen for the specified peer id -func expectRoutingRemoved(t *testing.T, ctx context.Context, events <-chan coord.RoutingNotification, id peer.ID) (*coord.EventRoutingRemoved, error) { - t.Helper() - for { - select { - case ev := <-events: - if tev, ok := ev.(*coord.EventRoutingRemoved); ok { - if tev.NodeID == id { - return tev, nil - } - t.Logf("saw routing removed for %s", tev.NodeID) - } - case <-ctx.Done(): - return nil, fmt.Errorf("test deadline exceeded while waiting for routing removed event") - } - } -} - -func connect(t *testing.T, ctx context.Context, a, b *DHT) { +func connect(t *testing.T, ctx context.Context, a, b *DHT, arn *coord.BufferedRoutingNotifier) { t.Helper() remoteAddrInfo := peer.AddrInfo{ @@ -136,42 +99,44 @@ func connect(t *testing.T, ctx context.Context, a, b *DHT) { require.NoError(t, err) // the include state machine runs in the background for a and eventually should add the node to routing table - _, err = expectRoutingUpdated(t, ctx, a.kad.RoutingNotifications(), b.host.ID()) + _, err = arn.ExpectRoutingUpdated(ctx, kadt.PeerID(b.host.ID())) require.NoError(t, err) // the routing table should now contain the node - _, err = a.kad.GetNode(ctx, b.host.ID()) + _, err = a.kad.GetNode(ctx, kadt.PeerID(b.host.ID())) require.NoError(t, err) } -// connectLinearChain connects the dhts together in a linear chain. -// The dhts are configured with routing tables that contain immediate neighbours. -func connectLinearChain(t *testing.T, ctx context.Context, dhts ...*DHT) { - for i := 1; i < len(dhts); i++ { - connect(t, ctx, dhts[i-1], dhts[i]) - connect(t, ctx, dhts[i], dhts[i-1]) - } -} - func TestRTAdditionOnSuccessfulQuery(t *testing.T) { ctx := kadtest.CtxShort(t) - ctx, tp := kadtest.MaybeTrace(t, ctx) - cfg := DefaultConfig() - cfg.TracerProvider = tp + // create dhts and associated routing notifiers so we can inspect routing events + cfg1 := DefaultConfig() + rn1 := coord.NewBufferedRoutingNotifier() + cfg1.Kademlia.RoutingNotifier = rn1 + d1 := newServerDht(t, cfg1) + + cfg2 := DefaultConfig() + rn2 := coord.NewBufferedRoutingNotifier() + cfg2.Kademlia.RoutingNotifier = rn2 + d2 := newServerDht(t, cfg2) - d1 := newServerDht(t, cfg) - d2 := newServerDht(t, cfg) - d3 := newServerDht(t, cfg) + cfg3 := DefaultConfig() + rn3 := coord.NewBufferedRoutingNotifier() + cfg3.Kademlia.RoutingNotifier = rn3 + d3 := newServerDht(t, cfg3) - connectLinearChain(t, ctx, d1, d2, d3) + connect(t, ctx, d1, d2, rn1) + connect(t, ctx, d2, d1, rn2) + connect(t, ctx, d2, d3, rn2) + connect(t, ctx, d3, d2, rn3) // d3 does not know about d1 - _, err := d3.kad.GetNode(ctx, d1.host.ID()) + _, err := d3.kad.GetNode(ctx, kadt.PeerID(d1.host.ID())) require.ErrorIs(t, err, coord.ErrNodeNotFound) // d1 does not know about d3 - _, err = d1.kad.GetNode(ctx, d3.host.ID()) + _, err = d1.kad.GetNode(ctx, kadt.PeerID(d3.host.ID())) require.ErrorIs(t, err, coord.ErrNodeNotFound) // // but when d3 queries d2, d1 and d3 discover each other @@ -179,31 +144,37 @@ func TestRTAdditionOnSuccessfulQuery(t *testing.T) { // ignore the error // d3 should update its routing table to include d1 during the query - _, err = expectRoutingUpdated(t, ctx, d3.kad.RoutingNotifications(), d1.host.ID()) + _, err = rn3.ExpectRoutingUpdated(ctx, kadt.PeerID(d1.host.ID())) require.NoError(t, err) // d3 now has d1 in its routing table - _, err = d3.kad.GetNode(ctx, d1.host.ID()) + _, err = d3.kad.GetNode(ctx, kadt.PeerID(d1.host.ID())) require.NoError(t, err) // d1 should update its routing table to include d3 during the query - _, err = expectRoutingUpdated(t, ctx, d1.kad.RoutingNotifications(), d3.host.ID()) + _, err = rn1.ExpectRoutingUpdated(ctx, kadt.PeerID(d3.host.ID())) require.NoError(t, err) // d1 now has d3 in its routing table - _, err = d1.kad.GetNode(ctx, d3.host.ID()) + _, err = d1.kad.GetNode(ctx, kadt.PeerID(d3.host.ID())) require.NoError(t, err) } func TestRTEvictionOnFailedQuery(t *testing.T) { ctx := kadtest.CtxShort(t) - cfg := DefaultConfig() + cfg1 := DefaultConfig() + rn1 := coord.NewBufferedRoutingNotifier() + cfg1.Kademlia.RoutingNotifier = rn1 + d1 := newServerDht(t, cfg1) + + cfg2 := DefaultConfig() + rn2 := coord.NewBufferedRoutingNotifier() + cfg2.Kademlia.RoutingNotifier = rn2 + d2 := newServerDht(t, cfg2) - d1 := newServerDht(t, cfg) - d2 := newServerDht(t, cfg) - connect(t, ctx, d1, d2) - connect(t, ctx, d2, d1) + connect(t, ctx, d1, d2, rn1) + connect(t, ctx, d2, d1, rn2) // close both hosts so query fails require.NoError(t, d1.host.Close()) @@ -213,17 +184,17 @@ func TestRTEvictionOnFailedQuery(t *testing.T) { // no scheduled probes will have taken place // d1 still has d2 in the routing table - _, err := d1.kad.GetNode(ctx, d2.host.ID()) + _, err := d1.kad.GetNode(ctx, kadt.PeerID(d2.host.ID())) require.NoError(t, err) // d2 still has d1 in the routing table - _, err = d2.kad.GetNode(ctx, d1.host.ID()) + _, err = d2.kad.GetNode(ctx, kadt.PeerID(d1.host.ID())) require.NoError(t, err) // failed queries should remove the queried peers from the routing table _, _ = d1.FindPeer(ctx, "test") // d1 should update its routing table to remove d2 because of the failure - _, err = expectRoutingRemoved(t, ctx, d1.kad.RoutingNotifications(), d2.host.ID()) + _, err = rn1.ExpectRoutingRemoved(ctx, kadt.PeerID(d2.host.ID())) require.NoError(t, err) } diff --git a/v2/router.go b/v2/router.go index 2c5ed505..14db2cd9 100644 --- a/v2/router.go +++ b/v2/router.go @@ -11,10 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-msgio" "github.com/libp2p/go-msgio/pbio" - "github.com/plprobelab/go-kademlia/kad" - "github.com/plprobelab/go-kademlia/network/address" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoreflect" "github.com/libp2p/go-libp2p-kad-dht/v2/coord" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" @@ -23,42 +20,26 @@ import ( type Router struct { host host.Host + // ProtocolID represents the DHT [protocol] we can query with and respond to. + // + // [protocol]: https://docs.libp2p.io/concepts/fundamentals/protocols/ + ProtocolID protocol.ID } -var _ coord.Router = (*Router)(nil) +var _ coord.Router[kadt.Key, kadt.PeerID, *pb.Message] = (*Router)(nil) -func WriteMsg(s network.Stream, msg protoreflect.ProtoMessage) error { - w := pbio.NewDelimitedWriter(s) - return w.WriteMsg(msg) -} - -func ReadMsg(s network.Stream, msg proto.Message) error { - r := pbio.NewDelimitedReader(s, network.MessageSizeMax) - return r.ReadMsg(msg) -} - -type ProtoKadMessage interface { - proto.Message -} - -type ProtoKadRequestMessage[K kad.Key[K], A kad.Address[A]] interface { - ProtoKadMessage - kad.Request[K, A] -} - -type ProtoKadResponseMessage[K kad.Key[K], A kad.Address[A]] interface { - ProtoKadMessage - kad.Response[K, A] -} - -func (r *Router) SendMessage(ctx context.Context, to peer.AddrInfo, protoID address.ProtocolID, req *pb.Message) (*pb.Message, error) { - if err := r.AddNodeInfo(ctx, to, time.Hour); err != nil { - return nil, fmt.Errorf("add node info: %w", err) +func FindKeyRequest(k kadt.Key) *pb.Message { + marshalledKey, _ := k.MarshalBinary() + return &pb.Message{ + Type: pb.Message_FIND_NODE, + Key: marshalledKey, } +} +func (r *Router) SendMessage(ctx context.Context, to kadt.PeerID, req *pb.Message) (*pb.Message, error) { // TODO: what to do with addresses in peer.AddrInfo? - if len(r.host.Peerstore().Addrs(to.ID)) == 0 { - return nil, fmt.Errorf("no address for peer %s", to.ID) + if len(r.host.Peerstore().Addrs(peer.ID(to))) == 0 { + return nil, fmt.Errorf("no address for peer %s", to) } var cancel context.CancelFunc @@ -68,7 +49,7 @@ func (r *Router) SendMessage(ctx context.Context, to peer.AddrInfo, protoID addr var err error var s network.Stream - s, err = r.host.NewStream(ctx, to.ID, protocol.ID(protoID)) + s, err = r.host.NewStream(ctx, peer.ID(to), r.ProtocolID) if err != nil { return nil, fmt.Errorf("stream creation: %w", err) } @@ -92,39 +73,27 @@ func (r *Router) SendMessage(ctx context.Context, to peer.AddrInfo, protoID addr } for _, info := range protoResp.CloserPeersAddrInfos() { - _ = r.AddNodeInfo(ctx, info, time.Hour) + _ = r.addToPeerStore(ctx, info, time.Hour) // TODO: replace hard coded time.Hour with config } return &protoResp, err } -func (r *Router) AddNodeInfo(ctx context.Context, ai peer.AddrInfo, ttl time.Duration) error { - // Don't add addresses for self or our connected peers. We have better ones. - if ai.ID == r.host.ID() || r.host.Network().Connectedness(ai.ID) == network.Connected { - return nil - } - - r.host.Peerstore().AddAddrs(ai.ID, ai.Addrs, ttl) - return nil -} - -func (r *Router) GetNodeInfo(ctx context.Context, id peer.ID) (peer.AddrInfo, error) { - return r.host.Peerstore().PeerInfo(id), nil -} - -func (r *Router) GetClosestNodes(ctx context.Context, to peer.AddrInfo, target kadt.Key) ([]peer.AddrInfo, error) { - resp, err := r.SendMessage(ctx, to, address.ProtocolID(ProtocolIPFS), FindKeyRequest(target)) +func (r *Router) GetClosestNodes(ctx context.Context, to kadt.PeerID, target kadt.Key) ([]kadt.PeerID, error) { + resp, err := r.SendMessage(ctx, to, FindKeyRequest(target)) if err != nil { return nil, err } - return resp.CloserPeersAddrInfos(), nil + return resp.CloserNodes(), nil } -func FindKeyRequest(k kadt.Key) *pb.Message { - marshalledKey, _ := k.MarshalBinary() - return &pb.Message{ - Type: pb.Message_FIND_NODE, - Key: marshalledKey, +func (r *Router) addToPeerStore(ctx context.Context, ai peer.AddrInfo, ttl time.Duration) error { + // Don't add addresses for self or our connected peers. We have better ones. + if ai.ID == r.host.ID() || r.host.Network().Connectedness(ai.ID) == network.Connected { + return nil } + + r.host.Peerstore().AddAddrs(ai.ID, ai.Addrs, ttl) + return nil } diff --git a/v2/routing.go b/v2/routing.go index e17ae434..396104a9 100644 --- a/v2/routing.go +++ b/v2/routing.go @@ -43,7 +43,7 @@ func (d *DHT) FindPeer(ctx context.Context, id peer.ID) (peer.AddrInfo, error) { var foundNode coord.Node fn := func(ctx context.Context, node coord.Node, stats coord.QueryStats) error { - if node.ID() == id { + if peer.ID(node.ID()) == id { foundNode = node return coord.ErrSkipRemaining } @@ -59,10 +59,7 @@ func (d *DHT) FindPeer(ctx context.Context, id peer.ID) (peer.AddrInfo, error) { return peer.AddrInfo{}, fmt.Errorf("peer record not found") } - return peer.AddrInfo{ - ID: foundNode.ID(), - Addrs: foundNode.Addresses(), - }, nil + return d.host.Peerstore().PeerInfo(peer.ID(foundNode.ID())), nil } func (d *DHT) Provide(ctx context.Context, c cid.Cid, brdcst bool) error {