From 81b43cd2cc1d9a3b965ac2a18724a2a74913b168 Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Tue, 30 Aug 2022 11:44:45 +0900 Subject: [PATCH] Add tests for bad HTTP and empty HTTP (#9) --- main.go | 50 ++++++++++++---------- main_test.go | 118 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 23 deletions(-) diff --git a/main.go b/main.go index 10a0733da8751..901ca50dd61a1 100644 --- a/main.go +++ b/main.go @@ -121,13 +121,13 @@ func (ctx *httpContext) OnHttpRequestHeaders(numHeaders int, endOfStream bool) t // TODO(anuraaga): Do these work with HTTP/1? path, err := proxywasm.GetHttpRequestHeader(":path") if err != nil { - proxywasm.LogCriticalf("failed to get path header: %v", err) + proxywasm.LogCriticalf("failed to get :path: %v", err) return types.ActionContinue } method, err := proxywasm.GetHttpRequestHeader(":method") if err != nil { - proxywasm.LogCriticalf("failed to get method header: %v", err) + proxywasm.LogCriticalf("failed to get :method: %v", err) return types.ActionContinue } @@ -155,16 +155,18 @@ func (ctx *httpContext) OnHttpRequestHeaders(numHeaders int, endOfStream bool) t func (ctx *httpContext) OnHttpRequestBody(bodySize int, endOfStream bool) types.Action { tx := ctx.tx - body, err := proxywasm.GetHttpRequestBody(0, bodySize) - if err != nil { - proxywasm.LogCriticalf("failed to get request body: %v", err) - return types.ActionContinue - } - - _, err = tx.RequestBodyBuffer.Write(body) - if err != nil { - proxywasm.LogCriticalf("failed to read request body: %v", err) - return types.ActionContinue + if bodySize > 0 { + body, err := proxywasm.GetHttpRequestBody(0, bodySize) + if err != nil { + proxywasm.LogCriticalf("failed to get request body: %v", err) + return types.ActionContinue + } + + _, err = tx.RequestBodyBuffer.Write(body) + if err != nil { + proxywasm.LogCriticalf("failed to read request body: %v", err) + return types.ActionContinue + } } if !endOfStream { @@ -189,7 +191,7 @@ func (ctx *httpContext) OnHttpResponseHeaders(numHeaders int, endOfStream bool) status, err := proxywasm.GetHttpResponseHeader(":status") if err != nil { - proxywasm.LogCriticalf("failed to get status header: %v", err) + proxywasm.LogCriticalf("failed to get :status: %v", err) return types.ActionContinue } code, err := strconv.Atoi(status) @@ -219,16 +221,18 @@ func (ctx *httpContext) OnHttpResponseHeaders(numHeaders int, endOfStream bool) func (ctx *httpContext) OnHttpResponseBody(bodySize int, endOfStream bool) types.Action { tx := ctx.tx - body, err := proxywasm.GetHttpResponseBody(0, bodySize) - if err != nil { - proxywasm.LogCriticalf("failed to get response body: %v", err) - return types.ActionContinue - } - - _, err = tx.ResponseBodyBuffer.Write(body) - if err != nil { - proxywasm.LogCriticalf("failed to read response body: %v", err) - return types.ActionContinue + if bodySize > 0 { + body, err := proxywasm.GetHttpResponseBody(0, bodySize) + if err != nil { + proxywasm.LogCriticalf("failed to get response body: %v", err) + return types.ActionContinue + } + + _, err = tx.ResponseBodyBuffer.Write(body) + if err != nil { + proxywasm.LogCriticalf("failed to read response body: %v", err) + return types.ActionContinue + } } if !endOfStream { diff --git a/main_test.go b/main_test.go index 27c145d42696a..aedf00d0da527 100644 --- a/main_test.go +++ b/main_test.go @@ -314,6 +314,124 @@ func TestBadConfig(t *testing.T) { }) } +func TestBadRequest(t *testing.T) { + tests := []struct { + name string + reqHdrs [][2]string + msg string + }{ + { + name: "missing path", + reqHdrs: [][2]string{ + {":method", "GET"}, + }, + msg: "failed to get :path", + }, + { + name: "missing method", + reqHdrs: [][2]string{ + {":path", "/hello"}, + }, + msg: "failed to get :method", + }, + } + + vmTest(t, func(t *testing.T, vm types.VMContext) { + for _, tc := range tests { + tt := tc + t.Run(tt.name, func(t *testing.T) { + opt := proxytest. + NewEmulatorOption(). + WithVMContext(vm). + WithPluginConfiguration([]byte{}) + + host, reset := proxytest.NewHostEmulator(opt) + defer reset() + + require.Equal(t, types.OnPluginStartStatusOK, host.StartPlugin()) + + id := host.InitializeHttpContext() + + action := host.CallOnRequestHeaders(id, tt.reqHdrs, false) + require.Equal(t, types.ActionContinue, action) + + logs := strings.Join(host.GetCriticalLogs(), "\n") + require.Contains(t, logs, tt.msg) + }) + } + }) +} + +func TestBadResponse(t *testing.T) { + tests := []struct { + name string + respHdrs [][2]string + msg string + }{ + { + name: "missing path", + respHdrs: [][2]string{ + {"content-length", "12"}, + }, + msg: "failed to get :status", + }, + } + + vmTest(t, func(t *testing.T, vm types.VMContext) { + for _, tc := range tests { + tt := tc + t.Run(tt.name, func(t *testing.T) { + opt := proxytest. + NewEmulatorOption(). + WithVMContext(vm). + WithPluginConfiguration([]byte{}) + + host, reset := proxytest.NewHostEmulator(opt) + defer reset() + + require.Equal(t, types.OnPluginStartStatusOK, host.StartPlugin()) + + id := host.InitializeHttpContext() + + action := host.CallOnResponseHeaders(id, tt.respHdrs, false) + require.Equal(t, types.ActionContinue, action) + + logs := strings.Join(host.GetCriticalLogs(), "\n") + require.Contains(t, logs, tt.msg) + }) + } + }) +} + +func TestEmptyBody(t *testing.T) { + vmTest(t, func(t *testing.T, vm types.VMContext) { + opt := proxytest. + NewEmulatorOption(). + WithVMContext(vm). + WithPluginConfiguration([]byte{}) + + host, reset := proxytest.NewHostEmulator(opt) + defer reset() + + require.Equal(t, types.OnPluginStartStatusOK, host.StartPlugin()) + + id := host.InitializeHttpContext() + + action := host.CallOnRequestBody(id, []byte{}, false) + require.Equal(t, types.ActionContinue, action) + action = host.CallOnRequestBody(id, []byte{}, true) + require.Equal(t, types.ActionContinue, action) + + action = host.CallOnResponseBody(id, []byte{}, false) + require.Equal(t, types.ActionContinue, action) + action = host.CallOnResponseBody(id, []byte{}, true) + require.Equal(t, types.ActionContinue, action) + + logs := strings.Join(host.GetCriticalLogs(), "\n") + require.Empty(t, logs) + }) +} + func vmTest(t *testing.T, f func(*testing.T, types.VMContext)) { t.Helper()