diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..64ae7f3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work +go.work.sum + +# env file +.env \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 0000000..a55e7a1 --- /dev/null +++ b/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..ab63989 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/realtime-pubsub-client.iml b/.idea/realtime-pubsub-client.iml new file mode 100644 index 0000000..5e764c4 --- /dev/null +++ b/.idea/realtime-pubsub-client.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0bc8101 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 21no.de + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..45ef99e --- /dev/null +++ b/README.md @@ -0,0 +1,548 @@ +# Realtime Pub/Sub Client for Go + +The `realtime-pubsub-client-go` is a Go client library for interacting with [Realtime Pub/Sub](https://realtime.21no.de) +applications. It enables developers to manage real-time WebSocket connections, handle subscriptions, and process +messages efficiently. The library provides a simple and flexible API to interact with realtime applications, supporting +features like publishing/sending messages, subscribing to topics, handling acknowledgements, and waiting for replies +with timeout support. + +## Features + +- **WebSocket Connection Management**: Seamlessly connect and disconnect from the Realtime Pub/Sub service with + automatic reconnection support. +- **Topic Subscription**: Subscribe and unsubscribe to topics for receiving messages. +- **Topic Publishing**: Publish messages to specific topics with optional message types and compression. +- **Message Sending**: Send messages to backend applications with optional message types and compression. +- **Event Handling**: Handle incoming messages with custom event listeners. +- **Acknowledgements and Replies**: Wait for gateway acknowledgements or replies to messages with timeout support. +- **Error Handling**: Robust error handling and logging capabilities. +- **Concurrency Support**: Safe for use in concurrent environments with thread-safe operations. + +## Installation + +Install the `realtime-pubsub-client-go` library using `go get`: + +```bash +go get github.com/backendstack21/realtime-pubsub-client-go +``` + +Or, if you're using Go modules, simply import the package in your project, and Go will handle the installation. + +## Getting Started + +This guide will help you set up and use the `realtime-pubsub-client-go` library in your Go project. + +### Prerequisites + +- Go 1.16 or later +- A Realtime Pub/Sub account and an application set up at [Realtime Pub/Sub](https://realtime.21no.de) +- Access token for authentication with sufficient permissions, see: [Subscribers](https://realtime.21no.de/documentation/#subscribers) + and [Publishers](https://realtime.21no.de/documentation/#publishers) + +### Working Examples + +The following are working examples demonstrating how to use the `realtime-pubsub-client-go` library: + +* [basic.go](main/basic.go) +* [rpc/client](main/rpc/client.go) +* [rpc/server](main/rpc/server.go) + +### Subscribing to Incoming Messages + +You can handle messages for specific topics and message types by registering event listeners. +In this example we handle `text-message` messages under the `chat` topic: + +```go +client.On("chat.text-message", func (args ...interface{}) { + // Extract message and reply function + message, _ := args[0].(realtime_pubsub.IncomingMessage) + replyFunc, _ := args[1].(realtime_pubsub.ReplyFunc) + + log.Printf("Received message: %v", message) + + // Reply to the message + if err := replyFunc("Got it!", "ok"); err != nil { + log.Printf("Failed to reply to message: %v", err) + } +}) + +``` + +#### Wildcard Subscriptions + +Wildcard subscriptions are also supported, allowing you to handle multiple message types under a topic: + +```go +client.On("chat.*", func (args ...interface{}) { + // Extract message and reply function + message, _ := args[0].(realtime_pubsub.IncomingMessage) + replyFunc, _ := args[1].(realtime_pubsub.ReplyFunc) + + //... +}) +``` + +### Publishing Messages + +Publish messages to a topic using the `Publish` method: + +```go +topic:= "chat" +payload := "Hello, World!" +messageType := "text-message" + +waitFor, err := client.Publish(topic, payload, + realtime_pubsub.WithPublishMessageType(messageType), + realtime_pubsub.WithPublishCompress(true), +) +if err != nil { + log.Fatalf("Failed to publish message: %v", err) +} + +// Wait for acknowledgment +if _, err := waitFor.WaitForAck(100 * time.Millisecond); err != nil { + log.Printf("Failed to receive acknowledgment: %v", err) +} + +// Wait for a reply +reply, err := waitFor.WaitForReply(500 * time.Millisecond) +if err != nil { + log.Printf("Failed to receive reply: %v", err) +} else { + log.Printf("Received reply: %v", reply["data"]) +} +``` + +### Sending Messages + +Send messages to the server using the `Send` method: + +```go +waitFor, err := client.Send(map[string]interface{}{ + "action": "create", + "data": map[string]interface{}{ + "name": "John Doe", + }, +}, realtime_pubsub.WithSendMessageType("create-user")) +if err != nil { + log.Fatalf("Failed to send message: %v", err) +} + +// Wait for acknowledgment +if _, err := waitFor.WaitForAck(100 * time.Millisecond); err != nil { + log.Printf("Failed to receive acknowledgment: %v", err) +} + +// Wait for a reply +reply, err := waitFor.WaitForReply(500 * time.Millisecond) +if err != nil { + log.Printf("Failed to receive reply: %v", err) +} else { + log.Printf("Received reply: %v", reply.Data()) +} +``` + +### Handling Replies + +Set up event listeners to handle incoming replies, use the ReplyFunc param to send replies to messages: + +```go +client.On("chat.text-message", func (args ...interface{}) { + // Extract message and reply function + message, _ := args[0].(realtime_pubsub.IncomingMessage) + replyFunc, _ := args[1].(realtime_pubsub.ReplyFunc) + + log.Printf("Received message: %v", message) + + // Reply to the message + replyPayload := "Got it!" + replyStatus := "ok" + if err := replyFunc(replyPayload, replyStatus); err != nil { + log.Printf("Failed to reply to message: %v", err) + } +}) +``` + +### Waiting for Acknowledgements and Replies + +- **Wait for Acknowledgement**: Use `WaitForAck` to wait for a gateway acknowledgement after publishing or sending a + message. + + ```go + if _, err := waitFor.WaitForAck(500 * time.Millisecond); err != nil { + log.Printf("Failed to receive acknowledgment: %v", err) + } + ``` + +- **Wait for Reply**: Use `WaitForReply` to wait for a reply to your message with a specified timeout. + + ```go + reply, err := waitFor.WaitForReply(500 * time.Millisecond) + if err != nil { + log.Printf("Failed to receive reply: %v", err) + } else { + log.Printf("Received reply: %v", reply.Data()) + } + ``` + +### Error Handling + +Handle errors and disconnections gracefully: + +```go +client.On("error", func (args ...interface{}) { + log.Printf("Received error: %v", args) +}) + +client.On("close", func (args ...interface{}) { + log.Printf("Connection closed: %v", args) +}) +``` + +## API Reference + +### RealtimeClient + +#### Constructor + +```go +func NewClient(config Config) *Client +``` + +Creates a new `Client` instance. + +- **config**: Configuration options for the client. + +#### Methods + +- **Connect()** + + ```go + func (c *Client) Connect() + ``` + + Establishes a connection to the WebSocket server. + +- **Disconnect() error** + + ```go + func (c *Client) Disconnect() error + ``` + + Terminates the WebSocket connection gracefully. + +- **On(event string, listener ListenerFunc)** + + ```go + func (c *Client) On(event string, listener ListenerFunc) + ``` + + Registers a listener for a specific event. Supports wildcard patterns. + +- **Off(event string, id int)** + + ```go + func (c *Client) Off(event string, id int) + ``` + + Removes a listener for a specific event using the listener ID. + +- **Once(event string, listener ListenerFunc) int** + + ```go + func (c *Client) Once(event string, listener ListenerFunc) int + ``` + + Registers a listener that will be called at most once for the specified event. + +- **SubscribeRemoteTopic(topic string) error** + + ```go + func (c *Client) SubscribeRemoteTopic(topic string) error + ``` + + Subscribes the client to a remote topic to receive messages. + +- **UnsubscribeRemoteTopic(topic string) error** + + ```go + func (c *Client) UnsubscribeRemoteTopic(topic string) error + ``` + + Unsubscribes the client from a remote topic. + +- **Publish(topic string, payload interface{}, opts ...PublishOption) (\*WaitFor, error)** + + ```go + func (c *Client) Publish(topic string, payload interface{}, opts ...PublishOption) (*WaitFor, error) + ``` + + Publishes a message to a specified topic with optional configurations. + +- **Send(payload interface{}, opts ...SendOption) (\*WaitFor, error)** + + ```go + func (c *Client) Send(payload interface{}, opts ...SendOption) (*WaitFor, error) + ``` + + Sends a message to the server with optional configurations. + +- **WaitFor(event string, timeout time.Duration) (interface{}, error)** + + ```go + func (c *Client) WaitFor(event string, timeout time.Duration) (interface{}, error) + ``` + + Waits for a specific event to occur within the given timeout. + +#### Events + +- **'session.started'** + + Emitted when the session starts. + + ```go + client.On("session.started", func(args ...interface{}) { + info := args[0].(realtime_pubsub.ConnectionInfo) + //... + }) + ``` + +- **'error'** + + Emitted on WebSocket errors. + + ```go + client.On("error", func(args ...interface{}) { + // Handle error + }) + ``` + +- **'close'** + + Emitted when the WebSocket connection closes. + + ```go + client.On("close", func(args ...interface{}) { + // Handle close event + }) + ``` + +- **Custom Events** + + Handle custom events based on topic and message type. + + ```go + client.On("topic1.action1", func(args ...interface{}) { + // Handle specific message + }) + ``` + + **Wildcard Subscriptions** + + ```go + client.On("topic1.*", func(args ...interface{}) { + // Handle any message under topic1 + }) + ``` + +## Type Definitions + +### IncomingMessage + +Represents a message received from the server. + +```go +type IncomingMessage map[string]interface{} +``` + +#### Methods + +- **Topic() string** + + Extracts the "topic" from the message. + + ```go + func (m IncomingMessage) Topic() string + ``` + +- **Data() interface{}** + + Extracts the "data" from the message. + + ```go + func (m IncomingMessage) Data() interface{} + ``` + +- **MessageType() string** + + Extracts the "messageType" from the message. + + ```go + func (m IncomingMessage) MessageType() string + ``` + +- **Compression() bool** + + Extracts the "compression" flag from the message. + + ```go + func (m IncomingMessage) Compression() bool + ``` + +- **DataAsMap() map[string]interface{}** + + Extracts the "data" as a map from the message. + + ```go + func (m IncomingMessage) DataAsMap() map[string]interface{} + ``` + +- **DataAsPresenceMessage() PresenceMessage** + + Extracts the "data" as a PresenceMessage from the message. + + ```go + func (m IncomingMessage) DataAsPresenceMessage() PresenceMessage + ``` + +### ResponseMessage + +Represents a message sent by a client in response to an incoming message. + +```go +type ResponseMessage map[string]interface{} +``` + +#### Methods + +- **Id() string** + + Extracts the "id" from the response. + + ```go + func (m ResponseMessage) Id() string + ``` + +- **Data() interface{}** + + Extracts the "data" from the response. + + ```go + func (m ResponseMessage) Data() interface{} + ``` + +- **Status() string** + + Extracts the "status" from the response. + + ```go + func (m ResponseMessage) Status() string + ``` + +- **DataAsMap() map[string]interface{}** + + Extracts the "data" as a map from the response. + + ```go + func (m ResponseMessage) DataAsMap() map[string]interface{} + ``` + +### ReplyFunc + +Defines the signature for reply functions used in event listeners. + +```go +type ReplyFunc func (data interface{}, status string, opts ...ReplyOption) error +``` + +## Configuration Options + +### Config + +Configuration options for initializing the client. + +```go +type Config struct { + Logger *logrus.Logger + WebSocketOptions WebSocketOptions +} +``` + +- **Logger**: Optional. Pass a custom logger implementing the `logrus.Logger` interface for logging purposes. If not + provided, the client uses the standard logger. +- **WebSocketOptions**: Configuration options for the WebSocket connection. + +### WebSocketOptions + +Options for configuring the WebSocket connection. + +```go +type WebSocketOptions struct { + URLProvider func () (string, error) +} +``` + +- **URLProvider**: A function that returns the WebSocket URL for connecting to the Realtime Pub/Sub server. It can + include dynamic parameters like access tokens. + +## API Components + +### WaitFor + +Represents a mechanism to wait for acknowledgements or replies to messages. + +```go +type WaitFor struct { + // Internal fields +} +``` + +#### Methods + +- **WaitForAck(timeout time.Duration) (interface{}, error)** + + Waits for an acknowledgement of the published or sent message within the specified timeout. + + ```go + func (w *WaitFor) WaitForAck(timeout time.Duration) (interface{}, error) + ``` + +- **WaitForReply(timeout time.Duration) (map[string]interface{}, error)** + + Waits for a reply to the published or sent message within the specified timeout. + + ```go + func (w *WaitFor) WaitForReply(timeout time.Duration) (map[string]interface{}, error) + ``` + +## License + +This library is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. + +--- + +For more detailed examples and advanced configurations, please refer to +the [documentation](https://realtime.21no.de/docs). + +## Notes + +- **Environment Setup**: Ensure that you have an account and an app set up + with [Realtime Pub/Sub](https://realtime.21no.de). +- **Authentication**: Customize the `urlProvider` function to retrieve the access token for connecting to your realtime + application. +- **Logging**: Use the `Logger` option to integrate with your application's logging system. +- **Error Handling**: Handle errors and disconnections gracefully to improve the robustness of your application. +- **Timeouts**: Handle timeouts when waiting for replies to avoid hanging operations. +- **Concurrency**: The client is designed to be safe for use in concurrent environments, with thread-safe operations on + its internal data structures. However, keep in mind that listeners are executed sequentially by default within the + event loop. If your application requires concurrent handling of events, you can spawn goroutines within your listeners + or implement custom thread pools for parallel processing. Additionally, if your listeners modify shared state, ensure + that you manage synchronization (e.g., using mutexes or channels) to avoid race conditions or data corruption. +- **Wildcard Subscriptions**: Utilize wildcard subscriptions (`*` and `**`) to handle multiple message types under a + single topic efficiently. + +--- + +Feel free to contribute to this project by submitting issues or pull requests +on [GitHub](https://github.com/BackendStack21/realtime-pubsub-client-go). diff --git a/client.go b/client.go new file mode 100644 index 0000000..e1cd4ee --- /dev/null +++ b/client.go @@ -0,0 +1,661 @@ +package realtime_pubsub + +import ( + "context" + "crypto/rand" + "encoding/json" + "fmt" + "github.com/sirupsen/logrus" + "net" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + writeWait = 10 * time.Second // Time allowed to write a message to the connection + pongWait = 15 * time.Second // Time allowed to read the next Pong message from the server + pingPeriod = (pongWait * 9) / 10 // Send pings to the server with this period + maxMessageSize = 6 * 1024 // Maximum message size allowed from the server +) + +// Client encapsulates WebSocket connection, subscription, and message handling. +type Client struct { + eventEmitter *EventEmitter + ws *websocket.Conn + config Config + logger *logrus.Logger + subscribedTopics map[string]struct{} + isConnecting bool + mu sync.Mutex + writeChan chan []byte + closeChan chan struct{} + closeOnce *sync.Once + ctx context.Context + cancel context.CancelFunc + done chan struct{} +} + +// NewClient initializes a new Client instance. +func NewClient(config Config) *Client { + logger := config.Logger + if logger == nil { + logger = logrus.New() + logger.SetFormatter(&logrus.JSONFormatter{}) + logger.SetOutput(os.Stdout) + logger.SetLevel(logrus.DebugLevel) + } + + ctx, cancel := context.WithCancel(context.Background()) + + client := &Client{ + eventEmitter: NewEventEmitter(), + config: config, + logger: logger, + subscribedTopics: make(map[string]struct{}), + writeChan: make(chan []byte, 256), + closeChan: make(chan struct{}), + closeOnce: new(sync.Once), + ctx: ctx, + cancel: cancel, + done: make(chan struct{}), + } + + // Register event listeners + client.eventEmitter.On("priv/acks.ack", client.onAck) + client.eventEmitter.On("*.response", client.onResponse) + client.eventEmitter.On("main.welcome", client.onWelcome) + return client +} + +// closeWebSocket safely closes the WebSocket connection once. +func (c *Client) closeWebSocket() { + c.mu.Lock() + if c.ws != nil { + _ = c.ws.Close() + c.ws = nil + } + c.mu.Unlock() +} + +// Connect establishes a connection to the WebSocket server. +func (c *Client) Connect() { + c.mu.Lock() + c.isConnecting = false + c.mu.Unlock() + go c.connectLoop() +} + +func (c *Client) Disconnect() error { + c.closeOnce.Do(func() { + c.logger.Infof("Disconnecting client") + close(c.done) // Signal all goroutines to exit + close(c.writeChan) + c.closeWebSocket() + }) + return nil +} + +// Publish publishes a message to a specified topic. +func (c *Client) Publish(topic string, payload interface{}, opts ...PublishOption) (*WaitFor, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.ws == nil { + c.logger.Errorf("Attempted to publish without an active WebSocket connection.") + return nil, fmt.Errorf("WebSocket connection is not established") + } + + if opts == nil { + opts = make([]PublishOption, 0) + } + + // Set default options + options := PublishOptions{ + ID: getRandomID(rand.Reader), + MessageType: "broadcast", + Compress: false, + } + + // Apply provided options + for _, opt := range opts { + opt(&options) + } + + message := map[string]interface{}{ + "type": "publish", + "data": map[string]interface{}{ + "topic": topic, + "messageType": options.MessageType, + "compress": options.Compress, + "payload": payload, + "id": options.ID, + }, + } + + messageBytes, err := json.Marshal(message) + if err != nil { + c.logger.Errorf("Failed to marshal message: %v", err) + return nil, err + } + + c.logger.Debugf("Publishing message to topic %s: %v", topic, payload) + err = c.sendMessage(messageBytes) + if err != nil { + c.logger.Errorf("Failed to send message: %v", err) + return nil, err + } + + return &WaitFor{ + client: c, + id: options.ID, + }, nil +} + +// Send sends a message directly to the server. +func (c *Client) Send(payload interface{}, opts ...SendOption) (*WaitFor, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.ws == nil { + c.logger.Errorf("Attempted to send without an active WebSocket connection.") + return nil, fmt.Errorf("WebSocket connection is not established") + } + + // Set default options + options := SendOptions{ + ID: getRandomID(rand.Reader), + MessageType: "", // Can be empty + Compress: false, // Default compress is false + } + + // Apply provided options + for _, opt := range opts { + opt(&options) + } + + data := map[string]interface{}{ + "messageType": options.MessageType, + "compress": options.Compress, + "payload": payload, + "id": options.ID, + } + + // Remove messageType from data if it's empty + if options.MessageType == "" { + delete(data, "messageType") + } + + message := map[string]interface{}{ + "type": "message", + "data": data, + } + + messageBytes, err := json.Marshal(message) + if err != nil { + c.logger.Errorf("Failed to marshal message: %v", err) + return nil, err + } + + c.logger.Debugf("Outgoing message: %v", string(messageBytes)) + err = c.sendMessage(messageBytes) + if err != nil { + c.logger.Errorf("Failed to send message: %v", err) + return nil, err + } + + return &WaitFor{ + client: c, + id: options.ID, + }, nil +} + +// SubscribeRemoteTopic subscribes to a remote topic to receive messages. +func (c *Client) SubscribeRemoteTopic(topic string) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.ws == nil { + c.logger.Errorf("Attempted to subscribe to %s without an active WebSocket connection.", topic) + return fmt.Errorf("WebSocket connection is not established") + } + // Check if already subscribed + if _, exists := c.subscribedTopics[topic]; exists { + c.logger.Warnf("Already subscribed to topic: %s", topic) + return nil + } + + c.subscribedTopics[topic] = struct{}{} + + message := map[string]interface{}{ + "type": "subscribe", + "data": map[string]interface{}{ + "topic": topic, + }, + } + messageBytes, err := json.Marshal(message) + if err != nil { + c.logger.Errorf("Failed to marshal subscribe message: %v", err) + return err + } + + c.logger.Infof("Subscribing to topic: %s", topic) + err = c.sendMessage(messageBytes) + if err != nil { + c.logger.Errorf("Failed to send subscribe message: %v", err) + return err + } + + return nil +} + +// UnsubscribeRemoteTopic unsubscribes from a previously subscribed topic. +func (c *Client) UnsubscribeRemoteTopic(topic string) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.ws == nil { + c.logger.Errorf("Attempted to unsubscribe from %s without an active WebSocket connection.", topic) + return fmt.Errorf("WebSocket connection is not established") + } + + if _, exists := c.subscribedTopics[topic]; !exists { + c.logger.Warnf("Not subscribed to topic: %s", topic) + return nil + } + + delete(c.subscribedTopics, topic) + + message := map[string]interface{}{ + "type": "unsubscribe", + "data": map[string]interface{}{ + "topic": topic, + }, + } + messageBytes, err := json.Marshal(message) + if err != nil { + c.logger.Errorf("Failed to marshal unsubscribe message: %v", err) + return err + } + + c.logger.Infof("Unsubscribing from topic: %s", topic) + err = c.sendMessage(messageBytes) + if err != nil { + c.logger.Errorf("Failed to send unsubscribe message: %v", err) + return err + } + + return nil +} + +// sendMessage queues a message to be sent to the WebSocket connection. +func (c *Client) sendMessage(message []byte) error { + select { + case c.writeChan <- message: + return nil + case <-time.After(5 * time.Second): + return fmt.Errorf("timeout sending message") + } +} + +// onAck handles acknowledgment messages received from the server. +func (c *Client) onAck(args ...interface{}) { + if len(args) < 1 { + c.logger.Errorf("onAck called with insufficient arguments") + return + } + message, ok := args[0].(IncomingMessage) + if !ok { + c.logger.Errorf("onAck received invalid message format") + return + } + data, ok := message["data"].(map[string]interface{}) + if !ok { + c.logger.Errorf("onAck message data is invalid") + return + } + ackID, ok := data["data"] + if !ok { + c.logger.Errorf("onAck data missing 'data' field") + return + } + c.logger.Debugf("Received ack: %v", data) + c.eventEmitter.Emit(fmt.Sprintf("ack.%v", ackID), data) +} + +// onResponse handles response messages received from other subscribers or backend services. +func (c *Client) onResponse(args ...interface{}) { + if len(args) < 1 { + c.logger.Errorf("onResponse called with insufficient arguments") + return + } + message, ok := args[0].(IncomingMessage) + if !ok { + c.logger.Errorf("onResponse received invalid message format") + return + } + topic, _ := message["topic"].(string) + if strings.HasPrefix(topic, "priv/") { + c.logger.Debugf("Received response for topic %s: %v", topic, message["data"]) + data, ok := message["data"].(map[string]interface{}) + if !ok { + c.logger.Errorf("onResponse message data is invalid") + return + } + payload, ok := data["payload"].(map[string]interface{}) + if !ok { + c.logger.Errorf("onResponse payload is invalid") + return + } + id := payload["id"] + c.eventEmitter.Emit(fmt.Sprintf("response.%v", id), payload) + } +} + +// onWelcome handles 'welcome' messages to indicate that the session has started. +func (c *Client) onWelcome(args ...interface{}) { + if len(args) < 1 { + c.logger.Errorf("onWelcome called with insufficient arguments") + return + } + message, _ := args[0].(IncomingMessage) + data, _ := message["data"].(map[string]interface{}) + connection, _ := data["connection"].(map[string]interface{}) + + c.logger.Infof("Session started, connection details: %v", connection) + c.eventEmitter.Emit("session.started", ConnectionInfo(connection)) +} + +// onMessage handles incoming WebSocket messages. +func (c *Client) onMessage(message []byte) { + var messageData map[string]interface{} + err := json.Unmarshal(message, &messageData) + if err != nil { + c.handleError(fmt.Errorf("failed to unmarshal message: %v", err)) + return + } + topic, ok := messageData["topic"].(string) + if !ok { + c.handleError(fmt.Errorf("message missing 'topic' field: %v", messageData)) + return + } + messageType, ok := messageData["messageType"].(string) + if !ok { + c.handleError(fmt.Errorf("message missing 'messageType' field: %v", messageData)) + return + } + data := messageData["data"] + messageEvent := IncomingMessage{ + "topic": topic, + "messageType": messageType, + "data": data, + "compression": false, + } + c.logger.Debugf("Incoming message: %v", string(message)) + if messageType != "" { + c.eventEmitter.Emit(fmt.Sprintf("%s.%s", topic, messageType), messageEvent, c.reply(messageEvent)) + } +} + +func (c *Client) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.closeWebSocket() + }() + for { + select { + case message, ok := <-c.writeChan: + if !ok { + // The write channel was closed. + return + } + c.mu.Lock() + ws := c.ws + c.mu.Unlock() + if ws == nil { + return + } + _ = ws.SetWriteDeadline(time.Now().Add(writeWait)) + err := ws.WriteMessage(websocket.TextMessage, message) + if err != nil { + c.handleError(err) + return + } + case <-ticker.C: + c.mu.Lock() + ws := c.ws + c.mu.Unlock() + if ws == nil { + return + } + _ = ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.PingMessage, nil); err != nil { + c.handleError(err) + return + } + case <-c.done: + ticker.Stop() + return + } + } +} + +func (c *Client) readPump() { + defer func() { + c.closeWebSocket() + c.handleClose() + }() + c.ws.SetReadLimit(maxMessageSize) + _ = c.ws.SetReadDeadline(time.Now().Add(pongWait)) + c.ws.SetPongHandler(func(string) error { + _ = c.ws.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + for { + select { + case <-c.done: + // Client is shutting down. + return + default: + // Continue reading messages + } + + _, message, err := c.ws.ReadMessage() + if err != nil { + select { + case <-c.done: + return + default: + c.handleError(err) + return + } + } + if message != nil { + c.onMessage(message) + } else { + continue + } + + } +} + +func (c *Client) connectLoop() { + c.mu.Lock() + if c.isConnecting { + c.mu.Unlock() + return + } + c.isConnecting = true + c.mu.Unlock() + + backoff := 1 * time.Second + maxBackoff := 60 * time.Second + + for { + c.logger.Debugf("connectLoop tick...") + select { + case <-c.done: + c.logger.Debugf("Connect loop exiting due to client disconnection") + return + default: + // Continue with connection attempts + } + + wsURL, err := c.config.WebSocketOptions.URLProvider() + if err != nil { + c.logger.Warnf("Error obtaining WebSocket URL: %v", err) + time.Sleep(backoff) + continue + } + + u, err := url.Parse(wsURL) + if err != nil { + c.logger.Warnf("Invalid WebSocket URL: %v", err) + time.Sleep(backoff) + continue + } + + // Create a net.Dialer with timeouts + netDialer := &net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 5 * time.Second, + } + + // Create a websocket.Dialer that uses net.Dialer + dialer := websocket.Dialer{ + NetDialContext: netDialer.DialContext, + } + + // Attempt to connect with a context that can be canceled + ctx, cancel := context.WithCancel(context.Background()) + go func() { + select { + case <-c.done: + cancel() + case <-ctx.Done(): + // Do nothing + } + }() + + conn, _, err := dialer.DialContext(ctx, u.String(), nil) + cancel() // Ensure the context is canceled to free resources + + if err != nil { + select { + case <-c.done: + c.logger.Debugf("Dial canceled due to client disconnection") + return + default: + // Continue + } + c.handleError(err) + c.logger.Warnf("Retrying connection in %v after failure: %v", backoff, err) + time.Sleep(backoff) + if backoff < maxBackoff { + backoff *= 2 + } else { + backoff = maxBackoff + } + continue + } + + c.mu.Lock() + c.ws = conn + c.closeOnce = new(sync.Once) // Reset closeOnce for the new connection + c.logger.Infof("Connected to WebSocket URL: %.70s", wsURL) + c.isConnecting = false + c.mu.Unlock() + + // Start writePump and readPump + go c.writePump() + go c.readPump() + return + } +} + +// handleError handles WebSocket errors by logging and emitting an 'error' event. +func (c *Client) handleError(err error) { + c.logger.Errorf("WebSocket error: %v", err) + c.eventEmitter.Emit("error", err) +} + +func (c *Client) handleClose() { + c.mu.Lock() + c.ws = nil + c.mu.Unlock() + + select { + case <-c.done: + return + default: + // Continue + } + + c.logger.Warnf("WebSocket closed unexpectedly, attempting to reconnect.") + c.eventEmitter.Emit("close") + go c.connectLoop() +} + +// reply creates a reply function for the given client and message. +func (c *Client) reply(message map[string]interface{}) ReplyFunc { + return func(data interface{}, status string, opts ...ReplyOption) error { + if status == "" { + status = "ok" + } + + dataMap, ok := message["data"].(map[string]interface{}) + if !ok { + return fmt.Errorf("data field is missing or not a map in the message") + } + + clientMap, ok := dataMap["client"].(map[string]interface{}) + if !ok { + return fmt.Errorf("client field is missing or not a map in the message data") + } + + connectionID, ok := clientMap["connectionId"] + if !ok { + return fmt.Errorf("connectionId is missing in the client data") + } + + originalID := dataMap["id"] + + // Set default options + options := ReplyOptions{ + Compress: false, + } + + // Apply provided options + for _, opt := range opts { + opt(&options) + } + + payload := map[string]interface{}{ + "data": data, + "status": status, + "id": originalID, + } + + _, err := c.Publish(fmt.Sprintf("priv/%v", connectionID), payload, + WithPublishMessageType("response"), + WithPublishCompress(options.Compress), + ) + return err + } +} + +// On registers a listener for a specific event. +func (c *Client) On(event string, listener ListenerFunc) int { + return c.eventEmitter.On(event, listener) +} + +// Off removes a listener for a specific event using the listener ID. +func (c *Client) Off(event string, id int) { + c.eventEmitter.Off(event, id) +} + +// Once registers a listener for a specific event that will be called at most once. +func (c *Client) Once(event string, listener ListenerFunc) int { + return c.eventEmitter.Once(event, listener) +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..fdbad98 --- /dev/null +++ b/client_test.go @@ -0,0 +1,832 @@ +package realtime_pubsub + +import ( + "encoding/json" + "github.com/sirupsen/logrus" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +// closeConnection safely closes a WebSocket connection. +func closeConnection(conn *websocket.Conn) { + _ = conn.Close() +} + +// disconnectClient safely disconnects the client. +func disconnectClient(c *Client) { + _ = c.Disconnect() +} + +// mockWebSocketServer creates a mock WebSocket server for testing purposes. +// It accepts a testing object and a handler function that manages the WebSocket connection. +func mockWebSocketServer(t *testing.T, handler func(*websocket.Conn)) *httptest.Server { + upgrader := websocket.Upgrader{} + + // Create an HTTP test server with a handler that upgrades connections to WebSockets. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Upgrade initial HTTP request to a WebSocket connection. + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Fatalf("Failed to upgrade WebSocket connection: %v", err) + return + } + + // Set read limits and handlers for control messages. + conn.SetReadLimit(maxMessageSize) + _ = conn.SetReadDeadline(time.Now().Add(pongWait)) + conn.SetPongHandler(func(string) error { + _ = conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + + // Start a goroutine to continuously read messages to keep the connection alive. + go func() { + defer closeConnection(conn) + for { + _, _, err := conn.ReadMessage() + if err != nil { + break + } + } + }() + + // Invoke the provided handler function to manage the connection. + handler(conn) + })) + + return server +} + +// TestClientSend verifies the client's ability to send messages directly to the server. +func TestClientSend(t *testing.T) { + var ( + receivedMessage []byte + wg sync.WaitGroup + ) + wg.Add(1) + + // Create a mock server that reads the client's message. + server := mockWebSocketServer(t, func(conn *websocket.Conn) { + defer closeConnection(conn) + for { + _, message, err := conn.ReadMessage() + if err != nil { + break + } + // Capture the first message and signal completion. + if receivedMessage == nil { + receivedMessage = message + wg.Done() + return + } + } + }) + defer server.Close() + + // Prepare the WebSocket URL for the client. + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Configure the client without a custom logger. + config := Config{ + WebSocketOptions: WebSocketOptions{ + URLProvider: func() (string, error) { + return wsURL, nil + }, + }, + } + + // Create and connect the client. + client := NewClient(config) + client.Connect() + defer disconnectClient(client) + + // Allow time for the connection to establish. + time.Sleep(100 * time.Millisecond) + + // Define the payload to send. + payload := map[string]interface{}{ + "command": "ping", + } + + // Send the message to the server. + _, err := client.Send(payload) + if err != nil { + t.Errorf("Failed to send message: %v", err) + } + + // Wait for the server to receive the message. + wg.Wait() + + // Verify the received message content. + var messageData map[string]interface{} + err = json.Unmarshal(receivedMessage, &messageData) + if err != nil { + t.Errorf("Failed to unmarshal received message: %v", err) + } + + // Validate the message type. + if messageData["type"] != "message" { + t.Errorf("Expected message type 'message', got '%v'", messageData["type"]) + } + + // Validate the payload content. + data, ok := messageData["data"].(map[string]interface{}) + if !ok { + t.Errorf("Invalid data format in received message") + } + + payloadReceived, ok := data["payload"].(map[string]interface{}) + if !ok { + t.Errorf("Invalid payload format in received message") + } + + if payloadReceived["command"] != "ping" { + t.Errorf("Expected command 'ping', got '%v'", payloadReceived["command"]) + } +} + +// TestClientPublish verifies the client's ability to publish messages to a topic. +func TestClientPublish(t *testing.T) { + var ( + receivedMessage []byte + wg sync.WaitGroup + ) + wg.Add(1) + + // Create a mock server that reads the client's publish message. + server := mockWebSocketServer(t, func(conn *websocket.Conn) { + defer closeConnection(conn) + // Read the first message from the client. + _, message, err := conn.ReadMessage() + if err != nil { + t.Errorf("Server failed to read message: %v", err) + } + receivedMessage = message + wg.Done() + }) + defer server.Close() + + // Prepare the WebSocket URL for the client. + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Configure the client with a logger for debugging. + config := Config{ + WebSocketOptions: WebSocketOptions{ + URLProvider: func() (string, error) { + return wsURL, nil + }, + }, + Logger: logrus.New(), + } + + // Create and connect the client. + client := NewClient(config) + client.Connect() + defer disconnectClient(client) + + // Allow time for the connection to establish. + time.Sleep(100 * time.Millisecond) + + // Define the payload to publish. + payload := map[string]interface{}{ + "content": "Hello, World!", + } + + // Publish the message to the topic. + _, err := client.Publish("test.topic", payload) + if err != nil { + t.Errorf("Failed to publish message: %v", err) + } + + // Wait for the server to receive the message. + wg.Wait() + + // Verify the received message content. + var messageData map[string]interface{} + err = json.Unmarshal(receivedMessage, &messageData) + if err != nil { + t.Errorf("Failed to unmarshal received message: %v", err) + } + + // Validate the message type. + if messageData["type"] != "publish" { + t.Errorf("Expected message type 'publish', got '%v'", messageData["type"]) + } + + // Validate the topic and payload. + data, ok := messageData["data"].(map[string]interface{}) + if !ok { + t.Errorf("Invalid data format in received message") + } + + if data["topic"] != "test.topic" { + t.Errorf("Expected topic 'test.topic', got '%v'", data["topic"]) + } + + if data["payload"] == nil { + t.Errorf("Expected payload, but got nil") + } +} + +// TestClientSubscribeRemoteTopic tests the client's ability to subscribe to a remote topic. +func TestClientSubscribeRemoteTopic(t *testing.T) { + var ( + receivedMessage []byte + wg sync.WaitGroup + ) + wg.Add(1) + + // Create a mock server that reads the client's subscribe message. + server := mockWebSocketServer(t, func(conn *websocket.Conn) { + defer closeConnection(conn) + for { + _, message, err := conn.ReadMessage() + if err != nil { + break + } + // Capture the subscribe message and signal completion. + if receivedMessage == nil { + receivedMessage = message + wg.Done() + return + } + } + }) + defer server.Close() + + // Prepare the WebSocket URL for the client. + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Configure the client. + config := Config{ + WebSocketOptions: WebSocketOptions{ + URLProvider: func() (string, error) { + return wsURL, nil + }, + }, + } + + // Create and connect the client. + client := NewClient(config) + client.Connect() + defer disconnectClient(client) + + // Allow time for the connection to establish. + time.Sleep(100 * time.Millisecond) + + // Subscribe to the topic. + err := client.SubscribeRemoteTopic("test.topic") + if err != nil { + t.Errorf("Failed to subscribe to topic: %v", err) + } + + // Wait for the server to receive the subscribe message. + wg.Wait() + + // Verify the received message content. + var messageData map[string]interface{} + err = json.Unmarshal(receivedMessage, &messageData) + if err != nil { + t.Errorf("Failed to unmarshal received message: %v", err) + } + + // Validate the message type. + if messageData["type"] != "subscribe" { + t.Errorf("Expected message type 'subscribe', got '%v'", messageData["type"]) + } + + // Validate the topic. + data, ok := messageData["data"].(map[string]interface{}) + if !ok { + t.Errorf("Invalid data format in received message") + } + + if data["topic"] != "test.topic" { + t.Errorf("Expected topic 'test.topic', got '%v'", data["topic"]) + } +} + +// TestClientUnsubscribeRemoteTopic tests the client's ability to unsubscribe from a remote topic. +func TestClientUnsubscribeRemoteTopic(t *testing.T) { + var ( + receivedMessages [][]byte + wg sync.WaitGroup + ) + // We expect to receive both subscribe and unsubscribe messages. + wg.Add(2) + + // Create a mock server that reads the client's messages. + server := mockWebSocketServer(t, func(conn *websocket.Conn) { + defer closeConnection(conn) + for { + _, message, err := conn.ReadMessage() + if err != nil { + break + } + receivedMessages = append(receivedMessages, message) + wg.Done() + if len(receivedMessages) >= 2 { + // Received both messages; exit the loop. + break + } + } + }) + defer server.Close() + + // Prepare the WebSocket URL for the client. + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Configure the client. + config := Config{ + WebSocketOptions: WebSocketOptions{ + URLProvider: func() (string, error) { + return wsURL, nil + }, + }, + } + + // Create and connect the client. + client := NewClient(config) + client.Connect() + defer disconnectClient(client) + + // Allow time for the connection to establish. + time.Sleep(100 * time.Millisecond) + + // Subscribe to the topic. + err := client.SubscribeRemoteTopic("test.topic") + if err != nil { + t.Errorf("Failed to subscribe to topic: %v", err) + } + + // Unsubscribe from the topic. + err = client.UnsubscribeRemoteTopic("test.topic") + if err != nil { + t.Errorf("Failed to unsubscribe from topic: %v", err) + } + + // Wait for the server to receive both messages. + wg.Wait() + + // Verify the unsubscribe message content. + var unsubscribeMessage map[string]interface{} + err = json.Unmarshal(receivedMessages[1], &unsubscribeMessage) + if err != nil { + t.Errorf("Failed to unmarshal unsubscribe message: %v", err) + } + + // Validate the message type. + if unsubscribeMessage["type"] != "unsubscribe" { + t.Errorf("Expected message type 'unsubscribe', got '%v'", unsubscribeMessage["type"]) + } + + // Validate the topic. + data, ok := unsubscribeMessage["data"].(map[string]interface{}) + if !ok { + t.Errorf("Invalid data format in unsubscribe message") + } + + if data["topic"] != "test.topic" { + t.Errorf("Expected topic 'test.topic', got '%v'", data["topic"]) + } +} + +// TestClientOnMessage tests the client's ability to handle incoming messages from the server. +func TestClientOnMessage(t *testing.T) { + var ( + messageReceived bool + wg sync.WaitGroup + ) + wg.Add(1) + + // Mock message to be sent from the server to the client. + mockMessage := map[string]interface{}{ + "topic": "test.topic", + "messageType": "update", + "data": map[string]interface{}{ + "content": "Hello, Client!", + }, + } + mockMessageBytes, _ := json.Marshal(mockMessage) + + // Create a mock server that sends a message to the client. + server := mockWebSocketServer(t, func(conn *websocket.Conn) { + defer closeConnection(conn) + // Read messages in a goroutine to handle Ping/Pong frames. + go func() { + for { + _, _, err := conn.ReadMessage() + if err != nil { + break + } + } + }() + + // Send the mock message to the client. + time.Sleep(100 * time.Millisecond) // Ensure client is ready. + err := conn.WriteMessage(websocket.TextMessage, mockMessageBytes) + if err != nil { + t.Errorf("Server write error: %v", err) + } + }) + defer server.Close() + + // Prepare the WebSocket URL for the client. + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Configure the client. + config := Config{ + WebSocketOptions: WebSocketOptions{ + URLProvider: func() (string, error) { + return wsURL, nil + }, + }, + } + + // Create and connect the client. + client := NewClient(config) + client.Connect() + defer disconnectClient(client) + + // Register a listener for the specific message type. + client.On("test.topic.update", func(args ...interface{}) { + defer wg.Done() + messageReceived = true + if len(args) < 1 { + t.Errorf("Expected at least 1 argument, got %d", len(args)) + return + } + message, ok := args[0].(IncomingMessage) + if !ok { + t.Errorf("Invalid message format") + return + } + data, ok := message.Data().(map[string]interface{}) + if !ok { + t.Errorf("Invalid data format in message") + return + } + if data["content"] != "Hello, Client!" { + t.Errorf("Expected content 'Hello, Client!', got '%v'", data["content"]) + } + }) + + // Wait for the message to be received. + wg.Wait() + + if !messageReceived { + t.Errorf("Expected to receive a message, but did not") + } +} + +// TestClientHandleError verifies that the client handles errors correctly. +func TestClientHandleError(t *testing.T) { + var ( + errorReceived error + wg sync.WaitGroup + ) + wg.Add(1) + + // Create a mock server that closes the connection immediately to simulate an error. + server := mockWebSocketServer(t, func(conn *websocket.Conn) { + closeConnection(conn) + }) + defer server.Close() + + // Prepare the WebSocket URL for the client. + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Configure the client. + config := Config{ + WebSocketOptions: WebSocketOptions{ + URLProvider: func() (string, error) { + return wsURL, nil + }, + }, + } + + // Create the client. + client := NewClient(config) + + // Register an error listener. + client.On("error", func(args ...interface{}) { + defer wg.Done() + if len(args) < 1 { + t.Errorf("Expected at least 1 argument, got %d", len(args)) + return + } + err, ok := args[0].(error) + if !ok { + t.Errorf("Expected an error, got %v", args[0]) + return + } + errorReceived = err + }) + + // Connect the client. + client.Connect() + defer disconnectClient(client) + + // Wait for the error to be received. + waitCh := make(chan struct{}) + go func() { + wg.Wait() + close(waitCh) + }() + + select { + case <-waitCh: + // Error received. + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for error event") + } + + if errorReceived == nil { + t.Errorf("Expected an error, but got nil") + } +} + +// TestClientWaitFor tests the client's ability to wait for acknowledgments. +func TestClientWaitFor(t *testing.T) { + // Create a mock server that sends an acknowledgment back to the client. + server := mockWebSocketServer(t, func(conn *websocket.Conn) { + defer closeConnection(conn) + for { + _, message, err := conn.ReadMessage() + if err != nil { + break + } + + // Unmarshal the message to get the ID. + var msgData map[string]interface{} + err = json.Unmarshal(message, &msgData) + if err != nil { + t.Errorf("Failed to unmarshal message: %v", err) + continue + } + + data, _ := msgData["data"].(map[string]interface{}) + id := data["id"] + + // Send an acknowledgment message back to the client. + ackMessage := map[string]interface{}{ + "type": "event", + "topic": "priv/acks", + "messageType": "ack", + "data": map[string]interface{}{ + "data": id, + }, + } + + ackBytes, _ := json.Marshal(ackMessage) + err = conn.WriteMessage(websocket.TextMessage, ackBytes) + if err != nil { + t.Errorf("Server write error: %v", err) + } + break // Exit after sending acknowledgment. + } + }) + defer server.Close() + + // Prepare the WebSocket URL for the client. + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Configure the client. + config := Config{ + WebSocketOptions: WebSocketOptions{ + URLProvider: func() (string, error) { + return wsURL, nil + }, + }, + } + + // Create and connect the client. + client := NewClient(config) + client.Connect() + defer disconnectClient(client) + + // Allow time for the connection to establish. + time.Sleep(100 * time.Millisecond) + + // Define the payload to publish. + payload := map[string]interface{}{ + "content": "Hello, World!", + } + + // Publish the message and wait for an acknowledgment. + waitFor, err := client.Publish("test.topic", payload) + if err != nil { + t.Errorf("Failed to publish message: %v", err) + return + } + + // Wait for the acknowledgment with a timeout. + result, err := waitFor.WaitForAck(1 * time.Second) + if err != nil { + t.Errorf("Failed to receive acknowledgment: %v", err) + return + } + + // Verify the acknowledgment ID matches. + args, ok := result.([]interface{}) + if !ok || len(args) == 0 { + t.Errorf("Expected result to be non-empty []interface{}, got %v", result) + return + } + + ackData, ok := args[0].(map[string]interface{}) + if !ok { + t.Errorf("Expected acknowledgment data to be map[string]interface{}, got %T", args[0]) + return + } + + ackID := ackData["data"] + if ackID != waitFor.id { + t.Errorf("Expected acknowledgment ID '%s', got '%v'", waitFor.id, ackID) + } +} + +// TestClientDisconnect verifies the client's ability to disconnect gracefully. +func TestClientDisconnect(t *testing.T) { + var ( + wg sync.WaitGroup + connectionClosed bool + ) + wg.Add(1) + + // Create a mock server that detects when the connection is closed. + server := mockWebSocketServer(t, func(conn *websocket.Conn) { + defer closeConnection(conn) + for { + _, _, err := conn.ReadMessage() + if err != nil { + connectionClosed = true + wg.Done() + break + } + } + }) + defer server.Close() + + // Prepare the WebSocket URL for the client. + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Configure the client. + config := Config{ + WebSocketOptions: WebSocketOptions{ + URLProvider: func() (string, error) { + return wsURL, nil + }, + }, + } + + // Create and connect the client. + client := NewClient(config) + client.Connect() + + // Allow time for the connection to establish. + time.Sleep(100 * time.Millisecond) + + // Disconnect the client. + err := client.Disconnect() + if err != nil { + t.Errorf("Failed to disconnect client: %v", err) + } + + // Wait for the server to detect the disconnection. + waitCh := make(chan struct{}) + go func() { + wg.Wait() + close(waitCh) + }() + + select { + case <-waitCh: + // Disconnection detected. + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for server to detect disconnection") + } + + if !connectionClosed { + t.Errorf("Expected server to detect connection closure, but it did not") + } + + // Wait to ensure the client does not attempt to reconnect. + time.Sleep(1 * time.Second) + + // Verify that the client's WebSocket connection is nil after disconnection. + client.mu.Lock() + if client.ws != nil { + t.Errorf("Expected client.ws to be nil after disconnect, but it is not") + } + client.mu.Unlock() +} + +// TestClientReconnect tests the client's ability to automatically reconnect after a disconnection. +func TestClientReconnect(t *testing.T) { + // Set a timeout for the test to prevent it from hanging indefinitely. + testTimeout := time.After(10 * time.Second) + done := make(chan struct{}) + + go func() { + defer close(done) + + // WaitGroups to synchronize the test steps. + var ( + initialConnectionEstablished sync.WaitGroup + reconnectionEstablished sync.WaitGroup + ) + initialConnectionEstablished.Add(1) + reconnectionEstablished.Add(1) + + upgrader := websocket.Upgrader{} + var ( + connCount int + connCountMu sync.Mutex + ) + + // Start the mock WebSocket server. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Upgrade the HTTP connection to a WebSocket. + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("Failed to upgrade: %v", err) + return + } + + // Increment the connection count. + connCountMu.Lock() + connCount++ + currentConnCount := connCount + connCountMu.Unlock() + + if currentConnCount == 1 { + // Initial connection established. + initialConnectionEstablished.Done() + + // Simulate server closing the connection after 500ms. + go func() { + time.Sleep(500 * time.Millisecond) + closeConnection(conn) + }() + } else if currentConnCount == 2 { + // Reconnection established. + reconnectionEstablished.Done() + } + + // Read messages to keep the connection open. + for { + _, _, err := conn.ReadMessage() + if err != nil { + return + } + } + })) + defer server.Close() + + // Prepare the WebSocket URL for the client. + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + // Configure the client with a logger. + config := Config{ + WebSocketOptions: WebSocketOptions{ + URLProvider: func() (string, error) { + return wsURL, nil + }, + }, + Logger: logrus.New(), + } + + // Create and connect the client. + client := NewClient(config) + client.Connect() + defer disconnectClient(client) + + // Wait for the client to establish the initial connection. + initialConnectionEstablished.Wait() + + // Wait for the client to reconnect after disconnection. + reconnectionEstablished.Wait() + + // Check that the client reconnected. + connCountMu.Lock() + if connCount < 2 { + t.Errorf("Expected client to reconnect, but it did not") + } + connCountMu.Unlock() + }() + + // Enforce the test timeout. + select { + case <-done: + // Test completed successfully. + case <-testTimeout: + t.Fatal("Test timed out") + } +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..1135c96 --- /dev/null +++ b/config.go @@ -0,0 +1,16 @@ +package realtime_pubsub + +import ( + "github.com/sirupsen/logrus" +) + +// Config represents the configuration options for Client. +type Config struct { + Logger *logrus.Logger + WebSocketOptions WebSocketOptions +} + +// WebSocketOptions represents the WebSocket configuration options. +type WebSocketOptions struct { + URLProvider func() (string, error) +} diff --git a/event_emitter.go b/event_emitter.go new file mode 100644 index 0000000..2bd4e68 --- /dev/null +++ b/event_emitter.go @@ -0,0 +1,113 @@ +package realtime_pubsub + +import ( + "strings" + "sync" +) + +// EventEmitter allows registering and emitting events with wildcard support. +type EventEmitter struct { + mu sync.RWMutex + events map[string]map[int]ListenerFunc + listenerID int +} + +// NewEventEmitter initializes and returns a new EventEmitter instance. +func NewEventEmitter() *EventEmitter { + return &EventEmitter{ + events: make(map[string]map[int]ListenerFunc), + } +} + +// On registers a listener for a specific event, supporting wildcards. +// Returns a unique listener ID for future removal. +func (e *EventEmitter) On(event string, listener ListenerFunc) int { + e.mu.Lock() + defer e.mu.Unlock() + e.listenerID++ + if e.events[event] == nil { + e.events[event] = make(map[int]ListenerFunc) + } + e.events[event][e.listenerID] = listener + return e.listenerID +} + +// Off removes a listener for a specific event using the listener ID. +func (e *EventEmitter) Off(event string, id int) { + e.mu.Lock() + defer e.mu.Unlock() + if listeners, ok := e.events[event]; ok { + delete(listeners, id) + if len(listeners) == 0 { + delete(e.events, event) + } + } +} + +// Once registers a listener for a specific event that will be called at most once. +func (e *EventEmitter) Once(event string, listener ListenerFunc) int { + var id int + wrapper := func(args ...interface{}) { + listener(args...) + e.Off(event, id) + } + id = e.On(event, wrapper) + return id +} + +// Emit emits an event with optional arguments, calling all matching listeners. +func (e *EventEmitter) Emit(event string, args ...interface{}) { + // Collect the listeners to be called. + e.mu.RLock() + var listenersToCall []ListenerFunc + for eventPattern, listeners := range e.events { + if eventMatches(eventPattern, event) { + for _, listener := range listeners { + listenersToCall = append(listenersToCall, listener) + } + } + } + e.mu.RUnlock() // Release the lock before calling listeners. + + // Call the listeners outside the lock. + for _, listener := range listenersToCall { + listener(args...) + } +} + +// eventMatches checks if an event pattern matches an event name, supporting wildcards '*' and '**'. +func eventMatches(pattern, eventName string) bool { + patternSegments := strings.Split(pattern, ".") + eventSegments := strings.Split(eventName, ".") + return matchSegments(patternSegments, eventSegments) +} + +// matchSegments matches event segments with pattern segments, handling wildcards. +func matchSegments(patternSegments, eventSegments []string) bool { + i, j := 0, 0 + for i < len(patternSegments) && j < len(eventSegments) { + if patternSegments[i] == "**" { + if i == len(patternSegments)-1 { + return true + } + for k := j; k <= len(eventSegments); k++ { + if matchSegments(patternSegments[i+1:], eventSegments[k:]) { + return true + } + } + return false + } else if patternSegments[i] == "*" { + i++ + j++ + } else if patternSegments[i] == eventSegments[j] { + i++ + j++ + } else { + return false + } + } + for i < len(patternSegments) && patternSegments[i] == "**" { + i++ + } + return i == len(patternSegments) && j == len(eventSegments) +} diff --git a/event_emitter_test.go b/event_emitter_test.go new file mode 100644 index 0000000..c93d4e0 --- /dev/null +++ b/event_emitter_test.go @@ -0,0 +1,320 @@ +package realtime_pubsub + +import ( + "sync" + "testing" + "time" +) + +// TestEventEmitterOnAndEmit tests the basic functionality of registering and emitting events. +func TestEventEmitterOnAndEmit(t *testing.T) { + emitter := NewEventEmitter() + var wg sync.WaitGroup + wg.Add(1) + + eventName := "test.event" + expectedData := "Hello, World!" + + emitter.On(eventName, func(args ...interface{}) { + defer wg.Done() + if len(args) != 1 { + t.Errorf("Expected 1 argument, got %d", len(args)) + return + } + if args[0] != expectedData { + t.Errorf("Expected data '%s', got '%v'", expectedData, args[0]) + } + }) + + emitter.Emit(eventName, expectedData) + + // Wait for the listener to be called + wg.Wait() +} + +// TestEventEmitterOff tests that listeners can be removed and no longer receive events. +func TestEventEmitterOff(t *testing.T) { + emitter := NewEventEmitter() + eventName := "test.event" + var called bool + + listenerID := emitter.On(eventName, func(args ...interface{}) { + called = true + }) + + // Remove the listener + emitter.Off(eventName, listenerID) + + emitter.Emit(eventName, "data") + + // Ensure the listener was not called + if called { + t.Errorf("Listener was called after being removed") + } +} + +// TestEventEmitterOnce tests that a listener registered with Once is only called once. +func TestEventEmitterOnce(t *testing.T) { + emitter := NewEventEmitter() + var callCount int + eventName := "test.event" + + emitter.Once(eventName, func(args ...interface{}) { + callCount++ + }) + + // Emit the event twice + emitter.Emit(eventName, "data1") + emitter.Emit(eventName, "data2") + + time.Sleep(100 * time.Millisecond) + // Listener should have been called only once + if callCount != 1 { + t.Errorf("Expected listener to be called once, but was called %d times", callCount) + } +} + +// TestEventEmitterWildcard tests that wildcard patterns in event names are supported. +func TestEventEmitterWildcard(t *testing.T) { + emitter := NewEventEmitter() + var wg sync.WaitGroup + wg.Add(4) + + eventPatterns := []string{"user.*", "order.**", "*.created"} + + expectedCalls := map[string]int{ + "user.*": 1, + "order.**": 1, + "*.created": 2, + } + + callCounts := make(map[string]int) + mu := sync.Mutex{} + + for _, pattern := range eventPatterns { + pattern := pattern + emitter.On(pattern, func(args ...interface{}) { + defer wg.Done() + mu.Lock() + callCounts[pattern]++ + mu.Unlock() + }) + } + + // Emit events + emitter.Emit("user.login", "User logged in") + emitter.Emit("order.created", "Order created") + emitter.Emit("product.created", "Product created") + + // Wait for all listeners to be called + wg.Wait() + + // Verify call counts + for pattern, expected := range expectedCalls { + if callCounts[pattern] != expected { + t.Errorf("Expected pattern '%s' to be called %d times, but was called %d times", pattern, expected, callCounts[pattern]) + } + } +} + +// TestEventEmitterNoMatch tests that listeners are not called if the event does not match. +func TestEventEmitterNoMatch(t *testing.T) { + emitter := NewEventEmitter() + var called bool + + emitter.On("test.event", func(args ...interface{}) { + called = true + }) + + // Emit a different event + emitter.Emit("other.event", "data") + + if called { + t.Errorf("Listener was called for a non-matching event") + } +} + +// TestEventEmitterMultipleListeners tests that multiple listeners for the same event are all called. +func TestEventEmitterMultipleListeners(t *testing.T) { + emitter := NewEventEmitter() + var wg sync.WaitGroup + wg.Add(2) + + eventName := "test.event" + + emitter.On(eventName, func(args ...interface{}) { + defer wg.Done() + }) + + emitter.On(eventName, func(args ...interface{}) { + defer wg.Done() + }) + + emitter.Emit(eventName, "data") + + // Wait for both listeners to be called + wg.Wait() +} + +// TestEventEmitterListenerDataIsolation tests that listeners receive their own data and are not affected by others. +func TestEventEmitterListenerDataIsolation(t *testing.T) { + emitter := NewEventEmitter() + var wg sync.WaitGroup + wg.Add(2) + + eventName1 := "test.event1" + eventName2 := "test.event2" + + emitter.On(eventName1, func(args ...interface{}) { + defer wg.Done() + if args[0] != "data1" { + t.Errorf("Listener for %s expected 'data1', got '%v'", eventName1, args[0]) + } + }) + + emitter.On(eventName2, func(args ...interface{}) { + defer wg.Done() + if args[0] != "data2" { + t.Errorf("Listener for %s expected 'data2', got '%v'", eventName2, args[0]) + } + }) + + // Emit events separately + emitter.Emit(eventName1, "data1") + emitter.Emit(eventName2, "data2") + + // Wait for both listeners to be called + wg.Wait() +} + +// TestEventEmitterConcurrentAccess tests that the EventEmitter is safe for concurrent use. +func TestEventEmitterConcurrentAccess(t *testing.T) { + emitter := NewEventEmitter() + eventName := "test.event" + var wg sync.WaitGroup + listenerCount := 100 + wg.Add(listenerCount) + + for i := 0; i < listenerCount; i++ { + emitter.On(eventName, func(args ...interface{}) { + defer wg.Done() + }) + } + + // Emit the event + emitter.Emit(eventName, "data") + + // Wait for all listeners to be called + wg.Wait() +} + +// TestEventEmitterRecursiveEmit tests that emitting events within listeners does not cause deadlocks. +func TestEventEmitterRecursiveEmit(t *testing.T) { + emitter := NewEventEmitter() + eventName1 := "event1" + eventName2 := "event2" + var wg sync.WaitGroup + wg.Add(2) + + emitter.On(eventName1, func(args ...interface{}) { + defer wg.Done() + // Emit in a new goroutine to prevent potential deadlock + go emitter.Emit(eventName2, "data") + }) + + emitter.On(eventName2, func(args ...interface{}) { + defer wg.Done() + }) + + // Start the chain by emitting event1 + emitter.Emit(eventName1, "data") + + // Wait for both listeners to be called + wg.Wait() +} + +// TestEventEmitterWildcardMultipleLevels tests '**' wildcard matching across multiple event levels. +func TestEventEmitterWildcardMultipleLevels(t *testing.T) { + emitter := NewEventEmitter() + var wg sync.WaitGroup + wg.Add(3) + + eventPattern := "app.**" + emittedEvents := []string{ + "app.start", + "app.module.init", + "app.module.component.load", + } + + emitter.On(eventPattern, func(args ...interface{}) { + defer wg.Done() + }) + + // Emit events + for _, event := range emittedEvents { + emitter.Emit(event, "data") + } + + // Wait for all listeners to be called + wg.Wait() +} + +// TestEventEmitterOffNonExistentListener tests that calling Off on a non-existent listener does not cause a panic. +func TestEventEmitterOffNonExistentListener(t *testing.T) { + emitter := NewEventEmitter() + // Attempt to remove a listener that doesn't exist + emitter.Off("nonexistent.event", 999) + // If no panic occurs, the test passes +} + +// TestEventEmitterEmitNoListeners tests that emitting an event with no listeners does not cause any issues. +func TestEventEmitterEmitNoListeners(t *testing.T) { + emitter := NewEventEmitter() + // Emit an event that has no listeners + emitter.Emit("no.listeners", "data") + // If no panic occurs, the test passes +} + +// TestEventEmitterListenerIDUniqueness tests that listener IDs are unique. +func TestEventEmitterListenerIDUniqueness(t *testing.T) { + emitter := NewEventEmitter() + eventName := "test.event" + + listenerIDs := make(map[int]struct{}) + for i := 0; i < 100; i++ { + id := emitter.On(eventName, func(args ...interface{}) {}) + if _, exists := listenerIDs[id]; exists { + t.Errorf("Duplicate listener ID detected: %d", id) + } + listenerIDs[id] = struct{}{} + } +} + +// TestEventEmitterEventPatternMatching tests various event pattern matching scenarios. +func TestEventEmitterEventPatternMatching(t *testing.T) { + testCases := []struct { + pattern string + eventName string + expected bool + }{ + {"user.*", "user.login", true}, + {"user.*", "user.profile.update", false}, + {"user.**", "user.profile.update", true}, + {"**", "any.event.name", true}, + {"order.**", "order.created", true}, + {"order.**", "order.processed.shipped", true}, + {"*.created", "user.created", true}, + {"*.created", "order.created", true}, + {"*.created", "user.profile.created", false}, + {"app.*.start", "app.server.start", true}, + {"app.*.start", "app.client.start", true}, + {"app.*.start", "app.server.stop", false}, + } + + for _, tc := range testCases { + match := eventMatches(tc.pattern, tc.eventName) + if match != tc.expected { + t.Errorf("Pattern '%s' matching event '%s': expected %v, got %v", tc.pattern, tc.eventName, tc.expected, match) + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..72e36a4 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module github.com/backendstack21/realtime-pubsub-client-go + +go 1.16 + +require ( + github.com/gorilla/websocket v1.5.3 + github.com/sirupsen/logrus v1.9.3 +) + +require github.com/joho/godotenv v1.5.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4cabbf7 --- /dev/null +++ b/go.sum @@ -0,0 +1,19 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..a70b714 --- /dev/null +++ b/listener.go @@ -0,0 +1,118 @@ +package realtime_pubsub + +// ListenerFunc represents the function signature for event listeners. +type ListenerFunc func(args ...interface{}) + +// ReplyFunc defines the signature for reply functions used in event listeners. +type ReplyFunc func(interface{}, string, ...ReplyOption) error + +// IncomingMessage represents a message received from the server. +type IncomingMessage map[string]interface{} + +// Topic receiver function that extracts the "topic" from IncomingMessage. +func (m IncomingMessage) Topic() string { + topic, _ := m["topic"].(string) + return topic +} + +// Data receiver function that extracts the "data" from IncomingMessage. +func (m IncomingMessage) Data() interface{} { + data, _ := m["data"] + return data +} + +// MessageType receiver function that extracts the "messageType" from IncomingMessage. +func (m IncomingMessage) MessageType() string { + messageType, _ := m["messageType"].(string) + return messageType +} + +// Compression receiver function that extracts the "compression" from IncomingMessage. +func (m IncomingMessage) Compression() bool { + compression, _ := m["compression"].(bool) + return compression +} + +func (m IncomingMessage) DataAsPresenceMessage() PresenceMessage { + return NewPresenceMessage(m) +} + +// ResponseMessage represents a message sent by a client in response to an incoming message. +type ResponseMessage map[string]interface{} + +// Id receiver function that extracts the "id" from ResponseMessage. +func (m ResponseMessage) Id() string { + topic, _ := m["id"].(string) + return topic +} + +// Data receiver function that extracts the "data" from ResponseMessage. +func (m ResponseMessage) Data() interface{} { + data, _ := m["data"] + return data +} + +func (m ResponseMessage) DataAsMap() map[string]interface{} { + data, _ := m["data"].(map[string]interface{}) + return data +} + +// Status receiver function that extracts the "status" from ResponseMessage. +func (m ResponseMessage) Status() string { + status, _ := m["status"].(string) + return status +} + +// ConnectionInfo represents the connection information received from the server. +type ConnectionInfo map[string]interface{} + +// AppId receiver function that extracts the "appId" from ConnectionInfo. +func (c ConnectionInfo) AppId() string { + appId, _ := c["appId"].(string) + return appId +} + +// ConnectionId receiver function that extracts the "id" from ConnectionInfo. +func (c ConnectionInfo) ConnectionId() string { + sessionId, _ := c["id"].(string) + return sessionId +} + +// RemoteAddress receiver function that extracts the "remoteAddress" from ConnectionInfo. +func (c ConnectionInfo) RemoteAddress() string { + remoteAddress, _ := c["remoteAddress"].(string) + return remoteAddress +} + +// PresenceMessage represents a presence event received from the server. +type PresenceMessage struct { + IncomingMessage +} + +// NewPresenceMessage creates a PresenceMessage from an IncomingMessage. +func NewPresenceMessage(msg IncomingMessage) PresenceMessage { + return PresenceMessage{msg} +} + +// Status receiver function that extracts the "status" attribute from payload. +func (m PresenceMessage) Status() string { + payload, _ := m.IncomingMessage.Data().(map[string]interface{})["payload"].(map[string]interface{}) + status, _ := payload["status"].(string) // Add type assertion check + return status +} + +// ConnectionId receiver function that extracts the "connectionId" attribute from client. +func (m PresenceMessage) ConnectionId() string { + payload, _ := m.IncomingMessage.Data().(map[string]interface{}) + client, _ := payload["client"].(map[string]interface{}) + connectionId, _ := client["connectionId"].(string) // Add type assertion check + return connectionId +} + +// Permissions receiver function that extracts the "permissions" attribute from client. +func (m PresenceMessage) Permissions() []string { + payload, _ := m.IncomingMessage.Data().(map[string]interface{}) + client, _ := payload["client"].(map[string]interface{}) + permissions, _ := client["permissions"].([]string) // Add type assertion check + return permissions +} diff --git a/listener_test.go b/listener_test.go new file mode 100644 index 0000000..5a84d58 --- /dev/null +++ b/listener_test.go @@ -0,0 +1,289 @@ +// listener_test.go + +package realtime_pubsub + +import ( + "reflect" + "testing" +) + +func TestIncomingMessage_Topic(t *testing.T) { + msg := IncomingMessage{ + "topic": "test/topic", + } + + expected := "test/topic" + actual := msg.Topic() + + if actual != expected { + t.Errorf("Expected topic '%s', got '%s'", expected, actual) + } +} + +func TestIncomingMessage_Data(t *testing.T) { + data := map[string]interface{}{ + "key": "value", + } + + msg := IncomingMessage{ + "data": data, + } + + actual := msg.Data() + if !reflect.DeepEqual(actual, data) { + t.Errorf("Expected data '%v', got '%v'", data, actual) + } +} + +func TestIncomingMessage_MessageType(t *testing.T) { + msg := IncomingMessage{ + "messageType": "testType", + } + + expected := "testType" + actual := msg.MessageType() + + if actual != expected { + t.Errorf("Expected messageType '%s', got '%s'", expected, actual) + } +} + +func TestIncomingMessage_Compression(t *testing.T) { + msg := IncomingMessage{ + "compression": true, + } + + expected := true + actual := msg.Compression() + + if actual != expected { + t.Errorf("Expected compression '%v', got '%v'", expected, actual) + } +} + +func TestIncomingMessage_DataAsPresenceMessage(t *testing.T) { + data := map[string]interface{}{ + "payload": map[string]interface{}{ + "status": "connected", + }, + "client": map[string]interface{}{ + "connectionId": "12345", + "permissions": []string{"read", "write"}, + }, + } + + msg := IncomingMessage{ + "data": data, + } + + presenceMsg := msg.DataAsPresenceMessage() + + status := presenceMsg.Status() + if status != "connected" { + t.Errorf("Expected status 'connected', got '%s'", status) + } + + connectionId := presenceMsg.ConnectionId() + if connectionId != "12345" { + t.Errorf("Expected connectionId '12345', got '%s'", connectionId) + } + + expectedPermissions := []string{"read", "write"} + permissions := presenceMsg.Permissions() + if !reflect.DeepEqual(permissions, expectedPermissions) { + t.Errorf("Expected permissions '%v', got '%v'", expectedPermissions, permissions) + } +} + +func TestResponseMessage_Id(t *testing.T) { + msg := ResponseMessage{ + "id": "response-123", + } + + expected := "response-123" + actual := msg.Id() + + if actual != expected { + t.Errorf("Expected id '%s', got '%s'", expected, actual) + } +} + +func TestResponseMessage_Data(t *testing.T) { + data := map[string]interface{}{ + "result": "success", + } + + msg := ResponseMessage{ + "data": data, + } + + actual := msg.Data() + if !reflect.DeepEqual(actual, data) { + t.Errorf("Expected data '%v', got '%v'", data, actual) + } +} + +func TestResponseMessage_DataAsMap(t *testing.T) { + data := map[string]interface{}{ + "result": "success", + } + + msg := ResponseMessage{ + "data": data, + } + + actual := msg.DataAsMap() + if !reflect.DeepEqual(actual, data) { + t.Errorf("Expected data as map '%v', got '%v'", data, actual) + } +} + +func TestResponseMessage_Status(t *testing.T) { + msg := ResponseMessage{ + "status": "ok", + } + + expected := "ok" + actual := msg.Status() + + if actual != expected { + t.Errorf("Expected status '%s', got '%s'", expected, actual) + } +} + +func TestConnectionInfo_AppId(t *testing.T) { + connInfo := ConnectionInfo{ + "appId": "app-123", + } + + expected := "app-123" + actual := connInfo.AppId() + + if actual != expected { + t.Errorf("Expected appId '%s', got '%s'", expected, actual) + } +} + +func TestConnectionInfo_ConnectionId(t *testing.T) { + connInfo := ConnectionInfo{ + "id": "conn-456", + } + + expected := "conn-456" + actual := connInfo.ConnectionId() + + if actual != expected { + t.Errorf("Expected connectionId '%s', got '%s'", expected, actual) + } +} + +func TestConnectionInfo_RemoteAddress(t *testing.T) { + connInfo := ConnectionInfo{ + "remoteAddress": "192.168.1.1", + } + + expected := "192.168.1.1" + actual := connInfo.RemoteAddress() + + if actual != expected { + t.Errorf("Expected remoteAddress '%s', got '%s'", expected, actual) + } +} + +func TestPresenceMessage_Status(t *testing.T) { + data := map[string]interface{}{ + "payload": map[string]interface{}{ + "status": "connected", + }, + } + + msg := IncomingMessage{ + "data": data, + } + + presenceMsg := NewPresenceMessage(msg) + + expected := "connected" + actual := presenceMsg.Status() + + if actual != expected { + t.Errorf("Expected status '%s', got '%s'", expected, actual) + } +} + +func TestPresenceMessage_ConnectionId(t *testing.T) { + data := map[string]interface{}{ + "client": map[string]interface{}{ + "connectionId": "conn-789", + }, + } + + msg := IncomingMessage{ + "data": data, + } + + presenceMsg := NewPresenceMessage(msg) + + expected := "conn-789" + actual := presenceMsg.ConnectionId() + + if actual != expected { + t.Errorf("Expected connectionId '%s', got '%s'", expected, actual) + } +} + +func TestPresenceMessage_Permissions(t *testing.T) { + data := map[string]interface{}{ + "client": map[string]interface{}{ + "permissions": []string{"read", "write"}, + }, + } + + msg := IncomingMessage{ + "data": data, + } + + presenceMsg := NewPresenceMessage(msg) + + expected := []string{"read", "write"} + actual := presenceMsg.Permissions() + + if !reflect.DeepEqual(actual, expected) { + t.Errorf("Expected permissions '%v', got '%v'", expected, actual) + } +} + +func TestPresenceMessage_InvalidTypeAssertions(t *testing.T) { + // Test with incorrect types to ensure type assertions are handled safely + data := map[string]interface{}{ + "payload": "invalid_payload_type", + "client": "invalid_client_type", + } + + msg := IncomingMessage{ + "data": data, + } + + presenceMsg := NewPresenceMessage(msg) + + // Expect empty string due to failed type assertion + expectedStatus := "" + actualStatus := presenceMsg.Status() + if actualStatus != expectedStatus { + t.Errorf("Expected status '%s', got '%s'", expectedStatus, actualStatus) + } + + // Expect empty string due to failed type assertion + expectedConnectionId := "" + actualConnectionId := presenceMsg.ConnectionId() + if actualConnectionId != expectedConnectionId { + t.Errorf("Expected connectionId '%s', got '%s'", expectedConnectionId, actualConnectionId) + } + + // Expect nil slice due to failed type assertion + expectedPermissions := []string(nil) + actualPermissions := presenceMsg.Permissions() + if actualPermissions != nil { + t.Errorf("Expected permissions '%v', got '%v'", expectedPermissions, actualPermissions) + } +} diff --git a/main/basic.go b/main/basic.go new file mode 100644 index 0000000..18fd5a9 --- /dev/null +++ b/main/basic.go @@ -0,0 +1,117 @@ +package main + +import ( + "fmt" + "github.com/backendstack21/realtime-pubsub-client-go" + "github.com/sirupsen/logrus" + "os" + "time" + + "github.com/joho/godotenv" +) + +func main() { + logger := logrus.New() + logger.SetFormatter(&logrus.JSONFormatter{}) + logger.SetOutput(os.Stdout) + logger.SetLevel(logrus.DebugLevel) + + // Load environment variables from .env file + if err := godotenv.Load(); err != nil { + logger.Errorf("Error loading .env file") + } + + appID := os.Getenv("APP_ID") + accessToken := os.Getenv("ACCESS_TOKEN") + if appID == "" || accessToken == "" { + logger.Fatal("APP_ID and ACCESS_TOKEN must be set in the environment") + } + + // Create a URL provider function + urlProvider := func() (string, error) { + url := fmt.Sprintf("wss://genesis.r7.21no.de/apps/%s?access_token=%s", appID, accessToken) + return url, nil + } + + // Initialize the client with configuration + + client := realtime_pubsub.NewClient(realtime_pubsub.Config{ + Logger: logger, + WebSocketOptions: realtime_pubsub.WebSocketOptions{ + URLProvider: urlProvider, + }, + }) + + client.On("error", func(args ...interface{}) { + logger.Infof("Received error: %v", args) + }) + + client.On("close", func(args ...interface{}) { + logger.Infof("Connection closed: %v", args) + }) + + client.Connect() + + // Handle session started event + client.On("session.started", func(args ...interface{}) { + info := args[0].(realtime_pubsub.ConnectionInfo) + logger.Infof("Connection ID: %v", info.ConnectionId()) + + // IMPORTANT: Subscribe to remote topics here + // so subscriptions are re-established on reconnection + + // Subscribe to the "chat" topic + if err := client.SubscribeRemoteTopic("chat"); err != nil { + logger.Errorf("Failed to subscribe to topic: %v", err) + } + }) + // Wait for the session to start + _, err := client.WaitFor("session.started", 10*time.Second) + if err != nil { + logger.Errorf("Failed to start session: %v", err) + } + + // Handle incoming messages on "chat.text-message" + client.On("chat.text-message", func(args ...interface{}) { + // Extract message and reply function + message := args[0].(realtime_pubsub.IncomingMessage) + replyFunc := args[1].(realtime_pubsub.ReplyFunc) + + logger.Infof("Received message: %v", message) + + // Reply to the message + if err := replyFunc("Got it!", "ok"); err != nil { + logger.Errorf("Failed to reply to message: %v", err) + } + }) + + // Handle incoming messages on the "main" topic + client.On("main.*", func(args ...interface{}) { + message := args[0].(realtime_pubsub.IncomingMessage) + logger.Infof("Received message on main topic: %v", message) + }) + + // Publish a message to the "chat" topic + waitFor, err := client.Publish("chat", "Hello, World!", + realtime_pubsub.WithPublishMessageType("text-message"), + realtime_pubsub.WithPublishCompress(true), + ) + if err != nil { + logger.Errorf("Failed to publish message: %v", err) + } + + // Wait for acknowledgment + if _, err := waitFor.WaitForAck(100 * time.Millisecond); err != nil { + logger.Errorf("Failed to receive acknowledgment: %v", err) + } + + // Wait for a reply + reply, err := waitFor.WaitForReply(500 * time.Millisecond) + if err != nil { + logger.Errorf("Failed to receive reply: %v", err) + } else { + logger.Infof("Received reply: %v", reply.Data()) + } + + _ = client.Disconnect() +} diff --git a/main/full.go b/main/full.go new file mode 100644 index 0000000..7a43c47 --- /dev/null +++ b/main/full.go @@ -0,0 +1,128 @@ +package main + +import ( + "fmt" + "github.com/backendstack21/realtime-pubsub-client-go" + "github.com/sirupsen/logrus" + "os" + "time" + + "github.com/joho/godotenv" +) + +func main() { + logger := logrus.New() + logger.SetFormatter(&logrus.JSONFormatter{}) + logger.SetOutput(os.Stdout) + logger.SetLevel(logrus.DebugLevel) + + // Load environment variables from .env file + if err := godotenv.Load(); err != nil { + logger.Errorf("Error loading .env file") + } + + appID := os.Getenv("APP_ID") + accessToken := os.Getenv("ACCESS_TOKEN") + if appID == "" || accessToken == "" { + logger.Fatal("APP_ID and ACCESS_TOKEN must be set in the environment") + } + + // Create a URL provider function + urlProvider := func() (string, error) { + url := fmt.Sprintf("wss://genesis.r7.21no.de/apps/%s?access_token=%s", appID, accessToken) + return url, nil + } + + // Initialize the client with configuration + + client := realtime_pubsub.NewClient(realtime_pubsub.Config{ + Logger: logger, + WebSocketOptions: realtime_pubsub.WebSocketOptions{ + URLProvider: urlProvider, + }, + }) + + client.On("error", func(args ...interface{}) { + logger.Infof("Received error: %v", args) + }) + + client.On("close", func(args ...interface{}) { + logger.Infof("Connection closed: %v", args) + }) + + client.Connect() + + // Handle session started event + client.On("session.started", func(args ...interface{}) { + info := args[0].(realtime_pubsub.ConnectionInfo) + logger.Infof("Connection ID: %v", info.ConnectionId()) + + // IMPORTANT: Subscribe to remote topics here + // so subscriptions are re-established on reconnection + + // Subscribe to the "chat" topic + if err := client.SubscribeRemoteTopic("chat"); err != nil { + logger.Errorf("Failed to subscribe to topic: %v", err) + } + }) + // Wait for the session to start + _, err := client.WaitFor("session.started", 10*time.Second) + if err != nil { + logger.Errorf("Failed to start session: %v", err) + } + + // Handle incoming messages on "chat.text-message" + client.On("chat.text-message", func(args ...interface{}) { + // Extract message and reply function + message := args[0].(realtime_pubsub.IncomingMessage) + replyFunc := args[1].(realtime_pubsub.ReplyFunc) + + logger.Infof("Received message: %v", message) + + // Reply to the message + if err := replyFunc("Got it!", "ok"); err != nil { + logger.Errorf("Failed to reply to message: %v", err) + } + }) + + // Handle incoming messages on the "main" topic + client.On("main.*", func(args ...interface{}) { + message := args[0].(realtime_pubsub.IncomingMessage) + logger.Infof("Received message on main topic: %v", message) + }) + + // Publish a message to the "chat" topic + waitFor, err := client.Publish("chat", "Hello, World!", + realtime_pubsub.WithPublishMessageType("text-message"), + realtime_pubsub.WithPublishCompress(true), + ) + if err != nil { + logger.Errorf("Failed to publish message: %v", err) + } + + // Wait for acknowledgment + if _, err := waitFor.WaitForAck(100 * time.Millisecond); err != nil { + logger.Errorf("Failed to receive acknowledgment: %v", err) + } + + // Wait for a reply + reply, err := waitFor.WaitForReply(500 * time.Millisecond) + if err != nil { + logger.Errorf("Failed to receive reply: %v", err) + } else { + logger.Infof("Received reply: %v", reply.Data()) + } + + // Send a message every 5 second + go func() { + for { + time.Sleep(5 * time.Second) + if _, err := client.Publish("secure/chat", "Hello!"); err != nil { + logger.Errorf("Failed to publish message: %v", err) + } + } + }() + + // Keep the process running indefinitely + select {} +} diff --git a/main/rpc/client.go b/main/rpc/client.go new file mode 100644 index 0000000..f8b47e0 --- /dev/null +++ b/main/rpc/client.go @@ -0,0 +1,70 @@ +package main + +import ( + "fmt" + "github.com/backendstack21/realtime-pubsub-client-go" + "github.com/sirupsen/logrus" + "os" + "time" + + "github.com/joho/godotenv" +) + +func main() { + done := make(chan bool) + + logger := logrus.New() + logger.SetFormatter(&logrus.JSONFormatter{}) + logger.SetOutput(os.Stdout) + logger.SetLevel(logrus.DebugLevel) + + // Load environment variables from .env file + if err := godotenv.Load(); err != nil { + logger.Errorf("Error loading .env file") + } + + appID := os.Getenv("APP_ID") + accessToken := os.Getenv("ACCESS_TOKEN") + + // Create a URL provider function + urlProvider := func() (string, error) { + url := fmt.Sprintf("wss://genesis.r7.21no.de/apps/%s?access_token=%s", appID, accessToken) + return url, nil + } + + // Initialize the client with configuration + + client := realtime_pubsub.NewClient(realtime_pubsub.Config{ + Logger: logger, + WebSocketOptions: realtime_pubsub.WebSocketOptions{ + URLProvider: urlProvider, + }, + }) + + client.On("session.started", func(args ...interface{}) { + go func() { + // Publish a message to the "chat" topic + waitFor, _ := client.Send("", + realtime_pubsub.WithSendMessageType("gettime"), + ) + + // Wait for acknowledgment + reply, err := waitFor.WaitForReply(500 * time.Millisecond) + if err != nil { + logger.Errorf("Error waiting for reply: %v", err) + } + + // Print reply + logger.Infof("Received date: %v\n", reply.DataAsMap()["time"]) + + done <- true + }() + }) + + // Handle error events + client.Connect() + + select { + case <-done: + } +} diff --git a/main/rpc/server.go b/main/rpc/server.go new file mode 100644 index 0000000..39fcbef --- /dev/null +++ b/main/rpc/server.go @@ -0,0 +1,76 @@ +package main + +import ( + "fmt" + "github.com/backendstack21/realtime-pubsub-client-go" + "github.com/joho/godotenv" + "github.com/sirupsen/logrus" + "os" + "time" +) + +func main() { + logger := logrus.New() + logger.SetFormatter(&logrus.JSONFormatter{}) + logger.SetOutput(os.Stdout) + logger.SetLevel(logrus.InfoLevel) + + // Load environment variables from .env file + if err := godotenv.Load(); err != nil { + logger.Errorf("Error loading .env file") + } + + appID := os.Getenv("APP_ID") + accessToken := os.Getenv("ACCESS_TOKEN") + + // Create a URL provider function + urlProvider := func() (string, error) { + url := fmt.Sprintf("wss://genesis.r7.21no.de/apps/%s?access_token=%s", appID, accessToken) + return url, nil + } + + // Initialize the client with configuration + + client := realtime_pubsub.NewClient(realtime_pubsub.Config{ + Logger: logger, + WebSocketOptions: realtime_pubsub.WebSocketOptions{ + URLProvider: urlProvider, + }, + }) + + client.On("secure/inbound.gettime", func(args ...interface{}) { + logger.Infof("Responding to gettime request...") + replyFunc := args[1].(realtime_pubsub.ReplyFunc) + + go func() { + if err := replyFunc(map[string]interface{}{ + "time": time.Now().Format(time.RFC3339), + }, "ok"); err != nil { + logger.Errorf("Failed to reply to message: %v", err) + } + }() + }) + + client.On("secure/inbound.presence", func(args ...interface{}) { + message := args[0].(realtime_pubsub.IncomingMessage) + presence := message.DataAsPresenceMessage() + + if presence.Status() == "connected" { + logger.Infof("Client %v connected with %v permissions...", presence.ConnectionId(), presence.Permissions()) + } else if presence.Status() == "disconnected" { + logger.Infof("Client %v disconnected...", presence.ConnectionId()) + } + }) + + client.On("session.started", func(args ...interface{}) { + err := client.SubscribeRemoteTopic("secure/inbound") + if err != nil { + logger.Errorf("Error subscribing to remote topic: %v", err) + } + }) + + // Handle error events + client.Connect() + + select {} +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..fc168b9 --- /dev/null +++ b/options.go @@ -0,0 +1,78 @@ +package realtime_pubsub + +// PublishOptions represents options for publishing messages. +type PublishOptions struct { + ID string + MessageType string + Compress bool +} + +// PublishOption defines a function type for setting PublishOptions. +type PublishOption func(*PublishOptions) + +// WithPublishID sets the ID in PublishOptions. +func WithPublishID(id string) PublishOption { + return func(opts *PublishOptions) { + opts.ID = id + } +} + +// WithPublishMessageType sets the MessageType in PublishOptions. +func WithPublishMessageType(messageType string) PublishOption { + return func(opts *PublishOptions) { + opts.MessageType = messageType + } +} + +// WithPublishCompress sets the Compress flag in PublishOptions. +func WithPublishCompress(compress bool) PublishOption { + return func(opts *PublishOptions) { + opts.Compress = compress + } +} + +// SendOptions represents options for sending messages. +type SendOptions struct { + ID string + MessageType string + Compress bool +} + +// SendOption defines a function type for setting SendOptions. +type SendOption func(*SendOptions) + +// WithSendID sets the ID in SendOptions. +func WithSendID(id string) SendOption { + return func(opts *SendOptions) { + opts.ID = id + } +} + +// WithSendMessageType sets the MessageType in SendOptions. +func WithSendMessageType(messageType string) SendOption { + return func(opts *SendOptions) { + opts.MessageType = messageType + } +} + +// WithSendCompress sets the Compress flag in SendOptions. +func WithSendCompress(compress bool) SendOption { + return func(opts *SendOptions) { + opts.Compress = compress + } +} + +// ReplyOptions represents options for replying to messages. +type ReplyOptions struct { + Compress bool +} + +// ReplyOption defines a function type for setting ReplyOptions. +type ReplyOption func(*ReplyOptions) + +// WithReplyCompress sets the Compress flag in ReplyOptions. +func WithReplyCompress(compress bool) ReplyOption { + return func(opts *ReplyOptions) { + opts.Compress = compress + } +} diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..6dd44f3 --- /dev/null +++ b/options_test.go @@ -0,0 +1,228 @@ +package realtime_pubsub + +import ( + "testing" +) + +// TestPublishOptions tests the PublishOptions functional options. +func TestPublishOptions(t *testing.T) { + // Create a default PublishOptions instance + opts := PublishOptions{ + ID: "default-id", + MessageType: "default-type", + Compress: false, + } + + // Apply functional options + options := []PublishOption{ + WithPublishID("custom-id"), + WithPublishMessageType("custom-type"), + WithPublishCompress(true), + } + + for _, opt := range options { + opt(&opts) + } + + // Verify that the options have been set correctly + if opts.ID != "custom-id" { + t.Errorf("Expected ID to be 'custom-id', got '%s'", opts.ID) + } + + if opts.MessageType != "custom-type" { + t.Errorf("Expected MessageType to be 'custom-type', got '%s'", opts.MessageType) + } + + if opts.Compress != true { + t.Errorf("Expected Compress to be 'true', got '%v'", opts.Compress) + } +} + +// TestSendOptions tests the SendOptions functional options. +func TestSendOptions(t *testing.T) { + // Create a default SendOptions instance + opts := SendOptions{ + ID: "default-id", + MessageType: "default-type", + Compress: false, + } + + // Apply functional options + options := []SendOption{ + WithSendID("custom-id"), + WithSendMessageType("custom-type"), + WithSendCompress(true), + } + + for _, opt := range options { + opt(&opts) + } + + // Verify that the options have been set correctly + if opts.ID != "custom-id" { + t.Errorf("Expected ID to be 'custom-id', got '%s'", opts.ID) + } + + if opts.MessageType != "custom-type" { + t.Errorf("Expected MessageType to be 'custom-type', got '%s'", opts.MessageType) + } + + if opts.Compress != true { + t.Errorf("Expected Compress to be 'true', got '%v'", opts.Compress) + } +} + +// TestReplyOptions tests the ReplyOptions functional options. +func TestReplyOptions(t *testing.T) { + // Create a default ReplyOptions instance + opts := ReplyOptions{ + Compress: false, + } + + // Apply functional options + options := []ReplyOption{ + WithReplyCompress(true), + } + + for _, opt := range options { + opt(&opts) + } + + // Verify that the options have been set correctly + if opts.Compress != true { + t.Errorf("Expected Compress to be 'true', got '%v'", opts.Compress) + } +} + +// TestDefaultPublishOptions tests the default values of PublishOptions when no options are applied. +func TestDefaultPublishOptions(t *testing.T) { + // Create a default PublishOptions instance + opts := PublishOptions{ + ID: "", + MessageType: "", + Compress: false, + } + + // No options applied + + // Verify default values + if opts.ID != "" { + t.Errorf("Expected default ID to be empty, got '%s'", opts.ID) + } + + if opts.MessageType != "" { + t.Errorf("Expected default MessageType to be empty, got '%s'", opts.MessageType) + } + + if opts.Compress != false { + t.Errorf("Expected default Compress to be 'false', got '%v'", opts.Compress) + } +} + +// TestDefaultSendOptions tests the default values of SendOptions when no options are applied. +func TestDefaultSendOptions(t *testing.T) { + // Create a default SendOptions instance + opts := SendOptions{ + ID: "", + MessageType: "", + Compress: false, + } + + // No options applied + + // Verify default values + if opts.ID != "" { + t.Errorf("Expected default ID to be empty, got '%s'", opts.ID) + } + + if opts.MessageType != "" { + t.Errorf("Expected default MessageType to be empty, got '%s'", opts.MessageType) + } + + if opts.Compress != false { + t.Errorf("Expected default Compress to be 'false', got '%v'", opts.Compress) + } +} + +// TestDefaultReplyOptions tests the default values of ReplyOptions when no options are applied. +func TestDefaultReplyOptions(t *testing.T) { + // Create a default ReplyOptions instance + opts := ReplyOptions{ + Compress: false, + } + + // No options applied + + // Verify default values + if opts.Compress != false { + t.Errorf("Expected default Compress to be 'false', got '%v'", opts.Compress) + } +} + +// TestPartialPublishOptions tests applying some, but not all, functional options to PublishOptions. +func TestPartialPublishOptions(t *testing.T) { + // Create a default PublishOptions instance + opts := PublishOptions{ + ID: "default-id", + MessageType: "default-type", + Compress: false, + } + + // Apply some options + options := []PublishOption{ + WithPublishID("custom-id"), + // MessageType remains default + WithPublishCompress(true), + } + + for _, opt := range options { + opt(&opts) + } + + // Verify that the options have been set correctly + if opts.ID != "custom-id" { + t.Errorf("Expected ID to be 'custom-id', got '%s'", opts.ID) + } + + if opts.MessageType != "default-type" { + t.Errorf("Expected MessageType to be 'default-type', got '%s'", opts.MessageType) + } + + if opts.Compress != true { + t.Errorf("Expected Compress to be 'true', got '%v'", opts.Compress) + } +} + +// TestPartialSendOptions tests applying some, but not all, functional options to SendOptions. +func TestPartialSendOptions(t *testing.T) { + // Create a default SendOptions instance + opts := SendOptions{ + ID: "default-id", + MessageType: "default-type", + Compress: false, + } + + // Apply some options + options := []SendOption{ + WithSendMessageType("custom-type"), + // ID remains default + // Compress remains default + } + + for _, opt := range options { + opt(&opts) + } + + // Verify that the options have been set correctly + if opts.ID != "default-id" { + t.Errorf("Expected ID to be 'default-id', got '%s'", opts.ID) + } + + if opts.MessageType != "custom-type" { + t.Errorf("Expected MessageType to be 'custom-type', got '%s'", opts.MessageType) + } + + if opts.Compress != false { + t.Errorf("Expected Compress to be 'false', got '%v'", opts.Compress) + } +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..394e70d --- /dev/null +++ b/utils.go @@ -0,0 +1,15 @@ +package realtime_pubsub + +import ( + "encoding/hex" + "io" +) + +// getRandomID generates a random identifier string. +// Accepts an io.Reader to allow for testing. +func getRandomID(reader io.Reader) string { + b := make([]byte, 8) + _, _ = reader.Read(b) + + return hex.EncodeToString(b) +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..55ae164 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,95 @@ +package realtime_pubsub + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "testing" +) + +// errorReader is an io.Reader that always returns an error. +type errorReader struct{} + +func (e *errorReader) Read(_ []byte) (int, error) { + return 0, errors.New("simulated read error") +} + +// TestGetRandomIDLength tests that getRandomID returns a string of the correct length. +func TestGetRandomIDLength(t *testing.T) { + id := getRandomID(rand.Reader) + expectedLength := 16 + if len(id) != expectedLength { + t.Errorf("Expected ID length to be %d, got %d", expectedLength, len(id)) + } +} + +// TestGetRandomIDUniqueness tests that multiple calls to getRandomID return unique values. +func TestGetRandomIDUniqueness(t *testing.T) { + ids := make(map[string]struct{}) + count := 1000 // Number of IDs to generate for the test + + for i := 0; i < count; i++ { + id := getRandomID(rand.Reader) + if _, exists := ids[id]; exists { + t.Errorf("Duplicate ID generated: %s", id) + } + ids[id] = struct{}{} + } +} + +// TestGetRandomIDFallback tests the fallback behavior when the reader fails. +func TestGetRandomIDFallback(t *testing.T) { + id := getRandomID(&errorReader{}) + + // Check that the ID is a numeric string (UnixNano timestamp) + if _, err := fmt.Sscanf(id, "%d", new(int64)); err != nil { + t.Errorf("Expected fallback ID to be a numeric string, got '%s'", id) + } +} + +// TestGetRandomIDConcurrency tests that getRandomID behaves correctly under concurrent usage. +func TestGetRandomIDConcurrency(t *testing.T) { + ids := make(chan string, 1000) + concurrency := 50 + perGoroutine := 20 + + // Function to generate IDs and send them to the channel + generateIDs := func() { + for i := 0; i < perGoroutine; i++ { + id := getRandomID(rand.Reader) + ids <- id + } + } + + // Start concurrent goroutines + for i := 0; i < concurrency; i++ { + go generateIDs() + } + + // Collect generated IDs + idMap := make(map[string]struct{}) + totalIDs := concurrency * perGoroutine + for i := 0; i < totalIDs; i++ { + id := <-ids + if _, exists := idMap[id]; exists { + t.Errorf("Duplicate ID generated in concurrent execution: %s", id) + } + idMap[id] = struct{}{} + } +} + +// TestGetRandomIDDeterministicReader tests getRandomID with a deterministic reader. +func TestGetRandomIDDeterministicReader(t *testing.T) { + // Use a bytes.Reader with known content + data := []byte("1234567890abcdef1234567890abcdef") + reader := bytes.NewReader(data) + + id := getRandomID(reader) + + expectedID := hex.EncodeToString(data[:8]) + if id != expectedID { + t.Errorf("Expected ID to be '%s', got '%s'", expectedID, id) + } +} diff --git a/wait_for.go b/wait_for.go new file mode 100644 index 0000000..2140ea6 --- /dev/null +++ b/wait_for.go @@ -0,0 +1,70 @@ +package realtime_pubsub + +import ( + "context" + "fmt" + "time" +) + +// WaitFor provides methods to wait for acknowledgments or replies. +type WaitFor struct { + client WaitForClient + id string +} + +// WaitForClient defines the interface that the client must implement. +type WaitForClient interface { + WaitFor(eventName string, timeout time.Duration) (interface{}, error) +} + +// WaitForAck waits for an acknowledgment event with a timeout. +func (w *WaitFor) WaitForAck(timeout time.Duration) (interface{}, error) { + eventName := fmt.Sprintf("ack.%v", w.id) + return w.client.WaitFor(eventName, timeout) +} + +// WaitForReply waits for a reply event with a timeout. +func (w *WaitFor) WaitForReply(timeout time.Duration) (ResponseMessage, error) { + eventName := fmt.Sprintf("response.%v", w.id) + + // Wait for the reply event + result, err := w.client.WaitFor(eventName, timeout) + if err != nil { + return nil, err + } + + // Extract the reply message directly + args, ok := result.([]interface{}) + if !ok || len(args) == 0 { + return nil, fmt.Errorf("invalid reply format") + } + + msg, ok := args[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid reply message format") + } + + return msg, nil +} + +// WaitFor waits for a specific event to occur within a timeout period. +func (c *Client) WaitFor(eventName string, timeout time.Duration) (interface{}, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + ch := make(chan interface{}, 1) + + listenerID := c.On(eventName, func(args ...interface{}) { + select { + case ch <- args: + default: + } + }) + defer c.Off(eventName, listenerID) + + select { + case result := <-ch: + return result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} diff --git a/wait_for_test.go b/wait_for_test.go new file mode 100644 index 0000000..bf31e1f --- /dev/null +++ b/wait_for_test.go @@ -0,0 +1,317 @@ +package realtime_pubsub + +import ( + "context" + "errors" + "log" + "sync" + "testing" + "time" +) + +// MockEventEmitter is a mock implementation of EventEmitter. +type MockEventEmitter struct { + mu sync.RWMutex + listeners map[string][]ListenerFunc +} + +func NewMockEventEmitter() *MockEventEmitter { + return &MockEventEmitter{ + listeners: make(map[string][]ListenerFunc), + } +} + +func (e *MockEventEmitter) On(event string, listener ListenerFunc) int { + e.mu.Lock() + defer e.mu.Unlock() + e.listeners[event] = append(e.listeners[event], listener) + // Return a dummy listener ID (not used in the mock) + return len(e.listeners[event]) - 1 +} + +func (e *MockEventEmitter) Off(event string, id int) { + e.mu.Lock() + defer e.mu.Unlock() + if listeners, ok := e.listeners[event]; ok { + if id >= 0 && id < len(listeners) { + // Remove the listener at index id + e.listeners[event] = append(listeners[:id], listeners[id+1:]...) + } + } +} + +func (e *MockEventEmitter) Emit(event string, args ...interface{}) { + e.mu.RLock() + defer e.mu.RUnlock() + if listeners, ok := e.listeners[event]; ok { + for _, listener := range listeners { + // Call the listener in a separate goroutine to mimic asynchronous behavior + go listener(args...) + } + } +} + +// MockClient is a mock implementation of Client. +type MockClient struct { + eventEmitter *MockEventEmitter + logger *log.Logger +} + +func NewMockClient() *MockClient { + return &MockClient{ + eventEmitter: NewMockEventEmitter(), + logger: log.New(log.Writer(), "mockclient: ", log.LstdFlags), + } +} + +func (c *MockClient) On(event string, listener ListenerFunc) int { + return c.eventEmitter.On(event, listener) +} + +func (c *MockClient) Off(event string, id int) { + c.eventEmitter.Off(event, id) +} + +func (c *MockClient) WaitFor(eventName string, timeout time.Duration) (interface{}, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + ch := make(chan interface{}, 1) + + listenerID := c.On(eventName, func(args ...interface{}) { + select { + case ch <- args: + default: + } + }) + defer c.Off(eventName, listenerID) + + select { + case result := <-ch: + c.logger.Printf("Event '%s' received: %v", eventName, result) + return result, nil + case <-ctx.Done(): + c.logger.Printf("Timeout waiting for event '%s'", eventName) + return nil, ctx.Err() + } +} + +func (c *MockClient) Emit(eventName string, args ...interface{}) { + c.eventEmitter.Emit(eventName, args...) +} + +// Ensure MockClient implements WaitForClient interface +var _ WaitForClient = (*MockClient)(nil) + +// TestWaitForAckSuccess tests that WaitForAck successfully returns when the ack event is emitted. +func TestWaitForAckSuccess(t *testing.T) { + client := NewMockClient() + waitFor := &WaitFor{ + client: client, // Now acceptable because MockClient implements WaitForClient + id: "test-id", + } + + // Emit the ack event after a short delay + go func() { + time.Sleep(100 * time.Millisecond) + client.Emit("ack.test-id", "ack data") + }() + + result, err := waitFor.WaitForAck(1 * time.Second) + if err != nil { + t.Errorf("Expected WaitForAck to succeed, but got error: %v", err) + } + + args, ok := result.([]interface{}) + if !ok || len(args) == 0 { + t.Errorf("Expected result to be non-empty []interface{}, got %v", result) + } + + if args[0] != "ack data" { + t.Errorf("Expected ack data to be 'ack data', got '%v'", args[0]) + } +} + +// TestWaitForAckTimeout tests that WaitForAck times out if the ack event is not emitted. +func TestWaitForAckTimeout(t *testing.T) { + client := NewMockClient() + waitFor := &WaitFor{ + client: client, + id: "test-id", + } + + // Do not emit the ack event + result, err := waitFor.WaitForAck(200 * time.Millisecond) + if err == nil { + t.Errorf("Expected WaitForAck to timeout, but got result: %v", result) + } + + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected error to be context.DeadlineExceeded, got %v", err) + } +} + +// TestWaitForReplySuccess tests that WaitForReply successfully returns when the response event is emitted. +func TestWaitForReplySuccess(t *testing.T) { + client := NewMockClient() + waitFor := &WaitFor{ + client: client, + id: "reply-id", + } + + // Emit the response event after a short delay + go func() { + time.Sleep(100 * time.Millisecond) + client.Emit("response.reply-id", map[string]interface{}{ + "data": "reply data", + "id": "reply-id", + "status": "ok", + }) + }() + + result, err := waitFor.WaitForReply(1 * time.Second) + if err != nil { + t.Errorf("Expected WaitForReply to succeed, but got error: %v", err) + } + + if result.Data() != "reply data" { + t.Errorf("Expected reply data to be 'reply data', got '%v'", result.Data()) + } + if result.Id() != "reply-id" { + t.Errorf("Expected reply id to be 'reply-id', got '%v'", result.Id()) + } + if result.Status() != "ok" { + t.Errorf("Expected reply status to be 'ok', got '%v'", result.Status()) + } + +} + +// TestWaitForReplyTimeout tests that WaitForReply times out if the response event is not emitted. +func TestWaitForReplyTimeout(t *testing.T) { + client := NewMockClient() + waitFor := &WaitFor{ + client: client, + id: "reply-id", + } + + // Do not emit the response event + result, err := waitFor.WaitForReply(200 * time.Millisecond) + if err == nil { + t.Errorf("Expected WaitForReply to timeout, but got result: %v", result) + } + + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected error to be context.DeadlineExceeded, got %v", err) + } +} + +// TestClientWaitForSuccess tests that Client.WaitFor successfully returns when the event is emitted. +func TestClientWaitForSuccess(t *testing.T) { + client := NewMockClient() + + // Emit the event after a short delay + go func() { + time.Sleep(100 * time.Millisecond) + client.Emit("test.event", "event data") + }() + + result, err := client.WaitFor("test.event", 1*time.Second) + if err != nil { + t.Errorf("Expected WaitFor to succeed, but got error: %v", err) + } + + args, ok := result.([]interface{}) + if !ok || len(args) == 0 { + t.Errorf("Expected result to be non-empty []interface{}, got %v", result) + } + + if args[0] != "event data" { + t.Errorf("Expected event data to be 'event data', got '%v'", args[0]) + } +} + +// TestClientWaitForTimeout tests that Client.WaitFor times out if the event is not emitted. +func TestClientWaitForTimeout(t *testing.T) { + client := NewMockClient() + + // Do not emit the event + result, err := client.WaitFor("test.event", 200*time.Millisecond) + if err == nil { + t.Errorf("Expected WaitFor to timeout, but got result: %v", result) + } + + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected error to be context.DeadlineExceeded, got %v", err) + } +} + +// TestWaitForMultipleListeners tests that multiple listeners can wait for the same event. +func TestWaitForMultipleListeners(t *testing.T) { + client := NewMockClient() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + result, err := client.WaitFor("test.event", 1*time.Second) + if err != nil { + t.Errorf("Listener 1: Expected WaitFor to succeed, but got error: %v", err) + return + } + args, _ := result.([]interface{}) + if args[0] != "event data" { + t.Errorf("Listener 1: Expected event data to be 'event data', got '%v'", args[0]) + } + }() + + go func() { + defer wg.Done() + result, err := client.WaitFor("test.event", 1*time.Second) + if err != nil { + t.Errorf("Listener 2: Expected WaitFor to succeed, but got error: %v", err) + return + } + args, _ := result.([]interface{}) + if args[0] != "event data" { + t.Errorf("Listener 2: Expected event data to be 'event data', got '%v'", args[0]) + } + }() + + // Emit the event after a short delay + time.Sleep(100 * time.Millisecond) + client.Emit("test.event", "event data") + + wg.Wait() +} + +// TestWaitForEventReceivedAfterTimeout tests that events received after timeout are ignored. +func TestWaitForEventReceivedAfterTimeout(t *testing.T) { + client := NewMockClient() + + resultCh := make(chan interface{}) + errCh := make(chan error) + + go func() { + result, err := client.WaitFor("test.event", 100*time.Millisecond) + if err != nil { + errCh <- err + return + } + resultCh <- result + }() + + // Emit the event after the timeout + time.Sleep(200 * time.Millisecond) + client.Emit("test.event", "event data") + + select { + case <-resultCh: + t.Errorf("Expected WaitFor to timeout, but it succeeded") + case err := <-errCh: + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected error to be context.DeadlineExceeded, got %v", err) + } + case <-time.After(1 * time.Second): + t.Errorf("Test timed out waiting for WaitFor to return") + } +}