diff --git a/chiware/stats.go b/chiware/stats.go index 9f8bd824..7adac563 100644 --- a/chiware/stats.go +++ b/chiware/stats.go @@ -12,7 +12,28 @@ import ( "github.com/rudderlabs/rudder-go-kit/stats" ) -func StatMiddleware(ctx context.Context, router chi.Router, s stats.Stats, component string) func(http.Handler) http.Handler { +type config struct { + redactUnknownPaths bool +} + +type Option func(*config) + +// RedactUnknownPaths sets the redactUnknownPaths flag. +// If set to true, the path will be redacted if the route is not found. +// If set to false, the path will be used as is. +func RedactUnknownPaths(redactUnknownPaths bool) Option { + return func(c *config) { + c.redactUnknownPaths = redactUnknownPaths + } +} + +func StatMiddleware(ctx context.Context, router chi.Router, s stats.Stats, component string, options ...Option) func(http.Handler) http.Handler { + conf := config{ + redactUnknownPaths: true, + } + for _, option := range options { + option(&conf) + } var concurrentRequests int32 activeClientCount := s.NewStat(fmt.Sprintf("%s.concurrent_requests_count", component), stats.GaugeType) go func() { @@ -33,6 +54,9 @@ func StatMiddleware(ctx context.Context, router chi.Router, s stats.Stats, compo if path := chi.RouteContext(r.Context()).RoutePattern(); path != "" { return path } + if conf.redactUnknownPaths { + return "/redacted" + } return r.URL.Path } return func(next http.Handler) http.Handler { @@ -43,7 +67,6 @@ func StatMiddleware(ctx context.Context, router chi.Router, s stats.Stats, compo defer atomic.AddInt32(&concurrentRequests, -1) next.ServeHTTP(sw, r) - s.NewSampledTaggedStat( fmt.Sprintf("%s.response_time", component), stats.TimerType, diff --git a/chiware/stats_test.go b/chiware/stats_test.go index 2a3d5135..a340d493 100644 --- a/chiware/stats_test.go +++ b/chiware/stats_test.go @@ -18,7 +18,7 @@ import ( func TestStatsMiddleware(t *testing.T) { component := "test" - testCase := func(expectedStatusCode int, pathTemplate, requestPath, expectedMethod string) func(t *testing.T) { + testCase := func(expectedStatusCode int, pathTemplate, requestPath, expectedMethod string, options ...chiware.Option) func(t *testing.T) { return func(t *testing.T) { ctrl := gomock.NewController(t) mockStats := mock_stats.NewMockStats(ctrl) @@ -40,7 +40,7 @@ func TestStatsMiddleware(t *testing.T) { defer cancel() router := chi.NewRouter() router.Use( - chiware.StatMiddleware(ctx, router, mockStats, component), + chiware.StatMiddleware(ctx, router, mockStats, component, options...), ) router.MethodFunc(expectedMethod, pathTemplate, handler) @@ -52,6 +52,8 @@ func TestStatsMiddleware(t *testing.T) { } t.Run("template with param in path", testCase(http.StatusNotFound, "/v1/{param}", "/v1/abc", "GET")) - t.Run("template without param in path", testCase(http.StatusNotFound, "/v1/some-other/key", "/v1/some-other/key", "GET")) + t.Run("template with unknown path ", testCase(http.StatusNotFound, "/a/b/c", "/a/b/c", "GET", chiware.RedactUnknownPaths(false))) + t.Run("template with unknown path ", testCase(http.StatusNotFound, "/redacted", "/a/b/c", "GET", chiware.RedactUnknownPaths(true))) + t.Run("template with unknown path ", testCase(http.StatusNotFound, "/redacted", "/a/b/c", "GET")) } diff --git a/go.mod b/go.mod index 482fb322..d0f87b33 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/go-chi/chi/v5 v5.0.8 github.com/go-redis/redis/v8 v8.11.5 github.com/golang/mock v1.6.0 - github.com/gorilla/mux v1.8.0 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 github.com/ory/dockertest/v3 v3.10.0 diff --git a/go.sum b/go.sum index 0217744c..68773550 100644 --- a/go.sum +++ b/go.sum @@ -199,8 +199,6 @@ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 h1:BZHcxBETFHIdVyhyEfOvn/RdU/QGdLI4y34qQGjGWO0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= diff --git a/gorillaware/stats.go b/gorillaware/stats.go deleted file mode 100644 index 3767a432..00000000 --- a/gorillaware/stats.go +++ /dev/null @@ -1,81 +0,0 @@ -package gorillaware - -import ( - "context" - "fmt" - "net/http" - "strconv" - "sync/atomic" - "time" - - "github.com/gorilla/mux" - "github.com/rudderlabs/rudder-go-kit/stats" -) - -func StatMiddleware(ctx context.Context, router *mux.Router, s stats.Stats, component string) func(http.Handler) http.Handler { - var concurrentRequests int32 - activeClientCount := s.NewStat(fmt.Sprintf("%s.concurrent_requests_count", component), stats.GaugeType) - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(10 * time.Second): - activeClientCount.Gauge(atomic.LoadInt32(&concurrentRequests)) - } - } - }() - - // getPath retrieves the path from the request. - // The matched route's template is used if a match is found, - // otherwise the request's URL path is used instead. - getPath := func(r *http.Request) string { - var match mux.RouteMatch - if router.Match(r, &match) { - if path, err := match.Route.GetPathTemplate(); err == nil { - return path - } - } - return r.URL.Path - } - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sw := newStatusCapturingWriter(w) - path := getPath(r) - start := time.Now() - atomic.AddInt32(&concurrentRequests, 1) - defer atomic.AddInt32(&concurrentRequests, -1) - - next.ServeHTTP(sw, r) - - s.NewSampledTaggedStat( - fmt.Sprintf("%s.response_time", component), - stats.TimerType, - map[string]string{ - "reqType": path, - "method": r.Method, - "code": strconv.Itoa(sw.status), - }).Since(start) - }) - } -} - -// newStatusCapturingWriter returns a new, properly initialized statusCapturingWriter -func newStatusCapturingWriter(w http.ResponseWriter) *statusCapturingWriter { - return &statusCapturingWriter{ - ResponseWriter: w, - status: http.StatusOK, - } -} - -// statusCapturingWriter is a response writer decorator that captures the status code. -type statusCapturingWriter struct { - http.ResponseWriter - status int -} - -// WriteHeader override the http.ResponseWriter's `WriteHeader` method -func (w *statusCapturingWriter) WriteHeader(status int) { - w.status = status - w.ResponseWriter.WriteHeader(status) -} diff --git a/gorillaware/stats_test.go b/gorillaware/stats_test.go deleted file mode 100644 index 2cb9f29c..00000000 --- a/gorillaware/stats_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package gorillaware_test - -import ( - "context" - "fmt" - "net/http" - "net/http/httptest" - "strconv" - "testing" - - "github.com/golang/mock/gomock" - "github.com/gorilla/mux" - "github.com/rudderlabs/rudder-go-kit/gorillaware" - "github.com/rudderlabs/rudder-go-kit/stats" - "github.com/rudderlabs/rudder-go-kit/stats/mock_stats" - "github.com/stretchr/testify/require" -) - -func TestStatsMiddleware(t *testing.T) { - component := "test" - testCase := func(expectedStatusCode int, pathTemplate, requestPath, expectedReqType, expectedMethod string) func(t *testing.T) { - return func(t *testing.T) { - ctrl := gomock.NewController(t) - mockStats := mock_stats.NewMockStats(ctrl) - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(expectedStatusCode) - }) - - measurement := mock_stats.NewMockMeasurement(ctrl) - mockStats.EXPECT().NewStat(fmt.Sprintf("%s.concurrent_requests_count", component), stats.GaugeType).Return(measurement).Times(1) - mockStats.EXPECT().NewSampledTaggedStat(fmt.Sprintf("%s.response_time", component), stats.TimerType, - map[string]string{ - "reqType": expectedReqType, - "method": expectedMethod, - "code": strconv.Itoa(expectedStatusCode), - }).Return(measurement).Times(1) - measurement.EXPECT().Since(gomock.Any()).Times(1) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - router := mux.NewRouter() - router.Use( - gorillaware.StatMiddleware(ctx, router, mockStats, component), - ) - router.HandleFunc(pathTemplate, handler).Methods(expectedMethod) - - response := httptest.NewRecorder() - request := httptest.NewRequest("GET", "http://example.com"+requestPath, http.NoBody) - router.ServeHTTP(response, request) - require.Equal(t, expectedStatusCode, response.Code) - } - } - - t.Run("template with param in path", testCase(http.StatusNotFound, "/v1/{param}", "/v1/abc", "/v1/{param}", "GET")) - - t.Run("template without param in path", testCase(http.StatusNotFound, "/v1/some-other/key", "/v1/some-other/key", "/v1/some-other/key", "GET")) -}