Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add optional HTTP Middleware function to StartSettings for serverimpl #263

Merged
4 changes: 3 additions & 1 deletion internal/examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/open-telemetry/opamp-go v0.1.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/stretchr/testify v1.8.4
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0
go.opentelemetry.io/otel v1.24.0
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.24.0
go.opentelemetry.io/otel/metric v1.24.0
Expand All @@ -19,12 +20,13 @@ require (

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/gorilla/websocket v1.5.1 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/gorilla/websocket v1.5.1 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/mapstructure v1.4.1 // indirect
Expand Down
4 changes: 4 additions & 0 deletions internal/examples/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/go-ldap/ldap v3.0.2+incompatible/go.mod h1:qfd9rJvER9Q0/D/Sqn1DfHRoBp40uXYvFoEVrNEPqRc=
Expand Down Expand Up @@ -127,6 +129,8 @@ github.com/tklauser/numcpus v0.3.0 h1:ILuRUQBtssgnxw0XXIjKUC56fgnOrFoQQ/4+DeU2bi
github.com/tklauser/numcpus v0.3.0/go.mod h1:yFGUr7TUHQRAhyqBcEg0Ge34zDBAsIvJJcyE6boqnA8=
github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg=
github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.24.0 h1:mM8nKi6/iFQ0iqst80wDHU2ge198Ye/TfN0WBS5U24Y=
Expand Down
3 changes: 3 additions & 0 deletions internal/examples/server/opampsrv/opampsrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"net/http"
"os"

"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"

"github.com/open-telemetry/opamp-go/internal"
"github.com/open-telemetry/opamp-go/internal/examples/server/data"
"github.com/open-telemetry/opamp-go/protobufs"
Expand Down Expand Up @@ -54,6 +56,7 @@ func (srv *Server) Start() {
},
},
ListenEndpoint: "127.0.0.1:4320",
HTTPMiddleware: otelhttp.NewMiddleware("/v1/opamp"),
}
tlsConfig, err := internal.CreateServerTLSConfig(
"../../certs/certs/ca.cert.pem",
Expand Down
5 changes: 5 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ type StartSettings struct {

// Server's TLS configuration.
TLSConfig *tls.Config

// HTTPMiddleware specifies middleware for HTTP messages received by the server.
// Note that the function will be called once for websockets upon connecting and will
// be called for every HTTP request. This function is optional to set.
HTTPMiddleware func(handler http.Handler) http.Handler
tigrannajaryan marked this conversation as resolved.
Show resolved Hide resolved
}

type HTTPHandlerFunc func(http.ResponseWriter, *http.Request)
Expand Down
18 changes: 17 additions & 1 deletion server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ type server struct {

var _ OpAMPServer = (*server)(nil)

// innerHTTPHandler implements the http.Handler interface so it can be used by functions
// that require the type (like Middleware) without exposing ServeHTTP directly on server.
type innerHTTPHander struct {
httpHandlerFunc http.HandlerFunc
}

func (i innerHTTPHander) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
i.httpHandlerFunc(writer, request)
}

// New creates a new OpAMP Server.
func New(logger types.Logger) *server {
if logger == nil {
Expand Down Expand Up @@ -82,7 +92,13 @@ func (s *server) Start(settings StartSettings) error {
path = defaultOpAMPPath
}

mux.HandleFunc(path, s.httpHandler)
handler := innerHTTPHander{s.httpHandler}

if settings.HTTPMiddleware != nil {
mux.Handle(path, settings.HTTPMiddleware(handler))
} else {
mux.Handle(path, handler)
}

hs := &http.Server{
Handler: mux,
Expand Down
130 changes: 130 additions & 0 deletions server/serverimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,33 @@ func TestServerStartStop(t *testing.T) {
assert.NoError(t, err)
}

func TestServerStartStopWithMiddleware(t *testing.T) {
var addedMiddleware atomic.Bool
assert.False(t, addedMiddleware.Load())

testHTTPMiddleware := func(handler http.Handler) http.Handler {
addedMiddleware.Store(true)
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r)
},
)
}

startSettings := &StartSettings{
HTTPMiddleware: testHTTPMiddleware,
}

srv := startServer(t, startSettings)
assert.True(t, addedMiddleware.Load())

err := srv.Start(*startSettings)
assert.ErrorIs(t, err, errAlreadyStarted)

err = srv.Stop(context.Background())
assert.NoError(t, err)
}

func TestServerAddrWithNonZeroPort(t *testing.T) {
srv := New(&sharedinternal.NopLogger{})
require.NotNil(t, srv)
Expand Down Expand Up @@ -830,6 +857,109 @@ func TestConnectionAllowsConcurrentWrites(t *testing.T) {
}
}

func TestServerCallsHTTPMiddlewareOverWebsocket(t *testing.T) {
middlewareCalled := int32(0)

testHTTPMiddleware := func(handler http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&middlewareCalled, 1)
handler.ServeHTTP(w, r)
},
)
}

