diff --git a/comms/internal/rpcz/chat.go b/comms/internal/rpcz/chat.go index 8bf44dec410..1e24409bea3 100644 --- a/comms/internal/rpcz/chat.go +++ b/comms/internal/rpcz/chat.go @@ -66,8 +66,13 @@ func chatSendMessage(tx *sqlx.Tx, userId int32, chatId string, messageId string, return err } -func chatReactMessage(tx *sqlx.Tx, userId int32, messageId string, reaction string, messageTimestamp time.Time) error { - _, err := tx.Exec("insert into chat_message_reactions (user_id, message_id, reaction, created_at, updated_at) values ($1, $2, $3, $4, $4) on conflict (user_id, message_id) do update set reaction = $3, updated_at = $4", userId, messageId, reaction, messageTimestamp) +func chatReactMessage(tx *sqlx.Tx, userId int32, messageId string, reaction *string, messageTimestamp time.Time) error { + var err error + if reaction != nil { + _, err = tx.Exec("insert into chat_message_reactions (user_id, message_id, reaction, created_at, updated_at) values ($1, $2, $3, $4, $4) on conflict (user_id, message_id) do update set reaction = $3, updated_at = $4", userId, messageId, *reaction, messageTimestamp) + } else { + _, err = tx.Exec("delete from chat_message_reactions where user_id = $1 and message_id = $2", userId, messageId) + } if err != nil { return err } diff --git a/comms/internal/rpcz/chat_test.go b/comms/internal/rpcz/chat_test.go index 9005fcebff8..54ecc2e2c98 100644 --- a/comms/internal/rpcz/chat_test.go +++ b/comms/internal/rpcz/chat_test.go @@ -72,11 +72,15 @@ func TestChat(t *testing.T) { assert.Equal(t, expected, unreadCount) } - assertReaction := func(userId int, messageId string, expected string) { + assertReaction := func(userId int, messageId string, expected *string) { var reaction string err := tx.Get(&reaction, "select reaction from chat_message_reactions where user_id = $1 and message_id = $2", userId, messageId) - assert.NoError(t, err) - assert.Equal(t, expected, reaction) + if expected != nil { + assert.NoError(t, err) + assert.Equal(t, *expected, reaction) + } else { + assert.ErrorIs(t, err, sql.ErrNoRows) + } } // assert sender has no unread messages @@ -119,14 +123,19 @@ func TestChat(t *testing.T) { // 91 reacts to 92's message reactTs := time.Now() reaction := "fire" - err = chatReactMessage(tx, 91, replyMessageId, reaction, reactTs) - assertReaction(91, replyMessageId, reaction) + err = chatReactMessage(tx, 91, replyMessageId, &reaction, reactTs) + assertReaction(91, replyMessageId, &reaction) - // 91 changes reaction to 92's old message + // 91 changes reaction to 92's message changedReactTs := time.Now() newReaction := "heart" - err = chatReactMessage(tx, 91, replyMessageId, newReaction, changedReactTs) - assertReaction(91, replyMessageId, newReaction) + err = chatReactMessage(tx, 91, replyMessageId, &newReaction, changedReactTs) + assertReaction(91, replyMessageId, &newReaction) + + // 91 removes reaction to 92's message + removedReactTs := time.Now() + err = chatReactMessage(tx, 91, replyMessageId, nil, removedReactTs) + assertReaction(91, replyMessageId, nil) tx.Rollback() } diff --git a/comms/schema/schema.go b/comms/schema/schema.go index 3f277051c76..f73ec50f6da 100644 --- a/comms/schema/schema.go +++ b/comms/schema/schema.go @@ -57,9 +57,9 @@ type ChatReactRPC struct { } type ChatReactRPCParams struct { - ChatID string `json:"chat_id"` - MessageID string `json:"message_id"` - Reaction string `json:"reaction"` + ChatID string `json:"chat_id"` + MessageID string `json:"message_id"` + Reaction *string `json:"reaction"` } type ChatReadRPC struct { @@ -109,7 +109,7 @@ type RPCPayloadParams struct { Message *string `json:"message,omitempty"` MessageID *string `json:"message_id,omitempty"` ParentMessageID *string `json:"parent_message_id,omitempty"` - Reaction *string `json:"reaction,omitempty"` + Reaction *string `json:"reaction"` UserID *string `json:"user_id,omitempty"` Permit *ChatPermission `json:"permit,omitempty"` } diff --git a/comms/schema/schema.ts b/comms/schema/schema.ts index ee70498f034..eed787c8b79 100644 --- a/comms/schema/schema.ts +++ b/comms/schema/schema.ts @@ -44,7 +44,7 @@ export type ChatReactRPC = { params: { chat_id: string message_id: string - reaction: string + reaction: string | null } } diff --git a/libs/src/sdk/api/chats/serverTypes.ts b/libs/src/sdk/api/chats/serverTypes.ts index ee70498f034..eed787c8b79 100644 --- a/libs/src/sdk/api/chats/serverTypes.ts +++ b/libs/src/sdk/api/chats/serverTypes.ts @@ -44,7 +44,7 @@ export type ChatReactRPC = { params: { chat_id: string message_id: string - reaction: string + reaction: string | null } }