diff --git a/find.go b/find.go index 986b2e4..8a71446 100644 --- a/find.go +++ b/find.go @@ -88,7 +88,8 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) { return } - ctx := r.Context() + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() method := r.Method req := r.URL @@ -147,25 +148,17 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) { return } - var res []byte for md := range sg.gather(ctx) { - if len(md) == 0 { - continue - } // It's ok to return the first encountered metadata. This is because metadata is uniquely identified // by ValueKey (peerID + contextID). I.e. it's not possible to have different metadata records for the same ValueKey. // In comparison to regular find requests where it's perfectly normal to have different results returned by different IPNI // instances and hence they need to be aggregated. - res = md - // Continue to iterate to drain channels and avoid memory leak. - } - - if len(res) == 0 { - http.Error(w, "", http.StatusNotFound) - return + if len(md) > 0 { + httpserver.WriteJsonResponse(w, http.StatusOK, md) + return + } } - - httpserver.WriteJsonResponse(w, http.StatusOK, res) + http.Error(w, "", http.StatusNotFound) } func (s *server) find(w http.ResponseWriter, r *http.Request, mh multihash.Multihash) { @@ -244,6 +237,9 @@ func (s *server) doFind(ctx context.Context, method, source string, req *url.URL maxWait: config.Server.ResultMaxWait, } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + var count int32 if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (**model.FindResponse, error) { // Copy the URL from original request and override host/schema to point @@ -279,7 +275,7 @@ func (s *server) doFind(ctx context.Context, method, source string, req *url.URL atomic.AddInt32(&count, 1) providers, err := model.UnmarshalFindResponse(data) if err != nil { - return nil, err + return nil, circuitbreaker.MarkAsSuccess(err) } return &providers, nil case http.StatusNotFound: diff --git a/find_ndjson.go b/find_ndjson.go index b7bc217..d95b85d 100644 --- a/find_ndjson.go +++ b/find_ndjson.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "hash/crc32" + "io" "net/http" "net/url" "sync/atomic" @@ -81,6 +82,9 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source maxWait: maxWait, } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + resultsChan := make(chan *encryptedOrPlainResult, 1) var count int32 if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (*any, error) { @@ -105,45 +109,59 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source } defer resp.Body.Close() - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - var result encryptedOrPlainResult - line := scanner.Bytes() - if len(line) == 0 { - continue + switch resp.StatusCode { + case http.StatusOK: + case http.StatusNotFound: + atomic.AddInt32(&count, 1) + return nil, nil + default: + bb, _ := io.ReadAll(resp.Body) + body := string(bb) + log := log.With("status", resp.StatusCode, "body", body) + log.Warn("Request processing was not successful") + err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host) + if resp.StatusCode < http.StatusInternalServerError { + err = circuitbreaker.MarkAsSuccess(err) } - switch resp.StatusCode { - case http.StatusOK: - atomic.AddInt32(&count, 1) - if err := json.Unmarshal(line, &result); err != nil { - return nil, err - } - // Sanity check the results in case backends don't respect accept media types; - // see: https://github.com/ipni/storetheindex/issues/1209 - if len(result.EncryptedValueKey) == 0 && (result.Provider.ID == "" || len(result.Provider.Addrs) == 0) { - return nil, nil - } - resultsChan <- &result - continue - case http.StatusNotFound: - atomic.AddInt32(&count, 1) + return nil, err + } + + scanner := bufio.NewScanner(resp.Body) + for { + select { + case <-cctx.Done(): return nil, nil default: - body := string(line) - log := log.With("status", resp.StatusCode, "body", body) - log.Warn("Request processing was not successful") - err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host) - if resp.StatusCode < http.StatusInternalServerError { - err = circuitbreaker.MarkAsSuccess(err) + if scanner.Scan() { + var result encryptedOrPlainResult + line := scanner.Bytes() + if len(line) == 0 { + continue + } + atomic.AddInt32(&count, 1) + if err := json.Unmarshal(line, &result); err != nil { + return nil, circuitbreaker.MarkAsSuccess(err) + } + // Sanity check the results in case backends don't respect accept media types; + // see: https://github.com/ipni/storetheindex/issues/1209 + if len(result.EncryptedValueKey) == 0 && (result.Provider.ID == "" || len(result.Provider.Addrs) == 0) { + continue + } + + select { + case <-cctx.Done(): + return nil, nil + case resultsChan <- &result: + } + continue + } + if err := scanner.Err(); err != nil { + log.Warnw("Failed to read backend response", "err", err) + return nil, circuitbreaker.MarkAsSuccess(err) } - return nil, err + return nil, nil } } - if err := scanner.Err(); err != nil { - log.Warnw("Failed to read backend response", "err", err) - return nil, err - } - return nil, nil }); err != nil { log.Errorw("Failed to scatter HTTP find request", "err", err) http.Error(w, "", http.StatusInternalServerError) @@ -172,7 +190,6 @@ LOOP: case _, ok := <-sg.gather(ctx): if !ok { close(resultsChan) - break LOOP } case result, ok := <-resultsChan: if !ok { diff --git a/reframe.go b/reframe.go index 03eee6f..22fc0bf 100644 --- a/reframe.go +++ b/reframe.go @@ -79,6 +79,9 @@ func (x *ReframeService) FindProviders(ctx context.Context, key cid.Cid) (<-chan maxWait: config.Reframe.ResultMaxWait, } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + if err := sg.scatter(ctx, func(cctx context.Context, b *backendDelegatedRoutingClient) (*drclient.FindProvidersAsyncResult, error) { ch, err := b.FindProvidersAsync(cctx, key) if err != nil { @@ -133,14 +136,20 @@ func (x *ReframeService) FindProviders(ctx context.Context, key cid.Cid) (<-chan result.AddrInfo = append(result.AddrInfo, ai) } if len(result.AddrInfo) > 0 { - out <- result + select { + case <-ctx.Done(): + return + case out <- result: + } } } // If nothing is found then return the last returned error, if any. - if len(pids) == 0 { - out <- drclient.FindProvidersAsyncResult{ - Err: lastErr, + if len(pids) == 0 && lastErr != nil { + select { + case <-ctx.Done(): + return + case out <- drclient.FindProvidersAsyncResult{Err: lastErr}: } } }() diff --git a/scatter_gather.go b/scatter_gather.go index a059c2d..cc18bb6 100644 --- a/scatter_gather.go +++ b/scatter_gather.go @@ -21,12 +21,17 @@ func (sg *scatterGather[T, R]) scatter(ctx context.Context, forEach func(context sg.start = time.Now() sg.out = make(chan R, 1) for i, t := range sg.targets { - if (len(sg.tcb) > 0) && !sg.tcb[i].Ready() { + + var cb *circuitbreaker.CircuitBreaker + if len(sg.tcb) > i { + cb = sg.tcb[i] + } + if cb != nil && !cb.Ready() { continue } sg.wg.Add(1) - go func(target T, i int) { + go func(target T, tcb *circuitbreaker.CircuitBreaker) { defer sg.wg.Done() select { @@ -36,28 +41,23 @@ func (sg *scatterGather[T, R]) scatter(ctx context.Context, forEach func(context default: } - cctx, cncl := context.WithTimeout(ctx, sg.maxWait) + cctx, cancel := context.WithTimeout(ctx, sg.maxWait) sout, err := forEach(cctx, target) - cncl() - if len(sg.tcb) > 0 { - err = sg.tcb[i].Done(ctx, err) + cancel() + if tcb != nil { + err = tcb.Done(cctx, err) } if err != nil { - log.Errorw("failed to scatter on target", "target", target, "err", err) + log.Errorw("failed to scatter on target", "target", target, "err", err, "maxWait", sg.maxWait) return } - if sout != nil { - if ctx.Err() == nil { - select { - case <-ctx.Done(): - return - case sg.out <- *sout: - return - } + select { + case <-ctx.Done(): + case sg.out <- *sout: } } - }(t, i) + }(t, cb) } go func() { defer close(sg.out) @@ -71,16 +71,23 @@ func (sg *scatterGather[_, R]) gather(ctx context.Context) <-chan R { go func() { defer func() { close(gout) - elapsed := time.Since(sg.start) - log.Debugw("Completed scatter gather", "elapsed", elapsed.String()) + log.Debugw("Completed scatter gather", "elapsed", time.Since(sg.start)) }() - for r := range sg.out { + for { select { case <-ctx.Done(): - continue - case gout <- r: - continue + return + case r, ok := <-sg.out: + if !ok { + return + } + select { + case <-ctx.Done(): + return + case gout <- r: + continue + } } } }()