diff --git a/injectproxy/routes.go b/injectproxy/routes.go index 6d1ef38e..4a6789d0 100644 --- a/injectproxy/routes.go +++ b/injectproxy/routes.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "io" + "log" "net/http" "net/http/httputil" "net/url" @@ -48,6 +49,8 @@ type routes struct { modifiers map[string]func(*http.Response) error errorOnReplace bool regexMatch bool + + logger *log.Logger } type options struct { @@ -297,6 +300,7 @@ func NewRoutes(upstream *url.URL, label string, extractLabeler ExtractLabeler, o el: extractLabeler, errorOnReplace: opt.errorOnReplace, regexMatch: opt.regexMatch, + logger: log.Default(), } mux := newStrictMux(newInstrumentedMux(http.NewServeMux(), opt.registerer)) @@ -378,10 +382,10 @@ func NewRoutes(upstream *url.URL, label string, extractLabeler ExtractLabeler, o "/api/v1/rules": modifyAPIResponse(r.filterRules), "/api/v1/alerts": modifyAPIResponse(r.filterAlerts), } - //FIXME: when ModifyResponse returns an error, the default ErrorHandler is - //called which returns 502 Bad Gateway. It'd be more appropriate to treat - //the error and return 400 in case of bad input for instance. proxy.ModifyResponse = r.ModifyResponse + proxy.ErrorHandler = r.errorHandler + proxy.ErrorLog = log.Default() + return r, nil } @@ -395,9 +399,19 @@ func (r *routes) ModifyResponse(resp *http.Response) error { // Return the server's response unmodified. return nil } + return m(resp) } +func (r *routes) errorHandler(rw http.ResponseWriter, _ *http.Request, err error) { + r.logger.Printf("http: proxy error: %v", err) + if errors.Is(err, errModifyResponseFailed) { + rw.WriteHeader(http.StatusBadRequest) + } + + rw.WriteHeader(http.StatusBadGateway) +} + func enforceMethods(h http.HandlerFunc, methods ...string) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { for _, m := range methods { diff --git a/injectproxy/rules.go b/injectproxy/rules.go index ef93cb96..7b39ecf5 100644 --- a/injectproxy/rules.go +++ b/injectproxy/rules.go @@ -17,6 +17,7 @@ import ( "bytes" "compress/gzip" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -166,6 +167,10 @@ type alert struct { Value string `json:"value"` } +// errModifyResponseFailed is returned when the proxy failed to modify the +// response from the backend. +var errModifyResponseFailed = errors.New("failed to process the API response") + // modifyAPIResponse unwraps the Prometheus API response, passes the enforced // label value and the response to the given function and finally replaces the // result in the response. @@ -178,23 +183,24 @@ func modifyAPIResponse(f func([]string, *apiResponse) (interface{}, error)) func apir, err := getAPIResponse(resp) if err != nil { - return fmt.Errorf("can't decode API response: %w", err) + return fmt.Errorf("can't decode the response: %w", err) } v, err := f(MustLabelValues(resp.Request.Context()), apir) if err != nil { - return err + return fmt.Errorf("%w: %w", errModifyResponseFailed, err) } b, err := json.Marshal(v) if err != nil { - return fmt.Errorf("can't replace data: %w", err) + return fmt.Errorf("can't encode the data: %w", err) } + apir.Data = json.RawMessage(b) var buf bytes.Buffer if err = json.NewEncoder(&buf).Encode(apir); err != nil { - return fmt.Errorf("can't encode API response: %w", err) + return fmt.Errorf("can't encode the response: %w", err) } resp.Body = io.NopCloser(&buf) resp.Header["Content-Length"] = []string{fmt.Sprint(buf.Len())} diff --git a/injectproxy/rules_test.go b/injectproxy/rules_test.go index 421d278d..a0b7f9e0 100644 --- a/injectproxy/rules_test.go +++ b/injectproxy/rules_test.go @@ -469,7 +469,7 @@ func TestRules(t *testing.T) { upstream: validRules(), opts: []Option{WithRegexMatch()}, - expCode: http.StatusBadGateway, + expCode: http.StatusBadRequest, golden: "rules_invalid_upstream_response.golden", }, } { @@ -513,7 +513,10 @@ func TestRules(t *testing.T) { t.Fatalf("expected status code %d, got %d", tc.expCode, resp.StatusCode) } - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("expected no error, got %s", err) + } if resp.StatusCode != http.StatusOK { golden.Assert(t, string(body), tc.golden) return @@ -613,7 +616,7 @@ func TestAlerts(t *testing.T) { upstream: validAlerts(), opts: []Option{WithRegexMatch()}, - expCode: http.StatusBadGateway, + expCode: http.StatusBadRequest, golden: "alerts_invalid_upstream_response.golden", }, } { @@ -650,7 +653,10 @@ func TestAlerts(t *testing.T) { t.Fatalf("expected status code %d, got %d", tc.expCode, resp.StatusCode) } - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("expected no error, got %s", err) + } if resp.StatusCode != http.StatusOK { golden.Assert(t, string(body), tc.golden) return