Skip to content

Commit

Permalink
refactored ws connection
Browse files Browse the repository at this point in the history
  • Loading branch information
jubeless committed Nov 9, 2020
1 parent b7f0534 commit 81f4290
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 121 deletions.
335 changes: 221 additions & 114 deletions rpc/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package rpc

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -30,108 +31,261 @@ import (
)

type result interface{}
type callBackInfo struct {
requestID uint64
subscription uint64
stream chan result
reflectType reflect.Type
}

type WSClient struct {
currentID uint64
rpcURL string
conn *websocket.Conn
lock sync.Mutex
callbacksByRequestID map[uint64]*callBackInfo
callbacksBySubscription map[uint64]*callBackInfo
lock sync.RWMutex
subscriptionByRequestID map[uint64]*Subscription
subscriptionByWSSubID map[uint64]*Subscription
reconnectOnErr bool
}

func NewWSClient(rpcURL string) (*WSClient, error) {
c := &WSClient{
currentID: 0,
func Dial(ctx context.Context, rpcURL string) (c *WSClient, err error) {
c = &WSClient{
rpcURL: rpcURL,
callbacksByRequestID: map[uint64]*callBackInfo{},
callbacksBySubscription: map[uint64]*callBackInfo{},
subscriptionByRequestID: map[uint64]*Subscription{},
subscriptionByWSSubID: map[uint64]*Subscription{},
}

conn, _, err := websocket.DefaultDialer.Dial(rpcURL, nil)
c.conn, _, err = websocket.DefaultDialer.DialContext(ctx, rpcURL, nil)
if err != nil {
return nil, fmt.Errorf("new ws client: dial: %w", err)
}
c.conn = conn

c.receiveMessages()

go c.receiveMessages()
return c, nil
}

func (c *WSClient) Close() {
c.conn.Close()
}

func (c *WSClient) receiveMessages() {
zlog.Info("ready to receive message")
go func() {
k := 0
for {
k++
_, message, err := c.conn.ReadMessage()
zlog.Info("")
if err != nil {
zlog.Error("message reception", zap.Error(err))
continue
}
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
c.closeAllSubscription(err)
return
}
c.handleMessage(message)
}
}

// when receiving message with id. the result will be a subscription number.
// that number will be associated to all future message destine to this request
if gjson.GetBytes(message, "id").Exists() {
zlog.Info("received subscription for id")
id := uint64(gjson.GetBytes(message, "id").Int())
sub := uint64(gjson.GetBytes(message, "result").Int())

//moving pending callback info to the actual callbacksBySubscription list
c.lock.Lock()
callBack := c.callbacksByRequestID[id]
callBack.subscription = sub
c.callbacksBySubscription[sub] = callBack
c.lock.Unlock()

zlog.Info("move sub from pending to callback", zap.Uint64("id", uint64(id)), zap.Uint64("subscription", uint64(sub)))
continue
}
func (c *WSClient) handleMessage(message []byte) {
// when receiving message with id. the result will be a subscription number.
// that number will be associated to all future message destine to this request
if gjson.GetBytes(message, "id").Exists() {
requestID := uint64(gjson.GetBytes(message, "id").Int())
subID := uint64(gjson.GetBytes(message, "result").Int())
c.handleNewSubscriptionMessage(requestID, subID)
return
}

//getting the callback
sub := uint64(gjson.GetBytes(message, "params.subscription").Int())
callBack := c.callbacksBySubscription[sub]
c.handleSubscriptionMessage(uint64(gjson.GetBytes(message, "params.subscription").Int()), message)

//getting and instantiate the return type for the call back.
resultType := reflect.New(callBack.reflectType)
result := resultType.Interface()
}

err = decodeClientResponse(bytes.NewReader(message), &result)
if err != nil {
zlog.Error("failed to decode result", zap.Uint64("subscription", uint64(sub)), zap.Error(err))
}
callBack.stream <- result
}
}()
func (c *WSClient) handleNewSubscriptionMessage(requestID, subID uint64) {
c.lock.Lock()
defer c.lock.Unlock()

zlog.Info("received new subscription message",
zap.Uint64("message_id", requestID),
zap.Uint64("subscription_id", subID),
)
callBack := c.subscriptionByRequestID[requestID]
callBack.subID = subID
c.subscriptionByWSSubID[subID] = callBack
return
}

func (c *WSClient) handleSubscriptionMessage(subID uint64, message []byte) {
zlog.Info("received subscription message",
zap.Uint64("subscription_id", subID),
)

c.lock.RLock()
sub, found := c.subscriptionByWSSubID[subID]
c.lock.RUnlock()
if !found {
zlog.Warn("unbale to find subscription for ws message", zap.Uint64("subscription_id", subID))
return
}

//getting and instantiate the return type for the call back.
resultType := reflect.New(sub.reflectType)
result := resultType.Interface()
err := decodeClientResponse(bytes.NewReader(message), &result)
if err != nil {
c.closeSubscription(sub.req.ID, fmt.Errorf("unable to decode client response: %w", err))
return
}

// this cannot be blocking or else
// we will no read any other message
if len(sub.stream) >= cap(sub.stream) {
c.closeSubscription(sub.req.ID, fmt.Errorf("reached channel max capacity %d", len(sub.stream)))
return
}

sub.stream <- result
return
}

func (c *WSClient) closeAllSubscription(err error) {
c.lock.Lock()
defer c.lock.Unlock()

for _, sub := range c.subscriptionByRequestID {
sub.err <- err
}

c.subscriptionByRequestID = map[uint64]*Subscription{}
c.subscriptionByWSSubID = map[uint64]*Subscription{}
}

