Skip to content

Commit

Permalink
Fix a bug where wrong context is passed to circuit breaker
Browse files Browse the repository at this point in the history
On `Done`, the circuit breaker library checks the `ctx` error separately
from the error that is passed in along with the context. Separately, it
is also configured to not fail on context cancelled.

On scatter gather, a new context with cancel is created with some max
wait time and then passed to the function to perform the action. The fix
here passes the same context as the one passed to do the action to the
circuit breaker `Done` function so that the config for context
cancellation is respected.

The changes here also improve context handling in scatter gather such
that as soon as the processing is finished the context is cancelled, and
all channels are closed up properly without the need for draining when
context is done.
  • Loading branch information
masih committed Feb 22, 2023
1 parent 9723ee6 commit 191abd3
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 75 deletions.
26 changes: 11 additions & 15 deletions find.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) {
return
}

ctx := r.Context()
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
method := r.Method
req := r.URL

Expand Down Expand Up @@ -147,25 +148,17 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) {
return
}

var res []byte
for md := range sg.gather(ctx) {
if len(md) == 0 {
continue
}
// It's ok to return the first encountered metadata. This is because metadata is uniquely identified
// by ValueKey (peerID + contextID). I.e. it's not possible to have different metadata records for the same ValueKey.
// In comparison to regular find requests where it's perfectly normal to have different results returned by different IPNI
// instances and hence they need to be aggregated.
res = md
// Continue to iterate to drain channels and avoid memory leak.
}

if len(res) == 0 {
http.Error(w, "", http.StatusNotFound)
return
if len(md) > 0 {
httpserver.WriteJsonResponse(w, http.StatusOK, md)
return
}
}

httpserver.WriteJsonResponse(w, http.StatusOK, res)
http.Error(w, "", http.StatusNotFound)
}

