diff --git a/librarian/discord/cogs/background/github.py b/librarian/discord/cogs/background/github.py index ad63665..5bea33d 100644 --- a/librarian/discord/cogs/background/github.py +++ b/librarian/discord/cogs/background/github.py @@ -274,6 +274,7 @@ async def sort_for_updates(self, pulls: typing.List[storage.models.pull.Pull]) - async def handle_update_exception(self, exc: errors.LibrarianException, item: typing.Tuple[int, int]): if isinstance(exc, errors.NoDiscordChannel): await self.bot.settings.reset(exc.channel_id) + self.storage.discord.delete_channel_messages(exc.channel_id) async def status(self) -> dict: # FIXME: make something useful here diff --git a/librarian/storage/models/discord.py b/librarian/storage/models/discord.py index dfdd125..4b9656c 100644 --- a/librarian/storage/models/discord.py +++ b/librarian/storage/models/discord.py @@ -90,6 +90,12 @@ def delete_message(self, message_id, channel_id, s): DiscordMessage.channel_id == channel_id, ).delete() + @utils.optional_session + def delete_channel_messages(self, channel_id, s) -> int: + return s.query(DiscordMessage).filter( + DiscordMessage.channel_id == channel_id, + ).delete() + @utils.optional_session def messages_by_pull_numbers(self, *pull_numbers: typing.List[int], s: orm.Session = None): """ Return all known messages that are tied to the specified pulls. """ diff --git a/tests/discord/cogs/background/test_github.py b/tests/discord/cogs/background/test_github.py index bab2ecb..5d46ac8 100644 --- a/tests/discord/cogs/background/test_github.py +++ b/tests/discord/cogs/background/test_github.py @@ -265,8 +265,10 @@ async def test__exception_handling_no_channel( response = collections.namedtuple("Response", "status reason")(404, "testing stuff") client.fetch_channel = mocker.AsyncMock(side_effect=discord_errors.NotFound(response, "error")) client.settings.reset = mocker.AsyncMock() + client.storage.discord.delete_channel_messages = mocker.Mock() await monitor.sort_for_updates([pp]) client.settings.reset.assert_called() + client.storage.discord.delete_channel_messages.assert_called() args, _ = client.settings.reset.call_args assert args == (123,) diff --git a/tests/storage/models/test_discord.py b/tests/storage/models/test_discord.py index 9fe49b7..c66cb67 100644 --- a/tests/storage/models/test_discord.py +++ b/tests/storage/models/test_discord.py @@ -30,6 +30,22 @@ def test__delete_message(self, storage, existing_pulls): storage.discord.delete_message(123, 456) assert not storage.discord.messages_by_pull_numbers(789) + def test__delete_channel_messages(self, storage, existing_pulls): + assert storage.discord.delete_channel_messages(123) == 0 + storage.discord.save_messages(*( + stg.DiscordMessage( + id=i, + channel_id=123, + pull_number=1 + ) + for i in range(10) + )) + storage.discord.save_messages(stg.DiscordMessage(id=15, channel_id=124, pull_number=1)) + + assert len(storage.discord.messages_by_pull_numbers(1)) == 11 + assert storage.discord.delete_channel_messages(123) == 10 + assert len(storage.discord.messages_by_pull_numbers(1)) == 1 + class TestDiscordUsers: def test__promote_one(self, storage):