From 79f1b329d18faf5ad3502aefecbe3667b2a7f34e Mon Sep 17 00:00:00 2001 From: Gabor Retvari Date: Fri, 15 Nov 2024 18:17:16 +0100 Subject: [PATCH] feature: Implement ICE tester server and client --- Makefile | 6 +- go.mod | 18 +- go.sum | 42 +++- pkg/icetester/dialer.go | 369 +++++++++++++++++++++++++++++++ pkg/icetester/icetester_test.go | 248 +++++++++++++++++++++ pkg/icetester/listener.go | 374 ++++++++++++++++++++++++++++++++ pkg/icetester/types.go | 49 +++++ 7 files changed, 1091 insertions(+), 15 deletions(-) create mode 100644 pkg/icetester/dialer.go create mode 100644 pkg/icetester/icetester_test.go create mode 100644 pkg/icetester/listener.go create mode 100644 pkg/icetester/types.go diff --git a/Makefile b/Makefile index ba85bf1..3da3495 100644 --- a/Makefile +++ b/Makefile @@ -46,6 +46,6 @@ build-bin: .PHONY: clean clean: - echo 'Use "make generate` to autogenerate server code' > pkg/server/server.go - echo 'Use "make generate` to autogenerate client code' > pkg/client/client.go - echo 'Use "make generate` to autogenerate client code' > pkg/types/types.go + # echo 'Use "make generate" to autogenerate server code' > pkg/server/server.go + # echo 'Use "make generate" to autogenerate client code' > pkg/client/client.go + # echo 'Use "make generate" to autogenerate client code' > pkg/types/types.go diff --git a/go.mod b/go.mod index 58c9698..295026c 100644 --- a/go.mod +++ b/go.mod @@ -14,10 +14,12 @@ require ( github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.1 github.com/oapi-codegen/runtime v1.1.1 - github.com/pion/dtls/v3 v3.0.2 + github.com/pion/datachannel v1.5.9 + github.com/pion/dtls/v3 v3.0.3 github.com/pion/logging v0.2.2 github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v4 v4.0.0 + github.com/pion/webrtc/v4 v4.0.1 github.com/prometheus/client_golang v1.20.4 github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 @@ -76,7 +78,15 @@ require ( github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/peterbourgon/diskv v2.0.1+incompatible // indirect + github.com/pion/ice/v4 v4.0.2 // indirect + github.com/pion/interceptor v0.1.37 // indirect + github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/randutil v0.1.0 // indirect + github.com/pion/rtcp v1.2.14 // indirect + github.com/pion/rtp v1.8.9 // indirect + github.com/pion/sctp v1.8.33 // indirect + github.com/pion/sdp/v3 v3.0.9 // indirect + github.com/pion/srtp/v3 v3.0.4 // indirect github.com/pion/stun/v3 v3.0.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -88,13 +98,13 @@ require ( go.opentelemetry.io/otel/trace v1.31.0 // indirect go.starlark.net v0.0.0-20240123142251-f86470692795 // indirect go.uber.org/multierr v1.10.0 // indirect - golang.org/x/crypto v0.27.0 // indirect + golang.org/x/crypto v0.28.0 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/net v0.29.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sync v0.8.0 // indirect - golang.org/x/term v0.24.0 // indirect - golang.org/x/text v0.18.0 // indirect + golang.org/x/term v0.25.0 // indirect + golang.org/x/text v0.19.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect google.golang.org/protobuf v1.35.1 // indirect gopkg.in/evanphx/json-patch.v5 v5.9.0 // indirect diff --git a/go.sum b/go.sum index bc4e93a..343cffa 100644 --- a/go.sum +++ b/go.sum @@ -118,6 +118,7 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg= @@ -126,18 +127,38 @@ github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= -github.com/pion/dtls/v3 v3.0.2 h1:425DEeJ/jfuTTghhUDW0GtYZYIwwMtnKKJNMcWccTX0= -github.com/pion/dtls/v3 v3.0.2/go.mod h1:dfIXcFkKoujDQ+jtd8M6RgqKK3DuaUilm3YatAbGp5k= +github.com/pion/datachannel v1.5.9 h1:LpIWAOYPyDrXtU+BW7X0Yt/vGtYxtXQ8ql7dFfYUVZA= +github.com/pion/datachannel v1.5.9/go.mod h1:kDUuk4CU4Uxp82NH4LQZbISULkX/HtzKa4P7ldf9izE= +github.com/pion/dtls/v3 v3.0.3 h1:j5ajZbQwff7Z8k3pE3S+rQ4STvKvXUdKsi/07ka+OWM= +github.com/pion/dtls/v3 v3.0.3/go.mod h1:weOTUyIV4z0bQaVzKe8kpaP17+us3yAuiQsEAG1STMU= +github.com/pion/ice/v4 v4.0.2 h1:1JhBRX8iQLi0+TfcavTjPjI6GO41MFn4CeTBX+Y9h5s= +github.com/pion/ice/v4 v4.0.2/go.mod h1:DCdqyzgtsDNYN6/3U8044j3U7qsJ9KFJC92VnOWHvXg= +github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI= +github.com/pion/interceptor v0.1.37/go.mod h1:JzxbJ4umVTlZAf+/utHzNesY8tmRkM2lVmkS82TTj8Y= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= +github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtcp v1.2.14 h1:KCkGV3vJ+4DAJmvP0vaQShsb0xkRfWkO540Gy102KyE= +github.com/pion/rtcp v1.2.14/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= +github.com/pion/rtp v1.8.9 h1:E2HX740TZKaqdcPmf4pw6ZZuG8u5RlMMt+l3dxeu6Wk= +github.com/pion/rtp v1.8.9/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= +github.com/pion/sctp v1.8.33 h1:dSE4wX6uTJBcNm8+YlMg7lw1wqyKHggsP5uKbdj+NZw= +github.com/pion/sctp v1.8.33/go.mod h1:beTnqSzewI53KWoG3nqB282oDMGrhNxBdb+JZnkCwRM= +github.com/pion/sdp/v3 v3.0.9 h1:pX++dCHoHUwq43kuwf3PyJfHlwIj4hXA7Vrifiq0IJY= +github.com/pion/sdp/v3 v3.0.9/go.mod h1:B5xmvENq5IXJimIO4zfp6LAe1fD9N+kFv+V/1lOdz8M= +github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M= +github.com/pion/srtp/v3 v3.0.4/go.mod h1:1Jx3FwDoxpRaTh1oRV8A/6G1BnFL+QI82eK4ms8EEJQ= github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM= github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA= +github.com/pion/webrtc/v4 v4.0.1 h1:6Unwc6JzoTsjxetcAIoWH81RUM4K5dBc1BbJGcF9WVE= +github.com/pion/webrtc/v4 v4.0.1/go.mod h1:SfNn8CcFxR6OUVjLXVslAQ3a3994JhyE3Hw1jAuqEto= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -161,10 +182,15 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= @@ -198,8 +224,8 @@ go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug= golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -225,12 +251,12 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= -golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/pkg/icetester/dialer.go b/pkg/icetester/dialer.go new file mode 100644 index 0000000..dc977a8 --- /dev/null +++ b/pkg/icetester/dialer.go @@ -0,0 +1,369 @@ +package icetester + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/pion/datachannel" + "github.com/pion/logging" + "github.com/pion/webrtc/v4" +) + +var ( + // Send pings to the CDS server with this period. Must be less than PongWait. + PingPeriod = 5 * time.Second + + // Time allowed to read the next pong message from the CDS server. + PongWait = 8 * time.Second + + // Time allowed to write a message to the CDS server. + WriteWait = 2 * time.Second + + // Period for retrying failed CDS connections. + RetryPeriod = 1 * time.Second +) + +var _ net.Conn = &dialerConn{} + +type Dialer struct { + iceConfig webrtc.Configuration + api *webrtc.API + logger logging.LoggerFactory + log logging.LeveledLogger +} + +func NewDialer(iceConfig webrtc.Configuration, logger logging.LoggerFactory) *Dialer { + e := webrtc.SettingEngine{} + e.DetachDataChannels() + + return &Dialer{ + iceConfig: iceConfig, + api: webrtc.NewAPI(webrtc.WithSettingEngine(e)), + logger: logger, + log: logger.NewLogger("tester-dialer"), + } +} + +func (d *Dialer) DialContext(ctx context.Context, addr string) (net.Conn, error) { + signalingServerURI, err := getURI(addr) + if err != nil { + return nil, fmt.Errorf("Failed to parse server address %q: %w", addr, err) + } + rawWsConn, _, err := websocket.DefaultDialer.DialContext(ctx, signalingServerURI.String(), makeHeader(signalingServerURI)) + if err != nil { + return nil, fmt.Errorf("Failed to connect to singlaing server at %q: %w", + signalingServerURI.String(), err) + } + // wrap with a locker to prevent concurrent writes + wsConn := &ThreadSafeWriter{Conn: rawWsConn} + + conn := &dialerConn{ + wsConn: wsConn, + log: d.logger.NewLogger(fmt.Sprintf("tester-client-%s-%s", + wsConn.LocalAddr(), wsConn.RemoteAddr())), + } + + conn.log.Debugf("Signaling connection successfully opened to tester server at %s", signalingServerURI.String()) + + // Start pinger thread: lifetime is the connection's lifetime, stops when connCtx is + // canceled + pingTicker := time.NewTicker(PingPeriod) + pingerCtx, pingerCancel := context.WithCancel(context.Background()) + conn.pingerCancel = pingerCancel + go func() { + defer pingTicker.Stop() + + for { + select { + case <-pingTicker.C: + wsConn.SetWriteDeadline(time.Now().Add(WriteWait)) //nolint:errcheck + if err := wsConn.WriteMessage(websocket.PingMessage, []byte("keepalive")); err != nil { + conn.log.Errorf("Could not ping tester signaling server at %q: %s", + wsConn.RemoteAddr(), err.Error()) + return + } + case <-pingerCtx.Done(): + conn.log.Tracef("Closing pinger thread WS connection %s-%s", + wsConn.LocalAddr().String(), wsConn.RemoteAddr().String()) + return + } + } + }() + + conn.log.Tracef("Creating new PeerConnection for WS connection %s-%s", + wsConn.LocalAddr().String(), wsConn.RemoteAddr().String()) + peerConn, err := d.api.NewPeerConnection(d.iceConfig) + if err != nil { + return nil, fmt.Errorf("Failed to create a PeerConnection: %w", err) + } + conn.peerConn = peerConn + + // Trickle ICE: emit server candidates to client, errors are not fatal + peerConn.OnICECandidate(func(i *webrtc.ICECandidate) { + if i == nil { + return + } + // When serializing a candidate use ToJSON, othwerwise json.Marshal will result in + // errors around `sdpMid` + candidateString, err := json.Marshal(i.ToJSON()) + if err != nil { + conn.log.Errorf("Failed to marshal candidate to json: %v", err) + return + } + + conn.log.Infof("Local candidate: %s", candidateString) + + if writeErr := wsConn.WriteJSON(&Message{ + Type: MessageTypeIceCandidate, + Data: string(candidateString), + }); writeErr != nil { + conn.log.Errorf("Failed to write JSON: %v", writeErr) + return + } + }) + + // If PeerConnection is closed, close the client + peerConn.OnConnectionStateChange(func(p webrtc.PeerConnectionState) { + conn.log.Infof("Connection State has changed: %s", p) + if p == webrtc.PeerConnectionStateFailed || p == webrtc.PeerConnectionStateClosed { + conn.Close() //nolint + } + }) + + // the next pong must arrive within the PongWait period + wsConn.SetReadDeadline(time.Now().Add(PongWait)) //nolint:errcheck + // reinit the deadline when we get a pong + wsConn.SetPongHandler(func(string) error { + // a.Tracef("Got PONG from server %q", url) + wsConn.SetReadDeadline(time.Now().Add(PongWait)) //nolint:errcheck + return nil + }) + + // Register data channel creation handling + connCh := make(chan any, 1) + defer close(connCh) + errCh := make(chan error) + + peerConn.OnDataChannel(func(dataChannel *webrtc.DataChannel) { + conn.log.Tracef("New DataChannel %s %d for WS connection %s-%s", dataChannel.Label(), + dataChannel.ID(), wsConn.LocalAddr().String(), wsConn.RemoteAddr().String()) + + // Register channel opening handling + dataChannel.OnOpen(func() { + conn.log.Debugf("Data channel '%s'-'%d' open for WS connection %s-%s", + dataChannel.Label(), dataChannel.ID(), + wsConn.LocalAddr().String(), wsConn.RemoteAddr().String()) + + raw, err := dataChannel.Detach() + if err != nil { + errCh <- fmt.Errorf("Failed to detach DataChannel: %w", err) + return + } + conn.dataConn = raw + + connCh <- struct{}{} + }) + }) + + candidateCache := []webrtc.ICECandidateInit{} + // Start signaling client speaker: lifetime is the connection's lifetime, stops when wsConn + // is closed + go func() { + defer close(errCh) + + message := &Message{} + for { + // ping-pong deadline misses will end up being caught here as a read beyond + // the deadline + msgType, raw, err := wsConn.ReadMessage() + if err != nil { + errCh <- err + return + } + + // Decoding errors are not fatal + if msgType != websocket.TextMessage { + conn.log.Errorf("Unexpected message type (code: %d) from client %q", + msgType, wsConn.RemoteAddr().String()) + continue + } + + if err := json.Unmarshal(raw, &message); err != nil { + conn.log.Errorf("Failed to unmarshal json to message: %v", err) + continue + } + + conn.log.Tracef("Got signaling message WS connection %s-%s: %v", + wsConn.LocalAddr().String(), wsConn.RemoteAddr().String(), message) + + switch message.Type { + case MessageTypeIceCandidate: + candidate := webrtc.ICECandidateInit{} + if err := json.Unmarshal([]byte(message.Data), &candidate); err != nil { + conn.log.Errorf("Failed to unmarshal json to candidate: %v", err) + continue + } + + conn.log.Infof("Remote candidate: %s", message.Data) + + if peerConn.RemoteDescription() == nil { + // cannot set candidates yet: cache candidate + candidateCache = append(candidateCache, candidate) + } else if err := peerConn.AddICECandidate(candidate); err != nil { + errCh <- fmt.Errorf("Failed to add ICE candidate: %w", err) + } + + case MessageTypeOffer: + offer := webrtc.SessionDescription{} + if err := json.Unmarshal([]byte(message.Data), &offer); err != nil { + conn.log.Errorf("Failed to unmarshal json to Offer: %v", err) + continue + } + + conn.log.Debugf("Got Offer on WS connection %s-%s: %v", + wsConn.LocalAddr().String(), wsConn.RemoteAddr().String(), + offer) + + if err := peerConn.SetRemoteDescription(offer); err != nil { + errCh <- fmt.Errorf("Failed to set remote description: %w", err) + return + } + + // flush candidate cache + for _, candidate := range candidateCache { + if err := peerConn.AddICECandidate(candidate); err != nil { + errCh <- fmt.Errorf("Failed to add cached ICE candidate: %w", err) + } + } + candidateCache = []webrtc.ICECandidateInit{} + + // Create an offer to send to the other process + conn.log.Tracef("Creating Answer WS connection %s-%s", + wsConn.LocalAddr().String(), wsConn.RemoteAddr().String()) + answer, err := peerConn.CreateAnswer(nil) + if err != nil { + errCh <- fmt.Errorf("Failed to create Answer: %w", err) + return + } + + // Sets the LocalDescription, and starts our UDP listeners + // Note: this will start the gathering of ICE candidates + if err = peerConn.SetLocalDescription(answer); err != nil { + errCh <- fmt.Errorf("Failed to set local description: %w", err) + return + } + + payload, err := json.Marshal(answer) + if err != nil { + errCh <- fmt.Errorf("Failed to JSON encode Answer: %v", err) + return + } + + conn.log.Debugf("Sending Answer on WS connection %s-%s: %v", + wsConn.LocalAddr().String(), wsConn.RemoteAddr().String(), + answer) + + if writeErr := wsConn.WriteJSON(&Message{ + Type: MessageTypeAnswer, + Data: string(payload), + }); writeErr != nil { + errCh <- fmt.Errorf("Failed to write Answer: %v", writeErr) + } + default: + conn.log.Errorf("unknown message: %+v", message) + } + } + }() + + select { + case <-connCh: + d.log.Infof("Creating new connection %s", conn.String()) + return conn, nil + case err := <-errCh: + conn.Close() + return nil, err + } +} + +type dialerConn struct { + pingerCancel context.CancelFunc + wsConn *ThreadSafeWriter + peerConn *webrtc.PeerConnection + dataConn datachannel.ReadWriteCloser + closed bool + log logging.LeveledLogger +} + +func (c *dialerConn) Close() error { + c.log.Tracef("Closing tester client connection %s", c.String()) + + if c.closed { + return nil + } + c.closed = true + + // Close the pinger thread + c.pingerCancel() + + // Close the WebSocket signaling connection: closes the signaling thread + c.wsConn.WriteMessage(websocket.CloseMessage, []byte{}) //nolint:errcheck + c.wsConn.Close() + + // Close the peerconnection + if err := c.peerConn.Close(); err != nil { + return fmt.Errorf("Failed to close PeerConnection: %w", err) + } + + return nil +} + +func (c *dialerConn) Read(b []byte) (int, error) { + return c.dataConn.Read(b) +} + +func (c *dialerConn) Write(b []byte) (int, error) { + return c.dataConn.Write(b) +} + +// TODO: implement +func (c *dialerConn) LocalAddr() net.Addr { return nil } +func (c *dialerConn) RemoteAddr() net.Addr { return nil } +func (c *dialerConn) SetDeadline(t time.Time) error { return nil } +func (c *dialerConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dialerConn) SetWriteDeadline(t time.Time) error { return nil } + +// String returns a unique identifier for the connection based on the underlying signaling connection. +func (c *dialerConn) String() string { + return fmt.Sprintf("%s-%s", c.wsConn.LocalAddr(), c.wsConn.RemoteAddr()) +} + +// creates an origin header +func makeHeader(url *url.URL) http.Header { + header := http.Header{} + origin := *url + origin.Scheme = "http" + origin.Path = "" + header.Set("origin", origin.String()) + return header +} + +func getURI(addr string) (*url.URL, error) { + if !strings.HasPrefix(addr, "http://") && !strings.HasPrefix(addr, "https://") && + !strings.HasPrefix(addr, "ws://") && !strings.HasPrefix(addr, "wss://") { + addr = "ws://" + addr + } + + url, err := url.Parse(addr) + if err != nil { + return nil, err + } + url.Path = signalingServerPath + return url, nil +} diff --git a/pkg/icetester/icetester_test.go b/pkg/icetester/icetester_test.go new file mode 100644 index 0000000..0c8614e --- /dev/null +++ b/pkg/icetester/icetester_test.go @@ -0,0 +1,248 @@ +package icetester + +import ( + "context" + "errors" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/pion/logging" + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/assert" + + slogger "github.com/l7mp/stunner/pkg/logger" +) + +var ( + testerLogLevel = "all:WARN" + // testerLogLevel = "all:TRACE" + // testerLogLevel = "all:INFO" + addr = "localhost:12345" + timeout = 5 * time.Second + interval = 50 * time.Millisecond + logger logging.LoggerFactory = slogger.NewLoggerFactory(testerLogLevel) + log logging.LeveledLogger = logger.NewLogger("test") +) + +// func init() { +// // setup a fast pinger so that we get a timely error notification +// PingPeriod = 500 * time.Millisecond +// PongWait = 800 * time.Millisecond +// WriteWait = 200 * time.Millisecond +// RetryPeriod = 250 * time.Millisecond +// } + +func echoTest(t *testing.T, conn net.Conn, content string) { + t.Helper() + + n, err := conn.Write([]byte(content)) + assert.NoError(t, err) + assert.Equal(t, len(content), n) + + buf := make([]byte, 2048) + n, err = conn.Read(buf) + assert.NoError(t, err) + assert.Equal(t, content, string(buf[:n])) +} + +var testerTestCases = []struct { + name string + tester func(t *testing.T, ctx context.Context, l *Listener) +}{ + { + name: "Basic connectivity", + tester: func(t *testing.T, ctx context.Context, l *Listener) { + log.Debug("Creating dialer") + d := NewDialer(webrtc.Configuration{}, logger) + assert.NotNil(t, d) + + log.Debug("Dialing") + clientConn, err := d.DialContext(ctx, addr) + assert.NoError(t, err) + + log.Debug("Echo test round 1") + echoTest(t, clientConn, "test1") + log.Debug("Echo test round 2") + echoTest(t, clientConn, "test2") + + assert.NoError(t, clientConn.Close(), "client conn close") + }, + }, { + name: "Closing dialer does not close client connection", + tester: func(t *testing.T, serverCtx context.Context, l *Listener) { + // a new context for the dialer + dialerCtx, dialerCancel := context.WithCancel(context.Background()) + + log.Debug("Creating dialer") + d := NewDialer(webrtc.Configuration{}, logger) + assert.NotNil(t, d) + + log.Debug("Dialing") + clientConn, err := d.DialContext(dialerCtx, addr) + assert.NoError(t, err) + + log.Debug("Echo test round 1") + echoTest(t, clientConn, "test1") + + log.Debug("Closing dialer") + dialerCancel() + + log.Debug("Echo test round 2") + echoTest(t, clientConn, "test2") + }, + }, { + name: "Client side close closes server", + tester: func(t *testing.T, serverCtx context.Context, l *Listener) { + log.Debug("Creating dialer") + d := NewDialer(webrtc.Configuration{}, logger) + assert.NotNil(t, d) + + log.Debug("Dialing") + clientConn, err := d.DialContext(serverCtx, addr) + assert.NoError(t, err) + + assert.Eventually(t, func() bool { return l.activeConns == 1 }, timeout, interval) + + log.Debug("Closing client connection") + assert.NoError(t, clientConn.Close()) + + // should close the server conn too + assert.Eventually(t, func() bool { return l.activeConns == 0 }, timeout, interval) + }, + }, { + name: "Server side close closes client", + tester: func(t *testing.T, serverCtx context.Context, l *Listener) { + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + + log.Debug("Creating dialer") + d := NewDialer(webrtc.Configuration{}, logger) + assert.NotNil(t, d) + + log.Debug("Dialing") + clientConn, err := d.DialContext(clientCtx, addr) + assert.NoError(t, err) + + assert.Eventually(t, func() bool { return l.activeConns == 1 }, timeout, interval) + + log.Debug("Closing server connections") + for _, lConn := range l.Conns() { + // log.Infof("------------ %s", lConn.String()) + assert.NoError(t, lConn.Close()) + } + + log.Info("&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&") + + assert.Eventually(t, func() bool { return l.activeConns == 0 }, timeout, interval) + + // should close the client conn too + assert.Eventually(t, func() bool { return clientConn.(*dialerConn).closed == true }, timeout, interval) + }, + }, { + name: "Multiple connections", + tester: func(t *testing.T, ctx context.Context, l *Listener) { + log.Debug("Creating dialer") + d := NewDialer(webrtc.Configuration{}, logger) + assert.NotNil(t, d) + + log.Debug("Dialing: creating 5 connections") + var wg sync.WaitGroup + wg.Add(5) + connChan := make(chan net.Conn, 5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + + clientConn, err := d.DialContext(ctx, addr) + assert.NoError(t, err) + + log.Debug("Echo test round 1") + echoTest(t, clientConn, "test1111") + + log.Debug("Echo test round 2") + echoTest(t, clientConn, "test2222") + + connChan <- clientConn + }() + } + + wg.Wait() + close(connChan) + + assert.Eventually(t, func() bool { return l.activeConns == 5 }, timeout, interval) + + for c := range connChan { + c.Close() + } + + assert.Eventually(t, func() bool { return l.activeConns == 0 }, timeout, interval) + }, + }, +} + +func TestTesterConn(t *testing.T) { + for _, c := range testerTestCases { + var l *Listener + + t.Run(c.name, func(t *testing.T) { + log.Infof("--------------------- %s ----------------------", c.name) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + log.Debug("Creating listener") + listener, err := NewListener(addr, webrtc.Configuration{}, logger) + assert.NoError(t, err) + l = listener + assert.NotNil(t, l) + + log.Debug("Creating echo services") + go func() { + for { + conn, err := l.Accept() + if err != nil { + return + } + + log.Debug("Accepting server connection") + + // responder + go func() { + buf := make([]byte, 100) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + + _, err = conn.Write(buf[:n]) + assert.NoError(t, err) + } + }() + + // closer + go func() { + <-ctx.Done() + + if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && + !errors.Is(err, http.ErrServerClosed) { + t.Error("server conn close") + } + + // close listener + l.Close() //nolint + }() + } + }() + + c.tester(t, ctx, l) + }) + + log.Debug("Waiting for connections to close") + if l != nil { + assert.Eventually(t, func() bool { return l.activeConns == 0 }, timeout, interval) + } + } +} diff --git a/pkg/icetester/listener.go b/pkg/icetester/listener.go new file mode 100644 index 0000000..19a0f52 --- /dev/null +++ b/pkg/icetester/listener.go @@ -0,0 +1,374 @@ +package icetester + +import ( + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" + + "github.com/gorilla/mux" + "github.com/gorilla/websocket" + "github.com/pion/datachannel" + "github.com/pion/logging" + "github.com/pion/webrtc/v4" +) + +const ( + messageSize = 2048 + signalingServerPath = "websocket" +) + +var _ net.Listener = &Listener{} +var _ net.Conn = &listenerConn{} + +// nolint +var ( + upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + serverDataPeriod = 100 * time.Millisecond // 10 pkt/sec +) + +type Listener struct { + *http.Server + addr string + iceConfig webrtc.Configuration + errCh chan error + connCh chan *listenerConn + api *webrtc.API + conns map[string]*listenerConn + lock sync.Mutex + activeConns int + logger logging.LoggerFactory + log logging.LeveledLogger +} + +func NewListener(addr string, iceConfig webrtc.Configuration, logger logging.LoggerFactory) (*Listener, error) { + e := webrtc.SettingEngine{} + e.DetachDataChannels() + l := &Listener{ + addr: addr, + iceConfig: iceConfig, + api: webrtc.NewAPI(webrtc.WithSettingEngine(e)), + errCh: make(chan error, 5), + connCh: make(chan *listenerConn, 128), + conns: map[string]*listenerConn{}, + logger: logger, + log: logger.NewLogger("tester-listener"), + } + + router := mux.NewRouter() + router.HandleFunc("/"+signalingServerPath, l.ServeHTTP) + l.Server = &http.Server{Addr: addr, Handler: router} + + c, err := net.Listen("tcp", addr) + if err != nil { + return nil, fmt.Errorf("WS tester signaling server on %s: %w", addr, err) + } + + go func() { + defer close(l.errCh) + defer close(l.connCh) + + if err := l.Server.Serve(c); err != nil { + l.errCh <- err + } + }() + + return l, nil +} + +func (l *Listener) Accept() (net.Conn, error) { + l.log.Trace("Accept: waiting for new connection") + + select { + case err := <-l.errCh: + return nil, err + case conn := <-l.connCh: + l.log.Infof("Accepting connection for WS connection %s-%s", + conn.wsConn.RemoteAddr(), conn.wsConn.LocalAddr()) + + l.lock.Lock() + l.activeConns += 1 + l.conns[conn.String()] = conn + l.lock.Unlock() + + return conn, nil + } +} + +func (l *Listener) Close() error { + l.log.Tracef("Closing tester server listener at address %s", l.addr) + defer l.Server.Close() + + select { + case err := <-l.errCh: + return err + default: + return nil + } +} + +func (l *Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // upgrade to webSocket + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + rawWsConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + l.errCh <- fmt.Errorf("Failed to upgrade WebSocket connection: %w", err) + return + } + + wsConn := &ThreadSafeWriter{Conn: rawWsConn} + + l.log.Tracef("Tester listener received new connection from client %s", wsConn.RemoteAddr()) + connLog := l.logger.NewLogger(fmt.Sprintf("tester-lconn-%s-%s", + wsConn.RemoteAddr(), wsConn.LocalAddr())) + + wsConn.SetPingHandler(func(string) error { + return wsConn.WriteMessage(websocket.PongMessage, []byte("keepalive")) + }) + + connLog.Trace("Creating PeerConnection") + peerConn, err := l.api.NewPeerConnection(l.iceConfig) + if err != nil { + l.errCh <- fmt.Errorf("Failed to create a PeerConnection: %w", err) + return + } + + connLog.Tracef("Creating DataChannel") + dataChannel, err := peerConn.CreateDataChannel("data", nil) + if err != nil { + l.errCh <- fmt.Errorf("Failed to create DataChannel: %w", err) + return + } + + peerConn.OnConnectionStateChange(func(p webrtc.PeerConnectionState) { + connLog.Debugf("Connection State has changed: %s", p.String()) + if p == webrtc.PeerConnectionStateFailed || p == webrtc.PeerConnectionStateClosed { + l.errCh <- errors.New("ICE connection closed prematurely") + return + } + }) + + dataChannel.OnOpen(func() { + connLog.Tracef("Data channel '%s'-'%d' open", dataChannel.Label(), dataChannel.ID()) + + raw, dErr := dataChannel.Detach() + if dErr != nil { + l.errCh <- fmt.Errorf("Failed to detach DataChannel: %w", err) + return + } + + l.log.Infof("Creating new connection for WS connection %s-%s", + wsConn.RemoteAddr(), wsConn.LocalAddr()) + + conn := &listenerConn{ + listener: l, + wsConn: wsConn, + peerConn: peerConn, + dataChan: dataChannel, + dataConn: raw, + log: connLog, + } + + l.connCh <- conn + + // redo ICE state change callback to react to the peerconnection going away + peerConn.OnConnectionStateChange(func(p webrtc.PeerConnectionState) { + conn.Close() //nolint:errcheck + }) + }) + + peerConn.OnICECandidate(func(i *webrtc.ICECandidate) { + if i == nil { + return + } + + // When serializing a candidate use ToJSON, othwerwise json.Marshal will result in + // errors around `sdpMid` + candidateString, err := json.Marshal(i.ToJSON()) + if err != nil { + connLog.Errorf("Failed to marshal candidate to json: %v", err) + return + } + + l.log.Debugf("Sending candidate: %s", candidateString) + + if writeErr := wsConn.WriteJSON(&Message{ + Type: MessageTypeIceCandidate, + Data: string(candidateString), + }); writeErr != nil { + connLog.Errorf("Failed to write JSON: %v", writeErr) + } + }) + + connLog.Trace("Creating Offer for client") + offer, err := peerConn.CreateOffer(nil) + if err != nil { + l.errCh <- fmt.Errorf("Failed to create Offer: %w", err) + return + } + + // Note: this will start the gathering of ICE candidates + connLog.Tracef("Setting local descrition (Offer)") + if err = peerConn.SetLocalDescription(offer); err != nil { + l.errCh <- fmt.Errorf("Failed to set local description: %w", err) + return + } + + payload, err := json.Marshal(offer) + if err != nil { + l.errCh <- fmt.Errorf("Failed to JSON encode Offer: %v", err) + return + } + + connLog.Tracef("Sending Offer: %s", offer) + if writeErr := wsConn.WriteJSON(&Message{ + Type: MessageTypeOffer, + Data: string(payload), + }); writeErr != nil { + l.errCh <- fmt.Errorf("Failed to write Offer: %v", writeErr) + return + } + + message := &Message{} + for { + _, raw, err := wsConn.ReadMessage() + if err != nil { + l.errCh <- fmt.Errorf("Failed to read message: %v", err) + return + } + + connLog.Tracef("Got message: %s", raw) + + if err := json.Unmarshal(raw, &message); err != nil { + l.errCh <- fmt.Errorf("Failed to unmarshal json to message: %v", err) + return + } + + switch message.Type { + case MessageTypeIceCandidate: + candidate := webrtc.ICECandidateInit{} + if err := json.Unmarshal([]byte(message.Data), &candidate); err != nil { + l.errCh <- fmt.Errorf("Failed to unmarshal json to candidate: %v", err) + return + } + + connLog.Debugf("Got ICE candidate: %v", candidate) + + if err := peerConn.AddICECandidate(candidate); err != nil { + l.errCh <- fmt.Errorf("Failed to add ICE candidate: %v", err) + return + } + + case MessageTypeAnswer: + answer := webrtc.SessionDescription{} + if err := json.Unmarshal([]byte(message.Data), &answer); err != nil { + l.errCh <- fmt.Errorf("Failed to unmarshal json to answer: %v", err) + return + } + + connLog.Debugf("Got Answer: %v", answer) + + connLog.Tracef("Setting remote descrition (Answer)") + if err := peerConn.SetRemoteDescription(answer); err != nil { + l.errCh <- fmt.Errorf("Failed to set remote description: %v", err) + return + } + default: + l.log.Errorf("unknown message: %+v", message) + } + } +} + +func (_ *Listener) Addr() net.Addr { + //TODO + return nil +} + +func (l *Listener) Conns() []*listenerConn { + l.lock.Lock() + defer l.lock.Unlock() + ret := []*listenerConn{} + for _, c := range l.conns { + ret = append(ret, c) + } + return ret +} + +type listenerConn struct { + listener *Listener + wsConn *ThreadSafeWriter + peerConn *webrtc.PeerConnection + dataChan *webrtc.DataChannel + dataConn datachannel.ReadWriteCloser + closed bool + log logging.LeveledLogger +} + +func (c *listenerConn) Close() error { + c.log.Tracef("Closing tester server listener connection %s", c.String()) + + if c.closed { + return nil + } + c.closed = true + + // Close the datachannel + var err error + if c.dataChan.ReadyState() == webrtc.DataChannelStateOpen { + if err = c.dataConn.Close(); err != nil { + c.log.Debugf("Error closing DataChannel for client %s: %s", + c.wsConn.RemoteAddr().String(), err.Error()) + } + } + + // Close the peer connection too + err = c.peerConn.Close() + if err != nil { + c.log.Debugf("Error closing PeerConnection for client %s: %s", + c.wsConn.RemoteAddr().String(), err.Error()) + } + + // Close the websocket, this will exit the peerconnection and the http handler + err = c.wsConn.Close() + if err != nil { + c.log.Debugf("Error closing WS connection for client %s: %s", + c.wsConn.RemoteAddr().String(), err.Error()) + } + + c.listener.lock.Lock() + c.listener.activeConns -= 1 + delete(c.listener.conns, c.String()) + c.listener.lock.Unlock() + + // Return the last error + return err +} + +func (c *listenerConn) Read(b []byte) (int, error) { + return c.dataConn.Read(b) +} + +func (c *listenerConn) Write(b []byte) (int, error) { + return c.dataConn.Write(b) +} + +// TODO: implement +func (c *listenerConn) LocalAddr() net.Addr { return nil } +func (c *listenerConn) RemoteAddr() net.Addr { return nil } +func (c *listenerConn) SetDeadline(t time.Time) error { return nil } +func (c *listenerConn) SetReadDeadline(t time.Time) error { return nil } +func (c *listenerConn) SetWriteDeadline(t time.Time) error { return nil } + +// String returns a unique identifier for the connection based on the underlying signaling connection. +func (c *listenerConn) String() string { + return fmt.Sprintf("%s-%s", c.wsConn.RemoteAddr(), c.wsConn.LocalAddr()) +} diff --git a/pkg/icetester/types.go b/pkg/icetester/types.go new file mode 100644 index 0000000..495a717 --- /dev/null +++ b/pkg/icetester/types.go @@ -0,0 +1,49 @@ +package icetester + +import ( + "sync" + + "github.com/gorilla/websocket" +) + +type MessageType int + +const ( + MessageTypeUnknown MessageType = iota + MessageTypeIceCandidate + MessageTypeOffer + MessageTypeAnswer + MessageTypeClose +) + +type Message struct { + Type MessageType `json:"type"` + Data string `json:"data"` +} + +// ThreadSafeWriter represents a client WebSocket connection. An added lock guards the underlying connection +// from concurrent write to websocket connection errors. +type ThreadSafeWriter struct { + *websocket.Conn + readLock, writeLock sync.Mutex // for writemessage +} + +// WriteMessage writes a message to the client connection with proper locking. +func (c *ThreadSafeWriter) WriteMessage(messageType int, data []byte) error { + c.writeLock.Lock() + defer c.writeLock.Unlock() + return c.Conn.WriteMessage(messageType, data) +} + +func (c *ThreadSafeWriter) WriteJSON(v any) error { + c.writeLock.Lock() + defer c.writeLock.Unlock() + return c.Conn.WriteJSON(v) +} + +// ReadMessage reads a message from the client connection with proper locking. +func (c *ThreadSafeWriter) ReadMessage() (int, []byte, error) { + c.readLock.Lock() + defer c.readLock.Unlock() + return c.Conn.ReadMessage() +}