callbacks := CallbacksStruct{
OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{
Accept: true,
ConnectionCallbacks: ConnectionCallbacksStruct{},
}
},
}

// Start a Server
settings := &StartSettings{
HTTPMiddleware: testHTTPMiddleware,
Settings: Settings{Callbacks: callbacks},
}
srv := startServer(t, settings)
defer func() {
err := srv.Stop(context.Background())
assert.NoError(t, err)
}()

// Connect to the server, ensuring successful connection
conn, resp, err := dialClient(settings)
assert.NoError(t, err)
assert.NotNil(t, conn)
require.NotNil(t, resp)
assert.EqualValues(t, 101, resp.StatusCode)

// Verify middleware was called once for the websocket connection
eventually(t, func() bool { return atomic.LoadInt32(&middlewareCalled) == int32(1) })
assert.Equal(t, int32(1), atomic.LoadInt32(&middlewareCalled))
}

func TestServerCallsHTTPMiddlewareOverHTTP(t *testing.T) {
middlewareCalled := int32(0)

testHTTPMiddleware := func(handler http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&middlewareCalled, 1)
handler.ServeHTTP(w, r)
},
)
}

callbacks := CallbacksStruct{
OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{
Accept: true,
ConnectionCallbacks: ConnectionCallbacksStruct{},
}
},
}

// Start a Server
settings := &StartSettings{
HTTPMiddleware: testHTTPMiddleware,
Settings: Settings{Callbacks: callbacks},
}
srv := startServer(t, settings)
defer func() {
err := srv.Stop(context.Background())
assert.NoError(t, err)
}()

// Send an AgentToServer message to the Server
sendMsg1 := protobufs.AgentToServer{InstanceUid: "01BX5ZZKBKACTAV9WEVGEMMVS1"}
serializedProtoBytes1, err := proto.Marshal(&sendMsg1)
require.NoError(t, err)
_, err = http.Post(
"http://"+settings.ListenEndpoint+settings.ListenPath,
contentTypeProtobuf,
bytes.NewReader(serializedProtoBytes1),
)
require.NoError(t, err)

// Send another AgentToServer message to the Server
sendMsg2 := protobufs.AgentToServer{InstanceUid: "01BX5ZZKBKACTAV9WEVGEMMVRZ"}
serializedProtoBytes2, err := proto.Marshal(&sendMsg2)
require.NoError(t, err)
_, err = http.Post(
"http://"+settings.ListenEndpoint+settings.ListenPath,
contentTypeProtobuf,
bytes.NewReader(serializedProtoBytes2),
)
require.NoError(t, err)

// Verify middleware was triggered for each HTTP call
eventually(t, func() bool { return atomic.LoadInt32(&middlewareCalled) == int32(2) })
assert.Equal(t, int32(2), atomic.LoadInt32(&middlewareCalled))
}

func BenchmarkSendToClient(b *testing.B) {
clientConnections := []*websocket.Conn{}
serverConnections := []types.Connection{}
Expand Down
Loading