From fe7e2a1adf9b84a44991f03fadfec730d6934b3c Mon Sep 17 00:00:00 2001 From: achettyiitr Date: Sun, 4 Aug 2024 18:29:58 +0530 Subject: [PATCH] chore: enforce max limit for webhook --- gateway/webhook/setup.go | 3 ++ gateway/webhook/webhook.go | 8 +++++ gateway/webhook/webhook_test.go | 62 ++++++++++++++++++++++++++++----- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/gateway/webhook/setup.go b/gateway/webhook/setup.go index f56df783b6..c5aa1ecaf0 100644 --- a/gateway/webhook/setup.go +++ b/gateway/webhook/setup.go @@ -17,6 +17,7 @@ import ( "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" + gwstats "github.com/rudderlabs/rudder-server/gateway/internal/stats" gwtypes "github.com/rudderlabs/rudder-server/gateway/internal/types" "github.com/rudderlabs/rudder-server/gateway/webhook/model" @@ -56,6 +57,8 @@ func Setup(gwHandle Gateway, transformerFeaturesService transformer.FeaturesServ maxTransformerProcess := config.GetIntVar(64, 1, "Gateway.webhook.maxTransformerProcess") // Parse all query params from sources mentioned in this list webhook.config.sourceListForParsingParams = config.GetStringSliceVar([]string{"Shopify", "adjust"}, "Gateway.webhook.sourceListForParsingParams") + // Maximum request size to gateway + webhook.config.maxReqSize = config.GetReloadableIntVar(4000, 1024, "Gateway.maxReqSizeInKB") webhook.config.forwardGetRequestForSrcMap = lo.SliceToMap( config.GetStringSliceVar([]string{"adjust"}, "Gateway.webhook.forwardGetRequestForSrcs"), diff --git a/gateway/webhook/webhook.go b/gateway/webhook/webhook.go index 2a4e96217c..14bbb971ce 100644 --- a/gateway/webhook/webhook.go +++ b/gateway/webhook/webhook.go @@ -66,6 +66,7 @@ type HandleT struct { backgroundCancel context.CancelFunc config struct { + maxReqSize config.ValueLoader[int] webhookBatchTimeout config.ValueLoader[time.Duration] maxWebhookBatchSize config.ValueLoader[int] sourceListForParsingParams []string @@ -334,6 +335,13 @@ func (bt *batchWebhookTransformerT) batchTransformLoop() { req.done <- transformerResponse{Err: response.GetStatus(response.InvalidJSON)} continue } + if len(body) > bt.webhook.config.maxReqSize.Load() { + req.done <- transformerResponse{ + StatusCode: response.GetErrorStatusCode(response.RequestBodyTooLarge), + Err: response.GetStatus(response.RequestBodyTooLarge), + } + continue + } payload, err := sourceTransformAdapter.getTransformerEvent(req.authContext, body) if err != nil { diff --git a/gateway/webhook/webhook_test.go b/gateway/webhook/webhook_test.go index 528bdbab5e..aace19adab 100644 --- a/gateway/webhook/webhook_test.go +++ b/gateway/webhook/webhook_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -15,6 +16,8 @@ import ( "go.uber.org/mock/gomock" + "github.com/rudderlabs/rudder-go-kit/bytesize" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,6 +25,7 @@ import ( "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-go-kit/stats/memstats" + gwStats "github.com/rudderlabs/rudder-server/gateway/internal/stats" gwtypes "github.com/rudderlabs/rudder-server/gateway/internal/types" mockWebhook "github.com/rudderlabs/rudder-server/gateway/mocks" @@ -61,11 +65,11 @@ type mockSourceTransformAdapter struct { url string } -func (v0 *mockSourceTransformAdapter) getTransformerEvent(authCtx *gwtypes.AuthRequestContext, body []byte) ([]byte, error) { +func (v0 *mockSourceTransformAdapter) getTransformerEvent(_ *gwtypes.AuthRequestContext, body []byte) ([]byte, error) { return body, nil } -func (v0 *mockSourceTransformAdapter) getTransformerURL(sourceType string) (string, error) { +func (v0 *mockSourceTransformAdapter) getTransformerURL(string) (string, error) { return v0.url, nil } @@ -77,13 +81,53 @@ func getMockSourceTransformAdapterFunc(url string) func(ctx context.Context) (so } } +func TestWebhookMaxRequestSize(t *testing.T) { + initWebhook() + + ctrl := gomock.NewController(t) + + mockGW := mockWebhook.NewMockGateway(ctrl) + mockGW.EXPECT().TrackRequestMetrics(gomock.Any()).Times(1) + mockGW.EXPECT().NewSourceStat(gomock.Any(), gomock.Any()).Return(&gwStats.SourceStat{}).Times(1) + + mockTransformerFeaturesService := mock_features.NewMockFeaturesService(ctrl) + + maxReqSizeInKB := 1 + + webhookHandler := Setup(mockGW, mockTransformerFeaturesService, stats.NOP, func(bt *batchWebhookTransformerT) { + bt.sourceTransformAdapter = func(ctx context.Context) (sourceTransformAdapter, error) { + return &mockSourceTransformAdapter{}, nil + } + }) + webhookHandler.config.maxReqSize = config.SingleValueLoader(maxReqSizeInKB) + t.Cleanup(func() { + _ = webhookHandler.Shutdown() + }) + + webhookHandler.Register(sourceDefName) + + payload := fmt.Sprintf(`{"hello":"world", "data": %q}`, strings.Repeat("a", 2*maxReqSizeInKB*int(bytesize.KB))) + require.Greater(t, len(payload), maxReqSizeInKB*int(bytesize.KB)) + + req := httptest.NewRequest(http.MethodPost, "/v1/webhook", bytes.NewBufferString(payload)) + resp := httptest.NewRecorder() + + reqCtx := context.WithValue(req.Context(), gwtypes.CtxParamCallType, "webhook") + reqCtx = context.WithValue(reqCtx, gwtypes.CtxParamAuthRequestContext, &gwtypes.AuthRequestContext{ + SourceDefName: sourceDefName, + }) + + webhookHandler.RequestHandler(resp, req.WithContext(reqCtx)) + require.Equal(t, http.StatusRequestEntityTooLarge, resp.Result().StatusCode) +} + func TestWebhookBlockTillFeaturesAreFetched(t *testing.T) { initWebhook() ctrl := gomock.NewController(t) mockGW := mockWebhook.NewMockGateway(ctrl) mockTransformerFeaturesService := mock_features.NewMockFeaturesService(ctrl) mockTransformerFeaturesService.EXPECT().Wait().Return(make(chan struct{})).Times(1) - webhookHandler := Setup(mockGW, mockTransformerFeaturesService, stats.Default) + webhookHandler := Setup(mockGW, mockTransformerFeaturesService, stats.NOP) mockGW.EXPECT().TrackRequestMetrics(gomock.Any()).Times(1) mockGW.EXPECT().NewSourceStat(gomock.Any(), gomock.Any()).Return(&gwStats.SourceStat{}).Times(1) @@ -112,7 +156,7 @@ func TestWebhookRequestHandlerWithTransformerBatchGeneralError(t *testing.T) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, sampleError, http.StatusBadRequest) })) - webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) { + webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) { bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL) }) @@ -157,7 +201,7 @@ func TestWebhookRequestHandlerWithTransformerBatchPayloadLengthMismatchError(t * respBody, _ := json.Marshal(responses) _, _ = w.Write(respBody) })) - webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) { + webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) { bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL) }) @@ -200,7 +244,7 @@ func TestWebhookRequestHandlerWithTransformerRequestError(t *testing.T) { respBody, _ := json.Marshal(responses) _, _ = w.Write(respBody) })) - webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) { + webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) { bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL) }) @@ -243,7 +287,7 @@ func TestWebhookRequestHandlerWithOutputToSource(t *testing.T) { respBody, _ := json.Marshal(responses) _, _ = w.Write(respBody) })) - webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) { + webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) { bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL) }) mockGW.EXPECT().TrackRequestMetrics("").Times(1) @@ -285,7 +329,7 @@ func TestWebhookRequestHandlerWithOutputToGateway(t *testing.T) { respBody, _ := json.Marshal(responses) _, _ = w.Write(respBody) })) - webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) { + webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) { bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL) }) mockGW.EXPECT().TrackRequestMetrics("").Times(1) @@ -332,7 +376,7 @@ func TestWebhookRequestHandlerWithOutputToGatewayAndSource(t *testing.T) { respBody, _ := json.Marshal(responses) _, _ = w.Write(respBody) })) - webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) { + webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) { bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL) }) mockGW.EXPECT().TrackRequestMetrics("").Times(1)