diff --git a/serv/ws.go b/serv/ws.go index 450d3d75..2a6b5b74 100644 --- a/serv/ws.go +++ b/serv/ws.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/http" + "sync" "time" "github.com/dosco/graphjin/auth/v3" @@ -56,10 +57,11 @@ func init() { } type wsConn struct { - c context.Context - sessions map[string]wsState - conn *websocket.Conn - done chan bool + c context.Context + sessions map[string]wsState + conn *websocket.Conn + connMutex sync.Mutex + done chan bool w http.ResponseWriter r *http.Request @@ -138,7 +140,12 @@ func (s *service) subSwitch(wc *wsConn, req wsReq) (err error) { err = auth.Err401 return } - if err = wc.conn.WritePreparedMessage(initMsg); err != nil { + + wc.connMutex.Lock() + err = wc.conn.WritePreparedMessage(initMsg) + wc.connMutex.Unlock() + + if err != nil { return } @@ -216,7 +223,16 @@ func (s *service) waitForData(wc *wsConn, st *wsState, useNext bool) { } msg := buf.Bytes() buf.Reset() + + wc.connMutex.Lock() err = wc.conn.WriteMessage(websocket.TextMessage, msg) + wc.connMutex.Unlock() + + if err != nil { + s.zlog.Error("Subscription", []zapcore.Field{zap.Error(err)}...) + sendError(wc, st.ID, err) //nolint:errcheck + break + } case v := <-st.done: if v { @@ -228,12 +244,6 @@ func (s *service) waitForData(wc *wsConn, st *wsState, useNext bool) { return } } - - if err != nil { - s.zlog.Error("Subscription", []zapcore.Field{zap.Error(err)}...) - sendError(wc, st.ID, err) //nolint:errcheck - break - } } } @@ -264,6 +274,9 @@ func sendError(wc *wsConn, id string, cerr error) (err error) { if err != nil { return } + + wc.connMutex.Lock() + defer wc.connMutex.Unlock() err = wc.conn.WriteMessage(websocket.TextMessage, msg) return }