diff --git a/shuttle/rpc/engines/websocket/websocket.go b/shuttle/rpc/engines/websocket/websocket.go index be739073..8db75850 100644 --- a/shuttle/rpc/engines/websocket/websocket.go +++ b/shuttle/rpc/engines/websocket/websocket.go @@ -2,6 +2,7 @@ package websocket import ( "context" + "encoding/json" "fmt" "net/url" "sync" @@ -13,6 +14,7 @@ import ( "github.com/filecoin-project/go-address" "github.com/libp2p/go-libp2p/core/peer" "golang.org/x/net/websocket" + gwebsocket "golang.org/x/net/websocket" "github.com/application-research/estuary/config" "go.uber.org/zap" @@ -39,7 +41,7 @@ type Connection struct { } type IEstuaryRpcEngine interface { - Connect(ws *websocket.Conn, handle string, done chan struct{}) (func() error, func(), error) + Connect(ws *gwebsocket.Conn, handle string, done chan struct{}) (func() error, func(), error) GetShuttleConnections() []*Connection GetShuttleConnection(handle string) (*Connection, bool) } @@ -67,18 +69,28 @@ func NewEstuaryRpcEngine(ctx context.Context, db *gorm.DB, cfg *config.Estuary, return wbsMgr } -func (m *manager) Connect(ws *websocket.Conn, handle string, done chan struct{}) (func() error, func(), error) { +func (m *manager) Connect(ws *gwebsocket.Conn, handle string, done chan struct{}) (func() error, func(), error) { + var helloBytes []byte + if err := gwebsocket.Message.Receive(ws, &helloBytes); err != nil { + return nil, nil, err + } + var hello rpcevent.Hello - if err := websocket.JSON.Receive(ws, &hello); err != nil { + if err := json.Unmarshal(helloBytes, &hello); err != nil { + return nil, nil, err + } + + b, err := json.Marshal(&rpcevent.Hi{QueueEngEnabled: m.cfg.RpcEngine.Queue.Enabled}) + if err != nil { return nil, nil, err } - // tell shuttle if api support queue engine - if err := websocket.JSON.Send(ws, &rpcevent.Hi{QueueEngEnabled: m.cfg.RpcEngine.Queue.Enabled}); err != nil { + // tell shuttle if api supports queue engine + if err := gwebsocket.Message.Send(ws, b); err != nil { return nil, nil, err } - _, err := url.Parse(hello.Host) + _, err = url.Parse(hello.Host) if err != nil { m.log.Errorf("shuttle had invalid hostname %q: %s", hello.Host, err) hello.Host = "" @@ -128,14 +140,17 @@ func (m *manager) Connect(ws *websocket.Conn, handle string, done chan struct{}) select { case msg := <-sc.cmds: go func() { - // Write - err := websocket.JSON.Send(ws, msg) + msgBytes, err := json.Marshal(msg) if err != nil { + m.log.Errorf("failed to serialize message: %s", err) + return + } + + if err = websocket.Message.Send(ws, msgBytes); err != nil { m.log.Errorf("failed to write command to shuttle: %s", err) return } }() - case <-done: return } @@ -143,10 +158,16 @@ func (m *manager) Connect(ws *websocket.Conn, handle string, done chan struct{}) } readWebsocket := func() error { + var msgBytes []byte + if err := websocket.Message.Receive(ws, &msgBytes); err != nil { + return err + } + var msg *rpcevent.Message - if err := websocket.JSON.Receive(ws, &msg); err != nil { + if err := json.Unmarshal(msgBytes, &msg); err != nil { return err } + msg.Handle = handle go func() { m.rpcWebsocket <- msg