diff --git a/libbeat/esleg/eslegclient/bulkapi.go b/libbeat/esleg/eslegclient/bulkapi.go index c70fad12a14..965215fe310 100644 --- a/libbeat/esleg/eslegclient/bulkapi.go +++ b/libbeat/esleg/eslegclient/bulkapi.go @@ -75,8 +75,6 @@ func (conn *Connection) Bulk( return 0, nil, nil } - mergedParams := mergeParams(conn.ConnectionSettings.Parameters, params) - enc := conn.Encoder enc.Reset() if err := bulkEncode(conn.log, enc, body); err != nil { @@ -84,6 +82,8 @@ func (conn *Connection) Bulk( return 0, nil, err } + mergedParams := mergeParams(conn.ConnectionSettings.Parameters, params) + requ, err := newBulkRequest(conn.URL, index, docType, mergedParams, enc) if err != nil { apm.CaptureError(ctx, err).Send() @@ -105,8 +105,6 @@ func (conn *Connection) SendMonitoringBulk( return nil, nil } - mergedParams := mergeParams(conn.ConnectionSettings.Parameters, params) - enc := conn.Encoder enc.Reset() if err := bulkEncode(conn.log, enc, body); err != nil { @@ -119,6 +117,8 @@ func (conn *Connection) SendMonitoringBulk( } } + mergedParams := mergeParams(conn.ConnectionSettings.Parameters, params) + requ, err := newMonitoringBulkRequest(conn.GetVersion(), conn.URL, mergedParams, enc) if err != nil { return nil, err diff --git a/libbeat/esleg/eslegclient/bulkapi_mock_test.go b/libbeat/esleg/eslegclient/bulkapi_mock_test.go index 44430eac614..3d4e33c4271 100644 --- a/libbeat/esleg/eslegclient/bulkapi_mock_test.go +++ b/libbeat/esleg/eslegclient/bulkapi_mock_test.go @@ -24,11 +24,13 @@ import ( "errors" "fmt" "net/http" + "net/url" "os" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/elastic/beats/v7/libbeat/logp" ) @@ -151,21 +153,7 @@ func TestOneHost503Resp_Bulk(t *testing.T) { } func TestEnforceParameters(t *testing.T) { - client, _ := NewConnection(ConnectionSettings{ - Parameters: map[string]string{"hello": "world"}, - URL: "http://localhost", - Timeout: 0, - }) - client.Encoder = NewJSONEncoder(nil, false) - - client.HTTP = &reqInspector{ - assert: func(req *http.Request) (*http.Response, error) { - assert.Equal(t, "world", req.URL.Query().Get("hello")) - // short circuit others logic. - return nil, errors.New("shortcut") - }, - } - + // Prepare the test bulk request. index := "what" ops := []map[string]interface{}{ @@ -186,11 +174,84 @@ func TestEnforceParameters(t *testing.T) { body = append(body, op) } - params := map[string]string{ - "refresh": "true", + tests := map[string]struct { + preconfigured map[string]string + reqParams map[string]string + expected map[string]string + }{ + "Preconfigured parameters are applied to bulk requests": { + preconfigured: map[string]string{ + "hello": "world", + }, + expected: map[string]string{ + "hello": "world", + }, + }, + "Preconfigured and local parameters are merged": { + preconfigured: map[string]string{ + "hello": "world", + }, + reqParams: map[string]string{ + "foo": "bar", + }, + expected: map[string]string{ + "hello": "world", + "foo": "bar", + }, + }, + "Local parameters only": { + reqParams: map[string]string{ + "foo": "bar", + }, + expected: map[string]string{ + "foo": "bar", + }, + }, + "no parameters": { + expected: map[string]string{}, + }, + "Local overrides preconfigured parameters": { + preconfigured: map[string]string{ + "foo": "world", + }, + reqParams: map[string]string{ + "foo": "bar", + }, + expected: map[string]string{ + "foo": "bar", + }, + }, } - client.Bulk(context.Background(), index, "type1", params, body) + for name, test := range tests { + t.Run(name, func(t *testing.T) { + client, _ := NewConnection(ConnectionSettings{ + Parameters: test.preconfigured, + URL: "http://localhost", + Timeout: 0, + }) + + client.Encoder = NewJSONEncoder(nil, false) + + var recParams url.Values + errShort := errors.New("shortcut") + + client.HTTP = &reqInspector{ + assert: func(req *http.Request) (*http.Response, error) { + recParams = req.URL.Query() + return nil, errShort + }, + } + + _, _, err := client.Bulk(context.Background(), index, "type1", test.reqParams, body) + require.Equal(t, errShort, err) + require.Equal(t, len(recParams), len(test.expected)) + + for k, v := range test.expected { + assert.Equal(t, recParams.Get(k), v) + } + }) + } } type reqInspector struct {