func (c *WSClient) closeSubscription(reqID uint64, err error) {
c.lock.Lock()
defer c.lock.Unlock()

sub, found := c.subscriptionByRequestID[reqID]
if !found {
return
}

sub.err <- err

err = c.rpcUnsubscribe(sub.subID, sub.unsubscriptionMethod)
if err != nil {
zlog.Warn("unable to send rpc unsubscribe call",
zap.Error(err),
)
}

delete(c.subscriptionByRequestID, sub.req.ID)
delete(c.subscriptionByWSSubID, sub.subID)
}

func (c *WSClient) rpcUnsubscribe(subID uint64, method string) error {
req := newClientRequest([]interface{}{subID}, method, map[string]interface{}{})
data, err := req.encode()
if err != nil {
return fmt.Errorf("unable to encode unsubscription message for subID %d and method %s", subID, method)
}

err = c.conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
return fmt.Errorf("unable to send unsubscription message for subID %d and method %s", subID, method)
}
return nil
}

type Subscription struct {
req *clientRequest
subID uint64
stream chan result
err chan error
reflectType reflect.Type
closeFunc func(err error)
unsubscriptionMethod string
}

func newSubscription(req *clientRequest, reflectType reflect.Type, closeFunc func(err error)) *Subscription {
return &Subscription{
req: req,
reflectType: reflectType,
stream: make(chan result, 200),
err: make(chan error, 1),
closeFunc: closeFunc,
}
}

func (s *Subscription) Recv() (interface{}, error) {
select {
case d := <-s.stream:
return d, nil
case err := <-s.err:
return nil, err
}
}

func (s *Subscription) Unsubscribe() {
s.unsubscribe(nil)
}

func (s *Subscription) unsubscribe(err error) {
s.closeFunc(err)

}

func (c *WSClient) ProgramSubscribe(programID string, commitment CommitmentType) (*Subscription, error) {
return c.subscribe([]interface{}{programID}, "programSubscribe", "programUnsubscribe", commitment, ProgramWSResult{})
}

func (c *WSClient) subscribe(params []interface{}, subscriptionMethod, unsubscriptionMethod string, commitment CommitmentType, resultType interface{}) (*Subscription, error) {
conf := map[string]interface{}{
"encoding": "jsonParsed",
}
if commitment != "" {
conf["commitment"] = string(commitment)
}

req := newClientRequest(params, subscriptionMethod, conf)
data, err := req.encode()
if err != nil {
return nil, fmt.Errorf("subscribe: unable to encode subsciption request: %w", err)
}

sub := newSubscription(req, reflect.TypeOf(resultType), func(err error) {
c.closeSubscription(req.ID, err)
})

c.lock.Lock()
c.subscriptionByRequestID[req.ID] = sub
zlog.Info("added new subscription to websocket client", zap.Int("count", len(c.subscriptionByRequestID)))
c.lock.Unlock()

err = c.conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
return nil, fmt.Errorf("unable to write request: %w", err)
}

return sub, nil
}

type ProgramWSResult struct {
Context struct {
Slot uint64
} `json:"context"`
Value struct {
Account Account `json:"account"`
} `json:"value"`
}

type clientRequest struct {
Version string `json:"jsonrpc"`
Method string `json:"method"`
Params interface{} `json:"params"`
Id uint64 `json:"id"`
ID uint64 `json:"id"`
}

func encodeClientRequest(method string, args interface{}) ([]byte, uint64, error) {
c := &clientRequest{
func newClientRequest(params []interface{}, method string, configuration map[string]interface{}) *clientRequest {
params = append(params, configuration)
return &clientRequest{
Version: "2.0",
Method: method,
Params: args,
Id: uint64(rand.Int63()),
Params: params,
ID: uint64(rand.Int63()),
}
}

func (c *clientRequest) encode() ([]byte, error) {
data, err := json.Marshal(c)
if err != nil {
return nil, 0, fmt.Errorf("encode request: json marshall: %w", err)
return nil, fmt.Errorf("encode request: json marshall: %w", err)
}
return data, c.Id, nil
return data, nil
}

type wsClientResponse struct {
Expand Down Expand Up @@ -168,50 +322,3 @@ func decodeClientResponse(r io.Reader, reply interface{}) (err error) {

return json.Unmarshal(*c.Params.Result, &reply)
}

type ProgramWSResult struct {
Context struct {
Slot uint64
} `json:"context"`
Value struct {
Account Account `json:"account"`
} `json:"value"`
}

func (c *WSClient) ProgramSubscribe(programID string, commitment CommitmentType) (stream chan result, id uint64, err error) {
c.lock.Lock()
defer c.lock.Unlock()

stream = make(chan result, 200)

params := []interface{}{programID}
conf := map[string]interface{}{
"encoding": "jsonParsed",
}
if commitment != "" {
conf["commitment"] = string(commitment)
}

params = append(params, conf)
data, id, err := encodeClientRequest("programSubscribe", params)
if err != nil {
return nil, 0, fmt.Errorf("program subscribe: encode request: %c", err)
}

err = c.conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
return nil, 0, fmt.Errorf("program subscribe: write message: %c", err)
}

c.callbacksByRequestID[id] = &callBackInfo{
requestID: id,
stream: stream,
reflectType: reflect.TypeOf(ProgramWSResult{}),
}

return stream, id, nil
}

func (c *WSClient) ProgramUnsubscribe(reqID int) {

}
Loading

0 comments on commit 81f4290

Please sign in to comment.