Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

feat 💥: Forums rewrite #646

Merged
merged 19 commits into from
Sep 25, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion naff/api/events/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def on_guild_join(event):
"StageInstanceDelete",
"StageInstanceUpdate",
"ThreadCreate",
"NewThreadCreate",
"ThreadDelete",
"ThreadListSync",
"ThreadMemberUpdate",
Expand Down Expand Up @@ -161,11 +162,16 @@ class ChannelPinsUpdate(ChannelCreate):

@define(kw_only=False)
class ThreadCreate(BaseEvent):
"""Dispatched when a thread is created."""
"""Dispatched when a thread is created, or a thread is new to the client"""

thread: "TYPE_THREAD_CHANNEL" = field(metadata=docs("The thread this event is dispatched from"))


@define(kw_only=False)
class NewThreadCreate(ThreadCreate):
"""Dispatched when a thread is newly created."""


@define(kw_only=False)
class ThreadUpdate(ThreadCreate):
"""Dispatched when a thread is updated."""
Expand Down
8 changes: 5 additions & 3 deletions naff/api/events/processors/thread_events.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import TYPE_CHECKING

import naff.api.events as events

from ._template import EventMixinTemplate, Processor
from naff.models import to_snowflake
from ._template import EventMixinTemplate, Processor

if TYPE_CHECKING:
from naff.api.events import RawGatewayEvent
Expand All @@ -14,7 +13,10 @@
class ThreadEvents(EventMixinTemplate):
@Processor.define()
async def _on_raw_thread_create(self, event: "RawGatewayEvent") -> None:
self.dispatch(events.ThreadCreate(self.cache.place_channel_data(event.data)))
thread = self.cache.place_channel_data(event.data)
if event.data.get("newly_created"):
self.dispatch(events.NewThreadCreate(thread))
self.dispatch(events.ThreadCreate(thread))

