Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

websocket retry connect #442

Merged
merged 11 commits into from
Sep 27, 2023
15 changes: 15 additions & 0 deletions mqtt/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ func (c *Client) Publish(qos QOS, topic string, payload []byte, pid ID, retain b
return c.Send(publish)
}

// Publish sends a publish packet out cache size will drop
func (c *Client) PublishWithDrop(qos QOS, topic string, payload []byte, pid ID, retain bool, dup bool) error {
publish := NewPublish()
publish.ID = pid
publish.Dup = dup
publish.Message.QOS = qos
publish.Message.Topic = topic
publish.Message.Payload = payload
publish.Message.Retain = retain
if qos != 0 && pid == 0 {
publish.ID = c.ids.NextID()
}
return c.SendOrDrop(publish)
}

// Send sends a generic packet
func (c *Client) Send(pkt Packet) error {
select {
Expand Down
54 changes: 54 additions & 0 deletions mqtt/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,60 @@ func TestMqttClientPublishSubscribeQOS0(t *testing.T) {
safeReceive(done)
}

func TestMqttClientPublishSubscribeQOS0WithDrop(t *testing.T) {
subscribe := NewSubscribe()
subscribe.Subscriptions = []Subscription{{Topic: "test"}}
subscribe.ID = subscribeId

suback := NewSuback()
suback.ReturnCodes = []QOS{0}
suback.ID = subscribeId

publish := NewPublish()
publish.Message.Topic = "test"
publish.Message.Payload = []byte("test")

broker := mock.NewFlow().Debug().
Receive(connectPacket()).
Send(connackPacket()).
Receive(subscribe).
Send(suback).
Receive(publish).
Send(publish).
Receive(disconnectPacket()).
End()

done, port := initMockBroker(t, broker)

ops := newClientOptions(t, port, []Subscription{{Topic: "test"}})
cli := NewClient(ops)
assert.NotNil(t, cli)

obs := newMockObserver(t)
err := cli.Start(obs)
assert.NoError(t, err)

err = cli.PublishWithDrop(publish.Message.QOS, publish.Message.Topic, publish.Message.Payload, publish.ID, publish.Message.Retain, publish.Dup)
assert.NoError(t, err)
obs.assertPkts(publish)

time.Sleep(time.Second)

assert.NoError(t, cli.Close())
safeReceive(done)

ops = newClientOptions(t, "222", []Subscription{{Topic: "test"}})
newCLi := NewClient(ops)
for i := 0; i < 30; i++ {
err = newCLi.PublishWithDrop(publish.Message.QOS, publish.Message.Topic, publish.Message.Payload, publish.ID, publish.Message.Retain, publish.Dup)
assert.NoError(t, err)
}

assert.NoError(t, newCLi.Close())
safeReceive(done)

}

func TestMqttClientPublishSubscribeQOS1(t *testing.T) {
subscribe := NewSubscribe()
subscribe.Subscriptions = []Subscription{{Topic: "test", QOS: 1}}
Expand Down
121 changes: 93 additions & 28 deletions websocket/client.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
package websocket

import (
v1 "github.com/baetyl/baetyl-go/v2/spec/v1"
"net/http"
"net/url"
"time"

"github.com/baetyl/baetyl-go/v2/errors"
"github.com/baetyl/baetyl-go/v2/log"
v1 "github.com/baetyl/baetyl-go/v2/spec/v1"

"github.com/gorilla/websocket"
"github.com/panjf2000/ants/v2"
)

type WsConnect struct {
wscon *websocket.Conn
readMsgChan chan *v1.Message
}

type Client struct {
conn chan *websocket.Conn
u url.URL
dialer websocket.Dialer
antPool *ants.Pool
pool chan *WsConnect
connNum int
u url.URL
dialer websocket.Dialer
antPool *ants.Pool
ops *ClientOptions
readMsgChan []chan *v1.Message
log *log.Logger
}

// NewClient 函数用于创建一个Client对象
// readMsgChan 为读取信息的通道 需要配置和并发数量一致
func NewClient(ops *ClientOptions, readMsgChan []chan *v1.Message) (*Client, error) {
u := url.URL{Scheme: ops.Schema, Host: ops.Address, Path: ops.Path}
dialer := websocket.Dialer{
Expand All @@ -28,46 +40,99 @@ func NewClient(ops *ClientOptions, readMsgChan []chan *v1.Message) (*Client, err
TLSClientConfig: ops.TLSConfig,
HandshakeTimeout: ops.TLSHandshakeTimeout,
}
p, _ := ants.NewPool(1)
if ops.SyncMaxConcurrency != 0 {
p, _ = ants.NewPool(ops.SyncMaxConcurrency)
// 最少为1条
if ops.SyncMaxConcurrency <= 0 {
ops.SyncMaxConcurrency = 1
}

p, err := ants.NewPool(ops.SyncMaxConcurrency)
if err != nil {
return nil, err
}

if readMsgChan != nil && cap(readMsgChan) < ops.SyncMaxConcurrency {
return nil, errors.New("read msg cap must > SyncMaxConcurrency")
}
connect := make(chan *websocket.Conn, ops.SyncMaxConcurrency)
connect := make(chan *WsConnect, ops.SyncMaxConcurrency)
client := &Client{
pool: connect,
connNum: 0,
u: u,
dialer: dialer,
antPool: p,
ops: ops,
readMsgChan: readMsgChan,
log: log.L().With(log.Any("link", "websocket link")),
}
go client.initLink()
return client, nil
}

func (c *Client) initLink() {
// 根据设置创建连接池
for i := 0; i < ops.SyncMaxConcurrency; i++ {
con, _, err := dialer.Dial(u.String(), nil)
if err != nil {
return nil, err
for i := 0; i < c.ops.SyncMaxConcurrency; i++ {
var connectReadMsgChan chan *v1.Message = nil
if c.readMsgChan != nil {
connectReadMsgChan = c.readMsgChan[i]
}
// 每个链接创建一个协程readMsg
if readMsgChan != nil {
go ReadConMsg(con, readMsgChan[i])
ws, err := c.Connect(connectReadMsgChan)
if err != nil {
c.log.Error("link websocket error", log.Any("err", err))
}
connect <- con
// 为了保证连接池数量 失败wscon 以nil方式放入连接池 每次发送的时候重新连接
c.pool <- ws
}
}

return &Client{
conn: connect,
u: u,
dialer: dialer,
antPool: p,
}, nil
func (c *Client) Connect(readMsgChan chan *v1.Message) (*WsConnect, error) {
con, _, err := c.dialer.Dial(c.u.String(), nil)
ws := &WsConnect{
readMsgChan: readMsgChan,
}
if err != nil {
ws.wscon = nil
c.log.Error("websocket link error", log.Any("err", err))
return ws, err
} else {
ws.wscon = con
if c.readMsgChan != nil {
go ws.ReadConMsg(readMsgChan)
}
}
return ws, nil
}

func (c *Client) SendMsg(msg []byte) error {
con := <-c.conn
err := con.WriteMessage(websocket.TextMessage, msg)
c.conn <- con
con := <-c.pool
var err error
if con.wscon == nil {
con, err = c.Connect(con.readMsgChan)
if err != nil {
c.pool <- con
c.log.Error("retry link websocket error", log.Any("err", err))
return err
}
}
err = con.wscon.WriteMessage(websocket.TextMessage, msg)
if err != nil {
c.log.Error("websocket write msg error", log.Any("err", err))
con, err = c.Connect(con.readMsgChan)
if err != nil {
c.pool <- con
c.log.Error("retry link websocket error", log.Any("err", err))
return err
}
}
c.pool <- con
return err
}

func ReadConMsg(con *websocket.Conn, readMsg chan *v1.Message) {
func (w *WsConnect) ReadConMsg(readMsg chan *v1.Message) {
for {
msgType, data, err := con.ReadMessage()
if w.wscon == nil {
return
}
msgType, data, err := w.wscon.ReadMessage()
msg := &v1.Message{}
if err != nil {
msg = &v1.Message{
Expand Down
55 changes: 32 additions & 23 deletions websocket/client_test.go
Original file line number Diff line number Diff line change
@@ -1,40 +1,49 @@
package websocket

import (
v1 "github.com/baetyl/baetyl-go/v2/spec/v1"
"log"
"net/http"
"net/http/httptest"
"testing"
"time"

v1 "github.com/baetyl/baetyl-go/v2/spec/v1"
"github.com/baetyl/baetyl-go/v2/utils"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)

func Test_client(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 升级连接为websocket协议
upgrader := websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Fatalf("Failed to upgrade HTTP connection to WebSocket: %v", err)
}
_, msg, err := conn.ReadMessage()
func echo(w http.ResponseWriter, r *http.Request) {
var upgrader = websocket.Upgrader{}

c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Print("upgrade:", err)
return
}
defer c.Close()
for {
_, msg, err := c.ReadMessage()
if err != nil {
t.Fatalf("Failed to read message from WebSocket: %v", err)
log.Println("read:", err)
break
}
err = conn.WriteMessage(websocket.TextMessage, msg)
err = c.WriteMessage(websocket.TextMessage, msg)
if err != nil {
t.Fatalf("Failed to write message to WebSocket: %v", err)
log.Println("read:", err)
}

}))
defer server.Close()
}
}
func WsServer() {
http.HandleFunc("/echo", echo)
log.Fatal(http.ListenAndServe("127.0.0.1:9341", nil))
}
func Test_client(t *testing.T) {
go WsServer()

cfg := ClientConfig{
Address: server.URL[len("http://"):],
Path: "",
Address: "127.0.0.1:9341",
Path: "echo",
Schema: "ws",
IdleConnTimeout: 0,
TLSHandshakeTimeout: 0,
Expand All @@ -51,18 +60,18 @@ func Test_client(t *testing.T) {

client, err := NewClient(options, msg)
assert.NoError(t, err)
result := make(chan *SyncResults, 100)
result := make(chan *SyncResults, 1000)
extra := map[string]interface{}{"a": 1}

for i := 0; i < 100; i++ {
time.Sleep(time.Second * 2)
for i := 0; i < 20; i++ {
client.SyncSendMsg([]byte("hello"), result, extra)
}
time.Sleep(time.Second)

time.Sleep(time.Second * 2)
re := <-result
assert.NoError(t, re.Err)
assert.Equal(t, re.Extra["a"], 1)
assert.Equal(t, 99, len(result))
assert.Equal(t, 19, len(result))

for _, m := range msg {
r := <-m
Expand Down
Loading