diff --git a/request_logger.go b/request_logger.go index 88b4c19..53edc4f 100644 --- a/request_logger.go +++ b/request_logger.go @@ -3,6 +3,7 @@ package xecho import ( "math" "net/http" + "strings" "time" "github.com/labstack/echo" @@ -18,6 +19,10 @@ func RequestLoggerMiddleware(timeFn TimeProvider) echo.MiddlewareFunc { } func RequestLogger(c *Context, next echo.HandlerFunc, time TimeProvider) error { + request := c.Request() + if request.URL.Path == "/health" && strings.Contains(request.UserAgent(), "HealthChecker") { + return next(c) + } before := time() lrw := &statefulResponseWriter{ResponseWriter: c.Response().Writer} c.Response().Writer = lrw @@ -25,12 +30,12 @@ func RequestLogger(c *Context, next echo.HandlerFunc, time TimeProvider) error { after := time() logger, ok := c.Logger().(*Logger) if !ok { - c.Logger().Infof("[%s] %s %d", c.Request().Method, c.Path(), lrw.statusCode) + c.Logger().Infof("[%s] %s %d", request.Method, c.Path(), lrw.statusCode) return err } logger. WithFields(createMap(c, after.Sub(before), lrw, err)). - Infof("[%s] %s %d", c.Request().Method, c.Path(), lrw.statusCode) + Infof("[%s] %s %d", request.Method, c.Path(), lrw.statusCode) return err } diff --git a/request_logger_test.go b/request_logger_test.go index f8bf14c..ee08009 100644 --- a/request_logger_test.go +++ b/request_logger_test.go @@ -89,6 +89,28 @@ func TestRequestLogger_LogTest(t *testing.T) { } +func TestRequestLogger_HealthNoLogTest(t *testing.T) { + buffer := &bytes.Buffer{} + URL, _ := url.Parse("http://somedomain/health") + writer, _ := NewWriter() + context := createTestContext(writer, URL, buffer) + now := time.Now() + provider := testTimeProvider{calls: []time.Time{now, now.Add(755 * time.Millisecond)}} + nextCalled := false + nextPtr := &nextCalled + var next echo.HandlerFunc = func(context echo.Context) error { + *nextPtr = true + context.Response().WriteHeader(200) + return nil + } + err := RequestLoggerMiddleware(provider.Next)(next)(context) + assert.Nil(t, err) + fields := getLogFields(buffer, err, t) + + assert.Equal(t, len(fields), 0) + +} + func createTestContext(writer *responseWriter, URL *url.URL, buffer *bytes.Buffer) *Context { return &Context{ Context: &testContext{ @@ -112,12 +134,11 @@ func createResponse(writer *responseWriter) *echo.Response { Size: 150, } } - func createRequest(URL *url.URL) *http.Request { return &http.Request{ Method: "GET", URL: URL, - Header: http.Header{"Correlation-Id": []string{"set_one"}}, + Header: http.Header{"Correlation-Id": []string{"set_one"}, "User-Agent": []string{"ELB-HealthChecker/2.0"}}, ContentLength: 34567, Host: "this.is.a.domain", RemoteAddr: "", @@ -128,8 +149,10 @@ func createRequest(URL *url.URL) *http.Request { func getLogFields(buffer *bytes.Buffer, err error, t *testing.T) logrus.Fields { fields := logrus.Fields{} logStatement := buffer.Bytes() - err = json.Unmarshal(logStatement, &fields) - assert.Nil(t, err) + if len(logStatement) > 0 { + err = json.Unmarshal(logStatement, &fields) + assert.Nil(t, err) + } return fields }