From 81f4290bc372ae8eb5f6a7a9ffe64652dfd281e6 Mon Sep 17 00:00:00 2001 From: Julien Cassis Date: Mon, 9 Nov 2020 16:00:13 -0500 Subject: [PATCH] refactored ws connection --- rpc/ws.go | 335 ++++++++++++++++++++++++++++++++----------------- rpc/ws_test.go | 22 ++-- 2 files changed, 236 insertions(+), 121 deletions(-) diff --git a/rpc/ws.go b/rpc/ws.go index 4b8e53f..fc8e3ee 100644 --- a/rpc/ws.go +++ b/rpc/ws.go @@ -16,6 +16,7 @@ package rpc import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -30,108 +31,261 @@ import ( ) type result interface{} -type callBackInfo struct { - requestID uint64 - subscription uint64 - stream chan result - reflectType reflect.Type -} type WSClient struct { - currentID uint64 rpcURL string conn *websocket.Conn - lock sync.Mutex - callbacksByRequestID map[uint64]*callBackInfo - callbacksBySubscription map[uint64]*callBackInfo + lock sync.RWMutex + subscriptionByRequestID map[uint64]*Subscription + subscriptionByWSSubID map[uint64]*Subscription + reconnectOnErr bool } -func NewWSClient(rpcURL string) (*WSClient, error) { - c := &WSClient{ - currentID: 0, +func Dial(ctx context.Context, rpcURL string) (c *WSClient, err error) { + c = &WSClient{ rpcURL: rpcURL, - callbacksByRequestID: map[uint64]*callBackInfo{}, - callbacksBySubscription: map[uint64]*callBackInfo{}, + subscriptionByRequestID: map[uint64]*Subscription{}, + subscriptionByWSSubID: map[uint64]*Subscription{}, } - conn, _, err := websocket.DefaultDialer.Dial(rpcURL, nil) + c.conn, _, err = websocket.DefaultDialer.DialContext(ctx, rpcURL, nil) if err != nil { return nil, fmt.Errorf("new ws client: dial: %w", err) } - c.conn = conn - - c.receiveMessages() + go c.receiveMessages() return c, nil } +func (c *WSClient) Close() { + c.conn.Close() +} + func (c *WSClient) receiveMessages() { - zlog.Info("ready to receive message") - go func() { - k := 0 - for { - k++ - _, message, err := c.conn.ReadMessage() - zlog.Info("") - if err != nil { - zlog.Error("message reception", zap.Error(err)) - continue - } + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + c.closeAllSubscription(err) + return + } + c.handleMessage(message) + } +} - // when receiving message with id. the result will be a subscription number. - // that number will be associated to all future message destine to this request - if gjson.GetBytes(message, "id").Exists() { - zlog.Info("received subscription for id") - id := uint64(gjson.GetBytes(message, "id").Int()) - sub := uint64(gjson.GetBytes(message, "result").Int()) - - //moving pending callback info to the actual callbacksBySubscription list - c.lock.Lock() - callBack := c.callbacksByRequestID[id] - callBack.subscription = sub - c.callbacksBySubscription[sub] = callBack - c.lock.Unlock() - - zlog.Info("move sub from pending to callback", zap.Uint64("id", uint64(id)), zap.Uint64("subscription", uint64(sub))) - continue - } +func (c *WSClient) handleMessage(message []byte) { + // when receiving message with id. the result will be a subscription number. + // that number will be associated to all future message destine to this request + if gjson.GetBytes(message, "id").Exists() { + requestID := uint64(gjson.GetBytes(message, "id").Int()) + subID := uint64(gjson.GetBytes(message, "result").Int()) + c.handleNewSubscriptionMessage(requestID, subID) + return + } - //getting the callback - sub := uint64(gjson.GetBytes(message, "params.subscription").Int()) - callBack := c.callbacksBySubscription[sub] + c.handleSubscriptionMessage(uint64(gjson.GetBytes(message, "params.subscription").Int()), message) - //getting and instantiate the return type for the call back. - resultType := reflect.New(callBack.reflectType) - result := resultType.Interface() +} - err = decodeClientResponse(bytes.NewReader(message), &result) - if err != nil { - zlog.Error("failed to decode result", zap.Uint64("subscription", uint64(sub)), zap.Error(err)) - } - callBack.stream <- result - } - }() +func (c *WSClient) handleNewSubscriptionMessage(requestID, subID uint64) { + c.lock.Lock() + defer c.lock.Unlock() + + zlog.Info("received new subscription message", + zap.Uint64("message_id", requestID), + zap.Uint64("subscription_id", subID), + ) + callBack := c.subscriptionByRequestID[requestID] + callBack.subID = subID + c.subscriptionByWSSubID[subID] = callBack + return +} + +func (c *WSClient) handleSubscriptionMessage(subID uint64, message []byte) { + zlog.Info("received subscription message", + zap.Uint64("subscription_id", subID), + ) + + c.lock.RLock() + sub, found := c.subscriptionByWSSubID[subID] + c.lock.RUnlock() + if !found { + zlog.Warn("unbale to find subscription for ws message", zap.Uint64("subscription_id", subID)) + return + } + + //getting and instantiate the return type for the call back. + resultType := reflect.New(sub.reflectType) + result := resultType.Interface() + err := decodeClientResponse(bytes.NewReader(message), &result) + if err != nil { + c.closeSubscription(sub.req.ID, fmt.Errorf("unable to decode client response: %w", err)) + return + } + + // this cannot be blocking or else + // we will no read any other message + if len(sub.stream) >= cap(sub.stream) { + c.closeSubscription(sub.req.ID, fmt.Errorf("reached channel max capacity %d", len(sub.stream))) + return + } + + sub.stream <- result + return +} + +func (c *WSClient) closeAllSubscription(err error) { + c.lock.Lock() + defer c.lock.Unlock() + + for _, sub := range c.subscriptionByRequestID { + sub.err <- err + } + + c.subscriptionByRequestID = map[uint64]*Subscription{} + c.subscriptionByWSSubID = map[uint64]*Subscription{} +} + +func (c *WSClient) closeSubscription(reqID uint64, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + sub, found := c.subscriptionByRequestID[reqID] + if !found { + return + } + + sub.err <- err + + err = c.rpcUnsubscribe(sub.subID, sub.unsubscriptionMethod) + if err != nil { + zlog.Warn("unable to send rpc unsubscribe call", + zap.Error(err), + ) + } + + delete(c.subscriptionByRequestID, sub.req.ID) + delete(c.subscriptionByWSSubID, sub.subID) +} + +func (c *WSClient) rpcUnsubscribe(subID uint64, method string) error { + req := newClientRequest([]interface{}{subID}, method, map[string]interface{}{}) + data, err := req.encode() + if err != nil { + return fmt.Errorf("unable to encode unsubscription message for subID %d and method %s", subID, method) + } + + err = c.conn.WriteMessage(websocket.TextMessage, data) + if err != nil { + return fmt.Errorf("unable to send unsubscription message for subID %d and method %s", subID, method) + } + return nil +} + +type Subscription struct { + req *clientRequest + subID uint64 + stream chan result + err chan error + reflectType reflect.Type + closeFunc func(err error) + unsubscriptionMethod string +} + +func newSubscription(req *clientRequest, reflectType reflect.Type, closeFunc func(err error)) *Subscription { + return &Subscription{ + req: req, + reflectType: reflectType, + stream: make(chan result, 200), + err: make(chan error, 1), + closeFunc: closeFunc, + } +} + +func (s *Subscription) Recv() (interface{}, error) { + select { + case d := <-s.stream: + return d, nil + case err := <-s.err: + return nil, err + } +} + +func (s *Subscription) Unsubscribe() { + s.unsubscribe(nil) +} + +func (s *Subscription) unsubscribe(err error) { + s.closeFunc(err) + +} + +func (c *WSClient) ProgramSubscribe(programID string, commitment CommitmentType) (*Subscription, error) { + return c.subscribe([]interface{}{programID}, "programSubscribe", "programUnsubscribe", commitment, ProgramWSResult{}) +} + +func (c *WSClient) subscribe(params []interface{}, subscriptionMethod, unsubscriptionMethod string, commitment CommitmentType, resultType interface{}) (*Subscription, error) { + conf := map[string]interface{}{ + "encoding": "jsonParsed", + } + if commitment != "" { + conf["commitment"] = string(commitment) + } + + req := newClientRequest(params, subscriptionMethod, conf) + data, err := req.encode() + if err != nil { + return nil, fmt.Errorf("subscribe: unable to encode subsciption request: %w", err) + } + + sub := newSubscription(req, reflect.TypeOf(resultType), func(err error) { + c.closeSubscription(req.ID, err) + }) + + c.lock.Lock() + c.subscriptionByRequestID[req.ID] = sub + zlog.Info("added new subscription to websocket client", zap.Int("count", len(c.subscriptionByRequestID))) + c.lock.Unlock() + + err = c.conn.WriteMessage(websocket.TextMessage, data) + if err != nil { + return nil, fmt.Errorf("unable to write request: %w", err) + } + + return sub, nil +} + +type ProgramWSResult struct { + Context struct { + Slot uint64 + } `json:"context"` + Value struct { + Account Account `json:"account"` + } `json:"value"` } type clientRequest struct { Version string `json:"jsonrpc"` Method string `json:"method"` Params interface{} `json:"params"` - Id uint64 `json:"id"` + ID uint64 `json:"id"` } -func encodeClientRequest(method string, args interface{}) ([]byte, uint64, error) { - c := &clientRequest{ +func newClientRequest(params []interface{}, method string, configuration map[string]interface{}) *clientRequest { + params = append(params, configuration) + return &clientRequest{ Version: "2.0", Method: method, - Params: args, - Id: uint64(rand.Int63()), + Params: params, + ID: uint64(rand.Int63()), } +} + +func (c *clientRequest) encode() ([]byte, error) { data, err := json.Marshal(c) if err != nil { - return nil, 0, fmt.Errorf("encode request: json marshall: %w", err) + return nil, fmt.Errorf("encode request: json marshall: %w", err) } - return data, c.Id, nil + return data, nil } type wsClientResponse struct { @@ -168,50 +322,3 @@ func decodeClientResponse(r io.Reader, reply interface{}) (err error) { return json.Unmarshal(*c.Params.Result, &reply) } - -type ProgramWSResult struct { - Context struct { - Slot uint64 - } `json:"context"` - Value struct { - Account Account `json:"account"` - } `json:"value"` -} - -func (c *WSClient) ProgramSubscribe(programID string, commitment CommitmentType) (stream chan result, id uint64, err error) { - c.lock.Lock() - defer c.lock.Unlock() - - stream = make(chan result, 200) - - params := []interface{}{programID} - conf := map[string]interface{}{ - "encoding": "jsonParsed", - } - if commitment != "" { - conf["commitment"] = string(commitment) - } - - params = append(params, conf) - data, id, err := encodeClientRequest("programSubscribe", params) - if err != nil { - return nil, 0, fmt.Errorf("program subscribe: encode request: %c", err) - } - - err = c.conn.WriteMessage(websocket.TextMessage, data) - if err != nil { - return nil, 0, fmt.Errorf("program subscribe: write message: %c", err) - } - - c.callbacksByRequestID[id] = &callBackInfo{ - requestID: id, - stream: stream, - reflectType: reflect.TypeOf(ProgramWSResult{}), - } - - return stream, id, nil -} - -func (c *WSClient) ProgramUnsubscribe(reqID int) { - -} diff --git a/rpc/ws_test.go b/rpc/ws_test.go index 71ae066..dffcd85 100644 --- a/rpc/ws_test.go +++ b/rpc/ws_test.go @@ -15,23 +15,31 @@ package rpc import ( + "context" + "fmt" "testing" - "time" + + "go.uber.org/zap" "github.com/stretchr/testify/require" ) func TestWSClient_ProgramSubscribe(t *testing.T) { + zlog, _ = zap.NewDevelopment() - c, err := NewWSClient("ws://api.mainnet-beta.solana.com:80/rpc") + c, err := Dial(context.Background(), "ws://api.mainnet-beta.solana.com:80/rpc") + defer c.Close() require.NoError(t, err) - stream, _, err := c.ProgramSubscribe("EUqojwWA2rd19FZrzeBncJsm38Jm1hEhE3zsmX3bRc2o", "") + sub, err := c.ProgramSubscribe("EUqojwWA2rd19FZrzeBncJsm38Jm1hEhE3zsmX3bRc2o", "") require.NoError(t, err) - select { - case <-stream: - case <-time.After(2000 * time.Millisecond): - t.Error("failed to run the giving time") + data, err := sub.Recv() + if err != nil { + fmt.Println("receive an error: ", err) + return } + fmt.Println("data received: ", data.(*ProgramWSResult).Value.Account.Owner) + return + }