Skip to content

Commit

Permalink
feat: multiple subscriptions on a single websocket
Browse files Browse the repository at this point in the history
- previously a single websocket only supported a single connection not anymore
- various bug fixes and code cleanup around websockets
  • Loading branch information
dosco committed Jan 20, 2023
1 parent e4b158f commit 43e619b
Showing 1 changed file with 82 additions and 64 deletions.
146 changes: 82 additions & 64 deletions serv/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,24 @@ func init() {
}
}

type wsConn struct {
c context.Context
sessions map[string]wsState
conn *websocket.Conn
done chan bool

w http.ResponseWriter
r *http.Request
ah auth.HandlerFunc
}

type wsState struct {
c context.Context
conn *websocket.Conn
req wsReq
ah auth.HandlerFunc
exit bool
ID string
m *core.Member
done chan bool

w http.ResponseWriter
r *http.Request
}

func (s *service) apiV1Ws(w http.ResponseWriter, r *http.Request, ah auth.HandlerFunc) {
var m *core.Member
var err error

conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
renderErr(w, err)
Expand All @@ -79,42 +81,42 @@ func (s *service) apiV1Ws(w http.ResponseWriter, r *http.Request, ah auth.Handle
defer conn.Close()
conn.SetReadLimit(2048)

st := wsState{
c: r.Context(),
done: make(chan bool),
conn: conn,
ah: ah,
w: w,
r: r,
wc := wsConn{
c: r.Context(),
sessions: make(map[string]wsState),
conn: conn,
done: make(chan bool),
w: w,
r: r,
ah: ah,
}

for {
var b []byte
var req wsReq

if _, b, err = conn.ReadMessage(); err != nil {
break
}

if err = json.Unmarshal(b, &st.req); err != nil {
if err = json.Unmarshal(b, &req); err != nil {
break
}

if err = s.subSwitch(&st); err != nil {
break
}

if st.exit {
if err = s.subSwitch(&wc, req); err != nil {
break
}
}

if err != nil {
if e, ok := err.(*websocket.CloseError); !ok ||
(e.Code != websocket.CloseNormalClosure && e.Code != websocket.CloseGoingAway) {
s.zlog.Error("Subscription", []zapcore.Field{zap.Error(err)}...)
sendError(&st, err) //nolint:errcheck
}

m.Unsubscribe()
st.done <- true
for _, st := range wc.sessions {
st.m.Unsubscribe()
}
wc.done <- true
}

type authHeaders struct {
Expand All @@ -123,68 +125,79 @@ type authHeaders struct {
UserID interface{} `json:"X-User-ID"`
}

func (s *service) subSwitch(st *wsState) (err error) {
switch st.req.Type {
func (s *service) subSwitch(wc *wsConn, req wsReq) (err error) {
switch req.Type {
case "connection_init":
if err = setHeaders(st); err != nil {
if err = setHeaders(req, wc.r); err != nil {
return
}
if st.c, err = st.ah(st.w, st.r); err != nil {
if wc.c, err = wc.ah(wc.w, wc.r); err != nil {
return
}
if s.conf.Serv.AuthFailBlock && !auth.IsAuth(st.c) {
if s.conf.Serv.AuthFailBlock && !auth.IsAuth(wc.c) {
err = auth.Err401
return
}
if err = st.conn.WritePreparedMessage(initMsg); err != nil {
if err = wc.conn.WritePreparedMessage(initMsg); err != nil {
return
}

case "start", "subscribe":
var p gqlReq
if err = json.Unmarshal(st.req.Payload, &p); err != nil {
return
if err = json.Unmarshal(req.Payload, &p); err != nil {
break
}

c := wc.c
if s.conf.Serv.Auth.Development {
var x authHeaders
if err = json.Unmarshal(p.Vars, &x); err != nil {
return
break
}
if x.UserIDProvider != "" {
st.c = context.WithValue(st.c, core.UserIDProviderKey, x.UserIDProvider)
c = context.WithValue(c, core.UserIDProviderKey, x.UserIDProvider)
}
if x.UserRole != "" {
st.c = context.WithValue(st.c, core.UserRoleKey, x.UserRole)
c = context.WithValue(c, core.UserRoleKey, x.UserRole)
}
if x.UserID != nil {
st.c = context.WithValue(st.c, core.UserIDKey, x.UserID)
c = context.WithValue(c, core.UserIDKey, x.UserID)
}
}

var m *core.Member
if m, err = s.gj.Subscribe(st.c, p.Query, p.Vars, nil); err != nil {
return
st := wsState{ID: req.ID, done: make(chan bool)}
if st.m, err = s.gj.Subscribe(c, p.Query, p.Vars, nil); err != nil {
break
}
go s.waitForData(st, m)
return
wc.sessions[st.ID] = st
useNext := req.Type == "subscribe"

go s.waitForData(wc, &st, useNext)

case "complete", "connection_terminate", "stop":
st.exit = true
if st, ok := wc.sessions[req.ID]; ok {
st.done <- true
st.m.Unsubscribe()
delete(wc.sessions, req.ID)
}

default:
err = fmt.Errorf("unknown message type: %s", st.req.Type)
err = fmt.Errorf("unknown message type: %s", req.Type)
}

if err != nil {
sendError(wc, req.ID, err) //nolint:errcheck
}
return
}

func (s *service) waitForData(st *wsState, m *core.Member) {
func (s *service) waitForData(wc *wsConn, st *wsState, useNext bool) {
var buf bytes.Buffer

var ptype string
var err error

if st.req.Type == "subscribe" {
if useNext {
ptype = "next"
} else {
ptype = "data"
Expand All @@ -193,59 +206,64 @@ func (s *service) waitForData(st *wsState, m *core.Member) {
enc := json.NewEncoder(&buf)
for {
select {
case v := <-m.Result:
m := wsRes{ID: st.req.ID, Type: ptype}
m.Payload.Data = v.Data
m.Payload.Errors = v.Errors
case v := <-st.m.Result:
res := wsRes{ID: st.ID, Type: ptype}
res.Payload.Data = v.Data
res.Payload.Errors = v.Errors

if err = enc.Encode(m); err != nil {
if err = enc.Encode(res); err != nil {
break
}
msg := buf.Bytes()
buf.Reset()
err = st.conn.WriteMessage(websocket.TextMessage, msg)
err = wc.conn.WriteMessage(websocket.TextMessage, msg)

case v := <-st.done:
if v {
return
}

case v := <-wc.done:
if v {
return
}
}

if err != nil {
s.zlog.Error("Subscription", []zapcore.Field{zap.Error(err)}...)
sendError(st, err) //nolint:errcheck
sendError(wc, st.ID, err) //nolint:errcheck
break
}
}
}

func setHeaders(st *wsState) (err error) {
if len(st.req.Payload) == 0 {
func setHeaders(req wsReq, r *http.Request) (err error) {
if len(req.Payload) == 0 {
return
}
var p map[string]interface{}
if err = json.Unmarshal(st.req.Payload, &p); err != nil {
if err = json.Unmarshal(req.Payload, &p); err != nil {
return
}
for k, v := range p {
switch v1 := v.(type) {
case string:
st.r.Header.Set(k, v1)
r.Header.Set(k, v1)
case json.Number:
st.r.Header.Set(k, v1.String())
r.Header.Set(k, v1.String())
}
}
return
}

func sendError(st *wsState, cerr error) (err error) {
m := wsRes{ID: st.req.ID, Type: "error"}
func sendError(wc *wsConn, id string, cerr error) (err error) {
m := wsRes{ID: id, Type: "error"}
m.Payload.Errors = []core.Error{{Message: cerr.Error()}}

msg, err := json.Marshal(m)
if err != nil {
return
}
err = st.conn.WriteMessage(websocket.TextMessage, msg)
err = wc.conn.WriteMessage(websocket.TextMessage, msg)
return
}

1 comment on commit 43e619b

@vercel
Copy link

@vercel vercel bot commented on 43e619b Jan 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.