func (s *server) find(w http.ResponseWriter, r *http.Request, mh multihash.Multihash) {
Expand Down Expand Up @@ -244,6 +237,9 @@ func (s *server) doFind(ctx context.Context, method, source string, req *url.URL
maxWait: config.Server.ResultMaxWait,
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

var count int32
if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (**model.FindResponse, error) {
// Copy the URL from original request and override host/schema to point
Expand Down Expand Up @@ -279,7 +275,7 @@ func (s *server) doFind(ctx context.Context, method, source string, req *url.URL
atomic.AddInt32(&count, 1)
providers, err := model.UnmarshalFindResponse(data)
if err != nil {
return nil, err
return nil, circuitbreaker.MarkAsSuccess(err)
}
return &providers, nil
case http.StatusNotFound:
Expand Down
85 changes: 51 additions & 34 deletions find_ndjson.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"hash/crc32"
"io"
"net/http"
"net/url"
"sync/atomic"
Expand Down Expand Up @@ -81,6 +82,9 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source
maxWait: maxWait,
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

resultsChan := make(chan *encryptedOrPlainResult, 1)
var count int32
if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (*any, error) {
Expand All @@ -105,45 +109,59 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source
}
defer resp.Body.Close()

scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
var result encryptedOrPlainResult
line := scanner.Bytes()
if len(line) == 0 {
continue
switch resp.StatusCode {
case http.StatusOK:
case http.StatusNotFound:
atomic.AddInt32(&count, 1)
return nil, nil
default:
bb, _ := io.ReadAll(resp.Body)
body := string(bb)
log := log.With("status", resp.StatusCode, "body", body)
log.Warn("Request processing was not successful")
err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host)
if resp.StatusCode < http.StatusInternalServerError {
err = circuitbreaker.MarkAsSuccess(err)
}
switch resp.StatusCode {
case http.StatusOK:
atomic.AddInt32(&count, 1)
if err := json.Unmarshal(line, &result); err != nil {
return nil, err
}
// Sanity check the results in case backends don't respect accept media types;
// see: https://github.com/ipni/storetheindex/issues/1209
if len(result.EncryptedValueKey) == 0 && (result.Provider.ID == "" || len(result.Provider.Addrs) == 0) {
return nil, nil
}
resultsChan <- &result
continue
case http.StatusNotFound:
atomic.AddInt32(&count, 1)
return nil, err
}

scanner := bufio.NewScanner(resp.Body)
for {
select {
case <-cctx.Done():
return nil, nil
default:
body := string(line)
log := log.With("status", resp.StatusCode, "body", body)
log.Warn("Request processing was not successful")
err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host)
if resp.StatusCode < http.StatusInternalServerError {
err = circuitbreaker.MarkAsSuccess(err)
if scanner.Scan() {
var result encryptedOrPlainResult
line := scanner.Bytes()
if len(line) == 0 {
continue
}
atomic.AddInt32(&count, 1)
if err := json.Unmarshal(line, &result); err != nil {
return nil, circuitbreaker.MarkAsSuccess(err)
}
// Sanity check the results in case backends don't respect accept media types;
// see: https://github.com/ipni/storetheindex/issues/1209
if len(result.EncryptedValueKey) == 0 && (result.Provider.ID == "" || len(result.Provider.Addrs) == 0) {
continue
}

select {
case <-cctx.Done():
return nil, nil
case resultsChan <- &result:
}
continue
}
if err := scanner.Err(); err != nil {
log.Warnw("Failed to read backend response", "err", err)
return nil, circuitbreaker.MarkAsSuccess(err)
}
return nil, err
return nil, nil
}
}
if err := scanner.Err(); err != nil {
log.Warnw("Failed to read backend response", "err", err)
return nil, err
}
return nil, nil
}); err != nil {
log.Errorw("Failed to scatter HTTP find request", "err", err)
http.Error(w, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -172,7 +190,6 @@ LOOP:
case _, ok := <-sg.gather(ctx):
if !ok {
close(resultsChan)
break LOOP
}
case result, ok := <-resultsChan:
if !ok {
Expand Down
17 changes: 13 additions & 4 deletions reframe.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ func (x *ReframeService) FindProviders(ctx context.Context, key cid.Cid) (<-chan
maxWait: config.Reframe.ResultMaxWait,
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

if err := sg.scatter(ctx, func(cctx context.Context, b *backendDelegatedRoutingClient) (*drclient.FindProvidersAsyncResult, error) {
ch, err := b.FindProvidersAsync(cctx, key)
if err != nil {
Expand Down Expand Up @@ -133,14 +136,20 @@ func (x *ReframeService) FindProviders(ctx context.Context, key cid.Cid) (<-chan
result.AddrInfo = append(result.AddrInfo, ai)
}
if len(result.AddrInfo) > 0 {
out <- result
select {
case <-ctx.Done():
return
case out <- result:
}
}
}

// If nothing is found then return the last returned error, if any.
if len(pids) == 0 {
out <- drclient.FindProvidersAsyncResult{
Err: lastErr,
if len(pids) == 0 && lastErr != nil {
select {
case <-ctx.Done():
return
case out <- drclient.FindProvidersAsyncResult{Err: lastErr}:
}
}
}()
Expand Down
51 changes: 29 additions & 22 deletions scatter_gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@ func (sg *scatterGather[T, R]) scatter(ctx context.Context, forEach func(context
sg.start = time.Now()
sg.out = make(chan R, 1)
for i, t := range sg.targets {
if (len(sg.tcb) > 0) && !sg.tcb[i].Ready() {

var cb *circuitbreaker.CircuitBreaker
if len(sg.tcb) > i {
cb = sg.tcb[i]
}
if cb != nil && !cb.Ready() {
continue
}

sg.wg.Add(1)
go func(target T, i int) {
go func(target T, tcb *circuitbreaker.CircuitBreaker) {
defer sg.wg.Done()

select {
Expand All @@ -36,28 +41,23 @@ func (sg *scatterGather[T, R]) scatter(ctx context.Context, forEach func(context
default:
}

cctx, cncl := context.WithTimeout(ctx, sg.maxWait)
cctx, cancel := context.WithTimeout(ctx, sg.maxWait)
sout, err := forEach(cctx, target)
cncl()
if len(sg.tcb) > 0 {
err = sg.tcb[i].Done(ctx, err)
cancel()
if tcb != nil {
err = tcb.Done(cctx, err)
}
if err != nil {
log.Errorw("failed to scatter on target", "target", target, "err", err)
log.Errorw("failed to scatter on target", "target", target, "err", err, "maxWait", sg.maxWait)
return
}

if sout != nil {
if ctx.Err() == nil {
select {
case <-ctx.Done():
return
case sg.out <- *sout:
return
}
select {
case <-ctx.Done():
case sg.out <- *sout:
}
}
}(t, i)
}(t, cb)
}
go func() {
defer close(sg.out)
Expand All @@ -71,16 +71,23 @@ func (sg *scatterGather[_, R]) gather(ctx context.Context) <-chan R {
go func() {
defer func() {
close(gout)
elapsed := time.Since(sg.start)
log.Debugw("Completed scatter gather", "elapsed", elapsed.String())
log.Debugw("Completed scatter gather", "elapsed", time.Since(sg.start))
}()

for r := range sg.out {
for {
select {
case <-ctx.Done():
continue
case gout <- r:
continue
return
case r, ok := <-sg.out:
if !ok {
return
}
select {
case <-ctx.Done():
return
case gout <- r:
continue
}
}
}
}()
Expand Down

0 comments on commit 191abd3

Please sign in to comment.