Skip to content

Commit

Permalink
Make backend CB individually configurable and separate cascade
Browse files Browse the repository at this point in the history
Refactor backend instantiation such that circuit breaker backends are
individually configurable. Reflect changes across the repo.

Add the ability to check if a backend matches a request before
attempting to call it. This avoids response slow-down caused by slow
backends early in the chain

Introduce a separate configuration to mark cascade lookup backends with
explicit matcher in order to avoid routing all requests to them.

Relates to:
 - #86
  • Loading branch information
masih committed Feb 23, 2023
1 parent 191abd3 commit 5ab5f8f
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 143 deletions.
84 changes: 84 additions & 0 deletions backend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package main

import (
"net/http"
"net/url"

"github.com/mercari/go-circuitbreaker"
)

var Matchers struct {
Any HttpRequestMatcher
AnyOf func(...HttpRequestMatcher) HttpRequestMatcher
QueryParam func(key, value string) HttpRequestMatcher
}

type (
HttpRequestMatcher func(r *http.Request) bool
Backend interface {
URL() *url.URL
CB() *circuitbreaker.CircuitBreaker
Matches(r *http.Request) bool
}
SimpleBackend struct {
url *url.URL
cb *circuitbreaker.CircuitBreaker
matcher HttpRequestMatcher
}
)

func (b *SimpleBackend) URL() *url.URL {
return b.url
}

func (b *SimpleBackend) CB() *circuitbreaker.CircuitBreaker {
return b.cb
}

func init() {
Matchers.Any = func(*http.Request) bool { return true }
Matchers.AnyOf = func(ms ...HttpRequestMatcher) HttpRequestMatcher {
return func(r *http.Request) bool {
for _, m := range ms {
if m(r) {
return true
}
}
return false
}
}
Matchers.QueryParam = func(key, value string) HttpRequestMatcher {
return func(r *http.Request) bool {
if r == nil {
return false
}
values, ok := r.URL.Query()[key]
if !ok {
return false
}
for _, got := range values {
if value == got {
return true
}
}
return false
}
}
}

func NewBackend(u string, cb *circuitbreaker.CircuitBreaker, matcher HttpRequestMatcher) (Backend, error) {
burl, err := url.Parse(u)
if err != nil {
return nil, err
}

return &SimpleBackend{
url: burl,
cb: cb,
matcher: matcher,
}, nil
}

func (b *SimpleBackend) Matches(r *http.Request) bool {
return b.matcher(r)
}
32 changes: 17 additions & 15 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"encoding/json"
"errors"
"net/url"
"os"
"path/filepath"
"strconv"
Expand Down Expand Up @@ -36,6 +35,10 @@ const (
defaultCircuitOpenTimeout = 0
defaultCircuitCounterReset = 1 * time.Second

defaultCascadeCircuitHalfOpenSuccesses = 10
defaultCascadeCircuitOpenTimeout = 0
defaultCascadeCircuitCounterReset = 1 * time.Second

// DefaultPathName is the default config dir name.
DefaultPathName = ".indexstar"
// DefaultPathRoot is the path to the default config dir location.
Expand Down Expand Up @@ -73,6 +76,11 @@ var config struct {
OpenTimeout time.Duration
CounterReset time.Duration
}
CascadeCircuit struct {
HalfOpenSuccesses int
OpenTimeout time.Duration
CounterReset time.Duration
}
}

