Skip to content

Commit

Permalink
fix: source payloads with query params to be allowed without body (#4677
Browse files Browse the repository at this point in the history
)

* fix: source payloads with query params to be allowed without body

* fix: adding adjust in sourceListForParsingParams

* fix: small edit

* fix: review comment addressed

* fix: query params not added with iteration (#4678)

* fix: review comment addressed

* fix: edit with sjson

* fix: refactor and adding unit test case (#4679)


* fix: refactor and addition of test cases

* fix: small change

* Apply suggestions from code review

Co-authored-by: Akash Chetty <achetty.iitr@gmail.com>

* fix: review comments addressed

* fix: small fix

* Apply suggestions from code review

Co-authored-by: Akash Chetty <achetty.iitr@gmail.com>

* fix: review comments addressed

* fix: review comments addressed

* fix: logs added

* Revert "fix: logs added"

This reverts commit e3b4afe.

---------

Co-authored-by: Akash Chetty <achetty.iitr@gmail.com>
  • Loading branch information
shrouti1507 and achettyiitr authored May 23, 2024
1 parent 96a0180 commit affd6bc
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 24 deletions.
1 change: 1 addition & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Gateway:
maxRetryTime: 10s
sourceListForParsingParams:
- shopify
- adjust
EventSchemas:
enableEventSchemasFeature: false
syncInterval: 240s
Expand Down
56 changes: 32 additions & 24 deletions gateway/webhook/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/hashicorp/go-retryablehttp"
"github.com/samber/lo"
"github.com/tidwall/sjson"

"github.com/rudderlabs/rudder-go-kit/config"
kithttputil "github.com/rudderlabs/rudder-go-kit/httputil"
Expand Down Expand Up @@ -261,6 +262,35 @@ func (webhook *HandleT) batchRequests(sourceDef string, requestQ chan *webhookT)
}
}

func prepareRequestBody(req *http.Request, includeQueryParams bool, sourceType string, sourceListForParsingParams []string) ([]byte, error) {
defer func() {
_ = req.Body.Close()
}()

body, err := io.ReadAll(req.Body)
if err != nil {
return nil, errors.New(response.RequestBodyReadFailed)
}

if len(body) == 0 {
body = []byte("{}") // If body is empty, set it to an empty JSON object
}

if includeQueryParams && slices.Contains(sourceListForParsingParams, strings.ToLower(sourceType)) {
queryParams := req.URL.Query()
paramsBytes, err := json.Marshal(queryParams)
if err != nil {
return nil, errors.New(response.ErrorInMarshal)
}

body, err = sjson.SetRawBytes(body, "query_parameters", paramsBytes)
if err != nil {
return nil, errors.New(response.InvalidJSON)
}
}
return body, nil
}

// TODO : return back immediately for blank request body. its waiting till timeout
func (bt *batchWebhookTransformerT) batchTransformLoop() {
for breq := range bt.webhook.batchRequestQ {
Expand Down Expand Up @@ -289,33 +319,11 @@ func (bt *batchWebhookTransformerT) batchTransformLoop() {
var payloadArr [][]byte
var webRequests []*webhookT
for _, req := range breq.batchRequest {
body, err := io.ReadAll(req.request.Body)
_ = req.request.Body.Close()

body, err := prepareRequestBody(req.request, slices.Contains(bt.webhook.config.sourceListForParsingParams, breq.sourceType), breq.sourceType, bt.webhook.config.sourceListForParsingParams)
if err != nil {
req.done <- transformerResponse{Err: response.GetStatus(response.RequestBodyReadFailed)}
req.done <- transformerResponse{Err: response.GetStatus(err.Error())}
continue
}

if slices.Contains(bt.webhook.config.sourceListForParsingParams, strings.ToLower(breq.sourceType)) {
queryParams := req.request.URL.Query()
paramsBytes, err := json.Marshal(queryParams)
if err != nil {
req.done <- transformerResponse{Err: response.GetStatus(response.ErrorInMarshal)}
continue
}

closingBraceIdx := bytes.LastIndexByte(body, '}')
if closingBraceIdx == -1 {
req.done <- transformerResponse{Err: response.GetStatus(response.InvalidJSON)}
continue
}
appendData := []byte(`, "query_parameters": `)
appendData = append(appendData, paramsBytes...)
body = append(body[:closingBraceIdx], appendData...)
body = append(body, '}')
}

if !json.Valid(body) {
req.done <- transformerResponse{Err: response.GetStatus(response.InvalidJSON)}
continue
Expand Down
66 changes: 66 additions & 0 deletions gateway/webhook/webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"testing/iotest"

"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -433,3 +435,67 @@ func TestRecordWebhookErrors(t *testing.T) {
})
require.EqualValues(t, m.LastValue(), 1)
}

func TestPrepareRequestBody(t *testing.T) {
createRequest := func(method, target string, body io.Reader, params map[string]string) *http.Request {
r := httptest.NewRequest(method, target, body)
q := r.URL.Query()
for k, v := range params {
q.Add(k, v)
}
r.URL.RawQuery = q.Encode()
return r
}

testCases := []struct {
name string
req *http.Request
includeQueryParams bool
wantError bool
expectedResponse []byte
}{
{
name: "Empty request body with no query parameters",
req: createRequest(http.MethodPost, "http://example.com", nil, nil),
includeQueryParams: false,
expectedResponse: []byte("{}"),
},
{
name: "Empty request body with query parameters",
req: createRequest(http.MethodPost, "http://example.com", nil, map[string]string{"key": "value"}),
includeQueryParams: true,
expectedResponse: []byte(`{"query_parameters":{"key":["value"]}}`),
},
{
name: "Error reading request body",
req: createRequest(http.MethodPost, "http://example.com", iotest.ErrReader(errors.New("some error")), nil),
includeQueryParams: false,
wantError: true,
expectedResponse: nil,
},
{
name: "Some payload with no query parameters",
req: createRequest(http.MethodPost, "http://example.com", strings.NewReader(`{"key":"value"}`), nil),
includeQueryParams: false,
expectedResponse: []byte(`{"key":"value"}`),
},
{
name: "Some payload with query parameters",
req: createRequest(http.MethodPost, "http://example.com", strings.NewReader(`{"key1":"value1"}`), map[string]string{"key2": "value2"}),
includeQueryParams: true,
expectedResponse: []byte(`{"key1":"value1","query_parameters":{"key2":["value2"]}}`),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := prepareRequestBody(tc.req, tc.includeQueryParams, "webhook", []string{"webhook"})
if tc.wantError {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.expectedResponse, result)
})
}
}

0 comments on commit affd6bc

Please sign in to comment.