Skip to content

Commit

Permalink
[receiver/splunkehecreceiver] Align acking behavior with that of Splu… (
Browse files Browse the repository at this point in the history
#32996)

**Description:**
- Make the channelID header case-insensitive
- Make hecreceiver endpoints able to extract channelID from query params

**Link to tracking Issue:** #32995
  • Loading branch information
zpzhuSplunk authored May 13, 2024
1 parent 41d853d commit 43aff69
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 9 deletions.
27 changes: 27 additions & 0 deletions .chloggen/hec-ack-alignment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Use this changelog template to create an entry for release notes.

# One of 'breaking', 'deprecation', 'new_component', 'enhancement', 'bug_fix'
change_type: enhancement

# The name of the component, or a single word describing the area of concern, (e.g. filelogreceiver)
component: splunkhecreceiver

# A brief description of the change. Surround your text with quotes ("") if it needs to start with a backtick (`).
note: Make the channelID header check case-insensitive and allow hecreceiver endpoints able to extract channelID from query params

# Mandatory: One or more tracking issues related to the change. You can use the PR number here if no issue exists.
issues: [32995]

# (Optional) One or more lines of additional information to render under the primary note.
# These lines will be padded with 2 spaces and then inserted directly into the document.
# Use pipe (|) for multiline entries.
subtext:

# If your change doesn't affect end users or the exported elements of any package,
# you should instead start your pull request title with [chore] or use the "Skip Changelog" label.
# Optional: The change log or logs in which this entry should be included.
# e.g. '[user]' or '[user, api]'
# Include 'user' if the change is relevant to end users.
# Include 'api' if there is a change to a library API.
# Default: '[user]'
change_logs: [user]
21 changes: 15 additions & 6 deletions receiver/splunkhecreceiver/receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func (r *splunkReceiver) handleAck(resp http.ResponseWriter, req *http.Request)

var channelID string
var extracted bool
if channelID, extracted = r.extractChannelHeader(req); extracted {
if channelID, extracted = r.extractChannel(req); extracted {
if channelErr := r.validateChannelHeader(channelID); channelErr != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, []byte(channelErr.Error()), 0, channelErr)
return
Expand Down Expand Up @@ -327,7 +327,7 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques

var channelID string
var extracted bool
if channelID, extracted = r.extractChannelHeader(req); extracted {
if channelID, extracted = r.extractChannel(req); extracted {
if channelErr := r.validateChannelHeader(channelID); channelErr != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, []byte(channelErr.Error()), 0, channelErr)
return
Expand Down Expand Up @@ -391,9 +391,18 @@ func (r *splunkReceiver) handleRawReq(resp http.ResponseWriter, req *http.Reques
}
}

func (r *splunkReceiver) extractChannelHeader(req *http.Request) (string, bool) {
if headers, ok := req.Header[splunk.HTTPSplunkChannelHeader]; ok {
return headers[0], true
func (r *splunkReceiver) extractChannel(req *http.Request) (string, bool) {
// check header
for k, v := range req.Header {
if strings.EqualFold(k, splunk.HTTPSplunkChannelHeader) {
return strings.ToUpper(v[0]), true
}
}
// check query param
for k, v := range req.URL.Query() {
if strings.EqualFold(k, "channel") {
return strings.ToUpper(v[0]), true
}
}

return "", false
Expand Down Expand Up @@ -434,7 +443,7 @@ func (r *splunkReceiver) handleReq(resp http.ResponseWriter, req *http.Request)
return
}

channelID, extracted := r.extractChannelHeader(req)
channelID, extracted := r.extractChannel(req)
if extracted {
if channelErr := r.validateChannelHeader(channelID); channelErr != nil {
r.failRequest(ctx, resp, http.StatusBadRequest, []byte(channelErr.Error()), 0, channelErr)
Expand Down
158 changes: 155 additions & 3 deletions receiver/splunkhecreceiver/receiver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,65 @@ func Test_splunkhecReceiver_handleAck(t *testing.T) {
}}, body)
},
},
{
name: "happy_path_with_case_insensitive_header",
req: func() *http.Request {
msgBytes, err := json.Marshal(buildSplunkHecAckMsg([]uint64{1, 2, 3}))
require.NoError(t, err)
req := httptest.NewRequest("POST", "http://localhost/ack", bytes.NewReader(msgBytes))
req.Header.Set("x-splunk-request-channel", "fbd3036f-0f1c-4e98-b71c-d4cd61213f90")
return req
}(),
setupMockAckExtension: func() component.Component {
return &mockAckExtension{
queryAcks: func(_ string, _ []uint64) map[uint64]bool {
return map[uint64]bool{
1: true,
2: false,
3: true,
}
},
}
},
assertResponse: func(t *testing.T, resp *http.Response, body any) {
status := resp.StatusCode
assert.Equal(t, http.StatusOK, status)
assert.Equal(t, map[string]any{"acks": map[string]any{
"1": true,
"2": false,
"3": true,
}}, body)
},
},
{
name: "happy_path_with_query_param",
req: func() *http.Request {
msgBytes, err := json.Marshal(buildSplunkHecAckMsg([]uint64{1, 2, 3}))
require.NoError(t, err)
req := httptest.NewRequest("POST", "http://localhost/ack?channel=fbd3036f-0f1c-4e98-b71c-d4cd61213f90", bytes.NewReader(msgBytes))
return req
}(),
setupMockAckExtension: func() component.Component {
return &mockAckExtension{
queryAcks: func(_ string, _ []uint64) map[uint64]bool {
return map[uint64]bool{
1: true,
2: false,
3: true,
}
},
}
},
assertResponse: func(t *testing.T, resp *http.Response, body any) {
status := resp.StatusCode
assert.Equal(t, http.StatusOK, status)
assert.Equal(t, map[string]any{"acks": map[string]any{
"1": true,
"2": false,
"3": true,
}}, body)
},
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -1340,7 +1399,7 @@ func Test_splunkhecReceiver_handleRawReq_WithAck(t *testing.T) {
config.Ack.Extension = &id
currentTime := float64(time.Now().UnixNano()) / 1e6
splunkMsg := buildSplunkHecMsg(currentTime, 3)

currAckID := uint64(0)
tests := []struct {
name string
req *http.Request
Expand Down Expand Up @@ -1410,13 +1469,57 @@ func Test_splunkhecReceiver_handleRawReq_WithAck(t *testing.T) {
setupMockAckExtension: func() component.Component {
return &mockAckExtension{
processEvent: func(_ string) (ackID uint64) {
return uint64(1)
currAckID++
return currAckID
},
ack: func(_ string, _ uint64) {},
}
},
assertResponse: func(t *testing.T, resp *http.Response, body any) {
assertHecSuccessResponseWithAckID(t, resp, body, 1)
assertHecSuccessResponseWithAckID(t, resp, body, currAckID)
},
},
{
name: "happy_path_with_case_insensitive_header",
req: func() *http.Request {
msgBytes, err := json.Marshal(splunkMsg)
require.NoError(t, err)
req := httptest.NewRequest("POST", "http://localhost/foo", bytes.NewReader(msgBytes))
req.Header.Set("x-splunk-request-channel", "fbd3036f-0f1c-4e98-b71c-d4cd61213f90")
return req
}(),
setupMockAckExtension: func() component.Component {
return &mockAckExtension{
processEvent: func(_ string) (ackID uint64) {
currAckID++
return currAckID
},
ack: func(_ string, _ uint64) {},
}
},
assertResponse: func(t *testing.T, resp *http.Response, body any) {
assertHecSuccessResponseWithAckID(t, resp, body, currAckID)
},
},
{
name: "happy_path_with_query_param",
req: func() *http.Request {
msgBytes, err := json.Marshal(splunkMsg)
require.NoError(t, err)
req := httptest.NewRequest("POST", "http://localhost/foo?Channel=fbd3036f-0f1c-4e98-b71c-d4cd61213f90", bytes.NewReader(msgBytes))
return req
}(),
setupMockAckExtension: func() component.Component {
return &mockAckExtension{
processEvent: func(_ string) (ackID uint64) {
currAckID++
return currAckID
},
ack: func(_ string, _ uint64) {},
}
},
assertResponse: func(t *testing.T, resp *http.Response, body any) {
assertHecSuccessResponseWithAckID(t, resp, body, currAckID)
},
},
}
Expand Down Expand Up @@ -1554,6 +1657,55 @@ func Test_splunkhecReceiver_handleReq_WithAck(t *testing.T) {
assert.Equal(t, 1, len(sink.AllLogs()))
},
},
{
name: "msg_accepted_with_case_insensitive_header",
req: func() *http.Request {
msgBytes, err := json.Marshal(splunkMsg)
require.NoError(t, err)
req := httptest.NewRequest("POST", "http://localhost/foo", bytes.NewReader(msgBytes))
req.Header.Set("x-splunk-request-channel", "fbd3036f-0f1c-4e98-b71c-d4cd61213f90")
return req
}(),
setupMockAckExtension: func() component.Component {
return &mockAckExtension{
processEvent: func(_ string) (ackID uint64) {
return uint64(1)
},
ack: func(_ string, _ uint64) {
},
}
},
assertResponse: func(t *testing.T, resp *http.Response, body any) {
assertHecSuccessResponseWithAckID(t, resp, body, 1)
},
assertSink: func(t *testing.T, sink *consumertest.LogsSink) {
assert.Equal(t, 1, len(sink.AllLogs()))
},
},
{
name: "msg_accepted_with_query_param",
req: func() *http.Request {
msgBytes, err := json.Marshal(splunkMsg)
require.NoError(t, err)
req := httptest.NewRequest("POST", "http://localhost/foo?channel=fbd3036f-0f1c-4e98-b71c-d4cd61213f90&isCheesy=true", bytes.NewReader(msgBytes))
return req
}(),
setupMockAckExtension: func() component.Component {
return &mockAckExtension{
processEvent: func(_ string) (ackID uint64) {
return uint64(1)
},
ack: func(_ string, _ uint64) {
},
}
},
assertResponse: func(t *testing.T, resp *http.Response, body any) {
assertHecSuccessResponseWithAckID(t, resp, body, 1)
},
assertSink: func(t *testing.T, sink *consumertest.LogsSink) {
assert.Equal(t, 1, len(sink.AllLogs()))
},
},
}

for _, tt := range tests {
Expand Down

0 comments on commit 43aff69

Please sign in to comment.