From 1d8099efef32683649c631f220c94d358d971b6d Mon Sep 17 00:00:00 2001 From: Michal Iskierko Date: Fri, 22 Mar 2024 13:23:43 +0100 Subject: [PATCH] fix: handle bridge message edits Issue #14044 --- VERSION | 2 +- protocol/message_persistence.go | 55 ++++++++++++++++-- protocol/messenger.go | 4 ++ protocol/messenger_edit_message_test.go | 74 +++++++++++++++++++++++++ protocol/messenger_messages.go | 10 +++- 5 files changed, 138 insertions(+), 7 deletions(-) diff --git a/VERSION b/VERSION index eca6fc9b29b..4ef231165ad 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.177.0 +0.177.1 diff --git a/protocol/message_persistence.go b/protocol/message_persistence.go index 92108b1b7c8..dc7837caeda 100644 --- a/protocol/message_persistence.go +++ b/protocol/message_persistence.go @@ -1551,6 +1551,18 @@ func (db sqlitePersistence) SaveMessages(messages []*common.Message) (err error) } if msg.ContentType == protobuf.ChatMessage_BRIDGE_MESSAGE { + // check updates first + var hasMessage bool + hasMessage, err = db.bridgeMessageExists(tx, msg.GetBridgeMessage().MessageID) + if err != nil { + return + } + if hasMessage { + // bridge message exists, this is edit + err = db.updateBridgeMessageContent(tx, msg.GetBridgeMessage().MessageID, msg.GetBridgeMessage().Content) + return + } + err = db.saveBridgeMessage(tx, msg.GetBridgeMessage(), msg.ID) if err != nil { return @@ -2967,9 +2979,26 @@ func (db sqlitePersistence) findStatusMessageIdsReplies(tx *sql.Tx, bridgeMessag return statusMessageIDs, nil } -// Finds status messages id which are replies for bridgeMessageID -func (db sqlitePersistence) findStatusMessageIdsRepliedTo(tx *sql.Tx, parentMessageID string) (string, error) { - rows, err := tx.Query(`SELECT user_messages_id FROM bridge_messages WHERE message_id = ?`, parentMessageID) +func (db sqlitePersistence) FindStatusMessageIdForBridgeMessageId(messageID string) (string, error) { + rows, err := db.db.Query(`SELECT user_messages_id FROM bridge_messages WHERE message_id = ?`, messageID) + if err != nil { + return "", err + } + defer rows.Close() + + if rows.Next() { + var statusMessageID string + err = rows.Scan(&statusMessageID) + if err != nil { + return "", err + } + return statusMessageID, nil + } + return "", nil +} + +func (db sqlitePersistence) findStatusMessageIdForBridgeMessageId(tx *sql.Tx, messageID string) (string, error) { + rows, err := tx.Query(`SELECT user_messages_id FROM bridge_messages WHERE message_id = ?`, messageID) if err != nil { return "", err } @@ -3003,6 +3032,23 @@ func (db sqlitePersistence) updateStatusMessagesWithResponse(tx *sql.Tx, statusM return err } +func (db sqlitePersistence) bridgeMessageExists(tx *sql.Tx, bridgeMessageID string) (exists bool, err error) { + err = tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM bridge_messages WHERE message_id = ?)`, bridgeMessageID).Scan(&exists) + return exists, err +} + +func (db sqlitePersistence) updateBridgeMessageContent(tx *sql.Tx, bridgeMessageID string, content string) error { + sql := "UPDATE bridge_messages SET content = ? WHERE message_id = ?" + stmt, err := tx.Prepare(sql) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(content, bridgeMessageID) + return err +} + // Finds if there are any messages that are replies to that message (in case replies were received earlier) func (db sqlitePersistence) findAndUpdateReplies(tx *sql.Tx, bridgeMessageID string, statusMessageID string) error { replyMessageIds, err := db.findStatusMessageIdsReplies(tx, bridgeMessageID) @@ -3016,7 +3062,8 @@ func (db sqlitePersistence) findAndUpdateReplies(tx *sql.Tx, bridgeMessageID str } func (db sqlitePersistence) findAndUpdateRepliedTo(tx *sql.Tx, discordParentMessageID string, statusMessageID string) error { - repliedMessageID, err := db.findStatusMessageIdsRepliedTo(tx, discordParentMessageID) + // Finds status messages id which are replies for bridgeMessageID + repliedMessageID, err := db.findStatusMessageIdForBridgeMessageId(tx, discordParentMessageID) if err != nil { return err } diff --git a/protocol/messenger.go b/protocol/messenger.go index a65fd65646d..e834067d96f 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -6042,3 +6042,7 @@ func (m *Messenger) startMessageSegmentsCleanupLoop() { } }() } + +func (m *Messenger) FindStatusMessageIdForBridgeMessageId(bridgeMessageID string) (string, error) { + return m.persistence.FindStatusMessageIdForBridgeMessageId(bridgeMessageID) +} diff --git a/protocol/messenger_edit_message_test.go b/protocol/messenger_edit_message_test.go index c7882dc43a6..e670329a7e6 100644 --- a/protocol/messenger_edit_message_test.go +++ b/protocol/messenger_edit_message_test.go @@ -90,6 +90,80 @@ func (s *MessengerEditMessageSuite) TestEditMessage() { s.Require().Equal(ErrInvalidEditOrDeleteAuthor, err) } +func (s *MessengerEditMessageSuite) TestEditBridgeMessage() { + theirMessenger := s.newMessenger() + defer TearDownMessenger(&s.Suite, theirMessenger) + + theirChat := CreateOneToOneChat("Their 1TO1", &s.privateKey.PublicKey, s.m.transport) + err := theirMessenger.SaveChat(theirChat) + s.Require().NoError(err) + + ourChat := CreateOneToOneChat("Our 1TO1", &theirMessenger.identity.PublicKey, s.m.transport) + err = s.m.SaveChat(ourChat) + s.Require().NoError(err) + + bridgeMessage := buildTestMessage(*theirChat) + bridgeMessage.ContentType = protobuf.ChatMessage_BRIDGE_MESSAGE + bridgeMessage.Payload = &protobuf.ChatMessage_BridgeMessage{ + BridgeMessage: &protobuf.BridgeMessage{ + BridgeName: "discord", + UserName: "user1", + UserAvatar: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADIAAAAyCAIAAACRXR/mAAAAjklEQVR4nOzXwQmFMBAAUZXUYh32ZB32ZB02sxYQQSZGsod55/91WFgSS0RM+SyjA56ZRZhFmEWYRRT6h+M6G16zrxv6fdJpmUWYRbxsYr13dKfanpN0WmYRZhGzXz6AWYRZRIfbaX26fT9Jk07LLMIsosPt9I/dTDotswizCG+nhFmEWYRZhFnEHQAA///z1CFkYamgfQAAAABJRU5ErkJggg==", + UserID: "123", + Content: "text1", + MessageID: "456", + ParentMessageID: "789", + }, + } + + sendResponse, err := theirMessenger.SendChatMessage(context.Background(), bridgeMessage) + s.NoError(err) + s.Require().Len(sendResponse.Messages(), 1) + + response, err := WaitOnMessengerResponse( + s.m, + func(r *MessengerResponse) bool { return len(r.messages) > 0 }, + "no messages", + ) + s.Require().NoError(err) + s.Require().Len(response.Chats(), 1) + s.Require().Len(response.Messages(), 1) + + messageToEdit := sendResponse.Messages()[0] + + messageID, err := types.DecodeHex(messageToEdit.ID) + s.Require().NoError(err) + + editedText := "edited text" + editedMessage := &requests.EditMessage{ + ID: messageID, + Text: editedText, + ContentType: protobuf.ChatMessage_BRIDGE_MESSAGE, + } + + sendResponse, err = theirMessenger.EditMessage(context.Background(), editedMessage) + s.Require().NoError(err) + s.Require().Len(sendResponse.Messages(), 1) + s.Require().NotEmpty(sendResponse.Messages()[0].EditedAt) + s.Require().Equal(sendResponse.Messages()[0].Text, "text-input-message") + s.Require().Equal(sendResponse.Messages()[0].GetBridgeMessage().Content, editedText) + s.Require().Len(sendResponse.Chats(), 1) + s.Require().NotNil(sendResponse.Chats()[0].LastMessage) + s.Require().NotEmpty(sendResponse.Chats()[0].LastMessage.EditedAt) + + response, err = WaitOnMessengerResponse( + s.m, + func(r *MessengerResponse) bool { return len(r.messages) > 0 }, + "no messages", + ) + s.Require().NoError(err) + s.Require().Len(response.Chats(), 1) + s.Require().Len(response.Messages(), 1) + + s.Require().NotEmpty(response.Chats()[0].LastMessage.EditedAt) + s.Require().Equal(response.Messages()[0].GetBridgeMessage().Content, "edited text") +} + func (s *MessengerEditMessageSuite) TestEditMessageEdgeCases() { theirMessenger := s.newMessenger() defer TearDownMessenger(&s.Suite, theirMessenger) diff --git a/protocol/messenger_messages.go b/protocol/messenger_messages.go index afc6c6043cf..2884149410c 100644 --- a/protocol/messenger_messages.go +++ b/protocol/messenger_messages.go @@ -33,7 +33,7 @@ func (m *Messenger) EditMessage(ctx context.Context, request *requests.EditMessa return nil, ErrInvalidEditOrDeleteAuthor } - if message.ContentType != protobuf.ChatMessage_TEXT_PLAIN && message.ContentType != protobuf.ChatMessage_EMOJI && message.ContentType != protobuf.ChatMessage_IMAGE { + if message.ContentType != protobuf.ChatMessage_TEXT_PLAIN && message.ContentType != protobuf.ChatMessage_EMOJI && message.ContentType != protobuf.ChatMessage_IMAGE && message.ContentType != protobuf.ChatMessage_BRIDGE_MESSAGE { return nil, ErrInvalidEditContentType } @@ -373,7 +373,13 @@ func (m *Messenger) applyEditMessage(editMessage *protobuf.EditMessage, message if err := ValidateText(editMessage.Text); err != nil { return err } - message.Text = editMessage.Text + + if editMessage.ContentType != protobuf.ChatMessage_BRIDGE_MESSAGE { + message.Text = editMessage.Text + } else { + message.GetBridgeMessage().Content = editMessage.Text + } + message.EditedAt = editMessage.Clock message.UnfurledLinks = editMessage.UnfurledLinks message.UnfurledStatusLinks = editMessage.UnfurledStatusLinks