diff --git a/backend.go b/backend.go new file mode 100644 index 0000000..129dc5f --- /dev/null +++ b/backend.go @@ -0,0 +1,84 @@ +package main + +import ( + "net/http" + "net/url" + + "github.com/mercari/go-circuitbreaker" +) + +var Matchers struct { + Any HttpRequestMatcher + AnyOf func(...HttpRequestMatcher) HttpRequestMatcher + QueryParam func(key, value string) HttpRequestMatcher +} + +type ( + HttpRequestMatcher func(r *http.Request) bool + Backend interface { + URL() *url.URL + CB() *circuitbreaker.CircuitBreaker + Matches(r *http.Request) bool + } + SimpleBackend struct { + url *url.URL + cb *circuitbreaker.CircuitBreaker + matcher HttpRequestMatcher + } +) + +func (b *SimpleBackend) URL() *url.URL { + return b.url +} + +func (b *SimpleBackend) CB() *circuitbreaker.CircuitBreaker { + return b.cb +} + +func init() { + Matchers.Any = func(*http.Request) bool { return true } + Matchers.AnyOf = func(ms ...HttpRequestMatcher) HttpRequestMatcher { + return func(r *http.Request) bool { + for _, m := range ms { + if m(r) { + return true + } + } + return false + } + } + Matchers.QueryParam = func(key, value string) HttpRequestMatcher { + return func(r *http.Request) bool { + if r == nil { + return false + } + values, ok := r.URL.Query()[key] + if !ok { + return false + } + for _, got := range values { + if value == got { + return true + } + } + return false + } + } +} + +func NewBackend(u string, cb *circuitbreaker.CircuitBreaker, matcher HttpRequestMatcher) (Backend, error) { + burl, err := url.Parse(u) + if err != nil { + return nil, err + } + + return &SimpleBackend{ + url: burl, + cb: cb, + matcher: matcher, + }, nil +} + +func (b *SimpleBackend) Matches(r *http.Request) bool { + return b.matcher(r) +} diff --git a/config.go b/config.go index daee143..c3e7f53 100644 --- a/config.go +++ b/config.go @@ -3,7 +3,6 @@ package main import ( "encoding/json" "errors" - "net/url" "os" "path/filepath" "strconv" @@ -36,6 +35,10 @@ const ( defaultCircuitOpenTimeout = 0 defaultCircuitCounterReset = 1 * time.Second + defaultCascadeCircuitHalfOpenSuccesses = 10 + defaultCascadeCircuitOpenTimeout = 0 + defaultCascadeCircuitCounterReset = 1 * time.Second + // DefaultPathName is the default config dir name. DefaultPathName = ".indexstar" // DefaultPathRoot is the path to the default config dir location. @@ -73,6 +76,11 @@ var config struct { OpenTimeout time.Duration CounterReset time.Duration } + CascadeCircuit struct { + HalfOpenSuccesses int + OpenTimeout time.Duration + CounterReset time.Duration + } } func init() { @@ -98,6 +106,10 @@ func init() { config.Circuit.HalfOpenSuccesses = getEnvOrDefault[int]("CIRCUIT_HALF_OPEN_SUCCESSES", defaultCircuitHalfOpenSuccesses) config.Circuit.OpenTimeout = getEnvOrDefault[time.Duration]("CIRCUIT_OPEN_TIMEOUT", defaultCircuitOpenTimeout) config.Circuit.CounterReset = getEnvOrDefault[time.Duration]("CIRCUIT_COUNTER_RESET", defaultCircuitCounterReset) + + config.CascadeCircuit.HalfOpenSuccesses = getEnvOrDefault[int]("CASCADE_CIRCUIT_HALF_OPEN_SUCCESSES", defaultCascadeCircuitHalfOpenSuccesses) + config.CascadeCircuit.OpenTimeout = getEnvOrDefault[time.Duration]("CASCADE_CIRCUIT_OPEN_TIMEOUT", defaultCascadeCircuitOpenTimeout) + config.CascadeCircuit.CounterReset = getEnvOrDefault[time.Duration]("CASCADE_CIRCUIT_COUNTER_RESET", defaultCascadeCircuitCounterReset) } func getEnvOrDefault[T any](key string, def T) T { @@ -175,7 +187,7 @@ func PathRoot() (string, error) { return homedir.Expand(DefaultPathRoot) } -func Load(filePath string) ([]*url.URL, error) { +func Load(filePath string) ([]string, error) { var err error if filePath == "" { filePath, err = Filename("") @@ -193,19 +205,9 @@ func Load(filePath string) ([]*url.URL, error) { } defer f.Close() - cfg := []string{} - if err = json.NewDecoder(f).Decode(&cfg); err != nil { + var urls []string + if err = json.NewDecoder(f).Decode(&urls); err != nil { return nil, err } - - surls := make([]*url.URL, 0, len(cfg)) - for _, s := range cfg { - surl, err := url.Parse(s) - if err != nil { - return nil, err - } - surls = append(surls, surl) - } - - return surls, nil + return urls, nil } diff --git a/find.go b/find.go index 8a71446..0d5cce1 100644 --- a/find.go +++ b/find.go @@ -93,19 +93,18 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) { method := r.Method req := r.URL - sg := &scatterGather[*url.URL, []byte]{ - targets: s.servers, - tcb: s.serverCallers, - maxWait: config.Server.ResultMaxWait, + sg := &scatterGather[Backend, []byte]{ + backends: s.backends, + maxWait: config.Server.ResultMaxWait, } // TODO: wait for the first successful response instead - if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (*[]byte, error) { + if err := sg.scatter(ctx, func(cctx context.Context, b Backend) (*[]byte, error) { // Copy the URL from original request and override host/schema to point // to the server. endpoint := *req - endpoint.Host = b.Host - endpoint.Scheme = b.Scheme + endpoint.Host = b.URL().Host + endpoint.Scheme = b.URL().Scheme log := log.With("backend", endpoint.Host) req, err := http.NewRequestWithContext(cctx, method, endpoint.String(), nil) @@ -115,6 +114,9 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) { } req.Header.Set("X-Forwarded-Host", req.Host) req.Header.Set("Accept", mediaTypeJson) + if !b.Matches(req) { + return nil, nil + } resp, err := s.Client.Do(req) if err != nil { log.Warnw("Failed to query backend for metadata", "err", err) @@ -136,7 +138,7 @@ func (s *server) findMetadataSubtree(w http.ResponseWriter, r *http.Request) { body := string(data) log := log.With("status", resp.StatusCode, "body", body) log.Warn("Request processing was not successful") - err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host) + err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.URL().Host) if resp.StatusCode < http.StatusInternalServerError { err = circuitbreaker.MarkAsSuccess(err) } @@ -231,26 +233,24 @@ func (s *server) doFind(ctx context.Context, method, source string, req *url.URL stats.WithMeasurements(metrics.FindLoad.M(1))) }() - sg := &scatterGather[*url.URL, *model.FindResponse]{ - targets: s.servers, - tcb: s.serverCallers, - maxWait: config.Server.ResultMaxWait, + sg := &scatterGather[Backend, *model.FindResponse]{ + backends: s.backends, + maxWait: config.Server.ResultMaxWait, } ctx, cancel := context.WithCancel(ctx) defer cancel() var count int32 - if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (**model.FindResponse, error) { + if err := sg.scatter(ctx, func(cctx context.Context, b Backend) (**model.FindResponse, error) { // Copy the URL from original request and override host/schema to point // to the server. endpoint := *req - endpoint.Host = b.Host - endpoint.Scheme = b.Scheme + endpoint.Host = b.URL().Host + endpoint.Scheme = b.URL().Scheme log := log.With("backend", endpoint.Host) bodyReader := bytes.NewReader(body) - req, err := http.NewRequestWithContext(cctx, method, endpoint.String(), bodyReader) if err != nil { log.Warnw("Failed to construct backend query", "err", err) @@ -258,6 +258,11 @@ func (s *server) doFind(ctx context.Context, method, source string, req *url.URL } req.Header.Set("X-Forwarded-Host", req.Host) req.Header.Set("Accept", mediaTypeJson) + + if !b.Matches(req) { + return nil, nil + } + resp, err := s.Client.Do(req) if err != nil { log.Warnw("Failed to query backend", "err", err) @@ -285,7 +290,7 @@ func (s *server) doFind(ctx context.Context, method, source string, req *url.URL body := string(data) log := log.With("status", resp.StatusCode, "body", body) log.Warn("Request processing was not successful") - err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host) + err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.URL().Host) if resp.StatusCode < http.StatusInternalServerError { err = circuitbreaker.MarkAsSuccess(err) } diff --git a/find_ndjson.go b/find_ndjson.go index d95b85d..b942492 100644 --- a/find_ndjson.go +++ b/find_ndjson.go @@ -76,10 +76,9 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source maxWait = config.Server.ResultStreamMaxWait } - sg := &scatterGather[*url.URL, any]{ - targets: s.servers, - tcb: s.serverCallers, - maxWait: maxWait, + sg := &scatterGather[Backend, any]{ + backends: s.backends, + maxWait: maxWait, } ctx, cancel := context.WithCancel(ctx) @@ -87,12 +86,12 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source resultsChan := make(chan *encryptedOrPlainResult, 1) var count int32 - if err := sg.scatter(ctx, func(cctx context.Context, b *url.URL) (*any, error) { + if err := sg.scatter(ctx, func(cctx context.Context, b Backend) (*any, error) { // Copy the URL from original request and override host/schema to point // to the server. endpoint := *req - endpoint.Host = b.Host - endpoint.Scheme = b.Scheme + endpoint.Host = b.URL().Host + endpoint.Scheme = b.URL().Scheme log := log.With("backend", endpoint.Host) req, err := http.NewRequestWithContext(cctx, http.MethodGet, endpoint.String(), nil) @@ -102,6 +101,11 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source } req.Header.Set("X-Forwarded-Host", req.Host) req.Header.Set("Accept", mediaTypeNDJson) + + if !b.Matches(req) { + return nil, nil + } + resp, err := s.Client.Do(req) if err != nil { log.Warnw("Failed to query backend", "err", err) @@ -119,7 +123,7 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source body := string(bb) log := log.With("status", resp.StatusCode, "body", body) log.Warn("Request processing was not successful") - err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.Host) + err := fmt.Errorf("status %d response from backend %s", resp.StatusCode, b.URL().Host) if resp.StatusCode < http.StatusInternalServerError { err = circuitbreaker.MarkAsSuccess(err) } @@ -182,15 +186,27 @@ func (s *server) doFindNDJson(ctx context.Context, w http.ResponseWriter, source encoder := json.NewEncoder(w) results := newResultSet() + // Results chan is done when gathering is finished. + // Do this in a separate goroutine to avoid potentially closing results chan twice. + go func() { + for { + select { + case <-ctx.Done(): + return + case _, ok := <-sg.gather(ctx): + if !ok { + close(resultsChan) + return + } + } + } + }() + LOOP: for { select { case <-ctx.Done(): break LOOP - case _, ok := <-sg.gather(ctx): - if !ok { - close(resultsChan) - } case result, ok := <-resultsChan: if !ok { break LOOP diff --git a/main.go b/main.go index 59d1a08..9b474c6 100644 --- a/main.go +++ b/main.go @@ -23,24 +23,28 @@ func main() { Flags: []cli.Flag{ &cli.StringFlag{ Name: "config", - Usage: "config file", + Usage: "Path to config file", TakesFile: true, }, &cli.StringFlag{ Name: "listen", - Usage: "listen address", + Usage: "HTTP server Listen address", Value: ":8080", }, &cli.StringFlag{ Name: "metrics", - Usage: "metrics address", + Usage: "Metrics server listen address", Value: ":8081", }, &cli.StringSliceFlag{ Name: "backends", - Usage: "backends to use", + Usage: "Backends to propagate requests to.", Value: cli.NewStringSlice("https://cid.contact/"), }, + &cli.StringSliceFlag{ + Name: "cascadeBackends", + Usage: "Backends to propagate lookup with SERVER_CASCADE_LABELS env var as query parameter", + }, &cli.BoolFlag{ Name: "translateReframe", Usage: "translate reframe requests into find requests to backends", @@ -104,7 +108,7 @@ func main() { case err := <-done: return err case <-reloadSig: - err := s.Reload() + err := s.Reload(c) if err != nil { log.Warnf("couldn't reload servers: %s", err) } diff --git a/providers.go b/providers.go index 2ce9040..55b3354 100644 --- a/providers.go +++ b/providers.go @@ -30,7 +30,7 @@ func (s *server) providers(w http.ResponseWriter, r *http.Request) { return } - for _, server := range s.servers { + for _, backend := range s.backends { wg.Add(1) go func(server *url.URL) { defer wg.Done() @@ -72,7 +72,7 @@ func (s *server) providers(w http.ResponseWriter, r *http.Request) { default: log.Warn("unexpected response while getting providers") } - }(server) + }(backend.URL()) } go func() { wg.Wait() @@ -122,7 +122,7 @@ func (s *server) provider(w http.ResponseWriter, r *http.Request) { return } - for _, server := range s.servers { + for _, backend := range s.backends { wg.Add(1) go func(server *url.URL) { defer wg.Done() @@ -162,7 +162,7 @@ func (s *server) provider(w http.ResponseWriter, r *http.Request) { default: log.Warn("unexpected response while getting provider") } - }(server) + }(backend.URL()) } go func() { wg.Wait() diff --git a/reframe.go b/reframe.go index 22fc0bf..2ced837 100644 --- a/reframe.go +++ b/reframe.go @@ -4,7 +4,6 @@ import ( "context" "net" "net/http" - "net/url" "path" "github.com/ipfs/go-cid" @@ -15,7 +14,7 @@ import ( "github.com/libp2p/go-libp2p/core/routing" ) -func NewReframeHTTPHandler(backends []*url.URL) (http.HandlerFunc, error) { +func NewReframeHTTPHandler(backends []Backend) (http.HandlerFunc, error) { svc, err := NewReframeService(backends) if err != nil { return nil, err @@ -23,7 +22,7 @@ func NewReframeHTTPHandler(backends []*url.URL) (http.HandlerFunc, error) { return drserver.DelegatedRoutingAsyncHandler(svc), nil } -func NewReframeService(backends []*url.URL) (*ReframeService, error) { +func NewReframeService(backends []Backend) (*ReframeService, error) { httpClient := http.Client{ Timeout: config.Reframe.HttpClientTimeout, Transport: reframeRoundTripper(), @@ -32,7 +31,7 @@ func NewReframeService(backends []*url.URL) (*ReframeService, error) { clients := make([]*backendDelegatedRoutingClient, 0, len(backends)) for _, b := range backends { // TODO: replace with URL.JoinPath once upgraded to go 1.19 - endpoint := path.Join(b.String(), "reframe") + endpoint := path.Join(b.URL().String(), "reframe") q, err := drproto.New_DelegatedRouting_Client(endpoint, drproto.DelegatedRouting_Client_WithHTTPClient(&httpClient)) if err != nil { return nil, err @@ -42,8 +41,8 @@ func NewReframeService(backends []*url.URL) (*ReframeService, error) { return nil, err } clients = append(clients, &backendDelegatedRoutingClient{ - DelegatedRoutingClient: drc, - url: b, + Backend: b, + client: drc, }) } return &ReframeService{clients}, nil @@ -69,21 +68,24 @@ type ReframeService struct { } type backendDelegatedRoutingClient struct { - drclient.DelegatedRoutingClient - url *url.URL + Backend + client drclient.DelegatedRoutingClient } func (x *ReframeService) FindProviders(ctx context.Context, key cid.Cid) (<-chan drclient.FindProvidersAsyncResult, error) { sg := &scatterGather[*backendDelegatedRoutingClient, drclient.FindProvidersAsyncResult]{ - targets: x.backends, - maxWait: config.Reframe.ResultMaxWait, + backends: x.backends, + maxWait: config.Reframe.ResultMaxWait, } ctx, cancel := context.WithCancel(ctx) defer cancel() if err := sg.scatter(ctx, func(cctx context.Context, b *backendDelegatedRoutingClient) (*drclient.FindProvidersAsyncResult, error) { - ch, err := b.FindProvidersAsync(cctx, key) + if !b.Matches(nil) { + return nil, nil + } + ch, err := b.client.FindProvidersAsync(cctx, key) if err != nil { return nil, err } diff --git a/reframe_translate_test.go b/reframe_translate_test.go index 6616a2e..02e8ff5 100644 --- a/reframe_translate_test.go +++ b/reframe_translate_test.go @@ -5,23 +5,19 @@ import ( "net" "net/http" "net/http/httputil" - "net/url" "testing" "github.com/ipfs/go-cid" drp "github.com/ipfs/go-delegated-routing/gen/proto" finderhttpclient "github.com/ipni/storetheindex/api/v0/finder/client/http" - "github.com/mercari/go-circuitbreaker" "github.com/stretchr/testify/require" ) func doServe(ctx context.Context, bound net.Listener) { - surls := make([]*url.URL, 0, 1) - surl, err := url.Parse("https://cid.contact/") + backends, err := loadBackends([]string{"https://cid.contact/"}, nil) if err != nil { return } - surls = append(surls, surl) b2, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -33,9 +29,8 @@ func doServe(ctx context.Context, bound net.Listener) { Client: *http.DefaultClient, Listener: bound, metricsListener: b2, - servers: surls, - serverCallers: []*circuitbreaker.CircuitBreaker{circuitbreaker.New()}, - base: httputil.NewSingleHostReverseProxy(surls[0]), + backends: backends, + base: httputil.NewSingleHostReverseProxy(backends[0].URL()), translateReframe: true, } s.Serve() diff --git a/scatter_gather.go b/scatter_gather.go index cc18bb6..f50ff42 100644 --- a/scatter_gather.go +++ b/scatter_gather.go @@ -4,34 +4,27 @@ import ( "context" "sync" "time" - - "github.com/mercari/go-circuitbreaker" ) -type scatterGather[T, R any] struct { - targets []T - tcb []*circuitbreaker.CircuitBreaker - start time.Time - wg sync.WaitGroup - out chan R - maxWait time.Duration +type scatterGather[B Backend, R any] struct { + backends []B + start time.Time + wg sync.WaitGroup + out chan R + maxWait time.Duration } -func (sg *scatterGather[T, R]) scatter(ctx context.Context, forEach func(context.Context, T) (*R, error)) error { +func (sg *scatterGather[B, R]) scatter(ctx context.Context, forEach func(context.Context, B) (*R, error)) error { sg.start = time.Now() sg.out = make(chan R, 1) - for i, t := range sg.targets { + for _, backend := range sg.backends { - var cb *circuitbreaker.CircuitBreaker - if len(sg.tcb) > i { - cb = sg.tcb[i] - } - if cb != nil && !cb.Ready() { + if backend.CB() != nil && !backend.CB().Ready() { continue } sg.wg.Add(1) - go func(target T, tcb *circuitbreaker.CircuitBreaker) { + go func(target B) { defer sg.wg.Done() select { @@ -44,11 +37,11 @@ func (sg *scatterGather[T, R]) scatter(ctx context.Context, forEach func(context cctx, cancel := context.WithTimeout(ctx, sg.maxWait) sout, err := forEach(cctx, target) cancel() - if tcb != nil { - err = tcb.Done(cctx, err) + if target.CB() != nil { + err = target.CB().Done(cctx, err) } if err != nil { - log.Errorw("failed to scatter on target", "target", target, "err", err, "maxWait", sg.maxWait) + log.Errorw("failed to scatter on target", "target", target.URL().Host, "err", err, "maxWait", sg.maxWait) return } if sout != nil { @@ -57,7 +50,7 @@ func (sg *scatterGather[T, R]) scatter(ctx context.Context, forEach func(context case sg.out <- *sout: } } - }(t, cb) + }(backend) } go func() { defer close(sg.out) diff --git a/scatter_gather_test.go b/scatter_gather_test.go index 98f9767..7fbfd24 100644 --- a/scatter_gather_test.go +++ b/scatter_gather_test.go @@ -4,20 +4,39 @@ import ( "context" "errors" "fmt" + "net/http" + "net/url" "testing" "time" + "github.com/mercari/go-circuitbreaker" "github.com/stretchr/testify/require" ) +var _ (Backend) = (*testBackend)(nil) + +type testBackend int + +func (t testBackend) URL() *url.URL { + u, err := url.Parse("http://test.invalid") + if err != nil { + panic(err) + } + return u +} + +func (t testBackend) CB() *circuitbreaker.CircuitBreaker { return nil } + +func (t testBackend) Matches(*http.Request) bool { return false } + func TestScatterGather_GathersExpectedResults(t *testing.T) { - subject := scatterGather[int, string]{ - targets: []int{1, 2, 3, 4, 5}, - maxWait: 2 * time.Second, + subject := scatterGather[testBackend, string]{ + backends: []testBackend{testBackend(1), testBackend(2), testBackend(3), testBackend(4), testBackend(5)}, + maxWait: 2 * time.Second, } ctx := context.Background() - err := subject.scatter(ctx, func(cctx context.Context, i int) (*string, error) { + err := subject.scatter(ctx, func(cctx context.Context, i testBackend) (*string, error) { if cctx.Err() == nil { str := fmt.Sprintf("%d fish", i) return &str, nil @@ -39,12 +58,12 @@ func TestScatterGather_GathersExpectedResults(t *testing.T) { } func TestScatterGather_ExcludesScatterErrors(t *testing.T) { - subject := scatterGather[int, string]{ - targets: []int{1, 2, 3}, - maxWait: 2 * time.Second, + subject := scatterGather[testBackend, string]{ + backends: []testBackend{testBackend(1), testBackend(2), testBackend(3)}, + maxWait: 2 * time.Second, } ctx := context.Background() - err := subject.scatter(ctx, func(cctx context.Context, i int) (*string, error) { + err := subject.scatter(ctx, func(cctx context.Context, i testBackend) (*string, error) { if i == 2 { return nil, errors.New("fish says no") } @@ -67,12 +86,12 @@ func TestScatterGather_ExcludesScatterErrors(t *testing.T) { } func TestScatterGather_DoesNotWaitLongerThanExpected(t *testing.T) { - subject := scatterGather[int, string]{ - targets: []int{1}, - maxWait: 100 * time.Millisecond, + subject := scatterGather[testBackend, string]{ + backends: []testBackend{testBackend(1)}, + maxWait: 100 * time.Millisecond, } ctx := context.Background() - err := subject.scatter(ctx, func(cctx context.Context, i int) (*string, error) { + err := subject.scatter(ctx, func(cctx context.Context, i testBackend) (*string, error) { time.Sleep(2 * time.Second) if cctx.Err() == nil { str := fmt.Sprintf("%d fish", i) @@ -90,14 +109,14 @@ func TestScatterGather_DoesNotWaitLongerThanExpected(t *testing.T) { } func TestScatterGather_GathersNothingWhenContextIsCancelled(t *testing.T) { - subject := scatterGather[int, string]{ - targets: []int{1, 2, 3}, - maxWait: 2 * time.Second, + subject := scatterGather[testBackend, string]{ + backends: []testBackend{testBackend(1), testBackend(2), testBackend(3)}, + maxWait: 2 * time.Second, } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) cancel() - err := subject.scatter(ctx, func(cctx context.Context, i int) (*string, error) { + err := subject.scatter(ctx, func(cctx context.Context, i testBackend) (*string, error) { if cctx.Err() == nil { str := fmt.Sprintf("%d fish", i) return &str, nil diff --git a/server.go b/server.go index 7b66cfc..feddac5 100644 --- a/server.go +++ b/server.go @@ -6,12 +6,11 @@ import ( "net" "net/http" "net/http/httputil" - "net/url" + "strings" logging "github.com/ipfs/go-log/v2" "github.com/ipni/indexstar/httpserver" "github.com/ipni/indexstar/metrics" - "github.com/mercari/go-circuitbreaker" "github.com/urfave/cli/v2" ) @@ -24,8 +23,7 @@ type server struct { net.Listener metricsListener net.Listener cfgBase string - servers []*url.URL - serverCallers []*circuitbreaker.CircuitBreaker + backends []Backend base http.Handler translateReframe bool translateNonStreaming bool @@ -41,36 +39,21 @@ func NewServer(c *cli.Context) (*server, error) { return nil, err } servers := c.StringSlice("backends") + cascadeServers := c.StringSlice("cascadeBackends") - var surls []*url.URL if len(servers) == 0 { if !c.IsSet("config") { return nil, fmt.Errorf("no backends specified") } - surls, err = Load(c.String("config")) + servers, err = Load(c.String("config")) if err != nil { return nil, fmt.Errorf("could not load backends from config: %w", err) } - } else { - surls = make([]*url.URL, len(servers)) - for i, s := range servers { - surls[i], err = url.Parse(s) - if err != nil { - return nil, err - } - } } - scallers := make([]*circuitbreaker.CircuitBreaker, len(surls)) - for i, surl := range surls { - scallers[i] = circuitbreaker.New( - circuitbreaker.WithFailOnContextCancel(false), - circuitbreaker.WithHalfOpenMaxSuccesses(int64(config.Circuit.HalfOpenSuccesses)), - circuitbreaker.WithOpenTimeout(config.Circuit.OpenTimeout), - circuitbreaker.WithCounterResetInterval(config.Circuit.CounterReset), - circuitbreaker.WithOnStateChangeHookFn(func(from, to circuitbreaker.State) { - log.Infof("circuit state for %s changed from %s to %s", surl.String(), from, to) - })) + backends, err := loadBackends(servers, cascadeServers) + if err != nil { + return nil, err } t := http.DefaultTransport.(*http.Transport).Clone() @@ -94,21 +77,73 @@ func NewServer(c *cli.Context) (*server, error) { cfgBase: c.String("config"), Listener: bound, metricsListener: mb, - servers: surls, - serverCallers: scallers, - base: httputil.NewSingleHostReverseProxy(surls[0]), + backends: backends, + base: httputil.NewSingleHostReverseProxy(backends[0].URL()), translateReframe: c.Bool("translateReframe"), translateNonStreaming: c.Bool("translateNonStreaming"), }, nil } -func (s *server) Reload() error { +func loadBackends(servers, cascadeServers []string) ([]Backend, error) { + var backends []Backend + for _, s := range servers { + b, err := NewBackend(s, circuitbreaker.New( + circuitbreaker.WithFailOnContextCancel(false), + circuitbreaker.WithHalfOpenMaxSuccesses(int64(config.Circuit.HalfOpenSuccesses)), + circuitbreaker.WithOpenTimeout(config.Circuit.OpenTimeout), + circuitbreaker.WithCounterResetInterval(config.Circuit.CounterReset), + circuitbreaker.WithOnStateChangeHookFn(func(from, to circuitbreaker.State) { + log.Infof("circuit state for %s changed from %s to %s", s, from, to) + })), Matchers.Any) + if err != nil { + return nil, fmt.Errorf("failed to instantiate backend: %w", err) + } + backends = append(backends, b) + } + + for _, cs := range cascadeServers { + matcher := Matchers.Any + if config.Server.CascadeLabels != "" { + labels := strings.Split(config.Server.CascadeLabels, ",") + if len(labels) > 0 { + labelMatchers := make([]HttpRequestMatcher, 0, len(labels)) + for _, label := range labels { + labelMatchers = append(labelMatchers, Matchers.QueryParam("cascade", label)) + } + matcher = Matchers.AnyOf(labelMatchers...) + } + } + b, err := NewBackend(cs, circuitbreaker.New( + circuitbreaker.WithFailOnContextCancel(false), + circuitbreaker.WithHalfOpenMaxSuccesses(int64(config.CascadeCircuit.HalfOpenSuccesses)), + circuitbreaker.WithOpenTimeout(config.CascadeCircuit.OpenTimeout), + circuitbreaker.WithCounterResetInterval(config.CascadeCircuit.CounterReset), + circuitbreaker.WithOnStateChangeHookFn(func(from, to circuitbreaker.State) { + log.Infof("cascade circuit state for %s changed from %s to %s", cs, from, to) + })), matcher) + if err != nil { + return nil, fmt.Errorf("failed to instantiate cascade backend: %w", err) + } + backends = append(backends, b) + } + + if len(backends) == 0 { + return nil, fmt.Errorf("no backends specified") + } + return backends, nil +} + +func (s *server) Reload(cctx *cli.Context) error { surls, err := Load(s.cfgBase) if err != nil { return err } - s.servers = surls - s.base = httputil.NewSingleHostReverseProxy(surls[0]) + b, err := loadBackends(surls, cctx.StringSlice("cascadeBackends")) + if err != nil { + return err + } + s.backends = b + s.base = httputil.NewSingleHostReverseProxy(b[0].URL()) return nil } @@ -133,7 +168,7 @@ func (s *server) Serve() chan error { } mux.HandleFunc("/reframe", reframe) } else { - reframe, err := NewReframeHTTPHandler(s.servers) + reframe, err := NewReframeHTTPHandler(s.backends) if err != nil { ec <- err close(ec) @@ -202,9 +237,10 @@ func (s *server) health(w http.ResponseWriter, r *http.Request) { func (s *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // default behavior. - r.URL.Host = s.servers[0].Host - r.URL.Scheme = s.servers[0].Scheme + firstBackend := s.backends[0].URL() + r.URL.Host = firstBackend.Host + r.URL.Scheme = firstBackend.Scheme r.Header.Set("X-Forwarded-Host", r.Header.Get("Host")) - r.Header.Set("Host", s.servers[0].Host) + r.Header.Set("Host", firstBackend.Host) s.base.ServeHTTP(w, r) }