diff --git a/v2/config.go b/v2/config.go index ba41447c..696b8dfe 100644 --- a/v2/config.go +++ b/v2/config.go @@ -19,8 +19,7 @@ import ( "go.uber.org/zap/exp/zapslog" "golang.org/x/exp/slog" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/routing" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/routing" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) @@ -112,8 +111,8 @@ type Config struct { // between both automatically (see ModeOpt). Mode ModeOpt - // Kademlia holds the configuration of the underlying Kademlia implementation. - Kademlia *coord.CoordinatorConfig + // Query holds the configuration used for queries managed by the DHT. + Query *QueryConfig // BucketSize determines the number of closer peers to return BucketSize int @@ -132,7 +131,7 @@ type Config struct { // [triert.TrieRT] routing table will be used. This field will be nil // in the default configuration because a routing table requires information // about the local node. - RoutingTable routing.RoutingTableCpl[kadt.Key, kadt.PeerID] + RoutingTable kadt.RoutingTable // The Backends field holds a map of key namespaces to their corresponding // backend implementation. For example, if we received an IPNS record, the @@ -193,7 +192,6 @@ func DefaultConfig() *Config { return &Config{ Clock: clock.New(), Mode: ModeOptAutoClient, - Kademlia: coord.DefaultCoordinatorConfig(), BucketSize: 20, // MAGIC BootstrapPeers: DefaultBootstrapPeers(), ProtocolID: ProtocolIPFS, @@ -205,6 +203,7 @@ func DefaultConfig() *Config { AddressFilter: AddrFilterPrivate, MeterProvider: otel.GetMeterProvider(), TracerProvider: otel.GetTracerProvider(), + Query: DefaultQueryConfig(), } } @@ -242,62 +241,104 @@ func (c *Config) Validate() error { return fmt.Errorf("invalid mode option: %s", c.Mode) } - if c.Kademlia == nil { - return fmt.Errorf("kademlia configuration must not be nil") + if c.Query == nil { + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("query configuration must not be nil"), + } } - if err := c.Kademlia.Validate(); err != nil { - return fmt.Errorf("invalid kademlia configuration: %w", err) + if err := c.Query.Validate(); err != nil { + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("invalid query configuration: %w", err), + } } if c.BucketSize == 0 { - return fmt.Errorf("bucket size must not be 0") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("bucket size must not be 0"), + } } if len(c.BootstrapPeers) == 0 { - return fmt.Errorf("no bootstrap peer") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("no bootstrap peer"), + } } if c.ProtocolID == "" { - return fmt.Errorf("protocolID must not be empty") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("protocolID must not be empty"), + } } if c.Logger == nil { - return fmt.Errorf("logger must not be nil") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("logger must not be nil"), + } } if c.TimeoutStreamIdle <= 0 { - return fmt.Errorf("stream idle timeout must be a positive duration") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("stream idle timeout must be a positive duration"), + } } if c.ProtocolID == ProtocolIPFS && len(c.Backends) != 0 { if len(c.Backends) != 3 { - return fmt.Errorf("ipfs protocol requires exactly three backends") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("ipfs protocol requires exactly three backends"), + } } if _, found := c.Backends[namespaceIPNS]; !found { - return fmt.Errorf("ipfs protocol requires an IPNS backend") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("ipfs protocol requires an IPNS backend"), + } } if _, found := c.Backends[namespacePublicKey]; !found { - return fmt.Errorf("ipfs protocol requires a public key backend") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("ipfs protocol requires a public key backend"), + } } if _, found := c.Backends[namespaceProviders]; !found { - return fmt.Errorf("ipfs protocol requires a providers backend") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("ipfs protocol requires a providers backend"), + } } } if c.AddressFilter == nil { - return fmt.Errorf("address filter must not be nil - use AddrFilterIdentity to disable filtering") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("address filter must not be nil - use AddrFilterIdentity to disable filtering"), + } } if c.MeterProvider == nil { - return fmt.Errorf("opentelemetry meter provider must not be nil") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("opentelemetry meter provider must not be nil"), + } } if c.TracerProvider == nil { - return fmt.Errorf("opentelemetry tracer provider must not be nil") + return &ConfigurationError{ + Component: "Config", + Err: fmt.Errorf("opentelemetry tracer provider must not be nil"), + } } return nil @@ -322,3 +363,61 @@ func AddrFilterPrivate(maddrs []ma.Multiaddr) []ma.Multiaddr { func AddrFilterPublic(maddrs []ma.Multiaddr) []ma.Multiaddr { return ma.FilterAddrs(maddrs, func(maddr ma.Multiaddr) bool { return !manet.IsIPLoopback(maddr) }) } + +// QueryConfig contains the configuration options for queries managed by a [DHT]. +type QueryConfig struct { + // Concurrency defines the maximum number of in-flight queries that may be waiting for message responses at any one time. + Concurrency int + + // Timeout defines the time to wait before terminating a query that is not making progress + Timeout time.Duration + + // RequestConcurrency defines the maximum number of concurrent requests that each query may have in flight. + // The maximum number of concurrent requests is equal to [RequestConcurrency] multiplied by [Concurrency]. + RequestConcurrency int + + // RequestTimeout defines the time to wait before terminating a request to a node that has not responded. + RequestTimeout time.Duration +} + +// DefaultQueryConfig returns the default query configuration options for a DHT. +func DefaultQueryConfig() *QueryConfig { + return &QueryConfig{ + Concurrency: 3, // MAGIC + Timeout: 5 * time.Minute, // MAGIC + RequestConcurrency: 3, // MAGIC + RequestTimeout: time.Minute, // MAGIC + } +} + +// Validate checks the configuration options and returns an error if any have invalid values. +func (cfg *QueryConfig) Validate() error { + if cfg.Concurrency < 1 { + return &ConfigurationError{ + Component: "QueryConfig", + Err: fmt.Errorf("concurrency must be greater than zero"), + } + } + if cfg.Timeout < 1 { + return &ConfigurationError{ + Component: "QueryConfig", + Err: fmt.Errorf("timeout must be greater than zero"), + } + } + + if cfg.RequestConcurrency < 1 { + return &ConfigurationError{ + Component: "QueryConfig", + Err: fmt.Errorf("request concurrency must be greater than zero"), + } + } + + if cfg.RequestTimeout < 1 { + return &ConfigurationError{ + Component: "QueryConfig", + Err: fmt.Errorf("request timeout must be greater than zero"), + } + } + + return nil +} diff --git a/v2/config_test.go b/v2/config_test.go index 6787fff9..739216ab 100644 --- a/v2/config_test.go +++ b/v2/config_test.go @@ -20,15 +20,15 @@ func TestConfig_Validate(t *testing.T) { assert.Error(t, cfg.Validate()) }) - t.Run("nil Kademlia configuration", func(t *testing.T) { + t.Run("nil Query configuration", func(t *testing.T) { cfg := DefaultConfig() - cfg.Kademlia = nil + cfg.Query = nil assert.Error(t, cfg.Validate()) }) - t.Run("invalid Kademlia configuration", func(t *testing.T) { + t.Run("invalid Query configuration", func(t *testing.T) { cfg := DefaultConfig() - cfg.Kademlia.Clock = nil + cfg.Query.Concurrency = -1 assert.Error(t, cfg.Validate()) }) @@ -114,3 +114,46 @@ func TestConfig_Validate(t *testing.T) { assert.Error(t, cfg.Validate()) }) } + +func TestQueryConfig_Validate(t *testing.T) { + t.Run("default is valid", func(t *testing.T) { + cfg := DefaultQueryConfig() + assert.NoError(t, cfg.Validate()) + }) + + t.Run("concurrency positive", func(t *testing.T) { + cfg := DefaultQueryConfig() + + cfg.Concurrency = 0 + assert.Error(t, cfg.Validate()) + cfg.Concurrency = -1 + assert.Error(t, cfg.Validate()) + }) + + t.Run("timeout positive", func(t *testing.T) { + cfg := DefaultQueryConfig() + + cfg.Timeout = 0 + assert.Error(t, cfg.Validate()) + cfg.Timeout = -1 + assert.Error(t, cfg.Validate()) + }) + + t.Run("request concurrency positive", func(t *testing.T) { + cfg := DefaultQueryConfig() + + cfg.RequestConcurrency = 0 + assert.Error(t, cfg.Validate()) + cfg.RequestConcurrency = -1 + assert.Error(t, cfg.Validate()) + }) + + t.Run("request timeout positive", func(t *testing.T) { + cfg := DefaultQueryConfig() + + cfg.RequestTimeout = 0 + assert.Error(t, cfg.Validate()) + cfg.RequestTimeout = -1 + assert.Error(t, cfg.Validate()) + }) +} diff --git a/v2/dht.go b/v2/dht.go index 1fd857d7..0afeb408 100644 --- a/v2/dht.go +++ b/v2/dht.go @@ -16,8 +16,8 @@ import ( "github.com/plprobelab/go-kademlia/key" "golang.org/x/exp/slog" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/routing" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/routing" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) @@ -107,12 +107,16 @@ func New(h host.Host, cfg *Config) (*DHT, error) { } // instantiate a new Kademlia DHT coordinator. - coordCfg := cfg.Kademlia + coordCfg := coord.DefaultCoordinatorConfig() + coordCfg.QueryConcurrency = cfg.Query.Concurrency + coordCfg.QueryTimeout = cfg.Query.Timeout + coordCfg.RequestConcurrency = cfg.Query.RequestConcurrency + coordCfg.RequestTimeout = cfg.Query.RequestTimeout coordCfg.Clock = cfg.Clock coordCfg.MeterProvider = cfg.MeterProvider coordCfg.TracerProvider = cfg.TracerProvider - d.kad, err = coord.NewCoordinator(kadt.PeerID(d.host.ID()), &Router{host: h, ProtocolID: cfg.ProtocolID}, d.rt, coordCfg) + d.kad, err = coord.NewCoordinator(kadt.PeerID(d.host.ID()), &router{host: h, ProtocolID: cfg.ProtocolID}, d.rt, coordCfg) if err != nil { return nil, fmt.Errorf("new coordinator: %w", err) } diff --git a/v2/dht_test.go b/v2/dht_test.go index 2f45e82c..44d68bc4 100644 --- a/v2/dht_test.go +++ b/v2/dht_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) diff --git a/v2/errors.go b/v2/errors.go new file mode 100644 index 00000000..55c6f84b --- /dev/null +++ b/v2/errors.go @@ -0,0 +1,22 @@ +package dht + +import "fmt" + +// A ConfigurationError is returned when a component's configuration is found to be invalid or unusable. +type ConfigurationError struct { + Component string + Err error +} + +var _ error = (*ConfigurationError)(nil) + +func (e *ConfigurationError) Error() string { + if e.Err == nil { + return fmt.Sprintf("configuration error: %s", e.Component) + } + return fmt.Sprintf("configuration error: %s: %s", e.Component, e.Err.Error()) +} + +func (e *ConfigurationError) Unwrap() error { + return e.Err +} diff --git a/v2/coord/behaviour.go b/v2/internal/coord/behaviour.go similarity index 100% rename from v2/coord/behaviour.go rename to v2/internal/coord/behaviour.go diff --git a/v2/coord/behaviour_test.go b/v2/internal/coord/behaviour_test.go similarity index 100% rename from v2/coord/behaviour_test.go rename to v2/internal/coord/behaviour_test.go diff --git a/v2/coord/coordinator.go b/v2/internal/coord/coordinator.go similarity index 78% rename from v2/coord/coordinator.go rename to v2/internal/coord/coordinator.go index 63bad31c..d3d619b2 100644 --- a/v2/coord/coordinator.go +++ b/v2/internal/coord/coordinator.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "sync" + "sync/atomic" "time" "github.com/benbjohnson/clock" @@ -19,8 +20,8 @@ import ( "go.uber.org/zap/exp/zapslog" "golang.org/x/exp/slog" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/routing" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/query" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/routing" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) @@ -58,6 +59,15 @@ type Coordinator struct { // tele provides tracing and metric reporting capabilities tele *Telemetry + + // routingNotifierMu guards access to routingNotifier which may be changed during coordinator operation + routingNotifierMu sync.RWMutex + + // routingNotifier receives routing notifications + routingNotifier RoutingNotifier + + // lastQueryID holds the last numeric query id generated + lastQueryID atomic.Uint64 } type RoutingNotifier interface { @@ -65,8 +75,6 @@ type RoutingNotifier interface { } type CoordinatorConfig struct { - PeerstoreTTL time.Duration // duration for which a peer is kept in the peerstore - Clock clock.Clock // a clock that may replaced by a mock when testing QueryConcurrency int // the maximum number of queries that may be waiting for message responses at any one time @@ -79,8 +87,6 @@ type CoordinatorConfig struct { MeterProvider metric.MeterProvider // the meter provider to use when initialising metric instruments TracerProvider trace.TracerProvider // the tracer provider to use when initialising tracing - - RoutingNotifier RoutingNotifier // receives notifications of routing events } // Validate checks the configuration options and returns an error if any have invalid values. @@ -140,20 +146,12 @@ func (cfg *CoordinatorConfig) Validate() error { } } - if cfg.RoutingNotifier == nil { - return &kaderr.ConfigurationError{ - Component: "CoordinatorConfig", - Err: fmt.Errorf("routing notifier must not be nil"), - } - } - return nil } func DefaultCoordinatorConfig() *CoordinatorConfig { return &CoordinatorConfig{ Clock: clock.New(), - PeerstoreTTL: 10 * time.Minute, QueryConcurrency: 3, QueryTimeout: 5 * time.Minute, RequestConcurrency: 3, @@ -161,7 +159,6 @@ func DefaultCoordinatorConfig() *CoordinatorConfig { Logger: slog.New(zapslog.NewHandler(logging.Logger("coord").Desugar().Core())), MeterProvider: otel.GetMeterProvider(), TracerProvider: otel.GetTracerProvider(), - RoutingNotifier: nullRoutingNotifier{}, } } @@ -185,7 +182,7 @@ func NewCoordinator(self kadt.PeerID, rtr Router[kadt.Key, kadt.PeerID, *pb.Mess qpCfg.QueryConcurrency = cfg.RequestConcurrency qpCfg.RequestTimeout = cfg.RequestTimeout - qp, err := query.NewPool[kadt.Key](self, qpCfg) + qp, err := query.NewPool[kadt.Key, kadt.PeerID, *pb.Message](self, qpCfg) if err != nil { return nil, fmt.Errorf("query pool: %w", err) } @@ -197,7 +194,7 @@ func NewCoordinator(self kadt.PeerID, rtr Router[kadt.Key, kadt.PeerID, *pb.Mess bootstrapCfg.RequestConcurrency = cfg.RequestConcurrency bootstrapCfg.RequestTimeout = cfg.RequestTimeout - bootstrap, err := routing.NewBootstrap[kadt.Key](kadt.PeerID(self), bootstrapCfg) + bootstrap, err := routing.NewBootstrap(kadt.PeerID(self), bootstrapCfg) if err != nil { return nil, fmt.Errorf("bootstrap: %w", err) } @@ -245,6 +242,7 @@ func NewCoordinator(self kadt.PeerID, rtr Router[kadt.Key, kadt.PeerID, *pb.Mess networkBehaviour: networkBehaviour, routingBehaviour: routingBehaviour, queryBehaviour: queryBehaviour, + routingNotifier: nullRoutingNotifier{}, } go d.eventLoop(ctx) @@ -301,12 +299,21 @@ func (c *Coordinator) dispatchEvent(ctx context.Context, ev BehaviourEvent) { case RoutingCommand: c.routingBehaviour.Notify(ctx, ev) case RoutingNotification: - c.cfg.RoutingNotifier.Notify(ctx, ev) + c.routingNotifierMu.RLock() + rn := c.routingNotifier + c.routingNotifierMu.RUnlock() + rn.Notify(ctx, ev) default: panic(fmt.Sprintf("unexpected event: %T", ev)) } } +func (c *Coordinator) SetRoutingNotifier(rn RoutingNotifier) { + c.routingNotifierMu.Lock() + c.routingNotifier = rn + c.routingNotifierMu.Unlock() +} + // GetNode retrieves the node associated with the given node id from the DHT's local routing table. // If the node isn't found in the table, it returns ErrNodeNotFound. func (c *Coordinator) GetNode(ctx context.Context, id kadt.PeerID) (Node, error) { @@ -351,8 +358,18 @@ func (c *Coordinator) PutValue(ctx context.Context, r Value, q int) error { panic("not implemented") } -// Query traverses the DHT calling fn for each node visited. -func (c *Coordinator) Query(ctx context.Context, target kadt.Key, fn QueryFunc) (QueryStats, error) { +// QueryClosest starts a query that attempts to find the closest nodes to the target key. +// It returns the closest nodes found to the target key and statistics on the actions of the query. +// +// The supplied [QueryFunc] is called after each successful request to a node with the ID of the node, +// the response received from the find nodes request made to the node and the current query stats. The query +// terminates when [QueryFunc] returns an error or when the query has visited the configured minimum number +// of closest nodes (default 20) +// +// numResults specifies the minimum number of nodes to successfully contact before considering iteration complete. +// The query is considered to be exhausted when it has received responses from at least this number of nodes +// and there are no closer nodes remaining to be contacted. A default of 20 is used if this value is less than 1. +func (c *Coordinator) QueryClosest(ctx context.Context, target kadt.Key, fn QueryFunc, numResults int) ([]kadt.PeerID, QueryStats, error) { ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.Query") defer span.End() @@ -361,7 +378,7 @@ func (c *Coordinator) Query(ctx context.Context, target kadt.Key, fn QueryFunc) seeds, err := c.GetClosestNodes(ctx, target, 20) if err != nil { - return QueryStats{}, err + return nil, QueryStats{}, err } seedIDs := make([]kadt.PeerID, 0, len(seeds)) @@ -370,23 +387,79 @@ func (c *Coordinator) Query(ctx context.Context, target kadt.Key, fn QueryFunc) } waiter := NewWaiter[BehaviourEvent]() - queryID := query.QueryID("foo") // TODO: choose query ID + queryID := c.newQueryID() - cmd := &EventStartQuery{ + cmd := &EventStartFindCloserQuery{ QueryID: queryID, Target: target, KnownClosestNodes: seedIDs, Notify: waiter, + NumResults: numResults, } // queue the start of the query c.queryBehaviour.Notify(ctx, cmd) + return c.waitForQuery(ctx, queryID, waiter, fn) +} + +// QueryMessage starts a query that iterates over the closest nodes to the target key in the supplied message. +// The message is sent to each node that is visited. +// +// The supplied [QueryFunc] is called after each successful request to a node with the ID of the node, +// the response received from the find nodes request made to the node and the current query stats. The query +// terminates when [QueryFunc] returns an error or when the query has visited the configured minimum number +// of closest nodes (default 20) +// +// numResults specifies the minimum number of nodes to successfully contact before considering iteration complete. +// The query is considered to be exhausted when it has received responses from at least this number of nodes +// and there are no closer nodes remaining to be contacted. A default of 20 is used if this value is less than 1. +func (c *Coordinator) QueryMessage(ctx context.Context, msg *pb.Message, fn QueryFunc, numResults int) (QueryStats, error) { + ctx, span := c.tele.Tracer.Start(ctx, "Coordinator.QueryMessage") + defer span.End() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + if numResults < 1 { + numResults = 20 + } + + seeds, err := c.GetClosestNodes(ctx, msg.Target(), numResults) + if err != nil { + return QueryStats{}, err + } + + seedIDs := make([]kadt.PeerID, 0, len(seeds)) + for _, s := range seeds { + seedIDs = append(seedIDs, kadt.PeerID(s.ID())) + } + + waiter := NewWaiter[BehaviourEvent]() + queryID := c.newQueryID() + + cmd := &EventStartMessageQuery{ + QueryID: queryID, + Target: msg.Target(), + Message: msg, + KnownClosestNodes: seedIDs, + Notify: waiter, + NumResults: numResults, + } + + // queue the start of the query + c.queryBehaviour.Notify(ctx, cmd) + + _, stats, err := c.waitForQuery(ctx, queryID, waiter, fn) + return stats, err +} + +func (c *Coordinator) waitForQuery(ctx context.Context, queryID query.QueryID, waiter *Waiter[BehaviourEvent], fn QueryFunc) ([]kadt.PeerID, QueryStats, error) { var lastStats QueryStats for { select { case <-ctx.Done(): - return lastStats, ctx.Err() + return nil, lastStats, ctx.Err() case wev := <-waiter.Chan(): ctx, ev := wev.Ctx, wev.Event switch ev := ev.(type) { @@ -403,26 +476,22 @@ func (c *Coordinator) Query(ctx context.Context, target kadt.Key, fn QueryFunc) break } - err = fn(ctx, nh, lastStats) + err = fn(ctx, nh.ID(), ev.Response, lastStats) if errors.Is(err, ErrSkipRemaining) { // done c.queryBehaviour.Notify(ctx, &EventStopQuery{QueryID: queryID}) - return lastStats, nil - } - if errors.Is(err, ErrSkipNode) { - // TODO: don't add closer nodes from this node - break + return nil, lastStats, nil } if err != nil { // user defined error that terminates the query c.queryBehaviour.Notify(ctx, &EventStopQuery{QueryID: queryID}) - return lastStats, err + return nil, lastStats, err } case *EventQueryFinished: // query is done lastStats.Exhausted = true - return lastStats, nil + return ev.ClosestNodes, lastStats, nil default: panic(fmt.Sprintf("unexpected event: %T", ev)) @@ -490,6 +559,11 @@ func (c *Coordinator) NotifyNonConnectivity(ctx context.Context, id kadt.PeerID) return nil } +func (c *Coordinator) newQueryID() query.QueryID { + next := c.lastQueryID.Add(1) + return query.QueryID(fmt.Sprintf("%016x", next)) +} + // A BufferedRoutingNotifier is a [RoutingNotifier] that buffers [RoutingNotification] events and provides methods // to expect occurrences of specific events. It is designed for use in a test environment. type BufferedRoutingNotifier struct { diff --git a/v2/coord/coordinator_test.go b/v2/internal/coord/coordinator_test.go similarity index 90% rename from v2/coord/coordinator_test.go rename to v2/internal/coord/coordinator_test.go index ba32444e..c267b4a0 100644 --- a/v2/coord/coordinator_test.go +++ b/v2/internal/coord/coordinator_test.go @@ -4,18 +4,16 @@ import ( "context" "log" "testing" - "time" "github.com/benbjohnson/clock" "github.com/stretchr/testify/require" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/nettest" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/internal/nettest" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" + "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) -const peerstoreTTL = 10 * time.Minute - func TestConfigValidate(t *testing.T) { t.Run("default is valid", func(t *testing.T) { cfg := DefaultCoordinatorConfig() @@ -53,7 +51,7 @@ func TestConfigValidate(t *testing.T) { cfg.RequestConcurrency = 0 require.Error(t, cfg.Validate()) - cfg.QueryConcurrency = -1 + cfg.RequestConcurrency = -1 require.Error(t, cfg.Validate()) }) @@ -84,12 +82,6 @@ func TestConfigValidate(t *testing.T) { cfg.TracerProvider = nil require.Error(t, cfg.Validate()) }) - - t.Run("routing notifier not nil", func(t *testing.T) { - cfg := DefaultCoordinatorConfig() - cfg.RoutingNotifier = nil - require.Error(t, cfg.Validate()) - }) } func TestExhaustiveQuery(t *testing.T) { @@ -101,7 +93,6 @@ func TestExhaustiveQuery(t *testing.T) { ccfg := DefaultCoordinatorConfig() ccfg.Clock = clk - ccfg.PeerstoreTTL = peerstoreTTL // A (ids[0]) is looking for D (ids[3]) // A will first ask B, B will reply with C's address (and A's address) @@ -115,13 +106,13 @@ func TestExhaustiveQuery(t *testing.T) { visited := make(map[string]int) // Record the nodes as they are visited - qfn := func(ctx context.Context, node Node, stats QueryStats) error { - visited[node.ID().String()]++ + qfn := func(ctx context.Context, id kadt.PeerID, msg *pb.Message, stats QueryStats) error { + visited[id.String()]++ return nil } // Run a query to find the value - _, err = c.Query(ctx, target, qfn) + _, _, err = c.QueryClosest(ctx, target, qfn, 20) require.NoError(t, err) require.Equal(t, 3, len(visited)) @@ -140,10 +131,6 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { ccfg := DefaultCoordinatorConfig() ccfg.Clock = clk - ccfg.PeerstoreTTL = peerstoreTTL - - rn := NewBufferedRoutingNotifier() - ccfg.RoutingNotifier = rn // A (ids[0]) is looking for D (ids[3]) // A will first ask B, B will reply with C's address (and A's address) @@ -154,13 +141,16 @@ func TestRoutingUpdatedEventEmittedForCloserNodes(t *testing.T) { log.Fatalf("unexpected error creating coordinator: %v", err) } - qfn := func(ctx context.Context, node Node, stats QueryStats) error { + rn := NewBufferedRoutingNotifier() + c.SetRoutingNotifier(rn) + + qfn := func(ctx context.Context, id kadt.PeerID, msg *pb.Message, stats QueryStats) error { return nil } // Run a query to find the value target := nodes[3].NodeID.Key() - _, err = c.Query(ctx, target, qfn) + _, _, err = c.QueryClosest(ctx, target, qfn, 20) require.NoError(t, err) // the query run by the dht should have received a response from nodes[1] with closer nodes @@ -201,15 +191,14 @@ func TestBootstrap(t *testing.T) { ccfg := DefaultCoordinatorConfig() ccfg.Clock = clk - ccfg.PeerstoreTTL = peerstoreTTL - - rn := NewBufferedRoutingNotifier() - ccfg.RoutingNotifier = rn self := kadt.PeerID(nodes[0].NodeID) d, err := NewCoordinator(self, nodes[0].Router, nodes[0].RoutingTable, ccfg) require.NoError(t, err) + rn := NewBufferedRoutingNotifier() + d.SetRoutingNotifier(rn) + seeds := []kadt.PeerID{nodes[1].NodeID} err = d.Bootstrap(ctx, seeds) require.NoError(t, err) @@ -253,10 +242,6 @@ func TestIncludeNode(t *testing.T) { ccfg := DefaultCoordinatorConfig() ccfg.Clock = clk - ccfg.PeerstoreTTL = peerstoreTTL - - rn := NewBufferedRoutingNotifier() - ccfg.RoutingNotifier = rn candidate := nodes[len(nodes)-1].NodeID // not in nodes[0] routing table @@ -265,6 +250,8 @@ func TestIncludeNode(t *testing.T) { if err != nil { log.Fatalf("unexpected error creating dht: %v", err) } + rn := NewBufferedRoutingNotifier() + d.SetRoutingNotifier(rn) // the routing table should not contain the node yet _, err = d.GetNode(ctx, candidate) diff --git a/v2/coord/coretypes.go b/v2/internal/coord/coretypes.go similarity index 95% rename from v2/coord/coretypes.go rename to v2/internal/coord/coretypes.go index 0f72cebf..12c9ba26 100644 --- a/v2/coord/coretypes.go +++ b/v2/internal/coord/coretypes.go @@ -8,6 +8,7 @@ import ( "github.com/plprobelab/go-kademlia/kad" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" + "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) // Value is a value that may be stored in the DHT. @@ -49,7 +50,7 @@ var ( // Query stops entirely and returns that error. // // The stats argument contains statistics on the progress of the query so far. -type QueryFunc func(ctx context.Context, node Node, stats QueryStats) error +type QueryFunc func(ctx context.Context, id kadt.PeerID, resp *pb.Message, stats QueryStats) error type QueryStats struct { Start time.Time // Start is the time the query began executing. diff --git a/v2/coord/event.go b/v2/internal/coord/event.go similarity index 70% rename from v2/coord/event.go rename to v2/internal/coord/event.go index 663cfee9..a0037732 100644 --- a/v2/coord/event.go +++ b/v2/internal/coord/event.go @@ -1,8 +1,9 @@ package coord import ( - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/query" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" + "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) type BehaviourEvent interface { @@ -60,15 +61,39 @@ func (*EventOutboundGetCloserNodes) behaviourEvent() {} func (*EventOutboundGetCloserNodes) nodeHandlerRequest() {} func (*EventOutboundGetCloserNodes) networkCommand() {} -type EventStartQuery struct { +type EventOutboundSendMessage struct { + QueryID query.QueryID + To kadt.PeerID + Message *pb.Message + Notify Notify[BehaviourEvent] +} + +func (*EventOutboundSendMessage) behaviourEvent() {} +func (*EventOutboundSendMessage) nodeHandlerRequest() {} +func (*EventOutboundSendMessage) networkCommand() {} + +type EventStartMessageQuery struct { + QueryID query.QueryID + Target kadt.Key + Message *pb.Message + KnownClosestNodes []kadt.PeerID + Notify NotifyCloser[BehaviourEvent] + NumResults int // the minimum number of nodes to successfully contact before considering iteration complete +} + +func (*EventStartMessageQuery) behaviourEvent() {} +func (*EventStartMessageQuery) queryCommand() {} + +type EventStartFindCloserQuery struct { QueryID query.QueryID Target kadt.Key KnownClosestNodes []kadt.PeerID Notify NotifyCloser[BehaviourEvent] + NumResults int // the minimum number of nodes to successfully contact before considering iteration complete } -func (*EventStartQuery) behaviourEvent() {} -func (*EventStartQuery) queryCommand() {} +func (*EventStartFindCloserQuery) behaviourEvent() {} +func (*EventStartFindCloserQuery) queryCommand() {} type EventStopQuery struct { QueryID query.QueryID @@ -109,12 +134,36 @@ type EventGetCloserNodesFailure struct { func (*EventGetCloserNodesFailure) behaviourEvent() {} func (*EventGetCloserNodesFailure) nodeHandlerResponse() {} +// EventSendMessageSuccess notifies a behaviour that a SendMessage request, initiated by an +// [EventOutboundSendMessage] event has produced a successful response. +type EventSendMessageSuccess struct { + QueryID query.QueryID + To kadt.PeerID // To is the peer that the SendMessage request was sent to. + Response *pb.Message + CloserNodes []kadt.PeerID +} + +func (*EventSendMessageSuccess) behaviourEvent() {} +func (*EventSendMessageSuccess) nodeHandlerResponse() {} + +// EventSendMessageFailure notifies a behaviour that a SendMessage request, initiated by an +// [EventOutboundSendMessage] event has failed to produce a valid response. +type EventSendMessageFailure struct { + QueryID query.QueryID + To kadt.PeerID // To is the peer that the SendMessage request was sent to. + Target kadt.Key + Err error +} + +func (*EventSendMessageFailure) behaviourEvent() {} +func (*EventSendMessageFailure) nodeHandlerResponse() {} + // EventQueryProgressed is emitted by the coordinator when a query has received a // response from a node. type EventQueryProgressed struct { QueryID query.QueryID NodeID kadt.PeerID - Response Message + Response *pb.Message Stats query.QueryStats } @@ -123,8 +172,9 @@ func (*EventQueryProgressed) behaviourEvent() {} // EventQueryFinished is emitted by the coordinator when a query has finished, either through // running to completion or by being canceled. type EventQueryFinished struct { - QueryID query.QueryID - Stats query.QueryStats + QueryID query.QueryID + Stats query.QueryStats + ClosestNodes []kadt.PeerID } func (*EventQueryFinished) behaviourEvent() {} diff --git a/v2/coord/event_test.go b/v2/internal/coord/event_test.go similarity index 84% rename from v2/coord/event_test.go rename to v2/internal/coord/event_test.go index 2944be13..99abc2fc 100644 --- a/v2/coord/event_test.go +++ b/v2/internal/coord/event_test.go @@ -8,7 +8,8 @@ var ( ) var ( - _ QueryCommand = (*EventStartQuery)(nil) + _ QueryCommand = (*EventStartMessageQuery)(nil) + _ QueryCommand = (*EventStartFindCloserQuery)(nil) _ QueryCommand = (*EventStopQuery)(nil) ) diff --git a/v2/coord/internal/nettest/layouts.go b/v2/internal/coord/internal/nettest/layouts.go similarity index 100% rename from v2/coord/internal/nettest/layouts.go rename to v2/internal/coord/internal/nettest/layouts.go diff --git a/v2/coord/internal/nettest/routing.go b/v2/internal/coord/internal/nettest/routing.go similarity index 100% rename from v2/coord/internal/nettest/routing.go rename to v2/internal/coord/internal/nettest/routing.go diff --git a/v2/coord/internal/nettest/topology.go b/v2/internal/coord/internal/nettest/topology.go similarity index 97% rename from v2/coord/internal/nettest/topology.go rename to v2/internal/coord/internal/nettest/topology.go index 96d6380a..dda13ade 100644 --- a/v2/coord/internal/nettest/topology.go +++ b/v2/internal/coord/internal/nettest/topology.go @@ -8,7 +8,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/plprobelab/go-kademlia/network/address" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/routing" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/routing" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) diff --git a/v2/coord/internal/tiny/node.go b/v2/internal/coord/internal/tiny/node.go similarity index 92% rename from v2/coord/internal/tiny/node.go rename to v2/internal/coord/internal/tiny/node.go index 72c67887..2ad224cc 100644 --- a/v2/coord/internal/tiny/node.go +++ b/v2/internal/coord/internal/tiny/node.go @@ -12,6 +12,10 @@ type Node struct { key Key } +type Message struct { + Content string +} + var _ kad.NodeID[Key] = Node{} func NewNode(k Key) Node { diff --git a/v2/coord/internal/tiny/node_test.go b/v2/internal/coord/internal/tiny/node_test.go similarity index 100% rename from v2/coord/internal/tiny/node_test.go rename to v2/internal/coord/internal/tiny/node_test.go diff --git a/v2/coord/network.go b/v2/internal/coord/network.go similarity index 87% rename from v2/coord/network.go rename to v2/internal/coord/network.go index 72369b6f..d4087564 100644 --- a/v2/coord/network.go +++ b/v2/internal/coord/network.go @@ -9,7 +9,7 @@ import ( "go.opentelemetry.io/otel/trace" "golang.org/x/exp/slog" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/query" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) @@ -59,6 +59,16 @@ func (b *NetworkBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { } b.nodeHandlersMu.Unlock() nh.Notify(ctx, ev) + case *EventOutboundSendMessage: + b.nodeHandlersMu.Lock() + p := kadt.PeerID(ev.To) + nh, ok := b.nodeHandlers[p] + if !ok { + nh = NewNodeHandler(p, b.rtr, b.logger, b.tracer) + b.nodeHandlers[p] = nh + } + b.nodeHandlersMu.Unlock() + nh.Notify(ctx, ev) default: panic(fmt.Sprintf("unexpected dht event: %T", ev)) } @@ -160,6 +170,26 @@ func (h *NodeHandler) send(ctx context.Context, ev NodeHandlerRequest) bool { Target: cmd.Target, CloserNodes: nodes, }) + case *EventOutboundSendMessage: + if cmd.Notify == nil { + break + } + resp, err := h.rtr.SendMessage(ctx, h.self, cmd.Message) + if err != nil { + cmd.Notify.Notify(ctx, &EventSendMessageFailure{ + QueryID: cmd.QueryID, + To: h.self, + Err: fmt.Errorf("NodeHandler: %w", err), + }) + return false + } + + cmd.Notify.Notify(ctx, &EventSendMessageSuccess{ + QueryID: cmd.QueryID, + To: h.self, + Response: resp, + CloserNodes: resp.CloserNodes(), + }) default: panic(fmt.Sprintf("unexpected command type: %T", cmd)) } diff --git a/v2/coord/network_test.go b/v2/internal/coord/network_test.go similarity index 92% rename from v2/coord/network_test.go rename to v2/internal/coord/network_test.go index 6baacf53..924565b8 100644 --- a/v2/coord/network_test.go +++ b/v2/internal/coord/network_test.go @@ -8,7 +8,7 @@ import ( "go.opentelemetry.io/otel/trace" "golang.org/x/exp/slog" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/nettest" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/internal/nettest" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" ) diff --git a/v2/coord/query.go b/v2/internal/coord/query.go similarity index 67% rename from v2/coord/query.go rename to v2/internal/coord/query.go index b8ebb982..91cbab09 100644 --- a/v2/coord/query.go +++ b/v2/internal/coord/query.go @@ -6,14 +6,15 @@ import ( "sync" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" + "github.com/libp2p/go-libp2p-kad-dht/v2/pb" "go.opentelemetry.io/otel/trace" "golang.org/x/exp/slog" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/query" ) type PooledQueryBehaviour struct { - pool *query.Pool[kadt.Key, kadt.PeerID] + pool *query.Pool[kadt.Key, kadt.PeerID, *pb.Message] waiters map[query.QueryID]NotifyCloser[BehaviourEvent] pendingMu sync.Mutex @@ -24,7 +25,7 @@ type PooledQueryBehaviour struct { tracer trace.Tracer } -func NewPooledQueryBehaviour(pool *query.Pool[kadt.Key, kadt.PeerID], logger *slog.Logger, tracer trace.Tracer) *PooledQueryBehaviour { +func NewPooledQueryBehaviour(pool *query.Pool[kadt.Key, kadt.PeerID, *pb.Message], logger *slog.Logger, tracer trace.Tracer) *PooledQueryBehaviour { h := &PooledQueryBehaviour{ pool: pool, waiters: make(map[query.QueryID]NotifyCloser[BehaviourEvent]), @@ -44,8 +45,8 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { var cmd query.PoolEvent switch ev := ev.(type) { - case *EventStartQuery: - cmd = &query.EventPoolAddQuery[kadt.Key, kadt.PeerID]{ + case *EventStartFindCloserQuery: + cmd = &query.EventPoolAddFindCloserQuery[kadt.Key, kadt.PeerID]{ QueryID: ev.QueryID, Target: ev.Target, KnownClosestNodes: ev.KnownClosestNodes, @@ -53,6 +54,16 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { if ev.Notify != nil { p.waiters[ev.QueryID] = ev.Notify } + case *EventStartMessageQuery: + cmd = &query.EventPoolAddQuery[kadt.Key, kadt.PeerID, *pb.Message]{ + QueryID: ev.QueryID, + Target: ev.Target, + Message: ev.Message, + KnownClosestNodes: ev.KnownClosestNodes, + } + if ev.Notify != nil { + p.waiters[ev.QueryID] = ev.Notify + } case *EventStopQuery: cmd = &query.EventPoolStopQuery{ @@ -60,8 +71,6 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { } case *EventGetCloserNodesSuccess: - // TODO: add addresses for discovered nodes in DHT - for _, info := range ev.CloserNodes { // TODO: do this after advancing pool p.pending = append(p.pending, &EventAddNode{ @@ -77,7 +86,7 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { // Stats: stats, }) } - cmd = &query.EventPoolFindCloserResponse[kadt.Key, kadt.PeerID]{ + cmd = &query.EventPoolNodeResponse[kadt.Key, kadt.PeerID]{ NodeID: ev.To, QueryID: ev.QueryID, CloserNodes: ev.CloserNodes, @@ -88,7 +97,38 @@ func (p *PooledQueryBehaviour) Notify(ctx context.Context, ev BehaviourEvent) { ev.To, }) - cmd = &query.EventPoolFindCloserFailure[kadt.Key, kadt.PeerID]{ + cmd = &query.EventPoolNodeFailure[kadt.Key, kadt.PeerID]{ + NodeID: ev.To, + QueryID: ev.QueryID, + Error: ev.Err, + } + case *EventSendMessageSuccess: + for _, info := range ev.CloserNodes { + // TODO: do this after advancing pool + p.pending = append(p.pending, &EventAddNode{ + NodeID: info, + }) + } + waiter, ok := p.waiters[ev.QueryID] + if ok { + waiter.Notify(ctx, &EventQueryProgressed{ + NodeID: ev.To, + QueryID: ev.QueryID, + Response: ev.Response, + }) + } + cmd = &query.EventPoolNodeResponse[kadt.Key, kadt.PeerID]{ + NodeID: ev.To, + QueryID: ev.QueryID, + CloserNodes: ev.CloserNodes, + } + case *EventSendMessageFailure: + // queue an event that will notify the routing behaviour of a failed node + p.pending = append(p.pending, &EventNotifyNonConnectivity{ + ev.To, + }) + + cmd = &query.EventPoolNodeFailure[kadt.Key, kadt.PeerID]{ NodeID: ev.To, QueryID: ev.QueryID, Error: ev.Err, @@ -162,16 +202,24 @@ func (p *PooledQueryBehaviour) advancePool(ctx context.Context, ev query.PoolEve Target: st.Target, Notify: p, }, true + case *query.StatePoolSendMessage[kadt.Key, kadt.PeerID, *pb.Message]: + return &EventOutboundSendMessage{ + QueryID: st.QueryID, + To: st.NodeID, + Message: st.Message, + Notify: p, + }, true case *query.StatePoolWaitingAtCapacity: // nothing to do except wait for message response or timeout case *query.StatePoolWaitingWithCapacity: // nothing to do except wait for message response or timeout - case *query.StatePoolQueryFinished: + case *query.StatePoolQueryFinished[kadt.Key, kadt.PeerID]: waiter, ok := p.waiters[st.QueryID] if ok { waiter.Notify(ctx, &EventQueryFinished{ - QueryID: st.QueryID, - Stats: st.Stats, + QueryID: st.QueryID, + Stats: st.Stats, + ClosestNodes: st.ClosestNodes, }) waiter.Close() } diff --git a/v2/coord/query/iter.go b/v2/internal/coord/query/iter.go similarity index 100% rename from v2/coord/query/iter.go rename to v2/internal/coord/query/iter.go diff --git a/v2/coord/query/iter_test.go b/v2/internal/coord/query/iter_test.go similarity index 97% rename from v2/coord/query/iter_test.go rename to v2/internal/coord/query/iter_test.go index d5d02de9..cb987349 100644 --- a/v2/coord/query/iter_test.go +++ b/v2/internal/coord/query/iter_test.go @@ -7,7 +7,7 @@ import ( "github.com/plprobelab/go-kademlia/key" "github.com/stretchr/testify/require" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/tiny" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/internal/tiny" ) var ( diff --git a/v2/coord/query/node.go b/v2/internal/coord/query/node.go similarity index 100% rename from v2/coord/query/node.go rename to v2/internal/coord/query/node.go diff --git a/v2/coord/query/pool.go b/v2/internal/coord/query/pool.go similarity index 64% rename from v2/coord/query/pool.go rename to v2/internal/coord/query/pool.go index 2fffc706..a94566cb 100644 --- a/v2/coord/query/pool.go +++ b/v2/internal/coord/query/pool.go @@ -12,11 +12,13 @@ import ( "github.com/libp2p/go-libp2p-kad-dht/v2/tele" ) -type Pool[K kad.Key[K], N kad.NodeID[K]] struct { +type Message interface{} + +type Pool[K kad.Key[K], N kad.NodeID[K], M Message] struct { // self is the node id of the system the pool is running on self N - queries []*Query[K, N] - queryIndex map[QueryID]*Query[K, N] + queries []*Query[K, N, M] + queryIndex map[QueryID]*Query[K, N, M] // cfg is a copy of the optional configuration supplied to the pool cfg PoolConfig @@ -92,23 +94,23 @@ func DefaultPoolConfig() *PoolConfig { } } -func NewPool[K kad.Key[K], N kad.NodeID[K]](self N, cfg *PoolConfig) (*Pool[K, N], error) { +func NewPool[K kad.Key[K], N kad.NodeID[K], M Message](self N, cfg *PoolConfig) (*Pool[K, N, M], error) { if cfg == nil { cfg = DefaultPoolConfig() } else if err := cfg.Validate(); err != nil { return nil, err } - return &Pool[K, N]{ + return &Pool[K, N, M]{ self: self, cfg: *cfg, - queries: make([]*Query[K, N], 0), - queryIndex: make(map[QueryID]*Query[K, N]), + queries: make([]*Query[K, N, M], 0), + queryIndex: make(map[QueryID]*Query[K, N, M]), }, nil } // Advance advances the state of the pool by attempting to advance one of its queries -func (p *Pool[K, N]) Advance(ctx context.Context, ev PoolEvent) PoolState { +func (p *Pool[K, N, M]) Advance(ctx context.Context, ev PoolEvent) PoolState { ctx, span := tele.StartSpan(ctx, "Pool.Advance") defer span.End() @@ -120,8 +122,10 @@ func (p *Pool[K, N]) Advance(ctx context.Context, ev PoolEvent) PoolState { eventQueryID := InvalidQueryID switch tev := ev.(type) { - case *EventPoolAddQuery[K, N]: - p.addQuery(ctx, tev.QueryID, tev.Target, tev.KnownClosestNodes) + case *EventPoolAddFindCloserQuery[K, N]: + p.addFindCloserQuery(ctx, tev.QueryID, tev.Target, tev.KnownClosestNodes, tev.NumResults) + case *EventPoolAddQuery[K, N, M]: + p.addQuery(ctx, tev.QueryID, tev.Target, tev.Message, tev.KnownClosestNodes, tev.NumResults) // TODO: return error as state case *EventPoolStopQuery: if qry, ok := p.queryIndex[tev.QueryID]; ok { @@ -131,9 +135,9 @@ func (p *Pool[K, N]) Advance(ctx context.Context, ev PoolEvent) PoolState { } eventQueryID = qry.id } - case *EventPoolFindCloserResponse[K, N]: + case *EventPoolNodeResponse[K, N]: if qry, ok := p.queryIndex[tev.QueryID]; ok { - state, terminal := p.advanceQuery(ctx, qry, &EventQueryFindCloserResponse[K, N]{ + state, terminal := p.advanceQuery(ctx, qry, &EventQueryNodeResponse[K, N]{ NodeID: tev.NodeID, CloserNodes: tev.CloserNodes, }) @@ -142,9 +146,9 @@ func (p *Pool[K, N]) Advance(ctx context.Context, ev PoolEvent) PoolState { } eventQueryID = qry.id } - case *EventPoolFindCloserFailure[K, N]: + case *EventPoolNodeFailure[K, N]: if qry, ok := p.queryIndex[tev.QueryID]; ok { - state, terminal := p.advanceQuery(ctx, qry, &EventQueryFindCloserFailure[K, N]{ + state, terminal := p.advanceQuery(ctx, qry, &EventQueryNodeFailure[K, N]{ NodeID: tev.NodeID, Error: tev.Error, }) @@ -170,7 +174,7 @@ func (p *Pool[K, N]) Advance(ctx context.Context, ev PoolEvent) PoolState { continue } - state, terminal := p.advanceQuery(ctx, qry, nil) + state, terminal := p.advanceQuery(ctx, qry, &EventQueryPoll{}) if terminal { return state } @@ -188,7 +192,7 @@ func (p *Pool[K, N]) Advance(ctx context.Context, ev PoolEvent) PoolState { return &StatePoolIdle{} } -func (p *Pool[K, N]) advanceQuery(ctx context.Context, qry *Query[K, N], qev QueryEvent) (PoolState, bool) { +func (p *Pool[K, N, M]) advanceQuery(ctx context.Context, qry *Query[K, N, M], qev QueryEvent) (PoolState, bool) { state := qry.Advance(ctx, qev) switch st := state.(type) { case *StateQueryFindCloser[K, N]: @@ -199,11 +203,20 @@ func (p *Pool[K, N]) advanceQuery(ctx context.Context, qry *Query[K, N], qev Que NodeID: st.NodeID, Target: st.Target, }, true - case *StateQueryFinished: - p.removeQuery(qry.id) - return &StatePoolQueryFinished{ + case *StateQuerySendMessage[K, N, M]: + p.queriesInFlight++ + return &StatePoolSendMessage[K, N, M]{ QueryID: st.QueryID, Stats: st.Stats, + NodeID: st.NodeID, + Message: st.Message, + }, true + case *StateQueryFinished[K, N]: + p.removeQuery(qry.id) + return &StatePoolQueryFinished[K, N]{ + QueryID: st.QueryID, + Stats: st.Stats, + ClosestNodes: st.ClosestNodes, }, true case *StateQueryWaitingAtCapacity: elapsed := p.cfg.Clock.Since(qry.stats.Start) @@ -229,7 +242,7 @@ func (p *Pool[K, N]) advanceQuery(ctx context.Context, qry *Query[K, N], qev Que return nil, false } -func (p *Pool[K, N]) removeQuery(queryID QueryID) { +func (p *Pool[K, N, M]) removeQuery(queryID QueryID) { for i := range p.queries { if p.queries[i].id != queryID { continue @@ -245,18 +258,49 @@ func (p *Pool[K, N]) removeQuery(queryID QueryID) { // addQuery adds a query to the pool, returning the new query id // TODO: remove target argument and use msg.Target -func (p *Pool[K, N]) addQuery(ctx context.Context, queryID QueryID, target K, knownClosestNodes []N) error { +func (p *Pool[K, N, M]) addQuery(ctx context.Context, queryID QueryID, target K, msg M, knownClosestNodes []N, numResults int) error { if _, exists := p.queryIndex[queryID]; exists { return fmt.Errorf("query id already in use") } iter := NewClosestNodesIter[K, N](target) - qryCfg := DefaultQueryConfig[K]() + qryCfg := DefaultQueryConfig() qryCfg.Clock = p.cfg.Clock qryCfg.Concurrency = p.cfg.QueryConcurrency qryCfg.RequestTimeout = p.cfg.RequestTimeout - qry, err := NewQuery[K, N](p.self, queryID, target, iter, knownClosestNodes, qryCfg) + if numResults > 0 { + qryCfg.NumResults = numResults + } + + qry, err := NewQuery[K, N, M](p.self, queryID, target, msg, iter, knownClosestNodes, qryCfg) + if err != nil { + return fmt.Errorf("new query: %w", err) + } + + p.queries = append(p.queries, qry) + p.queryIndex[queryID] = qry + + return nil +} + +// addQuery adds a find closer query to the pool, returning the new query id +func (p *Pool[K, N, M]) addFindCloserQuery(ctx context.Context, queryID QueryID, target K, knownClosestNodes []N, numResults int) error { + if _, exists := p.queryIndex[queryID]; exists { + return fmt.Errorf("query id already in use") + } + iter := NewClosestNodesIter[K, N](target) + + qryCfg := DefaultQueryConfig() + qryCfg.Clock = p.cfg.Clock + qryCfg.Concurrency = p.cfg.QueryConcurrency + qryCfg.RequestTimeout = p.cfg.RequestTimeout + + if numResults > 0 { + qryCfg.NumResults = numResults + } + + qry, err := NewFindCloserQuery[K, N, M](p.self, queryID, target, iter, knownClosestNodes, qryCfg) if err != nil { return fmt.Errorf("new query: %w", err) } @@ -284,6 +328,14 @@ type StatePoolFindCloser[K kad.Key[K], N kad.NodeID[K]] struct { Stats QueryStats } +// StatePoolSendMessage indicates that a pool query wants to send a message to a node. +type StatePoolSendMessage[K kad.Key[K], N kad.NodeID[K], M Message] struct { + QueryID QueryID + NodeID N // the node to send the message to + Message M + Stats QueryStats +} + // StatePoolWaitingAtCapacity indicates that at least one query is waiting for results and the pool has reached // its maximum number of concurrent queries. type StatePoolWaitingAtCapacity struct{} @@ -293,9 +345,10 @@ type StatePoolWaitingAtCapacity struct{} type StatePoolWaitingWithCapacity struct{} // StatePoolQueryFinished indicates that a query has finished. -type StatePoolQueryFinished struct { - QueryID QueryID - Stats QueryStats +type StatePoolQueryFinished[K kad.Key[K], N kad.NodeID[K]] struct { + QueryID QueryID + Stats QueryStats + ClosestNodes []N } // StatePoolQueryTimeout indicates that a query has timed out. @@ -305,23 +358,34 @@ type StatePoolQueryTimeout struct { } // poolState() ensures that only Pool states can be assigned to the PoolState interface. -func (*StatePoolIdle) poolState() {} -func (*StatePoolFindCloser[K, N]) poolState() {} -func (*StatePoolWaitingAtCapacity) poolState() {} -func (*StatePoolWaitingWithCapacity) poolState() {} -func (*StatePoolQueryFinished) poolState() {} -func (*StatePoolQueryTimeout) poolState() {} +func (*StatePoolIdle) poolState() {} +func (*StatePoolFindCloser[K, N]) poolState() {} +func (*StatePoolSendMessage[K, N, M]) poolState() {} +func (*StatePoolWaitingAtCapacity) poolState() {} +func (*StatePoolWaitingWithCapacity) poolState() {} +func (*StatePoolQueryFinished[K, N]) poolState() {} +func (*StatePoolQueryTimeout) poolState() {} // PoolEvent is an event intended to advance the state of a pool. type PoolEvent interface { poolEvent() } -// EventPoolAddQuery is an event that attempts to add a new query -type EventPoolAddQuery[K kad.Key[K], N kad.NodeID[K]] struct { +// EventPoolAddQuery is an event that attempts to add a new query that finds closer nodes to a target key. +type EventPoolAddFindCloserQuery[K kad.Key[K], N kad.NodeID[K]] struct { + QueryID QueryID // the id to use for the new query + Target K // the target key for the query + KnownClosestNodes []N // an initial set of close nodes the query should use + NumResults int // the minimum number of nodes to successfully contact before considering iteration complete +} + +// EventPoolAddQuery is an event that attempts to add a new query that sends a message. +type EventPoolAddQuery[K kad.Key[K], N kad.NodeID[K], M Message] struct { QueryID QueryID // the id to use for the new query Target K // the target key for the query + Message M // message to be sent to each node KnownClosestNodes []N // an initial set of close nodes the query should use + NumResults int // the minimum number of nodes to successfully contact before considering iteration complete } // EventPoolStopQuery notifies a [Pool] to stop a query. @@ -329,15 +393,15 @@ type EventPoolStopQuery struct { QueryID QueryID // the id of the query that should be stopped } -// EventPoolFindCloserResponse notifies a [Pool] that an attempt to find closer nodes has received a successful response. -type EventPoolFindCloserResponse[K kad.Key[K], N kad.NodeID[K]] struct { +// EventPoolNodeResponse notifies a [Pool] that an attempt to contact a node has received a successful response. +type EventPoolNodeResponse[K kad.Key[K], N kad.NodeID[K]] struct { QueryID QueryID // the id of the query that sent the message NodeID N // the node the message was sent to CloserNodes []N // the closer nodes sent by the node } -// EventPoolFindCloserFailure notifies a [Pool] that an attempt to find closer nodes has failed. -type EventPoolFindCloserFailure[K kad.Key[K], N kad.NodeID[K]] struct { +// EventPoolNodeFailure notifies a [Pool] that an attempt to contact a node has failed. +type EventPoolNodeFailure[K kad.Key[K], N kad.NodeID[K]] struct { QueryID QueryID // the id of the query that sent the message NodeID N // the node the message was sent to Error error // the error that caused the failure, if any @@ -347,8 +411,9 @@ type EventPoolFindCloserFailure[K kad.Key[K], N kad.NodeID[K]] struct { type EventPoolPoll struct{} // poolEvent() ensures that only events accepted by a [Pool] can be assigned to the [PoolEvent] interface. -func (*EventPoolAddQuery[K, N]) poolEvent() {} +func (*EventPoolAddQuery[K, N, M]) poolEvent() {} +func (*EventPoolAddFindCloserQuery[K, N]) poolEvent() {} func (*EventPoolStopQuery) poolEvent() {} -func (*EventPoolFindCloserResponse[K, N]) poolEvent() {} -func (*EventPoolFindCloserFailure[K, N]) poolEvent() {} +func (*EventPoolNodeResponse[K, N]) poolEvent() {} +func (*EventPoolNodeFailure[K, N]) poolEvent() {} func (*EventPoolPoll) poolEvent() {} diff --git a/v2/coord/query/pool_test.go b/v2/internal/coord/query/pool_test.go similarity index 75% rename from v2/coord/query/pool_test.go rename to v2/internal/coord/query/pool_test.go index 2f6ab26f..d54c6d23 100644 --- a/v2/coord/query/pool_test.go +++ b/v2/internal/coord/query/pool_test.go @@ -8,7 +8,7 @@ import ( "github.com/plprobelab/go-kademlia/key" "github.com/stretchr/testify/require" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/tiny" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/internal/tiny" ) func TestPoolConfigValidate(t *testing.T) { @@ -71,7 +71,7 @@ func TestPoolStartsIdle(t *testing.T) { cfg.Clock = clk self := tiny.NewNode(0) - p, err := NewPool[tiny.Key](self, cfg) + p, err := NewPool[tiny.Key, tiny.Node, tiny.Message](self, cfg) require.NoError(t, err) state := p.Advance(ctx, &EventPoolPoll{}) @@ -85,21 +85,21 @@ func TestPoolStopWhenNoQueries(t *testing.T) { cfg.Clock = clk self := tiny.NewNode(0) - p, err := NewPool[tiny.Key](self, cfg) + p, err := NewPool[tiny.Key, tiny.Node, tiny.Message](self, cfg) require.NoError(t, err) state := p.Advance(ctx, &EventPoolPoll{}) require.IsType(t, &StatePoolIdle{}, state) } -func TestPoolAddQueryStartsIfCapacity(t *testing.T) { +func TestPoolAddFindCloserQueryStartsIfCapacity(t *testing.T) { ctx := context.Background() clk := clock.NewMock() cfg := DefaultPoolConfig() cfg.Clock = clk self := tiny.NewNode(0) - p, err := NewPool[tiny.Key](self, cfg) + p, err := NewPool[tiny.Key, tiny.Node, tiny.Message](self, cfg) require.NoError(t, err) target := tiny.Key(0b00000001) @@ -108,7 +108,7 @@ func TestPoolAddQueryStartsIfCapacity(t *testing.T) { queryID := QueryID("test") // first thing the new pool should do is start the query - state := p.Advance(ctx, &EventPoolAddQuery[tiny.Key, tiny.Node]{ + state := p.Advance(ctx, &EventPoolAddFindCloserQuery[tiny.Key, tiny.Node]{ QueryID: queryID, Target: target, KnownClosestNodes: []tiny.Node{a}, @@ -132,14 +132,55 @@ func TestPoolAddQueryStartsIfCapacity(t *testing.T) { require.IsType(t, &StatePoolWaitingWithCapacity{}, state) } -func TestPoolMessageResponse(t *testing.T) { +func TestPoolAddQueryStartsIfCapacity(t *testing.T) { + ctx := context.Background() + clk := clock.NewMock() + cfg := DefaultPoolConfig() + cfg.Clock = clk + + self := tiny.NewNode(0) + p, err := NewPool[tiny.Key, tiny.Node, tiny.Message](self, cfg) + require.NoError(t, err) + + target := tiny.Key(0b00000001) + a := tiny.NewNode(0b00000100) // 4 + + queryID := QueryID("test") + msg := tiny.Message{Content: "msg"} + // first thing the new pool should do is start the query + state := p.Advance(ctx, &EventPoolAddQuery[tiny.Key, tiny.Node, tiny.Message]{ + QueryID: queryID, + Target: target, + Message: msg, + KnownClosestNodes: []tiny.Node{a}, + }) + require.IsType(t, &StatePoolSendMessage[tiny.Key, tiny.Node, tiny.Message]{}, state) + + // the query should attempt to contact the node it was given + st := state.(*StatePoolSendMessage[tiny.Key, tiny.Node, tiny.Message]) + + // the query should be the one just added + require.Equal(t, queryID, st.QueryID) + + // the query should attempt to contact the node it was given + require.Equal(t, a, st.NodeID) + + // with the correct message + require.Equal(t, msg, st.Message) + + // now the pool reports that it is waiting + state = p.Advance(ctx, &EventPoolPoll{}) + require.IsType(t, &StatePoolWaitingWithCapacity{}, state) +} + +func TestPoolNodeResponse(t *testing.T) { ctx := context.Background() clk := clock.NewMock() cfg := DefaultPoolConfig() cfg.Clock = clk self := tiny.NewNode(0) - p, err := NewPool[tiny.Key](self, cfg) + p, err := NewPool[tiny.Key, tiny.Node, tiny.Message](self, cfg) require.NoError(t, err) target := tiny.Key(0b00000001) @@ -148,7 +189,7 @@ func TestPoolMessageResponse(t *testing.T) { queryID := QueryID("test") // first thing the new pool should do is start the query - state := p.Advance(ctx, &EventPoolAddQuery[tiny.Key, tiny.Node]{ + state := p.Advance(ctx, &EventPoolAddFindCloserQuery[tiny.Key, tiny.Node]{ QueryID: queryID, Target: target, KnownClosestNodes: []tiny.Node{a}, @@ -161,15 +202,15 @@ func TestPoolMessageResponse(t *testing.T) { require.Equal(t, a, st.NodeID) // notify query that node was contacted successfully, but no closer nodes - state = p.Advance(ctx, &EventPoolFindCloserResponse[tiny.Key, tiny.Node]{ + state = p.Advance(ctx, &EventPoolNodeResponse[tiny.Key, tiny.Node]{ QueryID: queryID, NodeID: a, }) // pool should respond that query has finished - require.IsType(t, &StatePoolQueryFinished{}, state) + require.IsType(t, &StatePoolQueryFinished[tiny.Key, tiny.Node]{}, state) - stf := state.(*StatePoolQueryFinished) + stf := state.(*StatePoolQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, queryID, stf.QueryID) require.Equal(t, 1, stf.Stats.Requests) require.Equal(t, 1, stf.Stats.Success) @@ -183,7 +224,7 @@ func TestPoolPrefersRunningQueriesOverNewOnes(t *testing.T) { cfg.Concurrency = 2 // allow two queries to run concurrently self := tiny.NewNode(0) - p, err := NewPool[tiny.Key](self, cfg) + p, err := NewPool[tiny.Key, tiny.Node, tiny.Message](self, cfg) require.NoError(t, err) target := tiny.Key(0b00000001) @@ -194,7 +235,7 @@ func TestPoolPrefersRunningQueriesOverNewOnes(t *testing.T) { // Add the first query queryID1 := QueryID("1") - state := p.Advance(ctx, &EventPoolAddQuery[tiny.Key, tiny.Node]{ + state := p.Advance(ctx, &EventPoolAddFindCloserQuery[tiny.Key, tiny.Node]{ QueryID: queryID1, Target: target, KnownClosestNodes: []tiny.Node{a, b, c, d}, @@ -208,7 +249,7 @@ func TestPoolPrefersRunningQueriesOverNewOnes(t *testing.T) { // Add the second query queryID2 := QueryID("2") - state = p.Advance(ctx, &EventPoolAddQuery[tiny.Key, tiny.Node]{ + state = p.Advance(ctx, &EventPoolAddFindCloserQuery[tiny.Key, tiny.Node]{ QueryID: queryID2, Target: target, KnownClosestNodes: []tiny.Node{a, b, c, d}, @@ -235,7 +276,7 @@ func TestPoolPrefersRunningQueriesOverNewOnes(t *testing.T) { require.Equal(t, a, st.NodeID) // notify first query that node was contacted successfully, but no closer nodes - state = p.Advance(ctx, &EventPoolFindCloserResponse[tiny.Key, tiny.Node]{ + state = p.Advance(ctx, &EventPoolNodeResponse[tiny.Key, tiny.Node]{ QueryID: queryID1, NodeID: a, }) @@ -247,7 +288,7 @@ func TestPoolPrefersRunningQueriesOverNewOnes(t *testing.T) { require.Equal(t, d, st.NodeID) // notify first query that next node was contacted successfully, but no closer nodes - state = p.Advance(ctx, &EventPoolFindCloserResponse[tiny.Key, tiny.Node]{ + state = p.Advance(ctx, &EventPoolNodeResponse[tiny.Key, tiny.Node]{ QueryID: queryID1, NodeID: b, }) @@ -268,7 +309,7 @@ func TestPoolRespectsConcurrency(t *testing.T) { cfg.QueryConcurrency = 1 // allow each query to have a single request in flight self := tiny.NewNode(0) - p, err := NewPool[tiny.Key](self, cfg) + p, err := NewPool[tiny.Key, tiny.Node, tiny.Message](self, cfg) require.NoError(t, err) target := tiny.Key(0b00000001) @@ -276,7 +317,7 @@ func TestPoolRespectsConcurrency(t *testing.T) { // Add the first query queryID1 := QueryID("1") - state := p.Advance(ctx, &EventPoolAddQuery[tiny.Key, tiny.Node]{ + state := p.Advance(ctx, &EventPoolAddFindCloserQuery[tiny.Key, tiny.Node]{ QueryID: queryID1, Target: target, KnownClosestNodes: []tiny.Node{a}, @@ -290,7 +331,7 @@ func TestPoolRespectsConcurrency(t *testing.T) { // Add the second query queryID2 := QueryID("2") - state = p.Advance(ctx, &EventPoolAddQuery[tiny.Key, tiny.Node]{ + state = p.Advance(ctx, &EventPoolAddFindCloserQuery[tiny.Key, tiny.Node]{ QueryID: queryID2, Target: target, KnownClosestNodes: []tiny.Node{a}, @@ -304,7 +345,7 @@ func TestPoolRespectsConcurrency(t *testing.T) { // Add a third query queryID3 := QueryID("3") - state = p.Advance(ctx, &EventPoolAddQuery[tiny.Key, tiny.Node]{ + state = p.Advance(ctx, &EventPoolAddFindCloserQuery[tiny.Key, tiny.Node]{ QueryID: queryID3, Target: target, KnownClosestNodes: []tiny.Node{a}, @@ -314,14 +355,14 @@ func TestPoolRespectsConcurrency(t *testing.T) { require.IsType(t, &StatePoolWaitingAtCapacity{}, state) // notify first query that next node was contacted successfully, but no closer nodes - state = p.Advance(ctx, &EventPoolFindCloserResponse[tiny.Key, tiny.Node]{ + state = p.Advance(ctx, &EventPoolNodeResponse[tiny.Key, tiny.Node]{ QueryID: queryID1, NodeID: a, }) // first query is out of nodes so it has finished - require.IsType(t, &StatePoolQueryFinished{}, state) - stf := state.(*StatePoolQueryFinished) + require.IsType(t, &StatePoolQueryFinished[tiny.Key, tiny.Node]{}, state) + stf := state.(*StatePoolQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, queryID1, stf.QueryID) // advancing pool again allows query 3 to start diff --git a/v2/coord/query/query.go b/v2/internal/coord/query/query.go similarity index 65% rename from v2/coord/query/query.go rename to v2/internal/coord/query/query.go index 9b0d87eb..b0003a83 100644 --- a/v2/coord/query/query.go +++ b/v2/internal/coord/query/query.go @@ -27,7 +27,7 @@ type QueryStats struct { } // QueryConfig specifies optional configuration for a Query -type QueryConfig[K kad.Key[K]] struct { +type QueryConfig struct { Concurrency int // the maximum number of concurrent requests that may be in flight NumResults int // the minimum number of nodes to successfully contact before considering iteration complete RequestTimeout time.Duration // the timeout for contacting a single node @@ -35,7 +35,7 @@ type QueryConfig[K kad.Key[K]] struct { } // Validate checks the configuration options and returns an error if any have invalid values. -func (cfg *QueryConfig[K]) Validate() error { +func (cfg *QueryConfig) Validate() error { if cfg.Clock == nil { return &kaderr.ConfigurationError{ Component: "QueryConfig", @@ -65,8 +65,8 @@ func (cfg *QueryConfig[K]) Validate() error { // DefaultQueryConfig returns the default configuration options for a Query. // Options may be overridden before passing to NewQuery -func DefaultQueryConfig[K kad.Key[K]]() *QueryConfig[K] { - return &QueryConfig[K]{ +func DefaultQueryConfig() *QueryConfig { + return &QueryConfig{ Concurrency: 3, NumResults: 20, RequestTimeout: time.Minute, @@ -74,27 +74,44 @@ func DefaultQueryConfig[K kad.Key[K]]() *QueryConfig[K] { } } -type Query[K kad.Key[K], N kad.NodeID[K]] struct { +type Query[K kad.Key[K], N kad.NodeID[K], M Message] struct { self N id QueryID // cfg is a copy of the optional configuration supplied to the query - cfg QueryConfig[K] + cfg QueryConfig - iter NodeIter[K, N] - target K - stats QueryStats + iter NodeIter[K, N] + target K + msg M + findCloser bool + stats QueryStats // finished indicates that that the query has completed its work or has been stopped. finished bool + // targetNodes is the set of responsive nodes thought to be closest to the target. + // It is populated once the query has been marked as finished. + // This will contain up to [QueryConfig.NumResults] nodes. + targetNodes []N + // inFlight is number of requests in flight, will be <= concurrency inFlight int } -func NewQuery[K kad.Key[K], N kad.NodeID[K]](self N, id QueryID, target K, iter NodeIter[K, N], knownClosestNodes []N, cfg *QueryConfig[K]) (*Query[K, N], error) { +func NewFindCloserQuery[K kad.Key[K], N kad.NodeID[K], M Message](self N, id QueryID, target K, iter NodeIter[K, N], knownClosestNodes []N, cfg *QueryConfig) (*Query[K, N, M], error) { + var empty M + q, err := NewQuery[K, N, M](self, id, target, empty, iter, knownClosestNodes, cfg) + if err != nil { + return nil, err + } + q.findCloser = true + return q, nil +} + +func NewQuery[K kad.Key[K], N kad.NodeID[K], M Message](self N, id QueryID, target K, msg M, iter NodeIter[K, N], knownClosestNodes []N, cfg *QueryConfig) (*Query[K, N, M], error) { if cfg == nil { - cfg = DefaultQueryConfig[K]() + cfg = DefaultQueryConfig() } else if err := cfg.Validate(); err != nil { return nil, err } @@ -110,16 +127,17 @@ func NewQuery[K kad.Key[K], N kad.NodeID[K]](self N, id QueryID, target K, iter }) } - return &Query[K, N]{ + return &Query[K, N, M]{ self: self, id: id, cfg: *cfg, + msg: msg, iter: iter, target: target, }, nil } -func (q *Query[K, N]) Advance(ctx context.Context, ev QueryEvent) (out QueryState) { +func (q *Query[K, N, M]) Advance(ctx context.Context, ev QueryEvent) (out QueryState) { ctx, span := tele.StartSpan(ctx, "Query.Advance", trace.WithAttributes(tele.AttrInEvent(ev))) defer func() { span.SetAttributes(tele.AttrOutEvent(out)) @@ -127,26 +145,29 @@ func (q *Query[K, N]) Advance(ctx context.Context, ev QueryEvent) (out QueryStat }() if q.finished { - return &StateQueryFinished{ - QueryID: q.id, - Stats: q.stats, + return &StateQueryFinished[K, N]{ + QueryID: q.id, + Stats: q.stats, + ClosestNodes: q.targetNodes, } } switch tev := ev.(type) { case *EventQueryCancel: - q.markFinished() - return &StateQueryFinished{ - QueryID: q.id, - Stats: q.stats, + q.markFinished(ctx) + return &StateQueryFinished[K, N]{ + QueryID: q.id, + Stats: q.stats, + ClosestNodes: q.targetNodes, } - case *EventQueryFindCloserResponse[K, N]: - q.onMessageResponse(ctx, tev.NodeID, tev.CloserNodes) - case *EventQueryFindCloserFailure[K, N]: + case *EventQueryNodeResponse[K, N]: + q.onNodeResponse(ctx, tev.NodeID, tev.CloserNodes) + case *EventQueryNodeFailure[K, N]: span.RecordError(tev.Error) - q.onMessageFailure(ctx, tev.NodeID) - case nil: - // TEMPORARY: no event to process + q.onNodeFailure(ctx, tev.NodeID) + case *EventQueryPoll: + // no event to process + default: panic(fmt.Sprintf("unexpected event: %T", tev)) } @@ -191,10 +212,11 @@ func (q *Query[K, N]) Advance(ctx context.Context, ev QueryEvent) (out QueryStat // If the iterator is not progressing then it doesn't expect any more nodes to be added to the list. // If it has contacted at least NumResults nodes successfully then the iteration is done. if !progressing && successes >= q.cfg.NumResults { - q.markFinished() - returnState = &StateQueryFinished{ - QueryID: q.id, - Stats: q.stats, + q.markFinished(ctx) + returnState = &StateQueryFinished[K, N]{ + QueryID: q.id, + Stats: q.stats, + ClosestNodes: q.targetNodes, } return true } @@ -208,11 +230,21 @@ func (q *Query[K, N]) Advance(ctx context.Context, ev QueryEvent) (out QueryStat if q.stats.Start.IsZero() { q.stats.Start = q.cfg.Clock.Now() } - returnState = &StateQueryFindCloser[K, N]{ - NodeID: ni.NodeID, - QueryID: q.id, - Stats: q.stats, - Target: q.target, + + if q.findCloser { + returnState = &StateQueryFindCloser[K, N]{ + NodeID: ni.NodeID, + QueryID: q.id, + Stats: q.stats, + Target: q.target, + } + } else { + returnState = &StateQuerySendMessage[K, N, M]{ + NodeID: ni.NodeID, + QueryID: q.id, + Stats: q.stats, + Message: q.msg, + } } return true } @@ -248,22 +280,36 @@ func (q *Query[K, N]) Advance(ctx context.Context, ev QueryEvent) (out QueryStat // The iterator is finished because all available nodes have been contacted // and the iterator is not waiting for any more results. - q.markFinished() - return &StateQueryFinished{ - QueryID: q.id, - Stats: q.stats, + q.markFinished(ctx) + return &StateQueryFinished[K, N]{ + QueryID: q.id, + Stats: q.stats, + ClosestNodes: q.targetNodes, } } -func (q *Query[K, N]) markFinished() { +func (q *Query[K, N, M]) markFinished(ctx context.Context) { q.finished = true if q.stats.End.IsZero() { q.stats.End = q.cfg.Clock.Now() } + + q.targetNodes = make([]N, 0, q.cfg.NumResults) + + q.iter.Each(ctx, func(ctx context.Context, ni *NodeStatus[K, N]) bool { + switch ni.State.(type) { + case *StateNodeSucceeded: + q.targetNodes = append(q.targetNodes, ni.NodeID) + if len(q.targetNodes) >= q.cfg.NumResults { + return true + } + } + return false + }) } -// onMessageResponse processes the result of a successful response received from a node. -func (q *Query[K, N]) onMessageResponse(ctx context.Context, node N, closer []N) { +// onNodeResponse processes the result of a successful response received from a node. +func (q *Query[K, N, M]) onNodeResponse(ctx context.Context, node N, closer []N) { ni, found := q.iter.Find(node.Key()) if !found { // got a rogue message @@ -303,8 +349,8 @@ func (q *Query[K, N]) onMessageResponse(ctx context.Context, node N, closer []N) ni.State = &StateNodeSucceeded{} } -// onMessageFailure processes the result of a failed attempt to contact a node. -func (q *Query[K, N]) onMessageFailure(ctx context.Context, node N) { +// onNodeFailure processes the result of a failed attempt to contact a node. +func (q *Query[K, N, M]) onNodeFailure(ctx context.Context, node N) { ni, found := q.iter.Find(node.Key()) if !found { // got a rogue message @@ -338,9 +384,10 @@ type QueryState interface { } // StateQueryFinished indicates that the [Query] has finished. -type StateQueryFinished struct { - QueryID QueryID - Stats QueryStats +type StateQueryFinished[K kad.Key[K], N kad.NodeID[K]] struct { + QueryID QueryID + Stats QueryStats + ClosestNodes []N // contains the closest nodes to the target key that were found } // StateQueryFindCloser indicates that the [Query] wants to send a find closer nodes message to a node. @@ -351,6 +398,14 @@ type StateQueryFindCloser[K kad.Key[K], N kad.NodeID[K]] struct { Stats QueryStats } +// StateQuerySendMessage indicates that the [Query] wants to send a message to a node. +type StateQuerySendMessage[K kad.Key[K], N kad.NodeID[K], M Message] struct { + QueryID QueryID + NodeID N // the node to send the message to + Message M + Stats QueryStats +} + // StateQueryWaitingAtCapacity indicates that the [Query] is waiting for results and is at capacity. type StateQueryWaitingAtCapacity struct { QueryID QueryID @@ -364,10 +419,11 @@ type StateQueryWaitingWithCapacity struct { } // queryState() ensures that only [Query] states can be assigned to a QueryState. -func (*StateQueryFinished) queryState() {} -func (*StateQueryFindCloser[K, N]) queryState() {} -func (*StateQueryWaitingAtCapacity) queryState() {} -func (*StateQueryWaitingWithCapacity) queryState() {} +func (*StateQueryFinished[K, N]) queryState() {} +func (*StateQueryFindCloser[K, N]) queryState() {} +func (*StateQuerySendMessage[K, N, M]) queryState() {} +func (*StateQueryWaitingAtCapacity) queryState() {} +func (*StateQueryWaitingWithCapacity) queryState() {} type QueryEvent interface { queryEvent() @@ -376,19 +432,23 @@ type QueryEvent interface { // EventQueryMessageResponse notifies a query to stop all work and enter the finished state. type EventQueryCancel struct{} -// EventQueryFindCloserResponse notifies a [Query] that an attempt to find closer nodes has received a successful response. -type EventQueryFindCloserResponse[K kad.Key[K], N kad.NodeID[K]] struct { +// EventQueryNodeResponse notifies a [Query] that an attempt to contact a node has received a successful response. +type EventQueryNodeResponse[K kad.Key[K], N kad.NodeID[K]] struct { NodeID N // the node the message was sent to CloserNodes []N // the closer nodes sent by the node } -// EventQueryFindCloserFailure notifies a [Query] that an attempt to find closer nodes has failed. -type EventQueryFindCloserFailure[K kad.Key[K], N kad.NodeID[K]] struct { +// EventQueryNodeFailure notifies a [Query] that an attempt to to contact a node has failed. +type EventQueryNodeFailure[K kad.Key[K], N kad.NodeID[K]] struct { NodeID N // the node the message was sent to Error error // the error that caused the failure, if any } +// EventQueryPoll is an event that signals a [Query] that it can perform housekeeping work. +type EventQueryPoll struct{} + // queryEvent() ensures that only events accepted by [Query] can be assigned to a [QueryEvent]. -func (*EventQueryCancel) queryEvent() {} -func (*EventQueryFindCloserResponse[K, N]) queryEvent() {} -func (*EventQueryFindCloserFailure[K, N]) queryEvent() {} +func (*EventQueryCancel) queryEvent() {} +func (*EventQueryNodeResponse[K, N]) queryEvent() {} +func (*EventQueryNodeFailure[K, N]) queryEvent() {} +func (*EventQueryPoll) queryEvent() {} diff --git a/v2/coord/query/query_test.go b/v2/internal/coord/query/query_test.go similarity index 64% rename from v2/coord/query/query_test.go rename to v2/internal/coord/query/query_test.go index 49564dcd..6cb1d9d1 100644 --- a/v2/coord/query/query_test.go +++ b/v2/internal/coord/query/query_test.go @@ -9,23 +9,23 @@ import ( "github.com/plprobelab/go-kademlia/key" "github.com/stretchr/testify/require" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/tiny" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/internal/tiny" ) func TestQueryConfigValidate(t *testing.T) { t.Run("default is valid", func(t *testing.T) { - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() require.NoError(t, cfg.Validate()) }) t.Run("clock is not nil", func(t *testing.T) { - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = nil require.Error(t, cfg.Validate()) }) t.Run("request timeout positive", func(t *testing.T) { - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.RequestTimeout = 0 require.Error(t, cfg.Validate()) cfg.RequestTimeout = -1 @@ -33,7 +33,7 @@ func TestQueryConfigValidate(t *testing.T) { }) t.Run("concurrency positive", func(t *testing.T) { - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Concurrency = 0 require.Error(t, cfg.Validate()) cfg.Concurrency = -1 @@ -41,7 +41,7 @@ func TestQueryConfigValidate(t *testing.T) { }) t.Run("num results positive", func(t *testing.T) { - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.NumResults = 0 require.Error(t, cfg.Validate()) cfg.NumResults = -1 @@ -62,17 +62,17 @@ func TestQueryMessagesNode(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is request to send a message to the node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) // check that we are messaging the correct node with the right message @@ -85,14 +85,14 @@ func TestQueryMessagesNode(t *testing.T) { require.Equal(t, 0, st.Stats.Success) // advancing now reports that the query is waiting for a response but its underlying query still has capacity - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingWithCapacity{}, state) stw := state.(*StateQueryWaitingWithCapacity) require.Equal(t, 1, stw.Stats.Requests) require.Equal(t, 0, st.Stats.Success) } -func TestQueryMessagesNearest(t *testing.T) { +func TestQueryFindCloserNearest(t *testing.T) { ctx := context.Background() target := tiny.Key(0b00000011) @@ -111,17 +111,17 @@ func TestQueryMessagesNearest(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is message the nearest node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) // check that we are contacting the nearest node first @@ -142,26 +142,26 @@ func TestQueryCancelFinishesQuery(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is request to send a message to the node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) clk.Add(time.Second) // cancel the query state = qry.Advance(ctx, &EventQueryCancel{}) - require.IsType(t, &StateQueryFinished{}, state) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) - stf := state.(*StateQueryFinished) + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, 1, stf.Stats.Requests) // no successful responses were received before query was cancelled @@ -185,20 +185,20 @@ func TestQueryNoClosest(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) clk := clock.NewMock() - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // query is finished because there were no nodes to contat - state := qry.Advance(ctx, nil) - require.IsType(t, &StateQueryFinished{}, state) + state := qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) - stf := state.(*StateQueryFinished) + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) // no requests were made require.Equal(t, 0, stf.Stats.Requests) @@ -228,32 +228,32 @@ func TestQueryWaitsAtCapacity(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = 2 queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is request to send a message to the node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, a, st.NodeID) require.Equal(t, 1, st.Stats.Requests) // advancing sends the message to the next node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, b, st.NodeID) require.Equal(t, 2, st.Stats.Requests) // advancing now reports that the query is waiting at capacity since there are 2 messages in flight - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingAtCapacity{}, state) stw := state.(*StateQueryWaitingAtCapacity) @@ -281,7 +281,7 @@ func TestQueryTimedOutNodeMakesCapacity(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.RequestTimeout = 3 * time.Minute cfg.Concurrency = len(knownNodes) - 1 // one less than the number of initial nodes @@ -289,11 +289,11 @@ func TestQueryTimedOutNodeMakesCapacity(t *testing.T) { queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the nearest node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, a, st.NodeID) @@ -306,7 +306,7 @@ func TestQueryTimedOutNodeMakesCapacity(t *testing.T) { clk.Add(time.Minute) // while the query has capacity the query should contact the next nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, b, st.NodeID) @@ -319,7 +319,7 @@ func TestQueryTimedOutNodeMakesCapacity(t *testing.T) { clk.Add(time.Minute) // while the query has capacity the query should contact the second nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, c, st.NodeID) @@ -332,7 +332,7 @@ func TestQueryTimedOutNodeMakesCapacity(t *testing.T) { clk.Add(time.Minute) // the query should be at capacity - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingAtCapacity{}, state) stwa := state.(*StateQueryWaitingAtCapacity) require.Equal(t, 3, stwa.Stats.Requests) @@ -343,7 +343,7 @@ func TestQueryTimedOutNodeMakesCapacity(t *testing.T) { clk.Add(time.Minute) // the first node request should have timed out, making capacity for the last node to attempt connection - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, d, st.NodeID) @@ -357,7 +357,7 @@ func TestQueryTimedOutNodeMakesCapacity(t *testing.T) { clk.Add(time.Minute) // advancing now makes more capacity - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingWithCapacity{}, state) stww := state.(*StateQueryWaitingWithCapacity) @@ -366,7 +366,7 @@ func TestQueryTimedOutNodeMakesCapacity(t *testing.T) { require.Equal(t, 2, stww.Stats.Failure) } -func TestQueryMessageResponseMakesCapacity(t *testing.T) { +func TestQueryFindCloserResponseMakesCapacity(t *testing.T) { ctx := context.Background() target := tiny.Key(0b00000001) @@ -387,18 +387,18 @@ func TestQueryMessageResponseMakesCapacity(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = len(knownNodes) - 1 // one less than the number of initial nodes queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the nearest node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, a, st.NodeID) @@ -408,7 +408,7 @@ func TestQueryMessageResponseMakesCapacity(t *testing.T) { require.Equal(t, 0, stwm.Stats.Failure) // while the query has capacity the query should contact the next nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, b, st.NodeID) @@ -418,7 +418,7 @@ func TestQueryMessageResponseMakesCapacity(t *testing.T) { require.Equal(t, 0, stwm.Stats.Failure) // while the query has capacity the query should contact the second nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, c, st.NodeID) @@ -428,11 +428,11 @@ func TestQueryMessageResponseMakesCapacity(t *testing.T) { require.Equal(t, 0, stwm.Stats.Failure) // the query should be at capacity - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingAtCapacity{}, state) // notify query that first node was contacted successfully, now node d can be contacted - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{NodeID: a}) + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{NodeID: a}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, d, st.NodeID) @@ -442,7 +442,7 @@ func TestQueryMessageResponseMakesCapacity(t *testing.T) { require.Equal(t, 0, stwm.Stats.Failure) // the query should be at capacity again - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingAtCapacity{}, state) stwa := state.(*StateQueryWaitingAtCapacity) require.Equal(t, 4, stwa.Stats.Requests) @@ -471,28 +471,28 @@ func TestQueryCloserNodesAreAddedToIteration(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = 2 queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the first node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, d, st.NodeID) // advancing reports query has capacity - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingWithCapacity{}, state) // notify query that first node was contacted successfully, with closer nodes - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: d, CloserNodes: []tiny.Node{ b, @@ -527,34 +527,34 @@ func TestQueryCloserNodesIgnoresDuplicates(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = 2 queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the first node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, a, st.NodeID) // next the query attempts to contact second nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, d, st.NodeID) // advancing reports query has no capacity - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingAtCapacity{}, state) // notify query that second node was contacted successfully, with closer nodes - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: d, CloserNodes: []tiny.Node{ b, @@ -581,27 +581,27 @@ func TestQueryCancelFinishesIteration(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = 2 queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the first node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, a, st.NodeID) // cancel the query so it is now finished state = qry.Advance(ctx, &EventQueryCancel{}) - require.IsType(t, &StateQueryFinished{}, state) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) - stf := state.(*StateQueryFinished) + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, 0, stf.Stats.Success) } @@ -619,43 +619,43 @@ func TestQueryFinishedIgnoresLaterEvents(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = 2 queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the first node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, b, st.NodeID) // cancel the query so it is now finished state = qry.Advance(ctx, &EventQueryCancel{}) - require.IsType(t, &StateQueryFinished{}, state) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) // no successes - stf := state.(*StateQueryFinished) + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, 1, stf.Stats.Requests) require.Equal(t, 0, stf.Stats.Success) require.Equal(t, 0, stf.Stats.Failure) // notify query that second node was contacted successfully, with closer nodes - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: b, CloserNodes: []tiny.Node{a}, }) // query remains finished - require.IsType(t, &StateQueryFinished{}, state) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) // still no successes since contact message was after query had been cancelled - stf = state.(*StateQueryFinished) + stf = state.(*StateQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, 1, stf.Stats.Requests) require.Equal(t, 0, stf.Stats.Success) require.Equal(t, 0, stf.Stats.Failure) @@ -676,18 +676,18 @@ func TestQueryWithCloserIterIgnoresMessagesFromUnknownNodes(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = 2 queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the first node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, c, st.NodeID) @@ -697,7 +697,7 @@ func TestQueryWithCloserIterIgnoresMessagesFromUnknownNodes(t *testing.T) { require.Equal(t, 0, stwm.Stats.Failure) // notify query that second node was contacted successfully, with closer nodes - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: b, CloserNodes: []tiny.Node{a}, }) @@ -727,7 +727,7 @@ func TestQueryWithCloserIterFinishesWhenNumResultsReached(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = 4 cfg.NumResults = 2 @@ -735,23 +735,23 @@ func TestQueryWithCloserIterFinishesWhenNumResultsReached(t *testing.T) { queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // contact first node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, a, st.NodeID) // contact second node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, b, st.NodeID) // notify query that first node was contacted successfully - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: a, }) @@ -761,12 +761,15 @@ func TestQueryWithCloserIterFinishesWhenNumResultsReached(t *testing.T) { require.Equal(t, c, st.NodeID) // notify query that second node was contacted successfully - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: b, }) // query has finished since it contacted the NumResults closest nodes - require.IsType(t, &StateQueryFinished{}, state) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) + + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) + require.Equal(t, 2, len(stf.ClosestNodes)) } func TestQueryWithCloserIterContinuesUntilNumResultsReached(t *testing.T) { @@ -784,7 +787,7 @@ func TestQueryWithCloserIterContinuesUntilNumResultsReached(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = 4 cfg.NumResults = 2 @@ -792,18 +795,18 @@ func TestQueryWithCloserIterContinuesUntilNumResultsReached(t *testing.T) { queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // contact first node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, c, st.NodeID) // notify query that node was contacted successfully and tell it about // a closer one - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: c, CloserNodes: []tiny.Node{b}, }) @@ -815,7 +818,7 @@ func TestQueryWithCloserIterContinuesUntilNumResultsReached(t *testing.T) { // notify query that node was contacted successfully and tell it about // a closer one - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: b, CloserNodes: []tiny.Node{a}, }) @@ -828,14 +831,14 @@ func TestQueryWithCloserIterContinuesUntilNumResultsReached(t *testing.T) { require.Equal(t, a, st.NodeID) // notify query that second node was contacted successfully - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: a, }) // query has finished since it contacted the NumResults closest nodes - require.IsType(t, &StateQueryFinished{}, state) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) - stf := state.(*StateQueryFinished) + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, 3, stf.Stats.Success) } @@ -857,50 +860,50 @@ func TestQueryNotContactedMakesCapacity(t *testing.T) { iter := NewSequentialIter[tiny.Key, tiny.Node]() clk := clock.NewMock() - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = len(knownNodes) - 1 // one less than the number of initial nodes queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the nearest node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, a, st.NodeID) // while the query has capacity the query should contact the next nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, b, st.NodeID) // while the query has capacity the query should contact the second nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, c, st.NodeID) // the query should be at capacity - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingAtCapacity{}, state) // notify query that first node was not contacted, now node d can be contacted - state = qry.Advance(ctx, &EventQueryFindCloserFailure[tiny.Key, tiny.Node]{NodeID: a}) + state = qry.Advance(ctx, &EventQueryNodeFailure[tiny.Key, tiny.Node]{NodeID: a}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, d, st.NodeID) // the query should be at capacity again - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingAtCapacity{}, state) } -func TestQueryAllNotContactedFinishes(t *testing.T) { +func TestFindCloserQueryAllNotContactedFinishes(t *testing.T) { ctx := context.Background() target := tiny.Key(0b00000001) @@ -915,47 +918,47 @@ func TestQueryAllNotContactedFinishes(t *testing.T) { iter := NewSequentialIter[tiny.Key, tiny.Node]() - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = len(knownNodes) // allow all to be contacted at once queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the nearest node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) // while the query has capacity the query should contact the next nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) // while the query has capacity the query should contact the third nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) // the query should be at capacity - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingAtCapacity{}, state) // notify query that first node was not contacted - state = qry.Advance(ctx, &EventQueryFindCloserFailure[tiny.Key, tiny.Node]{NodeID: a}) + state = qry.Advance(ctx, &EventQueryNodeFailure[tiny.Key, tiny.Node]{NodeID: a}) require.IsType(t, &StateQueryWaitingWithCapacity{}, state) // notify query that second node was not contacted - state = qry.Advance(ctx, &EventQueryFindCloserFailure[tiny.Key, tiny.Node]{NodeID: b}) + state = qry.Advance(ctx, &EventQueryNodeFailure[tiny.Key, tiny.Node]{NodeID: b}) require.IsType(t, &StateQueryWaitingWithCapacity{}, state) // notify query that third node was not contacted - state = qry.Advance(ctx, &EventQueryFindCloserFailure[tiny.Key, tiny.Node]{NodeID: c}) + state = qry.Advance(ctx, &EventQueryNodeFailure[tiny.Key, tiny.Node]{NodeID: c}) // query has finished since it contacted all possible nodes - require.IsType(t, &StateQueryFinished{}, state) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) - stf := state.(*StateQueryFinished) + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, 0, stf.Stats.Success) } @@ -973,7 +976,7 @@ func TestQueryAllContactedFinishes(t *testing.T) { iter := NewSequentialIter[tiny.Key, tiny.Node]() - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = len(knownNodes) // allow all to be contacted at once cfg.NumResults = len(knownNodes) + 1 // one more than the size of the network @@ -981,41 +984,41 @@ func TestQueryAllContactedFinishes(t *testing.T) { queryID := QueryID("test") self := tiny.NewNode(0) - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the nearest node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) // while the query has capacity the query should contact the next nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) // while the query has capacity the query should contact the third nearest node - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) // the query should be at capacity - state = qry.Advance(ctx, nil) + state = qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryWaitingAtCapacity{}, state) // notify query that first node was contacted successfully, but no closer nodes - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{NodeID: a}) + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{NodeID: a}) require.IsType(t, &StateQueryWaitingWithCapacity{}, state) // notify query that second node was contacted successfully, but no closer nodes - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{NodeID: b}) + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{NodeID: b}) require.IsType(t, &StateQueryWaitingWithCapacity{}, state) // notify query that third node was contacted successfully, but no closer nodes - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{NodeID: c}) + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{NodeID: c}) // query has finished since it contacted all possible nodes, even though it didn't // reach the desired NumResults - require.IsType(t, &StateQueryFinished{}, state) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) - stf := state.(*StateQueryFinished) + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, 3, stf.Stats.Success) } @@ -1033,34 +1036,273 @@ func TestQueryNeverMessagesSelf(t *testing.T) { iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) - cfg := DefaultQueryConfig[tiny.Key]() + cfg := DefaultQueryConfig() cfg.Clock = clk cfg.Concurrency = 2 queryID := QueryID("test") self := a - qry, err := NewQuery[tiny.Key, tiny.Node](self, queryID, target, iter, knownNodes, cfg) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) require.NoError(t, err) // first thing the new query should do is contact the first node - state := qry.Advance(ctx, nil) + state := qry.Advance(ctx, &EventQueryPoll{}) require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) require.Equal(t, b, st.NodeID) // notify query that first node was contacted successfully, with closer nodes - state = qry.Advance(ctx, &EventQueryFindCloserResponse[tiny.Key, tiny.Node]{ + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ NodeID: b, CloserNodes: []tiny.Node{a}, }) // query is finished since it can't contact self - require.IsType(t, &StateQueryFinished{}, state) + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) // one successful message - stf := state.(*StateQueryFinished) + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) require.Equal(t, 1, stf.Stats.Requests) require.Equal(t, 1, stf.Stats.Success) require.Equal(t, 0, stf.Stats.Failure) } + +func TestQueryMessagesNearest(t *testing.T) { + ctx := context.Background() + + target := tiny.Key(0b00000011) + far := tiny.NewNode(0b11011011) + near := tiny.NewNode(0b00000110) + + // ensure near is nearer to target than far is + require.Less(t, target.Xor(near.Key()), target.Xor(far.Key())) + + // knownNodes are in "random" order with furthest before nearest + knownNodes := []tiny.Node{ + far, + near, + } + clk := clock.NewMock() + + iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) + + cfg := DefaultQueryConfig() + cfg.Clock = clk + + queryID := QueryID("test") + + self := tiny.NewNode(0) + msg := tiny.Message{Content: "msg"} + qry, err := NewQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, msg, iter, knownNodes, cfg) + require.NoError(t, err) + + // first thing the new query should do is message the nearest node + state := qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]{}, state) + + // check that we are contacting the nearest node first + st := state.(*StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]) + require.Equal(t, near, st.NodeID) +} + +func TestQueryMessageResponseMakesCapacity(t *testing.T) { + ctx := context.Background() + + target := tiny.Key(0b00000001) + a := tiny.NewNode(0b00000100) // 4 + b := tiny.NewNode(0b00001000) // 8 + c := tiny.NewNode(0b00010000) // 16 + d := tiny.NewNode(0b00100000) // 32 + + // ensure the order of the known nodes + require.True(t, target.Xor(a.Key()).Compare(target.Xor(b.Key())) == -1) + require.True(t, target.Xor(b.Key()).Compare(target.Xor(c.Key())) == -1) + require.True(t, target.Xor(c.Key()).Compare(target.Xor(d.Key())) == -1) + + // knownNodes are in "random" order + knownNodes := []tiny.Node{b, c, a, d} + + clk := clock.NewMock() + + iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) + + cfg := DefaultQueryConfig() + cfg.Clock = clk + cfg.Concurrency = len(knownNodes) - 1 // one less than the number of initial nodes + + queryID := QueryID("test") + + self := tiny.NewNode(0) + msg := tiny.Message{Content: "msg"} + qry, err := NewQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, msg, iter, knownNodes, cfg) + require.NoError(t, err) + + // first thing the new query should do is contact the nearest node + state := qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]{}, state) + st := state.(*StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]) + require.Equal(t, a, st.NodeID) + require.Equal(t, 1, st.Stats.Requests) + require.Equal(t, 0, st.Stats.Success) + require.Equal(t, 0, st.Stats.Failure) + + // while the query has capacity the query should contact the next nearest node + state = qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]{}, state) + st = state.(*StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]) + require.Equal(t, b, st.NodeID) + require.Equal(t, 2, st.Stats.Requests) + require.Equal(t, 0, st.Stats.Success) + require.Equal(t, 0, st.Stats.Failure) + + // while the query has capacity the query should contact the second nearest node + state = qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]{}, state) + st = state.(*StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]) + require.Equal(t, c, st.NodeID) + require.Equal(t, 3, st.Stats.Requests) + require.Equal(t, 0, st.Stats.Success) + require.Equal(t, 0, st.Stats.Failure) + + // the query should be at capacity + state = qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQueryWaitingAtCapacity{}, state) + + // notify query that first node was contacted successfully, now node d can be contacted + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{NodeID: a}) + require.IsType(t, &StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]{}, state) + st = state.(*StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]) + require.Equal(t, d, st.NodeID) + require.Equal(t, 4, st.Stats.Requests) + require.Equal(t, 1, st.Stats.Success) + require.Equal(t, 0, st.Stats.Failure) + + // the query should be at capacity again + state = qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQueryWaitingAtCapacity{}, state) + stwa := state.(*StateQueryWaitingAtCapacity) + require.Equal(t, 4, stwa.Stats.Requests) + require.Equal(t, 1, stwa.Stats.Success) + require.Equal(t, 0, stwa.Stats.Failure) +} + +func TestQueryAllNotContactedFinishes(t *testing.T) { + ctx := context.Background() + + target := tiny.Key(0b00000001) + a := tiny.NewNode(0b00000100) // 4 + b := tiny.NewNode(0b00001000) // 8 + c := tiny.NewNode(0b00010000) // 16 + + // knownNodes are in "random" order + knownNodes := []tiny.Node{a, b, c} + + clk := clock.NewMock() + + iter := NewSequentialIter[tiny.Key, tiny.Node]() + + cfg := DefaultQueryConfig() + cfg.Clock = clk + cfg.Concurrency = len(knownNodes) // allow all to be contacted at once + + queryID := QueryID("test") + + self := tiny.NewNode(0) + msg := tiny.Message{Content: "msg"} + qry, err := NewQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, msg, iter, knownNodes, cfg) + require.NoError(t, err) + + // first thing the new query should do is contact the nearest node + state := qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]{}, state) + + // while the query has capacity the query should contact the next nearest node + state = qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]{}, state) + + // while the query has capacity the query should contact the third nearest node + state = qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQuerySendMessage[tiny.Key, tiny.Node, tiny.Message]{}, state) + + // the query should be at capacity + state = qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQueryWaitingAtCapacity{}, state) + + // notify query that first node was not contacted + state = qry.Advance(ctx, &EventQueryNodeFailure[tiny.Key, tiny.Node]{NodeID: a}) + require.IsType(t, &StateQueryWaitingWithCapacity{}, state) + + // notify query that second node was not contacted + state = qry.Advance(ctx, &EventQueryNodeFailure[tiny.Key, tiny.Node]{NodeID: b}) + require.IsType(t, &StateQueryWaitingWithCapacity{}, state) + + // notify query that third node was not contacted + state = qry.Advance(ctx, &EventQueryNodeFailure[tiny.Key, tiny.Node]{NodeID: c}) + + // query has finished since it contacted all possible nodes + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) + + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) + require.Equal(t, 0, stf.Stats.Success) + require.Equal(t, 3, stf.Stats.Failure) +} + +func TestFindCloserQueryIncludesPartialClosestNodesWhenCancelled(t *testing.T) { + ctx := context.Background() + + target := tiny.Key(0b00000001) + a := tiny.NewNode(0b00000100) // 4 + b := tiny.NewNode(0b00001000) // 8 + c := tiny.NewNode(0b00010000) // 16 + d := tiny.NewNode(0b00100000) // 32 + + // one known node to start with + knownNodes := []tiny.Node{a, b, c, d} + + clk := clock.NewMock() + + iter := NewClosestNodesIter[tiny.Key, tiny.Node](target) + + cfg := DefaultQueryConfig() + cfg.Clock = clk + cfg.Concurrency = 4 + cfg.NumResults = 4 + + queryID := QueryID("test") + + self := tiny.NewNode(0) + qry, err := NewFindCloserQuery[tiny.Key, tiny.Node, tiny.Message](self, queryID, target, iter, knownNodes, cfg) + require.NoError(t, err) + + // contact first node + state := qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) + st := state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) + require.Equal(t, a, st.NodeID) + + // contact second node + state = qry.Advance(ctx, &EventQueryPoll{}) + require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) + st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) + require.Equal(t, b, st.NodeID) + + // notify query that first node was contacted successfully + state = qry.Advance(ctx, &EventQueryNodeResponse[tiny.Key, tiny.Node]{ + NodeID: a, + }) + + // query attempts to contact third node + require.IsType(t, &StateQueryFindCloser[tiny.Key, tiny.Node]{}, state) + st = state.(*StateQueryFindCloser[tiny.Key, tiny.Node]) + require.Equal(t, c, st.NodeID) + + // cancel query + state = qry.Advance(ctx, &EventQueryCancel{}) + + // query has finished + require.IsType(t, &StateQueryFinished[tiny.Key, tiny.Node]{}, state) + + stf := state.(*StateQueryFinished[tiny.Key, tiny.Node]) + require.Equal(t, 1, len(stf.ClosestNodes)) +} diff --git a/v2/coord/routing.go b/v2/internal/coord/routing.go similarity index 99% rename from v2/coord/routing.go rename to v2/internal/coord/routing.go index 1c34bca8..832bfa64 100644 --- a/v2/coord/routing.go +++ b/v2/internal/coord/routing.go @@ -9,7 +9,7 @@ import ( "go.opentelemetry.io/otel/trace" "golang.org/x/exp/slog" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/routing" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/routing" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) diff --git a/v2/coord/routing/bootstrap.go b/v2/internal/coord/routing/bootstrap.go similarity index 93% rename from v2/coord/routing/bootstrap.go rename to v2/internal/coord/routing/bootstrap.go index e4b9d452..914f9615 100644 --- a/v2/coord/routing/bootstrap.go +++ b/v2/internal/coord/routing/bootstrap.go @@ -11,7 +11,7 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/query" "github.com/libp2p/go-libp2p-kad-dht/v2/tele" ) @@ -20,7 +20,7 @@ type Bootstrap[K kad.Key[K], N kad.NodeID[K]] struct { self N // qry is the query used by the bootstrap process - qry *query.Query[K, N] + qry *query.Query[K, N, any] // cfg is a copy of the optional configuration supplied to the Bootstrap cfg BootstrapConfig[K] @@ -101,29 +101,29 @@ func (b *Bootstrap[K, N]) Advance(ctx context.Context, ev BootstrapEvent) Bootst // TODO: ignore start event if query is already in progress iter := query.NewClosestNodesIter[K, N](b.self.Key()) - qryCfg := query.DefaultQueryConfig[K]() + qryCfg := query.DefaultQueryConfig() qryCfg.Clock = b.cfg.Clock qryCfg.Concurrency = b.cfg.RequestConcurrency qryCfg.RequestTimeout = b.cfg.RequestTimeout queryID := query.QueryID("bootstrap") - qry, err := query.NewQuery[K, N](b.self, queryID, b.self.Key(), iter, tev.KnownClosestNodes, qryCfg) + qry, err := query.NewFindCloserQuery[K, N, any](b.self, queryID, b.self.Key(), iter, tev.KnownClosestNodes, qryCfg) if err != nil { // TODO: don't panic panic(err) } b.qry = qry - return b.advanceQuery(ctx, nil) + return b.advanceQuery(ctx, &query.EventQueryPoll{}) case *EventBootstrapFindCloserResponse[K, N]: - return b.advanceQuery(ctx, &query.EventQueryFindCloserResponse[K, N]{ + return b.advanceQuery(ctx, &query.EventQueryNodeResponse[K, N]{ NodeID: tev.NodeID, CloserNodes: tev.CloserNodes, }) case *EventBootstrapFindCloserFailure[K, N]: span.RecordError(tev.Error) - return b.advanceQuery(ctx, &query.EventQueryFindCloserFailure[K, N]{ + return b.advanceQuery(ctx, &query.EventQueryNodeFailure[K, N]{ NodeID: tev.NodeID, Error: tev.Error, }) @@ -135,7 +135,7 @@ func (b *Bootstrap[K, N]) Advance(ctx context.Context, ev BootstrapEvent) Bootst } if b.qry != nil { - return b.advanceQuery(ctx, nil) + return b.advanceQuery(ctx, &query.EventQueryPoll{}) } return &StateBootstrapIdle{} @@ -154,7 +154,7 @@ func (b *Bootstrap[K, N]) advanceQuery(ctx context.Context, qev query.QueryEvent NodeID: st.NodeID, Target: st.Target, } - case *query.StateQueryFinished: + case *query.StateQueryFinished[K, N]: span.SetAttributes(attribute.String("out_state", "StateBootstrapFinished")) return &StateBootstrapFinished{ Stats: st.Stats, diff --git a/v2/coord/routing/bootstrap_test.go b/v2/internal/coord/routing/bootstrap_test.go similarity index 98% rename from v2/coord/routing/bootstrap_test.go rename to v2/internal/coord/routing/bootstrap_test.go index df1364df..70c8b6f0 100644 --- a/v2/coord/routing/bootstrap_test.go +++ b/v2/internal/coord/routing/bootstrap_test.go @@ -8,8 +8,8 @@ import ( "github.com/plprobelab/go-kademlia/key" "github.com/stretchr/testify/require" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/tiny" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/internal/tiny" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/query" ) func TestBootstrapConfigValidate(t *testing.T) { diff --git a/v2/coord/routing/include.go b/v2/internal/coord/routing/include.go similarity index 100% rename from v2/coord/routing/include.go rename to v2/internal/coord/routing/include.go diff --git a/v2/coord/routing/include_test.go b/v2/internal/coord/routing/include_test.go similarity index 99% rename from v2/coord/routing/include_test.go rename to v2/internal/coord/routing/include_test.go index a788e8d5..a565521a 100644 --- a/v2/coord/routing/include_test.go +++ b/v2/internal/coord/routing/include_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/benbjohnson/clock" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/tiny" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/internal/tiny" "github.com/plprobelab/go-kademlia/key" "github.com/plprobelab/go-kademlia/routing/simplert" "github.com/stretchr/testify/require" diff --git a/v2/coord/routing/probe.go b/v2/internal/coord/routing/probe.go similarity index 100% rename from v2/coord/routing/probe.go rename to v2/internal/coord/routing/probe.go diff --git a/v2/coord/routing/probe_test.go b/v2/internal/coord/routing/probe_test.go similarity index 99% rename from v2/coord/routing/probe_test.go rename to v2/internal/coord/routing/probe_test.go index e07d6445..e97ddce3 100644 --- a/v2/coord/routing/probe_test.go +++ b/v2/internal/coord/routing/probe_test.go @@ -11,7 +11,7 @@ import ( "github.com/plprobelab/go-kademlia/routing/simplert" "github.com/stretchr/testify/require" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/tiny" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/internal/tiny" ) var _ heap.Interface = (*nodeValuePendingList[tiny.Key, tiny.Node])(nil) diff --git a/v2/coord/routing_test.go b/v2/internal/coord/routing_test.go similarity index 97% rename from v2/coord/routing_test.go rename to v2/internal/coord/routing_test.go index 2b07d6d1..c789c9dc 100644 --- a/v2/coord/routing_test.go +++ b/v2/internal/coord/routing_test.go @@ -12,9 +12,9 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/slog" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/internal/nettest" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/query" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord/routing" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/internal/nettest" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/query" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/routing" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) diff --git a/v2/coord/telemetry.go b/v2/internal/coord/telemetry.go similarity index 100% rename from v2/coord/telemetry.go rename to v2/internal/coord/telemetry.go diff --git a/v2/kadt/kadt.go b/v2/kadt/kadt.go index f87057a8..2ad4bbef 100644 --- a/v2/kadt/kadt.go +++ b/v2/kadt/kadt.go @@ -1,7 +1,4 @@ // Package kadt contains the kademlia types for interacting with go-kademlia. -// It would be nicer to have these types in the top-level DHT package; however, -// we also need these types in, e.g., the pb package to let the -// [pb.Message] type conform to certain interfaces. package kadt import ( @@ -73,3 +70,15 @@ func (ai AddrInfo) Addresses() []ma.Multiaddr { copy(addrs, ai.Info.Addrs) return addrs } + +// RoutingTable is a mapping between [Key] and [PeerID] and provides methods to interact with the mapping +// and find PeerIDs close to a particular Key. +type RoutingTable interface { + kad.RoutingTable[Key, PeerID] + + // Cpl returns the longest common prefix length the supplied key shares with the table's key. + Cpl(kk Key) int + + // CplSize returns the number of nodes in the table whose longest common prefix with the table's key is of length cpl. + CplSize(cpl int) int +} diff --git a/v2/query_test.go b/v2/query_test.go index 3fa63336..86ea55c2 100644 --- a/v2/query_test.go +++ b/v2/query_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord" "github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" ) diff --git a/v2/router.go b/v2/router.go index 14db2cd9..70bd69ca 100644 --- a/v2/router.go +++ b/v2/router.go @@ -13,12 +13,12 @@ import ( "github.com/libp2p/go-msgio/pbio" "google.golang.org/protobuf/proto" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" "github.com/libp2p/go-libp2p-kad-dht/v2/pb" ) -type Router struct { +type router struct { host host.Host // ProtocolID represents the DHT [protocol] we can query with and respond to. // @@ -26,7 +26,7 @@ type Router struct { ProtocolID protocol.ID } -var _ coord.Router[kadt.Key, kadt.PeerID, *pb.Message] = (*Router)(nil) +var _ coord.Router[kadt.Key, kadt.PeerID, *pb.Message] = (*router)(nil) func FindKeyRequest(k kadt.Key) *pb.Message { marshalledKey, _ := k.MarshalBinary() @@ -36,7 +36,7 @@ func FindKeyRequest(k kadt.Key) *pb.Message { } } -func (r *Router) SendMessage(ctx context.Context, to kadt.PeerID, req *pb.Message) (*pb.Message, error) { +func (r *router) SendMessage(ctx context.Context, to kadt.PeerID, req *pb.Message) (*pb.Message, error) { // TODO: what to do with addresses in peer.AddrInfo? if len(r.host.Peerstore().Addrs(peer.ID(to))) == 0 { return nil, fmt.Errorf("no address for peer %s", to) @@ -79,7 +79,7 @@ func (r *Router) SendMessage(ctx context.Context, to kadt.PeerID, req *pb.Messag return &protoResp, err } -func (r *Router) GetClosestNodes(ctx context.Context, to kadt.PeerID, target kadt.Key) ([]kadt.PeerID, error) { +func (r *router) GetClosestNodes(ctx context.Context, to kadt.PeerID, target kadt.Key) ([]kadt.PeerID, error) { resp, err := r.SendMessage(ctx, to, FindKeyRequest(target)) if err != nil { return nil, err @@ -88,7 +88,7 @@ func (r *Router) GetClosestNodes(ctx context.Context, to kadt.PeerID, target kad return resp.CloserNodes(), nil } -func (r *Router) addToPeerStore(ctx context.Context, ai peer.AddrInfo, ttl time.Duration) error { +func (r *router) addToPeerStore(ctx context.Context, ai peer.AddrInfo, ttl time.Duration) error { // Don't add addresses for self or our connected peers. We have better ones. if ai.ID == r.host.ID() || r.host.Network().Connectedness(ai.ID) == network.Connected { return nil diff --git a/v2/routing.go b/v2/routing.go index bfb76805..569e66b9 100644 --- a/v2/routing.go +++ b/v2/routing.go @@ -8,8 +8,9 @@ import ( "github.com/ipfs/go-cid" ds "github.com/ipfs/go-datastore" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" + "github.com/libp2p/go-libp2p-kad-dht/v2/pb" record "github.com/libp2p/go-libp2p-record" recpb "github.com/libp2p/go-libp2p-record/pb" "github.com/libp2p/go-libp2p/core/network" @@ -40,25 +41,25 @@ func (d *DHT) FindPeer(ctx context.Context, id peer.ID) (peer.AddrInfo, error) { target := kadt.PeerID(id) - var foundNode coord.Node - fn := func(ctx context.Context, node coord.Node, stats coord.QueryStats) error { - if peer.ID(node.ID()) == id { - foundNode = node + var foundPeer peer.ID + fn := func(ctx context.Context, visited kadt.PeerID, msg *pb.Message, stats coord.QueryStats) error { + if peer.ID(visited) == id { + foundPeer = peer.ID(visited) return coord.ErrSkipRemaining } return nil } - _, err := d.kad.Query(ctx, target.Key(), fn) + _, _, err := d.kad.QueryClosest(ctx, target.Key(), fn, 20) if err != nil { return peer.AddrInfo{}, fmt.Errorf("failed to run query: %w", err) } - if foundNode == nil { + if foundPeer == "" { return peer.AddrInfo{}, fmt.Errorf("peer record not found") } - return d.host.Peerstore().PeerInfo(peer.ID(foundNode.ID())), nil + return d.host.Peerstore().PeerInfo(foundPeer), nil } func (d *DHT) Provide(ctx context.Context, c cid.Cid, brdcst bool) error { @@ -151,14 +152,44 @@ func (d *DHT) GetValue(ctx context.Context, key string, option ...routing.Option defer span.End() v, err := d.getValueLocal(ctx, key) - if err != nil { + if err == nil { return v, nil } if !errors.Is(err, ds.ErrNotFound) { return nil, fmt.Errorf("put value locally: %w", err) } - panic("implement me") + req := &pb.Message{ + Type: pb.Message_GET_VALUE, + Key: []byte(key), + } + + // TODO: quorum + var value []byte + fn := func(ctx context.Context, id kadt.PeerID, resp *pb.Message, stats coord.QueryStats) error { + if resp == nil { + return nil + } + + if resp.GetType() != pb.Message_GET_VALUE { + return nil + } + + if string(resp.GetKey()) != key { + return nil + } + + value = resp.GetRecord().GetValue() + + return coord.ErrSkipRemaining + } + + _, err = d.kad.QueryMessage(ctx, req, fn, d.cfg.BucketSize) + if err != nil { + return nil, fmt.Errorf("failed to run query: %w", err) + } + + return value, nil } // getValueLocal retrieves a value from the local datastore without querying the network. diff --git a/v2/routing_test.go b/v2/routing_test.go index 5204ae48..ec80da31 100644 --- a/v2/routing_test.go +++ b/v2/routing_test.go @@ -44,8 +44,6 @@ func TestGetSetValueLocal(t *testing.T) { } func TestGetValueOnePeer(t *testing.T) { - t.Skip("not implemented yet") - ctx := kadtest.CtxShort(t) top := NewTopology(t) local := top.AddServer(nil) diff --git a/v2/topology_test.go b/v2/topology_test.go index 1af5353f..189b494e 100644 --- a/v2/topology_test.go +++ b/v2/topology_test.go @@ -7,7 +7,7 @@ import ( "github.com/benbjohnson/clock" "github.com/libp2p/go-libp2p" - "github.com/libp2p/go-libp2p-kad-dht/v2/coord" + "github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord" "github.com/libp2p/go-libp2p-kad-dht/v2/kadt" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" @@ -55,12 +55,12 @@ func (t *Topology) AddServer(cfg *Config) *DHT { } cfg.Mode = ModeOptServer - rn := coord.NewBufferedRoutingNotifier() - cfg.Kademlia.RoutingNotifier = rn - d, err := New(h, cfg) require.NoError(t.tb, err) + rn := coord.NewBufferedRoutingNotifier() + d.kad.SetRoutingNotifier(rn) + // add at least 1 entry in the routing table so the server will pass connectivity checks fillRoutingTable(t.tb, d, 1) require.NotEmpty(t.tb, d.rt.NearestNodes(kadt.PeerID(d.host.ID()).Key(), 1)) @@ -97,12 +97,12 @@ func (t *Topology) AddClient(cfg *Config) *DHT { } cfg.Mode = ModeOptClient - rn := coord.NewBufferedRoutingNotifier() - cfg.Kademlia.RoutingNotifier = rn - d, err := New(h, cfg) require.NoError(t.tb, err) + rn := coord.NewBufferedRoutingNotifier() + d.kad.SetRoutingNotifier(rn) + t.tb.Cleanup(func() { if err = d.Close(); err != nil { t.tb.Logf("unexpected error when closing dht: %s", err)