diff --git a/server/server.go b/server/server.go index 066edf42..885db052 100644 --- a/server/server.go +++ b/server/server.go @@ -42,6 +42,10 @@ type StartSettings struct { // Server's TLS configuration. TLSConfig *tls.Config + + // HTTPMiddlewareFunc specifies middleware for HTTP messages received by the server. + // This function is optional to set. + HTTPMiddlewareFunc func(handlerFunc http.HandlerFunc) http.HandlerFunc } type HTTPHandlerFunc func(http.ResponseWriter, *http.Request) diff --git a/server/serverimpl.go b/server/serverimpl.go index 10b159c3..e7fa1e9a 100644 --- a/server/serverimpl.go +++ b/server/serverimpl.go @@ -82,7 +82,11 @@ func (s *server) Start(settings StartSettings) error { path = defaultOpAMPPath } - mux.HandleFunc(path, s.httpHandler) + if settings.HTTPMiddlewareFunc != nil { + mux.HandleFunc(path, settings.HTTPMiddlewareFunc(s.httpHandler)) + } else { + mux.HandleFunc(path, s.httpHandler) + } hs := &http.Server{ Handler: mux, diff --git a/server/serverimpl_test.go b/server/serverimpl_test.go index 071a99e1..081a1332 100644 --- a/server/serverimpl_test.go +++ b/server/serverimpl_test.go @@ -57,6 +57,31 @@ func TestServerStartStop(t *testing.T) { assert.NoError(t, err) } +func TestServerStartStopWithMiddleware(t *testing.T) { + var addedMiddleware atomic.Bool + assert.False(t, addedMiddleware.Load()) + + testHTTPMiddlewareFunc := func(handlerFunc http.HandlerFunc) http.HandlerFunc { + addedMiddleware.Store(true) + return func(writer http.ResponseWriter, request *http.Request) { + handlerFunc(writer, request) + } + } + + startSettings := &StartSettings{ + HTTPMiddlewareFunc: testHTTPMiddlewareFunc, + } + + srv := startServer(t, startSettings) + assert.True(t, addedMiddleware.Load()) + + err := srv.Start(*startSettings) + assert.ErrorIs(t, err, errAlreadyStarted) + + err = srv.Stop(context.Background()) + assert.NoError(t, err) +} + func TestServerAddrWithNonZeroPort(t *testing.T) { srv := New(&sharedinternal.NopLogger{}) require.NotNil(t, srv) @@ -830,6 +855,105 @@ func TestConnectionAllowsConcurrentWrites(t *testing.T) { } } +func TestServerCallsHTTPMiddlewareOverWebsocket(t *testing.T) { + middlewareCalled := int32(0) + + testHTTPMiddlewareFunc := func(handlerFunc http.HandlerFunc) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + atomic.AddInt32(&middlewareCalled, 1) + handlerFunc(writer, request) + } + } + + callbacks := CallbacksStruct{ + OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{ + Accept: true, + ConnectionCallbacks: ConnectionCallbacksStruct{}, + } + }, + } + + // Start a Server + settings := &StartSettings{ + HTTPMiddlewareFunc: testHTTPMiddlewareFunc, + Settings: Settings{Callbacks: callbacks}, + } + srv := startServer(t, settings) + defer func() { + err := srv.Stop(context.Background()) + assert.NoError(t, err) + }() + + // Connect to the server, ensuring successful connection + conn, resp, err := dialClient(settings) + assert.NoError(t, err) + assert.NotNil(t, conn) + require.NotNil(t, resp) + assert.EqualValues(t, 101, resp.StatusCode) + + // Verify middleware was called once for the websocket connection + eventually(t, func() bool { return atomic.LoadInt32(&middlewareCalled) == int32(1) }) + assert.Equal(t, int32(1), atomic.LoadInt32(&middlewareCalled)) +} + +func TestServerCallsHTTPMiddlewareOverHTTP(t *testing.T) { + middlewareCalled := int32(0) + + testHTTPMiddlewareFunc := func(handlerFunc http.HandlerFunc) http.HandlerFunc { + return func(writer http.ResponseWriter, request *http.Request) { + atomic.AddInt32(&middlewareCalled, 1) + handlerFunc(writer, request) + } + } + + callbacks := CallbacksStruct{ + OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{ + Accept: true, + ConnectionCallbacks: ConnectionCallbacksStruct{}, + } + }, + } + + // Start a Server + settings := &StartSettings{ + HTTPMiddlewareFunc: testHTTPMiddlewareFunc, + Settings: Settings{Callbacks: callbacks}, + } + srv := startServer(t, settings) + defer func() { + err := srv.Stop(context.Background()) + assert.NoError(t, err) + }() + + // Send an AgentToServer message to the Server + sendMsg1 := protobufs.AgentToServer{InstanceUid: "01BX5ZZKBKACTAV9WEVGEMMVS1"} + serializedProtoBytes1, err := proto.Marshal(&sendMsg1) + require.NoError(t, err) + _, err = http.Post( + "http://"+settings.ListenEndpoint+settings.ListenPath, + contentTypeProtobuf, + bytes.NewReader(serializedProtoBytes1), + ) + require.NoError(t, err) + + // Send another AgentToServer message to the Server + sendMsg2 := protobufs.AgentToServer{InstanceUid: "01BX5ZZKBKACTAV9WEVGEMMVRZ"} + serializedProtoBytes2, err := proto.Marshal(&sendMsg2) + require.NoError(t, err) + _, err = http.Post( + "http://"+settings.ListenEndpoint+settings.ListenPath, + contentTypeProtobuf, + bytes.NewReader(serializedProtoBytes2), + ) + require.NoError(t, err) + + // Verify middleware was triggered for each HTTP call + eventually(t, func() bool { return atomic.LoadInt32(&middlewareCalled) == int32(2) }) + assert.Equal(t, int32(2), atomic.LoadInt32(&middlewareCalled)) +} + func BenchmarkSendToClient(b *testing.B) { clientConnections := []*websocket.Conn{} serverConnections := []types.Connection{}