func init() {
Expand All @@ -98,6 +106,10 @@ func init() {
config.Circuit.HalfOpenSuccesses = getEnvOrDefault[int]("CIRCUIT_HALF_OPEN_SUCCESSES", defaultCircuitHalfOpenSuccesses)
config.Circuit.OpenTimeout = getEnvOrDefault[time.Duration]("CIRCUIT_OPEN_TIMEOUT", defaultCircuitOpenTimeout)
config.Circuit.CounterReset = getEnvOrDefault[time.Duration]("CIRCUIT_COUNTER_RESET", defaultCircuitCounterReset)

config.CascadeCircuit.HalfOpenSuccesses = getEnvOrDefault[int]("CASCADE_CIRCUIT_HALF_OPEN_SUCCESSES", defaultCascadeCircuitHalfOpenSuccesses)
config.CascadeCircuit.OpenTimeout = getEnvOrDefault[time.Duration]("CASCADE_CIRCUIT_OPEN_TIMEOUT", defaultCascadeCircuitOpenTimeout)
config.CascadeCircuit.CounterReset = getEnvOrDefault[time.Duration]("CASCADE_CIRCUIT_COUNTER_RESET", defaultCascadeCircuitCounterReset)
}

func getEnvOrDefault[T any](key string, def T) T {
Expand Down Expand Up @@ -175,7 +187,7 @@ func PathRoot() (string, error) {
return homedir.Expand(DefaultPathRoot)
}

func Load(filePath string) ([]*url.URL, error) {
func Load(filePath string) ([]string, error) {
var err error
if filePath == "" {
filePath, err = Filename("")
Expand All @@ -193,19 +205,9 @@ func Load(filePath string) ([]*url.URL, error) {
}
defer f.Close()

cfg := []string{}
if err = json.NewDecoder(f).Decode(&cfg); err != nil {
var urls []string
if err = json.NewDecoder(f).Decode(&urls); err != nil {
return nil, err
}

surls := make([]*url.URL, 0, len(cfg))
for _, s := range cfg {
surl, err := url.Parse(s)
if err != nil {
return nil, err
}
surls = append(surls, surl)
}

return surls, nil
return urls, nil
}
39 changes: 22 additions & 17 deletions find.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,18 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) {
method := r.Method
req := r.URL

sg := &scatterGather[*url.URL, []byte]{
targets: s.servers,
tcb: s.serverCallers,
maxWait: config.Server.ResultMaxWait,
sg := &scatterGather[Backend, []byte]{
backends: s.backends,
maxWait: config.Server.ResultMaxWait,
}

// TODO: wait for the first successful response instead
if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (*[]byte, error) {
if err := sg.scatter(ctx, func(cctx context.Context, b Backend) (*[]byte, error) {
// Copy the URL from original request and override host/schema to point
// to the server.
endpoint := *req
endpoint.Host = b.Host
endpoint.Scheme = b.Scheme
endpoint.Host = b.URL().Host
endpoint.Scheme = b.URL().Scheme
log := log.With("backend", endpoint.Host)

req, err := http.NewRequestWithContext(cctx, method, endpoint.String(), nil)
Expand All @@ -115,6 +114,9 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) {
}
req.Header.Set("X-Forwarded-Host", req.Host)
req.Header.Set("Accept", mediaTypeJson)
if !b.Matches(req) {
return nil, nil
}
resp, err := s.Client.Do(req)
if err != nil {
log.Warnw("Failed to query backend for metadata", "err", err)
Expand All @@ -136,7 +138,7 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) {
body := string(data)
log := log.With("status", resp.StatusCode, "body", body)
log.Warn("Request processing was not successful")
err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host)
err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.URL().Host)
if resp.StatusCode < http.StatusInternalServerError {
err = circuitbreaker.MarkAsSuccess(err)
}
Expand Down Expand Up @@ -231,33 +233,36 @@ func (s *server) doFind(ctx context.Context, method, source string, req *url.URL
stats.WithMeasurements(metrics.FindLoad.M(1)))
}()

sg := &scatterGather[*url.URL, *model.FindResponse]{
targets: s.servers,
tcb: s.serverCallers,
maxWait: config.Server.ResultMaxWait,
sg := &scatterGather[Backend, *model.FindResponse]{
backends: s.backends,
maxWait: config.Server.ResultMaxWait,
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

var count int32
if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (**model.FindResponse, error) {
if err := sg.scatter(ctx, func(cctx context.Context, b Backend) (**model.FindResponse, error) {
// Copy the URL from original request and override host/schema to point
// to the server.
endpoint := *req
endpoint.Host = b.Host
endpoint.Scheme = b.Scheme
endpoint.Host = b.URL().Host
endpoint.Scheme = b.URL().Scheme
log := log.With("backend", endpoint.Host)

bodyReader := bytes.NewReader(body)

req, err := http.NewRequestWithContext(cctx, method, endpoint.String(), bodyReader)
if err != nil {
log.Warnw("Failed to construct backend query", "err", err)
return nil, err
}
req.Header.Set("X-Forwarded-Host", req.Host)
req.Header.Set("Accept", mediaTypeJson)

if !b.Matches(req) {
return nil, nil
}

resp, err := s.Client.Do(req)
if err != nil {
log.Warnw("Failed to query backend", "err", err)
Expand Down Expand Up @@ -285,7 +290,7 @@ func (s *server) doFind(ctx context.Context, method, source string, req *url.URL
body := string(data)
log := log.With("status", resp.StatusCode, "body", body)
log.Warn("Request processing was not successful")
err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host)
err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.URL().Host)
if resp.StatusCode < http.StatusInternalServerError {
err = circuitbreaker.MarkAsSuccess(err)
}
Expand Down
40 changes: 28 additions & 12 deletions find_ndjson.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,22 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source
maxWait = config.Server.ResultStreamMaxWait
}

sg := &scatterGather[*url.URL, any]{
targets: s.servers,
tcb: s.serverCallers,
maxWait: maxWait,
sg := &scatterGather[Backend, any]{
backends: s.backends,
maxWait: maxWait,
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

resultsChan := make(chan *encryptedOrPlainResult, 1)
var count int32
if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (*any, error) {
if err := sg.scatter(ctx, func(cctx context.Context, b Backend) (*any, error) {
// Copy the URL from original request and override host/schema to point
// to the server.
endpoint := *req
endpoint.Host = b.Host
endpoint.Scheme = b.Scheme
endpoint.Host = b.URL().Host
endpoint.Scheme = b.URL().Scheme
log := log.With("backend", endpoint.Host)

req, err := http.NewRequestWithContext(cctx, http.MethodGet, endpoint.String(), nil)
Expand All @@ -102,6 +101,11 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source
}
req.Header.Set("X-Forwarded-Host", req.Host)
req.Header.Set("Accept", mediaTypeNDJson)

if !b.Matches(req) {
return nil, nil
}

resp, err := s.Client.Do(req)
if err != nil {
log.Warnw("Failed to query backend", "err", err)
Expand All @@ -119,7 +123,7 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source
body := string(bb)
log := log.With("status", resp.StatusCode, "body", body)
log.Warn("Request processing was not successful")
err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host)
err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.URL().Host)
if resp.StatusCode < http.StatusInternalServerError {
err = circuitbreaker.MarkAsSuccess(err)
}
Expand Down Expand Up @@ -182,15 +186,27 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source
encoder := json.NewEncoder(w)
results := newResultSet()

// Results chan is done when gathering is finished.
// Do this in a separate goroutine to avoid potentially closing results chan twice.
go func() {
for {
select {
case <-ctx.Done():
return
case _, ok := <-sg.gather(ctx):
if !ok {
close(resultsChan)
return
}
}
}
}()

LOOP:
for {
select {
case <-ctx.Done():
break LOOP
case _, ok := <-sg.gather(ctx):
if !ok {
close(resultsChan)
}
case result, ok := <-resultsChan:
if !ok {
break LOOP
Expand Down
14 changes: 9 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,28 @@ func main() {
Flags: []cli.Flag{
&cli.StringFlag{
Name: "config",
Usage: "config file",
Usage: "Path to config file",
TakesFile: true,
},
&cli.StringFlag{
Name: "listen",
Usage: "listen address",
Usage: "HTTP server Listen address",
Value: ":8080",
},
&cli.StringFlag{
Name: "metrics",
Usage: "metrics address",
Usage: "Metrics server listen address",
Value: ":8081",
},
&cli.StringSliceFlag{
Name: "backends",
Usage: "backends to use",
Usage: "Backends to propagate requests to.",
Value: cli.NewStringSlice("https://cid.contact/"),
},
&cli.StringSliceFlag{
Name: "cascadeBackends",
Usage: "Backends to propagate lookup with SERVER_CASCADE_LABELS env var as query parameter",
},
&cli.BoolFlag{
Name: "translateReframe",
Usage: "translate reframe requests into find requests to backends",
Expand Down Expand Up @@ -104,7 +108,7 @@ func main() {
case err := <-done:
return err
case <-reloadSig:
err := s.Reload()
err := s.Reload(c)
if err != nil {
log.Warnf("couldn't reload servers: %s", err)
}
Expand Down
Loading

0 comments on commit 5ab5f8f

Please sign in to comment.