diff --git a/naff/api/events/discord.py b/naff/api/events/discord.py index 763a9852c..1778a2c06 100644 --- a/naff/api/events/discord.py +++ b/naff/api/events/discord.py @@ -71,6 +71,7 @@ def on_guild_join(event): "StageInstanceDelete", "StageInstanceUpdate", "ThreadCreate", + "NewThreadCreate", "ThreadDelete", "ThreadListSync", "ThreadMemberUpdate", @@ -168,11 +169,16 @@ class ChannelPinsUpdate(ChannelCreate): @define(kw_only=False) class ThreadCreate(BaseEvent): - """Dispatched when a thread is created.""" + """Dispatched when a thread is created, or a thread is new to the client""" thread: "TYPE_THREAD_CHANNEL" = field(metadata=docs("The thread this event is dispatched from")) +@define(kw_only=False) +class NewThreadCreate(ThreadCreate): + """Dispatched when a thread is newly created.""" + + @define(kw_only=False) class ThreadUpdate(ThreadCreate): """Dispatched when a thread is updated.""" diff --git a/naff/api/events/processors/thread_events.py b/naff/api/events/processors/thread_events.py index 5b4886111..232aea8cf 100644 --- a/naff/api/events/processors/thread_events.py +++ b/naff/api/events/processors/thread_events.py @@ -1,9 +1,8 @@ from typing import TYPE_CHECKING import naff.api.events as events - -from ._template import EventMixinTemplate, Processor from naff.models import to_snowflake +from ._template import EventMixinTemplate, Processor if TYPE_CHECKING: from naff.api.events import RawGatewayEvent @@ -14,7 +13,10 @@ class ThreadEvents(EventMixinTemplate): @Processor.define() async def _on_raw_thread_create(self, event: "RawGatewayEvent") -> None: - self.dispatch(events.ThreadCreate(self.cache.place_channel_data(event.data))) + thread = self.cache.place_channel_data(event.data) + if event.data.get("newly_created"): + self.dispatch(events.NewThreadCreate(thread)) + self.dispatch(events.ThreadCreate(thread)) @Processor.define() async def _on_raw_thread_update(self, event: "RawGatewayEvent") -> None: diff --git a/naff/api/http/http_requests/channels.py b/naff/api/http/http_requests/channels.py index 81be399ee..ae70c518e 100644 --- a/naff/api/http/http_requests/channels.py +++ b/naff/api/http/http_requests/channels.py @@ -567,7 +567,7 @@ async def create_tag( payload: PAYLOAD_TYPE = { "name": name, "emoji_id": int(emoji_id) if emoji_id else None, - "emoji_name": emoji_name, + "emoji_name": emoji_name if emoji_name else None, } payload = dict_filter_none(payload) diff --git a/naff/api/http/http_requests/threads.py b/naff/api/http/http_requests/threads.py index 6b317542c..7206f5420 100644 --- a/naff/api/http/http_requests/threads.py +++ b/naff/api/http/http_requests/threads.py @@ -210,23 +210,23 @@ async def create_forum_thread( name: The name of the thread auto_archive_duration: Time before the thread will be automatically archived. Note 3 day and 7 day archive durations require the server to be boosted. message: The message-content for the post/thread + applied_tags: The tags to apply to the thread rate_limit_per_user: The time users must wait between sending messages + files: The files to upload reason: The reason for creating this thread Returns: The created thread object """ - # note: `{"use_nested_fields": 1}` seems to be a temporary flag until forums launch return await self.request( Route("POST", f"/channels/{channel_id}/threads"), payload={ "name": name, "auto_archive_duration": auto_archive_duration, "rate_limit_per_user": rate_limit_per_user, - "applied_tags": applied_tags, "message": message, + "applied_tags": applied_tags, }, - params={"use_nested_fields": 1}, files=files, reason=reason, ) diff --git a/naff/models/discord/channel.py b/naff/models/discord/channel.py index 3e60220d4..60090317e 100644 --- a/naff/models/discord/channel.py +++ b/naff/models/discord/channel.py @@ -1,26 +1,27 @@ import time +from asyncio import QueueEmpty from collections import namedtuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Callable import attrs import naff.models as models - from naff.client.const import MISSING, DISCORD_EPOCH, Absent, logger from naff.client.errors import NotFound, VoiceNotConnected, TooManyChanges from naff.client.mixins.send import SendMixin from naff.client.mixins.serialization import DictSerializationMixin -from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import optional as optional_c from naff.client.utils.attr_converters import timestamp_converter +from naff.client.utils.attr_utils import define, field from naff.client.utils.misc_utils import get from naff.client.utils.serializer import to_dict, to_image_data from naff.models.discord.base import DiscordObject +from naff.models.discord.emoji import PartialEmoji from naff.models.discord.file import UPLOADABLE_TYPE from naff.models.discord.snowflake import Snowflake_Type, to_snowflake, to_optional_snowflake, SnowflakeObject -from naff.models.misc.iterator import AsyncIterator from naff.models.discord.thread import ThreadTag -from naff.models.discord.emoji import PartialEmoji +from naff.models.misc.context_manager import Typing +from naff.models.misc.iterator import AsyncIterator from .enums import ( ChannelFlags, ChannelTypes, @@ -32,7 +33,6 @@ MessageFlags, InviteTargetTypes, ) -from naff.models.misc.context_manager import Typing if TYPE_CHECKING: from aiohttp import FormData @@ -128,6 +128,33 @@ async def fetch(self) -> List["models.Message"]: return messages +class ArchivedForumPosts(AsyncIterator): + def __init__(self, channel: "BaseChannel", limit: int = 50, before: Snowflake_Type = None) -> None: + self.channel: "BaseChannel" = channel + self.before: Snowflake_Type = before + self._more: bool = True + super().__init__(limit) + + if self.before: + self.last = self.before + + async def fetch(self) -> list["GuildForumPost"]: + if self._more: + expected = self.get_limit + + rcv = await self.channel._client.http.list_public_archived_threads( + self.channel.id, limit=expected, before=self.last + ) + threads = [self.channel._client.cache.place_channel_data(data) for data in rcv["threads"]] + + if not rcv: + raise QueueEmpty + + self._more = rcv.get("has_more", False) + return threads + raise QueueEmpty + + @define() class PermissionOverwrite(SnowflakeObject, DictSerializationMixin): """ @@ -739,6 +766,12 @@ def from_dict_factory(cls, data: dict, client: "Client") -> "TYPE_ALL_CHANNEL": logger.error(f"Unsupported channel type for {data} ({channel_type}).") channel_class = BaseChannel + if channel_class == GuildPublicThread: + # attempt to determine if this thread is a forum post (thanks discord) + parent_channel = client.cache.get_channel(data["parent_id"]) + if parent_channel and parent_channel.type == ChannelTypes.GUILD_FORUM: + channel_class = GuildForumPost + return channel_class.from_dict(data, client) @property @@ -1905,8 +1938,60 @@ async def edit( @define() class GuildPublicThread(ThreadChannel): + async def edit( + self, + name: Absent[str] = MISSING, + archived: Absent[bool] = MISSING, + auto_archive_duration: Absent[AutoArchiveDuration] = MISSING, + locked: Absent[bool] = MISSING, + rate_limit_per_user: Absent[int] = MISSING, + flags: Absent[Union[int, ChannelFlags]] = MISSING, + reason: Absent[str] = MISSING, + **kwargs, + ) -> "GuildPublicThread": + """ + Edit this thread. - _applied_tags: List[Snowflake_Type] = field(factory=list) + Args: + name: 1-100 character channel name + archived: whether the thread is archived + auto_archive_duration: duration in minutes to automatically archive the thread after recent activity, can be set to: 60, 1440, 4320, 10080 + locked: whether the thread is locked; when a thread is locked, only users with MANAGE_THREADS can unarchive it + rate_limit_per_user: amount of seconds a user has to wait before sending another message (0-21600) + flags: channel flags for forum threads + reason: The reason for this change + + Returns: + The edited thread channel object. + """ + return await super().edit( + name=name, + archived=archived, + auto_archive_duration=auto_archive_duration, + locked=locked, + rate_limit_per_user=rate_limit_per_user, + reason=reason, + flags=flags, + **kwargs, + ) + + +@define() +class GuildForumPost(GuildPublicThread): + """ + A forum post + + !!! note + This model is an abstraction of the api - In reality all posts are GuildPublicThread + """ + + _applied_tags: list[Snowflake_Type] = field(factory=list) + + @classmethod + def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any]: + data = super()._process_dict(data, client) + data["_applied_tags"] = data.pop("applied_tags") if "applied_tags" in data else [] + return data async def edit( self, @@ -1919,18 +2004,18 @@ async def edit( flags: Absent[Union[int, ChannelFlags]] = MISSING, reason: Absent[str] = MISSING, **kwargs, - ) -> "GuildPublicThread": + ) -> "GuildForumPost": """ Edit this thread. Args: name: 1-100 character channel name archived: whether the thread is archived - applied_tags: list of tags to apply to a forum post (!!! This is for forum threads only) + applied_tags: list of tags to apply auto_archive_duration: duration in minutes to automatically archive the thread after recent activity, can be set to: 60, 1440, 4320, 10080 locked: whether the thread is locked; when a thread is locked, only users with MANAGE_THREADS can unarchive it rate_limit_per_user: amount of seconds a user has to wait before sending another message (0-21600) - flags: channel flags for forum threads + flags: channel flags to apply reason: The reason for this change Returns: @@ -1953,29 +2038,44 @@ async def edit( @property def applied_tags(self) -> list[ThreadTag]: - """ - The tags applied to this thread. - - !!! note - This is only on forum threads. - - """ + """The tags applied to this thread.""" if not isinstance(self.parent_channel, GuildForum): raise AttributeError("This is only available on forum threads.") return [tag for tag in self.parent_channel.available_tags if str(tag.id) in self._applied_tags] @property def initial_post(self) -> Optional["Message"]: + """The initial message posted by the OP.""" + if not isinstance(self.parent_channel, GuildForum): + raise AttributeError("This is only available on forum threads.") + return self.get_message(self.id) + + @property + def pinned(self) -> bool: + """Whether this thread is pinned.""" + return ChannelFlags.PINNED in self.flags + + async def pin(self, reason: Absent[str] = MISSING) -> None: """ - The initial message posted by the OP. + Pin this thread. - !!! note - This is only on forum threads. + Args: + reason: The reason for this pin """ - if not isinstance(self.parent_channel, GuildForum): - raise AttributeError("This is only available on forum threads.") - return self.get_message(self.id) + flags = self.flags | ChannelFlags.PINNED + await self.edit(flags=flags, reason=reason) + + async def unpin(self, reason: Absent[str] = MISSING) -> None: + """ + Unpin this thread. + + Args: + reason: The reason for this unpin + + """ + flags = self.flags & ~ChannelFlags.PINNED + await self.edit(flags=flags, reason=reason) @define() @@ -2218,7 +2318,7 @@ async def create_post( self, name: str, content: str | None, - applied_tags: Optional[List[Union["Snowflake_Type", "ThreadTag"]]] = MISSING, + applied_tags: Optional[List[Union["Snowflake_Type", "ThreadTag", str]]] = MISSING, *, auto_archive_duration: AutoArchiveDuration = AutoArchiveDuration.ONE_DAY, rate_limit_per_user: Absent[int] = MISSING, @@ -2233,7 +2333,7 @@ async def create_post( file: Optional["UPLOADABLE_TYPE"] = None, tts: bool = False, reason: Absent[str] = MISSING, - ) -> "GuildPublicThread": + ) -> "GuildForumPost": """ Create a post within this channel. @@ -2254,10 +2354,23 @@ async def create_post( reason: The reason for creating this post Returns: - A GuildPublicThread object representing the created post. + A GuildForumPost object representing the created post. """ if applied_tags != MISSING: - applied_tags = [str(tag.id) if isinstance(tag, ThreadTag) else str(tag) for tag in applied_tags] + processed = [] + for tag in applied_tags: + if isinstance(tag, ThreadTag): + tag = tag.id + elif isinstance(tag, (str, int)): + tag = self.get_tag(tag, case_insensitive=True) + if not tag: + continue + tag = tag.id + elif isinstance(tag, dict): + tag = tag["id"] + processed.append(tag) + + applied_tags = processed message_payload = models.discord.message.process_message_payload( content=content, @@ -2280,7 +2393,86 @@ async def create_post( ) return self._client.cache.place_channel_data(data) - async def create_tag(self, name: str, emoji: Union["models.PartialEmoji", dict, str]) -> "ThreadTag": + async def fetch_posts(self) -> List["GuildForumPost"]: + """ + Requests all active posts within this channel. + + Returns: + A list of GuildForumPost objects representing the posts. + """ + # I can guarantee this endpoint will need to be converted to an async iterator eventually + data = await self._client.http.list_active_threads(self._guild_id) + threads = [self._client.cache.place_channel_data(post_data) for post_data in data["threads"]] + + return [thread for thread in threads if thread.parent_id == self.id] + + def get_posts(self, *, exclude_archived: bool = True) -> List["GuildForumPost"]: + """ + List all, cached, active posts within this channel. + + Args: + exclude_archived: Whether to exclude archived posts from the response + + Returns: + A list of GuildForumPost objects representing the posts. + """ + out = [thread for thread in self.guild.threads if thread.parent_id == self.id] + if exclude_archived: + return [thread for thread in out if not thread.archived] + return out + + def archived_posts(self, limit: int = 0, before: Snowflake_Type | None = None) -> ArchivedForumPosts: + """An async iterator for all archived posts in this channel.""" + return ArchivedForumPosts(self, limit, before) + + async def fetch_post(self, id: "Snowflake_Type") -> "GuildForumPost": + """ + Fetch a post within this channel. + + Args: + id: The id of the post to fetch + + Returns: + A GuildForumPost object representing the post. + """ + return await self._client.fetch_channel(id) + + def get_post(self, id: "Snowflake_Type") -> "GuildForumPost": + """ + Get a post within this channel. + + Args: + id: The id of the post to get + + Returns: + A GuildForumPost object representing the post. + """ + return self._client.cache.get_channel(id) + + def get_tag(self, value: str | Snowflake_Type, *, case_insensitive: bool = False) -> Optional["ThreadTag"]: + """ + Get a tag within this channel. + + Args: + value: The name or ID of the tag to get + case_insensitive: Whether to ignore case when searching for the tag + + Returns: + A ThreadTag object representing the tag. + """ + + def maybe_insensitive(string: str) -> str: + return string.lower() if case_insensitive else string + + def predicate(tag: ThreadTag) -> Optional["ThreadTag"]: + if str(tag.id) == str(value): + return tag + if maybe_insensitive(tag.name) == maybe_insensitive(value): + return tag + + return next((tag for tag in self.available_tags if predicate(tag)), None) + + async def create_tag(self, name: str, emoji: Union["models.PartialEmoji", dict, str, None] = None) -> "ThreadTag": """ Create a tag for this forum. @@ -2295,15 +2487,20 @@ async def create_tag(self, name: str, emoji: Union["models.PartialEmoji", dict, The created tag object. """ - if isinstance(emoji, str): - emoji = PartialEmoji.from_str(emoji) - elif isinstance(emoji, dict): - emoji = PartialEmoji.from_dict(emoji) + payload = {"channel_id": self.id, "name": name} - if emoji.id: - data = await self._client.http.create_tag(self.id, name, emoji_id=emoji.id) - else: - data = await self._client.http.create_tag(self.id, name, emoji_name=emoji.name) + if emoji: + if isinstance(emoji, str): + emoji = PartialEmoji.from_str(emoji) + elif isinstance(emoji, dict): + emoji = PartialEmoji.from_dict(emoji) + + if emoji.id: + payload["emoji_id"] = emoji.id + else: + payload["emoji_name"] = emoji.name + + data = await self._client.http.create_tag(**payload) channel_data = self._client.cache.place_channel_data(data) return [tag for tag in channel_data.available_tags if tag.name == name][0] @@ -2383,6 +2580,7 @@ def process_permission_overwrites( GuildStageVoice, GuildCategory, GuildPublicThread, + GuildForumPost, GuildPrivateThread, GuildNewsThread, DM,