diff --git a/dht_test.go b/dht_test.go index 3b3e0fe7e..447000347 100644 --- a/dht_test.go +++ b/dht_test.go @@ -1403,10 +1403,6 @@ func testFindPeerQuery(t *testing.T, connect(t, ctx, guy, others[i]) } - for _, d := range dhts { - d.RefreshRoutingTable() - } - var reachableIds []peer.ID for i, d := range dhts { lp := len(d.host.Network().Peers()) @@ -1471,8 +1467,8 @@ func TestFindClosestPeers(t *testing.T) { out = append(out, p) } - if len(out) != querier.bucketSize { - t.Fatalf("got wrong number of peers (got %d, expected %d)", len(out), querier.bucketSize) + if len(out) < querier.beta { + t.Fatalf("got wrong number of peers (got %d, expected at least %d)", len(out), querier.beta) } } diff --git a/lookup.go b/lookup.go index 71b360ceb..38c5b0036 100644 --- a/lookup.go +++ b/lookup.go @@ -113,7 +113,7 @@ func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) (<-chan pee out <- p } - if ctx.Err() == nil { + if ctx.Err() == nil && lookupRes.completed { // refresh the cpl for this key as the query was successful dht.routingTable.ResetCplRefreshedAtForID(kb.ConvertKey(key), time.Now()) } diff --git a/query.go b/query.go index cb6ad45af..68ac09752 100644 --- a/query.go +++ b/query.go @@ -68,11 +68,6 @@ func (dht *IpfsDHT) runLookup(ctx context.Context, d int, target string, queryFn return nil, err } - // return if the lookup has been externally stopped - if stopFn() || ctx.Err() != nil { - return lookupRes, nil - } - // query all of the top K peers we've either Heard about or have outstanding queries we're Waiting on. // This ensures that all of the top K results have been queried which adds to resiliency against churn for query // functions that carry state (e.g. FindProviders and GetValue) as well as establish connections that are needed @@ -84,6 +79,16 @@ func (dht *IpfsDHT) runLookup(ctx context.Context, d int, target string, queryFn } } + if len(queryPeers) == 0 { + return lookupRes, nil + } + + // return if the lookup has been externally stopped + if ctx.Err() != nil || stopFn() { + lookupRes.completed = false + return lookupRes, nil + } + doneCh := make(chan struct{}, 1) followUpCtx, cancelFollowUp := context.WithCancel(ctx) for _, p := range queryPeers { @@ -96,17 +101,22 @@ func (dht *IpfsDHT) runLookup(ctx context.Context, d int, target string, queryFn // wait for all queries to complete before returning, aborting ongoing queries if we've been externally stopped processFollowUp: - for i := 0; i < len(queryPeers); i++ { - select{ - case <-doneCh: - if stopFn() { - cancelFollowUp() + for i := 0; i < len(queryPeers); i++ { + select{ + case <-doneCh: + if stopFn() { + cancelFollowUp() + if i < len(queryPeers) - 1 { + lookupRes.completed = false + } + break processFollowUp + } + case <-ctx.Done(): + lookupRes.completed = false break processFollowUp } - case <-ctx.Done(): - break processFollowUp } - } + return lookupRes, nil } @@ -115,9 +125,6 @@ func (dht *IpfsDHT) runLookup(ctx context.Context, d int, target string, queryFn func (dht *IpfsDHT) runDisjointQueries(ctx context.Context, d int, target string, queryFn queryFn, stopFn stopFn) (*lookupResult, error) { queryCtx, cancelQuery := context.WithCancel(ctx) - numQueriesComplete := 0 - queryDone := make(chan struct{}, d) - // pick the K closest peers to the key in our Routing table and shuffle them. seedPeers := dht.routingTable.NearestPeers(kb.ConvertKey(target), dht.bucketSize) if len(seedPeers) == 0 { @@ -157,6 +164,7 @@ func (dht *IpfsDHT) runDisjointQueries(ctx context.Context, d int, target string } // start the "d" disjoint queries + queryDone := make(chan struct{}, d) for i := 0; i < d; i++ { query := queries[i] go func() { @@ -165,20 +173,16 @@ func (dht *IpfsDHT) runDisjointQueries(ctx context.Context, d int, target string }() } -loop: // wait for all the "d" disjoint queries to complete before we return // XXX: Waiting until all queries are done is a vector for DoS attacks: // The disjoint lookup paths that are taken over by adversarial peers // can easily be fooled to go on forever. + numQueriesComplete := 0 for { - select { - case <-queryDone: - numQueriesComplete++ - if numQueriesComplete == d { - break loop - } - case <-ctx.Done(): - break loop + <-queryDone + numQueriesComplete++ + if numQueriesComplete == d { + break } } @@ -197,9 +201,20 @@ loop: for _, q := range queries { qp := q.queryPeers.GetClosestNotUnreachable(dht.bucketSize) for _, p := range qp { - peerState[p] = q.queryPeers.GetState(p) + // Since the same peer can be seen in multiple queries use the "best" state for the peer + // Note: It's possible that a peer was marked undialable in one path, but wasn't yet tried in another path + // for now we're going to return that peer as long as some path does not think it is undialable. This may + // change in the future if we track addresses dialed per path. + state := q.queryPeers.GetState(p) + if currState, ok := peerState[p]; ok { + if state > currState { + peerState[p] = state + } + } else { + peerState[p] = state + peers = append(peers, p) + } } - peers = append(peers , qp...) } // get the top K overall peers @@ -242,13 +257,17 @@ func (q *query) runWithGreedyParallelism() { q.updateState(update) case <-pathCtx.Done(): q.terminate() - return } // termination is triggered on end-of-lookup conditions or starvation of unused peers if q.readyToTerminate() { q.terminate() - return + + // exit once all goroutines have been cleaned up + if q.queryPeers.NumWaiting() == 0 { + return + } + continue } // if all "threads" are busy, wait until someone finishes @@ -360,21 +379,24 @@ func (q *query) updateState(up *queryUpdate) { continue } q.queryPeers.TryAdd(p) - q.queryPeers.SetState(p, qpeerset.PeerHeard) } for _, p := range up.queried { if p == q.dht.self { // don't add self. continue } q.queryPeers.TryAdd(p) - q.queryPeers.SetState(p, qpeerset.PeerQueried) + if st := q.queryPeers.GetState(p); st == qpeerset.PeerWaiting { + q.queryPeers.SetState(p, qpeerset.PeerQueried) + } } for _, p := range up.unreachable { if p == q.dht.self { // don't add self. continue } q.queryPeers.TryAdd(p) - q.queryPeers.SetState(p, qpeerset.PeerUnreachable) + if st := q.queryPeers.GetState(p); st == qpeerset.PeerWaiting { + q.queryPeers.SetState(p, qpeerset.PeerUnreachable) + } } } diff --git a/routing.go b/routing.go index be34ba939..21e4d122e 100644 --- a/routing.go +++ b/routing.go @@ -564,7 +564,7 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash } } - queries, err := dht.runLookup(ctx, dht.d, string(key), + lookupRes, err := dht.runLookup(ctx, dht.d, string(key), func(ctx context.Context, p peer.ID) ([]*peer.AddrInfo, error) { // For DHT query command routing.PublishQueryEvent(ctx, &routing.QueryEvent{ @@ -621,7 +621,7 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash ) if err != nil && ctx.Err() == nil { - dht.refreshRTIfNoShortcut(kb.ConvertKey(string(key)), queries) + dht.refreshRTIfNoShortcut(kb.ConvertKey(string(key)), lookupRes) } } @@ -682,7 +682,7 @@ func (dht *IpfsDHT) FindPeer(ctx context.Context, id peer.ID) (_ peer.AddrInfo, break } } - if discoveredPeerDuringQuery { + if discoveredPeerDuringQuery || lookupRes.completed{ dht.routingTable.ResetCplRefreshedAtForID(kb.ConvertPeerID(id), time.Now()) } }