@Processor.define()
async def _on_raw_thread_update(self, event: "RawGatewayEvent") -> None:
Expand Down
6 changes: 3 additions & 3 deletions naff/api/http/http_requests/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,23 +210,23 @@ async def create_forum_thread(
name: The name of the thread
auto_archive_duration: Time before the thread will be automatically archived. Note 3 day and 7 day archive durations require the server to be boosted.
message: The message-content for the post/thread
applied_tags: The tags to apply to the thread
rate_limit_per_user: The time users must wait between sending messages
files: The files to upload
reason: The reason for creating this thread

Returns:
The created thread object
"""
# note: `{"use_nested_fields": 1}` seems to be a temporary flag until forums launch
return await self.request(
Route("POST", f"/channels/{channel_id}/threads"),
payload={
"name": name,
"auto_archive_duration": auto_archive_duration,
"rate_limit_per_user": rate_limit_per_user,
"applied_tags": applied_tags,
"message": message,
"applied_tags": applied_tags,
},
params={"use_nested_fields": 1},
files=files,
reason=reason,
)
196 changes: 171 additions & 25 deletions naff/models/discord/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@
import attrs

import naff.models as models

from naff.client.const import MISSING, DISCORD_EPOCH, Absent, logger
from naff.client.errors import NotFound, VoiceNotConnected, TooManyChanges
from naff.client.mixins.send import SendMixin
from naff.client.mixins.serialization import DictSerializationMixin
from naff.client.utils.attr_utils import define, field
from naff.client.utils.attr_converters import optional as optional_c
from naff.client.utils.attr_converters import timestamp_converter
from naff.client.utils.attr_utils import define, field
from naff.client.utils.misc_utils import get
from naff.client.utils.serializer import to_dict, to_image_data
from naff.models.discord.base import DiscordObject
from naff.models.discord.emoji import PartialEmoji
from naff.models.discord.file import UPLOADABLE_TYPE
from naff.models.discord.snowflake import Snowflake_Type, to_snowflake, to_optional_snowflake, SnowflakeObject
from naff.models.misc.iterator import AsyncIterator
from naff.models.discord.thread import ThreadTag
from naff.models.discord.emoji import PartialEmoji
from naff.models.misc.context_manager import Typing
from naff.models.misc.iterator import AsyncIterator
from .enums import (
ChannelFlags,
ChannelTypes,
Expand All @@ -32,7 +32,6 @@
MessageFlags,
InviteTargetTypes,
)
from naff.models.misc.context_manager import Typing

if TYPE_CHECKING:
from aiohttp import FormData
Expand Down Expand Up @@ -739,6 +738,12 @@ def from_dict_factory(cls, data: dict, client: "Client") -> "TYPE_ALL_CHANNEL":
logger.error(f"Unsupported channel type for {data} ({channel_type}).")
channel_class = BaseChannel

if channel_class == GuildPublicThread:
# attempt to determine if this thread is a forum post (thanks discord)
parent_channel = client.cache.get_channel(data["parent_id"])
if parent_channel and parent_channel.type == ChannelTypes.GUILD_FORUM:
channel_class = GuildForumPost

return channel_class.from_dict(data, client)

@property
Expand Down Expand Up @@ -1905,8 +1910,54 @@ async def edit(

@define()
class GuildPublicThread(ThreadChannel):
async def edit(
self,
name: Absent[str] = MISSING,
archived: Absent[bool] = MISSING,
auto_archive_duration: Absent[AutoArchiveDuration] = MISSING,
locked: Absent[bool] = MISSING,
rate_limit_per_user: Absent[int] = MISSING,
flags: Absent[Union[int, ChannelFlags]] = MISSING,
reason: Absent[str] = MISSING,
**kwargs,
) -> "GuildPublicThread":
LordOfPolls marked this conversation as resolved.
Show resolved Hide resolved
"""
Edit this thread.

_applied_tags: List[Snowflake_Type] = field(factory=list)
Args:
name: 1-100 character channel name
archived: whether the thread is archived
auto_archive_duration: duration in minutes to automatically archive the thread after recent activity, can be set to: 60, 1440, 4320, 10080
locked: whether the thread is locked; when a thread is locked, only users with MANAGE_THREADS can unarchive it
rate_limit_per_user: amount of seconds a user has to wait before sending another message (0-21600)
flags: channel flags for forum threads
reason: The reason for this change

Returns:
The edited thread channel object.
"""
return await super().edit(
name=name,
archived=archived,
auto_archive_duration=auto_archive_duration,
locked=locked,
rate_limit_per_user=rate_limit_per_user,
reason=reason,
flags=flags,
**kwargs,
)


@define()
class GuildForumPost(GuildPublicThread):
"""
A forum post

!!! note
This model is an abstraction of the api - In reality all posts are GuildPublicThread
"""

_applied_tags: list[Snowflake_Type] = field(factory=list)

async def edit(
self,
Expand All @@ -1926,11 +1977,11 @@ async def edit(
Args:
name: 1-100 character channel name
archived: whether the thread is archived
applied_tags: list of tags to apply to a forum post (!!! This is for forum threads only)
applied_tags: list of tags to apply
Kigstn marked this conversation as resolved.
Show resolved Hide resolved
auto_archive_duration: duration in minutes to automatically archive the thread after recent activity, can be set to: 60, 1440, 4320, 10080
locked: whether the thread is locked; when a thread is locked, only users with MANAGE_THREADS can unarchive it
rate_limit_per_user: amount of seconds a user has to wait before sending another message (0-21600)
flags: channel flags for forum threads
flags: channel flags to apply
reason: The reason for this change

Returns:
Expand All @@ -1953,29 +2004,44 @@ async def edit(

@property
def applied_tags(self) -> list[ThreadTag]:
"""
The tags applied to this thread.

!!! note
This is only on forum threads.

"""
"""The tags applied to this thread."""
if not isinstance(self.parent_channel, GuildForum):
raise AttributeError("This is only available on forum threads.")
return [tag for tag in self.parent_channel.available_tags if str(tag.id) in self._applied_tags]

@property
def initial_post(self) -> Optional["Message"]:
"""The initial message posted by the OP."""
if not isinstance(self.parent_channel, GuildForum):
raise AttributeError("This is only available on forum threads.")
return self.get_message(self.id)

@property
def pinned(self) -> bool:
"""Whether this thread is pinned."""
return ChannelFlags.PINNED in self.flags

async def pin(self, reason: Absent[str] = MISSING) -> None:
"""
The initial message posted by the OP.
Pin this thread.

!!! note
This is only on forum threads.
Args:
reason: The reason for this pin

"""
if not isinstance(self.parent_channel, GuildForum):
raise AttributeError("This is only available on forum threads.")
return self.get_message(self.id)
flags = self.flags | ChannelFlags.PINNED
await self.edit(flags=flags, reason=reason)

async def unpin(self, reason: Absent[str] = MISSING) -> None:
"""
Unpin this thread.

Args:
reason: The reason for this unpin

"""
flags = self.flags & ~ChannelFlags.PINNED
await self.edit(flags=flags, reason=reason)


@define()
Expand Down Expand Up @@ -2218,7 +2284,7 @@ async def create_post(
self,
name: str,
content: str | None,
applied_tags: Optional[List[Union["Snowflake_Type", "ThreadTag"]]] = MISSING,
applied_tags: Optional[List[Union["Snowflake_Type", "ThreadTag", str]]] = MISSING,
Kigstn marked this conversation as resolved.
Show resolved Hide resolved
*,
auto_archive_duration: AutoArchiveDuration = AutoArchiveDuration.ONE_DAY,
rate_limit_per_user: Absent[int] = MISSING,
Expand All @@ -2233,7 +2299,7 @@ async def create_post(
file: Optional["UPLOADABLE_TYPE"] = None,
tts: bool = False,
reason: Absent[str] = MISSING,
) -> "GuildPublicThread":
) -> "GuildForumPost":
"""
Create a post within this channel.

Expand All @@ -2254,10 +2320,21 @@ async def create_post(
reason: The reason for creating this post

Returns:
A GuildPublicThread object representing the created post.
A GuildForumPost object representing the created post.
"""
if applied_tags != MISSING:
applied_tags = [str(tag.id) if isinstance(tag, ThreadTag) else str(tag) for tag in applied_tags]
processed = []
for tag in applied_tags:
if isinstance(tag, ThreadTag):
tag = tag.id
if isinstance(tag, (str, int)):
LordOfPolls marked this conversation as resolved.
Show resolved Hide resolved
tag = self.get_tag(tag, case_insensitive=True)
if not tag:
continue
tag = tag.id
processed.append(tag)

applied_tags = processed

message_payload = models.discord.message.process_message_payload(
content=content,
Expand All @@ -2280,6 +2357,74 @@ async def create_post(
)
return self._client.cache.place_channel_data(data)

async def fetch_posts(self) -> List["GuildForumPost"]:
LordOfPolls marked this conversation as resolved.
Show resolved Hide resolved
"""
Requests all active posts within this channel.

Returns:
A list of GuildForumPost objects representing the posts.
"""
data = await self._client.http.list_active_threads(self._guild_id)
threads = [self._client.cache.place_channel_data(post_data) for post_data in data["threads"]]

return [thread for thread in threads if thread.parent_id == self.id]

def get_posts(self) -> List["GuildForumPost"]:
"""
List all, cached, active posts within this channel.

Returns:
A list of GuildForumPost objects representing the posts.
"""
return [thread for thread in self.guild.threads if thread.parent_id == self.id]

async def fetch_post(self, id: "Snowflake_Type") -> "GuildForumPost":
"""
Fetch a post within this channel.

Args:
id: The id of the post to fetch

Returns:
A GuildForumPost object representing the post.
"""
return await self._client.fetch_channel(id)

def get_post(self, id: "Snowflake_Type") -> "GuildForumPost":
"""
Get a post within this channel.

Args:
id: The id of the post to get

Returns:
A GuildForumPost object representing the post.
"""
return self._client.cache.get_channel(id)

def get_tag(self, value: str | Snowflake_Type, *, case_insensitive: bool = False) -> Optional["ThreadTag"]:
"""
Get a tag within this channel.

Args:
value: The name or ID of the tag to get
case_insensitive: Whether to ignore case when searching for the tag

Returns:
A ThreadTag object representing the tag.
"""

def maybe_insensitive(string: str) -> str:
return string.lower() if case_insensitive else string

def predicate(tag: ThreadTag) -> Optional["ThreadTag"]:
if str(tag.id) == str(value):
return tag
if maybe_insensitive(tag.name) == maybe_insensitive(value):
return tag

return next((tag for tag in self.available_tags if predicate(tag)), None)

async def create_tag(self, name: str, emoji: Union["models.PartialEmoji", dict, str]) -> "ThreadTag":
"""
Create a tag for this forum.
Expand Down Expand Up @@ -2383,6 +2528,7 @@ def process_permission_overwrites(
GuildStageVoice,
GuildCategory,
GuildPublicThread,
GuildForumPost,
GuildPrivateThread,
GuildNewsThread,
DM,
Expand Down