From 7cd8decaf2b8672d4a53cd3f4fca3cb7ccc5ac2f Mon Sep 17 00:00:00 2001 From: miaodanyang Date: Wed, 6 Sep 2023 18:44:09 +0800 Subject: [PATCH 1/7] tls add key passphrase --- utils/cert.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/cert.go b/utils/cert.go index 099e16f6..d8c0a100 100644 --- a/utils/cert.go +++ b/utils/cert.go @@ -17,6 +17,7 @@ type Certificate struct { Key string `yaml:"key" json:"key"` Cert string `yaml:"cert" json:"cert"` Name string `yaml:"name" json:"name"` + Passphrase string `yaml:"passphrase" json:"passphrase"` InsecureSkipVerify bool `yaml:"insecureSkipVerify" json:"insecureSkipVerify"` // for client, for test purpose tls.ClientAuthType `yaml:"clientAuthType" json:"clientAuthType"` } @@ -29,6 +30,6 @@ func NewTLSConfigServer(c Certificate) (*tls.Config, error) { // NewTLSConfigClient loads tls config for client func NewTLSConfigClient(c Certificate) (*tls.Config, error) { - cfg, err := tlsconfig.Client(tlsconfig.Options{CAFile: c.CA, KeyFile: c.Key, CertFile: c.Cert, InsecureSkipVerify: c.InsecureSkipVerify}) + cfg, err := tlsconfig.Client(tlsconfig.Options{CAFile: c.CA, KeyFile: c.Key, CertFile: c.Cert, InsecureSkipVerify: c.InsecureSkipVerify, Passphrase: c.Passphrase}) return cfg, errors.Trace(err) } From 87d278d43ff5954fb1ac908e8e22ea995d6a5741 Mon Sep 17 00:00:00 2001 From: miaodanyang Date: Wed, 6 Sep 2023 18:44:09 +0800 Subject: [PATCH 2/7] tls add key passphrase --- utils/cert.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/utils/cert.go b/utils/cert.go index 099e16f6..6ff3d74b 100644 --- a/utils/cert.go +++ b/utils/cert.go @@ -17,6 +17,7 @@ type Certificate struct { Key string `yaml:"key" json:"key"` Cert string `yaml:"cert" json:"cert"` Name string `yaml:"name" json:"name"` + Passphrase string `yaml:"passphrase" json:"passphrase"` InsecureSkipVerify bool `yaml:"insecureSkipVerify" json:"insecureSkipVerify"` // for client, for test purpose tls.ClientAuthType `yaml:"clientAuthType" json:"clientAuthType"` } @@ -32,3 +33,9 @@ func NewTLSConfigClient(c Certificate) (*tls.Config, error) { cfg, err := tlsconfig.Client(tlsconfig.Options{CAFile: c.CA, KeyFile: c.Key, CertFile: c.Cert, InsecureSkipVerify: c.InsecureSkipVerify}) return cfg, errors.Trace(err) } + +// NewTLSConfigClientWithPassphrase loads tls config for client with passphrase +func NewTLSConfigClientWithPassphrase(c Certificate) (*tls.Config, error) { + cfg, err := tlsconfig.Client(tlsconfig.Options{CAFile: c.CA, KeyFile: c.Key, CertFile: c.Cert, InsecureSkipVerify: c.InsecureSkipVerify, Passphrase: c.Passphrase}) + return cfg, errors.Trace(err) +} From b97eb2736d3ed041f189b009db2dc87e5706f1d4 Mon Sep 17 00:00:00 2001 From: miaodanyang Date: Thu, 7 Sep 2023 10:15:59 +0800 Subject: [PATCH 3/7] tls add key passphrase --- utils/cert.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/cert.go b/utils/cert.go index 15641a8e..6ff3d74b 100644 --- a/utils/cert.go +++ b/utils/cert.go @@ -30,7 +30,7 @@ func NewTLSConfigServer(c Certificate) (*tls.Config, error) { // NewTLSConfigClient loads tls config for client func NewTLSConfigClient(c Certificate) (*tls.Config, error) { - cfg, err := tlsconfig.Client(tlsconfig.Options{CAFile: c.CA, KeyFile: c.Key, CertFile: c.Cert, InsecureSkipVerify: c.InsecureSkipVerify, Passphrase: c.Passphrase}) + cfg, err := tlsconfig.Client(tlsconfig.Options{CAFile: c.CA, KeyFile: c.Key, CertFile: c.Cert, InsecureSkipVerify: c.InsecureSkipVerify}) return cfg, errors.Trace(err) } From b4b9e4a8cf5dfac1b4a27d575494d98c965ba20f Mon Sep 17 00:00:00 2001 From: miaodanyang Date: Mon, 11 Sep 2023 18:40:03 +0800 Subject: [PATCH 4/7] tls add key passphrase --- utils/cert_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/utils/cert_test.go b/utils/cert_test.go index 206140f7..eba70716 100644 --- a/utils/cert_test.go +++ b/utils/cert_test.go @@ -40,3 +40,21 @@ func TestNewTLSConfigClient(t *testing.T) { assert.NoError(t, err) assert.NotEmpty(t, tl) } + +func TestNewTLSConfigClientWithPassphrase(t *testing.T) { + tl, err := NewTLSConfigClientWithPassphrase(Certificate{Key: "../example/var/lib/baetyl/testcert/client.key"}) + assert.Error(t, err) + + tl, err = NewTLSConfigClientWithPassphrase(Certificate{Cert: "../example/var/lib/baetyl/testcert/client.crt"}) + assert.Error(t, err) + assert.Empty(t, tl) + + c := Certificate{ + Key: "../example/var/lib/baetyl/testcert/client.key", + Cert: "../example/var/lib/baetyl/testcert/client.crt", + Passphrase: "1234", + } + tl, err = NewTLSConfigClientWithPassphrase(c) + assert.NoError(t, err) + assert.NotEmpty(t, tl) +} From b33aae5d476e03f2214ca27cbfd2d17857c13bbf Mon Sep 17 00:00:00 2001 From: miaodanyang Date: Thu, 14 Sep 2023 15:10:22 +0800 Subject: [PATCH 5/7] mqtt websocket http add to option with passphrase --- http/options.go | 21 +++++++++++++++++++++ mqtt/options.go | 27 +++++++++++++++++++++++++++ websocket/options.go | 17 +++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/http/options.go b/http/options.go index 5e3de66f..ec3d2f1f 100644 --- a/http/options.go +++ b/http/options.go @@ -99,3 +99,24 @@ func (cc ClientConfig) ToClientOptions() (*ClientOptions, error) { SyncMaxConcurrency: cc.SyncMaxConcurrency, }, nil } + +// ToClientOptionsWithPassphrase converts client config to client options with passphrase +func (cc ClientConfig) ToClientOptionsWithPassphrase() (*ClientOptions, error) { + tlsConfig, err := utils.NewTLSConfigClientWithPassphrase(cc.Certificate) + if err != nil { + return nil, errors.Trace(err) + } + return &ClientOptions{ + Address: cc.Address, + Timeout: cc.Timeout, + TLSConfig: tlsConfig, + KeepAlive: cc.KeepAlive, + MaxIdleConns: cc.MaxIdleConns, + IdleConnTimeout: cc.IdleConnTimeout, + TLSHandshakeTimeout: cc.TLSHandshakeTimeout, + ExpectContinueTimeout: cc.ExpectContinueTimeout, + SpeedLimit: cc.SpeedLimit, + ByteUnit: cc.ByteUnit, + SyncMaxConcurrency: cc.SyncMaxConcurrency, + }, nil +} diff --git a/mqtt/options.go b/mqtt/options.go index fdecfe8a..532b6020 100644 --- a/mqtt/options.go +++ b/mqtt/options.go @@ -84,3 +84,30 @@ func (cc ClientConfig) ToClientOptions() (*ClientOptions, error) { } return ops, nil } + +// ToClientOptionsWithPassphrase converts client config to client options with passphrase +func (cc ClientConfig) ToClientOptionsWithPassphrase() (*ClientOptions, error) { + ops := &ClientOptions{ + Address: cc.Address, + Username: cc.Username, + Password: cc.Password, + ClientID: cc.ClientID, + CleanSession: cc.CleanSession, + Timeout: cc.Timeout, + KeepAlive: cc.KeepAlive, + MaxReconnectInterval: cc.MaxReconnectInterval, + MaxCacheMessages: cc.MaxCacheMessages, + DisableAutoAck: cc.DisableAutoAck, + } + if cc.Certificate.Key != "" || cc.Certificate.Cert != "" { + tlsconfig, err := utils.NewTLSConfigClientWithPassphrase(cc.Certificate) + if err != nil { + return nil, errors.Trace(err) + } + ops.TLSConfig = tlsconfig + } + for _, topic := range cc.Subscriptions { + ops.Subscriptions = append(ops.Subscriptions, Subscription{Topic: topic.Topic, QOS: QOS(topic.QOS)}) + } + return ops, nil +} diff --git a/websocket/options.go b/websocket/options.go index b4114968..cd41b9a6 100644 --- a/websocket/options.go +++ b/websocket/options.go @@ -44,6 +44,7 @@ type ClientConfig struct { utils.Certificate `yaml:",inline" json:",inline"` } +// ToClientOptions converts client config to client options func (cc ClientConfig) ToClientOptions() (*ClientOptions, error) { tlsConfig, err := utils.NewTLSConfigClient(cc.Certificate) if err != nil { @@ -58,3 +59,19 @@ func (cc ClientConfig) ToClientOptions() (*ClientOptions, error) { SyncMaxConcurrency: cc.SyncMaxConcurrency, }, nil } + +// NewTLSConfigClientWithPassphrase converts client config to client options with passphrase +func (cc ClientConfig) NewTLSConfigClientWithPassphrase() (*ClientOptions, error) { + tlsConfig, err := utils.NewTLSConfigClientWithPassphrase(cc.Certificate) + if err != nil { + return nil, errors.Trace(err) + } + return &ClientOptions{ + Address: cc.Address, + Path: cc.Path, + Schema: cc.Schema, + TLSConfig: tlsConfig, + TLSHandshakeTimeout: cc.TLSHandshakeTimeout, + SyncMaxConcurrency: cc.SyncMaxConcurrency, + }, nil +} From 056691e319cfb687dcc6f4ec761ef3ea15a8341e Mon Sep 17 00:00:00 2001 From: miaodanyang Date: Wed, 27 Sep 2023 10:19:11 +0800 Subject: [PATCH 6/7] websocket retry connect --- mqtt/client.go | 15 +++++ mqtt/client_test.go | 54 ++++++++++++++++++ websocket/client.go | 120 ++++++++++++++++++++++++++++++--------- websocket/client_test.go | 55 ++++++++++-------- 4 files changed, 193 insertions(+), 51 deletions(-) diff --git a/mqtt/client.go b/mqtt/client.go index 8c2f1b37..165ac622 100644 --- a/mqtt/client.go +++ b/mqtt/client.go @@ -64,6 +64,21 @@ func (c *Client) Publish(qos QOS, topic string, payload []byte, pid ID, retain b return c.Send(publish) } +// Publish sends a publish packet out cache size will drop +func (c *Client) PublishWithDrop(qos QOS, topic string, payload []byte, pid ID, retain bool, dup bool) error { + publish := NewPublish() + publish.ID = pid + publish.Dup = dup + publish.Message.QOS = qos + publish.Message.Topic = topic + publish.Message.Payload = payload + publish.Message.Retain = retain + if qos != 0 && pid == 0 { + publish.ID = c.ids.NextID() + } + return c.SendOrDrop(publish) +} + // Send sends a generic packet func (c *Client) Send(pkt Packet) error { select { diff --git a/mqtt/client_test.go b/mqtt/client_test.go index 8b360a0e..5d6ba47a 100644 --- a/mqtt/client_test.go +++ b/mqtt/client_test.go @@ -264,6 +264,60 @@ func TestMqttClientPublishSubscribeQOS0(t *testing.T) { safeReceive(done) } +func TestMqttClientPublishSubscribeQOS0WithDrop(t *testing.T) { + subscribe := NewSubscribe() + subscribe.Subscriptions = []Subscription{{Topic: "test"}} + subscribe.ID = subscribeId + + suback := NewSuback() + suback.ReturnCodes = []QOS{0} + suback.ID = subscribeId + + publish := NewPublish() + publish.Message.Topic = "test" + publish.Message.Payload = []byte("test") + + broker := mock.NewFlow().Debug(). + Receive(connectPacket()). + Send(connackPacket()). + Receive(subscribe). + Send(suback). + Receive(publish). + Send(publish). + Receive(disconnectPacket()). + End() + + done, port := initMockBroker(t, broker) + + ops := newClientOptions(t, port, []Subscription{{Topic: "test"}}) + cli := NewClient(ops) + assert.NotNil(t, cli) + + obs := newMockObserver(t) + err := cli.Start(obs) + assert.NoError(t, err) + + err = cli.PublishWithDrop(publish.Message.QOS, publish.Message.Topic, publish.Message.Payload, publish.ID, publish.Message.Retain, publish.Dup) + assert.NoError(t, err) + obs.assertPkts(publish) + + time.Sleep(time.Second) + + assert.NoError(t, cli.Close()) + safeReceive(done) + + ops = newClientOptions(t, "222", []Subscription{{Topic: "test"}}) + newCLi := NewClient(ops) + for i := 0; i < 30; i++ { + err = newCLi.PublishWithDrop(publish.Message.QOS, publish.Message.Topic, publish.Message.Payload, publish.ID, publish.Message.Retain, publish.Dup) + assert.NoError(t, err) + } + + assert.NoError(t, newCLi.Close()) + safeReceive(done) + +} + func TestMqttClientPublishSubscribeQOS1(t *testing.T) { subscribe := NewSubscribe() subscribe.Subscriptions = []Subscription{{Topic: "test", QOS: 1}} diff --git a/websocket/client.go b/websocket/client.go index cfe9ad3c..a4330efa 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -1,24 +1,35 @@ package websocket import ( - v1 "github.com/baetyl/baetyl-go/v2/spec/v1" "net/http" "net/url" "time" "github.com/baetyl/baetyl-go/v2/errors" "github.com/baetyl/baetyl-go/v2/log" + v1 "github.com/baetyl/baetyl-go/v2/spec/v1" "github.com/gorilla/websocket" "github.com/panjf2000/ants/v2" ) +type WsConnect struct { + wscon *websocket.Conn + readMsgChan chan *v1.Message +} + type Client struct { - conn chan *websocket.Conn - u url.URL - dialer websocket.Dialer - antPool *ants.Pool + pool chan *WsConnect + connNum int + u url.URL + dialer websocket.Dialer + antPool *ants.Pool + ops *ClientOptions + readMsgChan []chan *v1.Message + log *log.Logger } +// NewClient 函数用于创建一个Client对象 +// readMsgChan 为读取信息的通道 需要配置和并发数量一致 func NewClient(ops *ClientOptions, readMsgChan []chan *v1.Message) (*Client, error) { u := url.URL{Scheme: ops.Schema, Host: ops.Address, Path: ops.Path} dialer := websocket.Dialer{ @@ -28,46 +39,99 @@ func NewClient(ops *ClientOptions, readMsgChan []chan *v1.Message) (*Client, err TLSClientConfig: ops.TLSConfig, HandshakeTimeout: ops.TLSHandshakeTimeout, } - p, _ := ants.NewPool(1) - if ops.SyncMaxConcurrency != 0 { - p, _ = ants.NewPool(ops.SyncMaxConcurrency) + // 最少为1条 + if ops.SyncMaxConcurrency <= 0 { + ops.SyncMaxConcurrency = 1 + } + + p, err := ants.NewPool(ops.SyncMaxConcurrency) + if err != nil { + return nil, err } + if readMsgChan != nil && cap(readMsgChan) < ops.SyncMaxConcurrency { return nil, errors.New("read msg cap must > SyncMaxConcurrency") } - connect := make(chan *websocket.Conn, ops.SyncMaxConcurrency) + connect := make(chan *WsConnect, ops.SyncMaxConcurrency) + client := &Client{ + pool: connect, + connNum: 0, + u: u, + dialer: dialer, + antPool: p, + ops: ops, + readMsgChan: readMsgChan, + log: log.L().With(log.Any("link", "websocket link")), + } + go client.initLink() + return client, nil +} +func (c *Client) initLink() { // 根据设置创建连接池 - for i := 0; i < ops.SyncMaxConcurrency; i++ { - con, _, err := dialer.Dial(u.String(), nil) - if err != nil { - return nil, err + for i := 0; i < c.ops.SyncMaxConcurrency; i++ { + var connectReadMsgChan chan *v1.Message = nil + if c.readMsgChan != nil { + connectReadMsgChan = c.readMsgChan[i] } - // 每个链接创建一个协程readMsg - if readMsgChan != nil { - go ReadConMsg(con, readMsgChan[i]) + ws, err := c.Connect(connectReadMsgChan) + if err != nil { + c.log.Error("link websocket error", log.Any("err", err)) } - connect <- con + // 为了保证连接池数量 失败wscon 以nil方式放入连接池 每次发送的时候重新连接 + c.pool <- ws } +} - return &Client{ - conn: connect, - u: u, - dialer: dialer, - antPool: p, - }, nil +func (c *Client) Connect(readMsgChan chan *v1.Message) (*WsConnect, error) { + con, _, err := c.dialer.Dial(c.u.String(), nil) + ws := &WsConnect{ + readMsgChan: readMsgChan, + } + if err != nil { + ws.wscon = nil + c.log.Error("websocket link error", log.Any("err", err)) + return ws, err + } else { + ws.wscon = con + if c.readMsgChan != nil { + go ws.ReadConMsg(readMsgChan) + } + } + return ws, nil } func (c *Client) SendMsg(msg []byte) error { - con := <-c.conn - err := con.WriteMessage(websocket.TextMessage, msg) - c.conn <- con + con := <-c.pool + var err error + if con.wscon == nil { + con, err = c.Connect(con.readMsgChan) + if err != nil { + c.pool <- con + c.log.Error("retry link websocket error", log.Any("err", err)) + return err + } + } + err = con.wscon.WriteMessage(websocket.TextMessage, msg) + if err != nil { + c.log.Error("websocket write msg error", log.Any("err", err)) + con, err = c.Connect(con.readMsgChan) + if err != nil { + c.pool <- con + c.log.Error("retry link websocket error", log.Any("err", err)) + return err + } + } + c.pool <- con return err } -func ReadConMsg(con *websocket.Conn, readMsg chan *v1.Message) { +func (w *WsConnect) ReadConMsg(readMsg chan *v1.Message) { for { - msgType, data, err := con.ReadMessage() + if w.wscon == nil { + return + } + msgType, data, err := w.wscon.ReadMessage() msg := &v1.Message{} if err != nil { msg = &v1.Message{ diff --git a/websocket/client_test.go b/websocket/client_test.go index b8987618..b34c8498 100644 --- a/websocket/client_test.go +++ b/websocket/client_test.go @@ -1,40 +1,49 @@ package websocket import ( - v1 "github.com/baetyl/baetyl-go/v2/spec/v1" + "log" "net/http" - "net/http/httptest" "testing" "time" + v1 "github.com/baetyl/baetyl-go/v2/spec/v1" "github.com/baetyl/baetyl-go/v2/utils" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) -func Test_client(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 升级连接为websocket协议 - upgrader := websocket.Upgrader{} - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - t.Fatalf("Failed to upgrade HTTP connection to WebSocket: %v", err) - } - _, msg, err := conn.ReadMessage() +func echo(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + defer c.Close() + for { + _, msg, err := c.ReadMessage() if err != nil { - t.Fatalf("Failed to read message from WebSocket: %v", err) + log.Println("read:", err) + break } - err = conn.WriteMessage(websocket.TextMessage, msg) + err = c.WriteMessage(websocket.TextMessage, msg) if err != nil { - t.Fatalf("Failed to write message to WebSocket: %v", err) + log.Println("read:", err) } - })) - defer server.Close() + } +} +func WsServer() { + http.HandleFunc("/echo", echo) + log.Fatal(http.ListenAndServe("127.0.0.1:9341", nil)) +} +func Test_client(t *testing.T) { + go WsServer() cfg := ClientConfig{ - Address: server.URL[len("http://"):], - Path: "", + Address: "127.0.0.1:9341", + Path: "echo", Schema: "ws", IdleConnTimeout: 0, TLSHandshakeTimeout: 0, @@ -51,18 +60,18 @@ func Test_client(t *testing.T) { client, err := NewClient(options, msg) assert.NoError(t, err) - result := make(chan *SyncResults, 100) + result := make(chan *SyncResults, 1000) extra := map[string]interface{}{"a": 1} - for i := 0; i < 100; i++ { + time.Sleep(time.Second * 2) + for i := 0; i < 20; i++ { client.SyncSendMsg([]byte("hello"), result, extra) } - time.Sleep(time.Second) - + time.Sleep(time.Second * 2) re := <-result assert.NoError(t, re.Err) assert.Equal(t, re.Extra["a"], 1) - assert.Equal(t, 99, len(result)) + assert.Equal(t, 19, len(result)) for _, m := range msg { r := <-m From 75a836c170bdfd4d168079e03d12b16367789a69 Mon Sep 17 00:00:00 2001 From: miaodanyang Date: Wed, 27 Sep 2023 10:19:11 +0800 Subject: [PATCH 7/7] websocket retry connect --- mqtt/client.go | 15 +++++ mqtt/client_test.go | 54 +++++++++++++++++ websocket/client.go | 121 ++++++++++++++++++++++++++++++--------- websocket/client_test.go | 55 ++++++++++-------- 4 files changed, 194 insertions(+), 51 deletions(-) diff --git a/mqtt/client.go b/mqtt/client.go index 8c2f1b37..165ac622 100644 --- a/mqtt/client.go +++ b/mqtt/client.go @@ -64,6 +64,21 @@ func (c *Client) Publish(qos QOS, topic string, payload []byte, pid ID, retain b return c.Send(publish) } +// Publish sends a publish packet out cache size will drop +func (c *Client) PublishWithDrop(qos QOS, topic string, payload []byte, pid ID, retain bool, dup bool) error { + publish := NewPublish() + publish.ID = pid + publish.Dup = dup + publish.Message.QOS = qos + publish.Message.Topic = topic + publish.Message.Payload = payload + publish.Message.Retain = retain + if qos != 0 && pid == 0 { + publish.ID = c.ids.NextID() + } + return c.SendOrDrop(publish) +} + // Send sends a generic packet func (c *Client) Send(pkt Packet) error { select { diff --git a/mqtt/client_test.go b/mqtt/client_test.go index 8b360a0e..5d6ba47a 100644 --- a/mqtt/client_test.go +++ b/mqtt/client_test.go @@ -264,6 +264,60 @@ func TestMqttClientPublishSubscribeQOS0(t *testing.T) { safeReceive(done) } +func TestMqttClientPublishSubscribeQOS0WithDrop(t *testing.T) { + subscribe := NewSubscribe() + subscribe.Subscriptions = []Subscription{{Topic: "test"}} + subscribe.ID = subscribeId + + suback := NewSuback() + suback.ReturnCodes = []QOS{0} + suback.ID = subscribeId + + publish := NewPublish() + publish.Message.Topic = "test" + publish.Message.Payload = []byte("test") + + broker := mock.NewFlow().Debug(). + Receive(connectPacket()). + Send(connackPacket()). + Receive(subscribe). + Send(suback). + Receive(publish). + Send(publish). + Receive(disconnectPacket()). + End() + + done, port := initMockBroker(t, broker) + + ops := newClientOptions(t, port, []Subscription{{Topic: "test"}}) + cli := NewClient(ops) + assert.NotNil(t, cli) + + obs := newMockObserver(t) + err := cli.Start(obs) + assert.NoError(t, err) + + err = cli.PublishWithDrop(publish.Message.QOS, publish.Message.Topic, publish.Message.Payload, publish.ID, publish.Message.Retain, publish.Dup) + assert.NoError(t, err) + obs.assertPkts(publish) + + time.Sleep(time.Second) + + assert.NoError(t, cli.Close()) + safeReceive(done) + + ops = newClientOptions(t, "222", []Subscription{{Topic: "test"}}) + newCLi := NewClient(ops) + for i := 0; i < 30; i++ { + err = newCLi.PublishWithDrop(publish.Message.QOS, publish.Message.Topic, publish.Message.Payload, publish.ID, publish.Message.Retain, publish.Dup) + assert.NoError(t, err) + } + + assert.NoError(t, newCLi.Close()) + safeReceive(done) + +} + func TestMqttClientPublishSubscribeQOS1(t *testing.T) { subscribe := NewSubscribe() subscribe.Subscriptions = []Subscription{{Topic: "test", QOS: 1}} diff --git a/websocket/client.go b/websocket/client.go index cfe9ad3c..6438af13 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -1,24 +1,36 @@ package websocket import ( - v1 "github.com/baetyl/baetyl-go/v2/spec/v1" "net/http" "net/url" "time" "github.com/baetyl/baetyl-go/v2/errors" "github.com/baetyl/baetyl-go/v2/log" + v1 "github.com/baetyl/baetyl-go/v2/spec/v1" + "github.com/gorilla/websocket" "github.com/panjf2000/ants/v2" ) +type WsConnect struct { + wscon *websocket.Conn + readMsgChan chan *v1.Message +} + type Client struct { - conn chan *websocket.Conn - u url.URL - dialer websocket.Dialer - antPool *ants.Pool + pool chan *WsConnect + connNum int + u url.URL + dialer websocket.Dialer + antPool *ants.Pool + ops *ClientOptions + readMsgChan []chan *v1.Message + log *log.Logger } +// NewClient 函数用于创建一个Client对象 +// readMsgChan 为读取信息的通道 需要配置和并发数量一致 func NewClient(ops *ClientOptions, readMsgChan []chan *v1.Message) (*Client, error) { u := url.URL{Scheme: ops.Schema, Host: ops.Address, Path: ops.Path} dialer := websocket.Dialer{ @@ -28,46 +40,99 @@ func NewClient(ops *ClientOptions, readMsgChan []chan *v1.Message) (*Client, err TLSClientConfig: ops.TLSConfig, HandshakeTimeout: ops.TLSHandshakeTimeout, } - p, _ := ants.NewPool(1) - if ops.SyncMaxConcurrency != 0 { - p, _ = ants.NewPool(ops.SyncMaxConcurrency) + // 最少为1条 + if ops.SyncMaxConcurrency <= 0 { + ops.SyncMaxConcurrency = 1 } + + p, err := ants.NewPool(ops.SyncMaxConcurrency) + if err != nil { + return nil, err + } + if readMsgChan != nil && cap(readMsgChan) < ops.SyncMaxConcurrency { return nil, errors.New("read msg cap must > SyncMaxConcurrency") } - connect := make(chan *websocket.Conn, ops.SyncMaxConcurrency) + connect := make(chan *WsConnect, ops.SyncMaxConcurrency) + client := &Client{ + pool: connect, + connNum: 0, + u: u, + dialer: dialer, + antPool: p, + ops: ops, + readMsgChan: readMsgChan, + log: log.L().With(log.Any("link", "websocket link")), + } + go client.initLink() + return client, nil +} +func (c *Client) initLink() { // 根据设置创建连接池 - for i := 0; i < ops.SyncMaxConcurrency; i++ { - con, _, err := dialer.Dial(u.String(), nil) - if err != nil { - return nil, err + for i := 0; i < c.ops.SyncMaxConcurrency; i++ { + var connectReadMsgChan chan *v1.Message = nil + if c.readMsgChan != nil { + connectReadMsgChan = c.readMsgChan[i] } - // 每个链接创建一个协程readMsg - if readMsgChan != nil { - go ReadConMsg(con, readMsgChan[i]) + ws, err := c.Connect(connectReadMsgChan) + if err != nil { + c.log.Error("link websocket error", log.Any("err", err)) } - connect <- con + // 为了保证连接池数量 失败wscon 以nil方式放入连接池 每次发送的时候重新连接 + c.pool <- ws } +} - return &Client{ - conn: connect, - u: u, - dialer: dialer, - antPool: p, - }, nil +func (c *Client) Connect(readMsgChan chan *v1.Message) (*WsConnect, error) { + con, _, err := c.dialer.Dial(c.u.String(), nil) + ws := &WsConnect{ + readMsgChan: readMsgChan, + } + if err != nil { + ws.wscon = nil + c.log.Error("websocket link error", log.Any("err", err)) + return ws, err + } else { + ws.wscon = con + if c.readMsgChan != nil { + go ws.ReadConMsg(readMsgChan) + } + } + return ws, nil } func (c *Client) SendMsg(msg []byte) error { - con := <-c.conn - err := con.WriteMessage(websocket.TextMessage, msg) - c.conn <- con + con := <-c.pool + var err error + if con.wscon == nil { + con, err = c.Connect(con.readMsgChan) + if err != nil { + c.pool <- con + c.log.Error("retry link websocket error", log.Any("err", err)) + return err + } + } + err = con.wscon.WriteMessage(websocket.TextMessage, msg) + if err != nil { + c.log.Error("websocket write msg error", log.Any("err", err)) + con, err = c.Connect(con.readMsgChan) + if err != nil { + c.pool <- con + c.log.Error("retry link websocket error", log.Any("err", err)) + return err + } + } + c.pool <- con return err } -func ReadConMsg(con *websocket.Conn, readMsg chan *v1.Message) { +func (w *WsConnect) ReadConMsg(readMsg chan *v1.Message) { for { - msgType, data, err := con.ReadMessage() + if w.wscon == nil { + return + } + msgType, data, err := w.wscon.ReadMessage() msg := &v1.Message{} if err != nil { msg = &v1.Message{ diff --git a/websocket/client_test.go b/websocket/client_test.go index b8987618..b34c8498 100644 --- a/websocket/client_test.go +++ b/websocket/client_test.go @@ -1,40 +1,49 @@ package websocket import ( - v1 "github.com/baetyl/baetyl-go/v2/spec/v1" + "log" "net/http" - "net/http/httptest" "testing" "time" + v1 "github.com/baetyl/baetyl-go/v2/spec/v1" "github.com/baetyl/baetyl-go/v2/utils" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) -func Test_client(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // 升级连接为websocket协议 - upgrader := websocket.Upgrader{} - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - t.Fatalf("Failed to upgrade HTTP connection to WebSocket: %v", err) - } - _, msg, err := conn.ReadMessage() +func echo(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + defer c.Close() + for { + _, msg, err := c.ReadMessage() if err != nil { - t.Fatalf("Failed to read message from WebSocket: %v", err) + log.Println("read:", err) + break } - err = conn.WriteMessage(websocket.TextMessage, msg) + err = c.WriteMessage(websocket.TextMessage, msg) if err != nil { - t.Fatalf("Failed to write message to WebSocket: %v", err) + log.Println("read:", err) } - })) - defer server.Close() + } +} +func WsServer() { + http.HandleFunc("/echo", echo) + log.Fatal(http.ListenAndServe("127.0.0.1:9341", nil)) +} +func Test_client(t *testing.T) { + go WsServer() cfg := ClientConfig{ - Address: server.URL[len("http://"):], - Path: "", + Address: "127.0.0.1:9341", + Path: "echo", Schema: "ws", IdleConnTimeout: 0, TLSHandshakeTimeout: 0, @@ -51,18 +60,18 @@ func Test_client(t *testing.T) { client, err := NewClient(options, msg) assert.NoError(t, err) - result := make(chan *SyncResults, 100) + result := make(chan *SyncResults, 1000) extra := map[string]interface{}{"a": 1} - for i := 0; i < 100; i++ { + time.Sleep(time.Second * 2) + for i := 0; i < 20; i++ { client.SyncSendMsg([]byte("hello"), result, extra) } - time.Sleep(time.Second) - + time.Sleep(time.Second * 2) re := <-result assert.NoError(t, re.Err) assert.Equal(t, re.Extra["a"], 1) - assert.Equal(t, 99, len(result)) + assert.Equal(t, 19, len(result)) for _, m := range msg { r := <-m