diff --git a/CHANGELOG.md b/CHANGELOG.md index 688462f439..060ffb47a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ however, insignificant breaking changes do not guarantee a major version bump, s # v4.1.1 +### Breaking +- Modmail threads are now potentially Discord threads + ### Fixed - `?msglink` now supports threads with multiple recipients. ([PR #3341](https://github.com/modmail-dev/Modmail/pull/3341)) - Fixed persistent notes not working due to discord.py internal change. ([PR #3324](https://github.com/modmail-dev/Modmail/pull/3324)) diff --git a/bot.py b/bot.py index 3c6ebe7911..47c0ca68cb 100644 --- a/bot.py +++ b/bot.py @@ -305,13 +305,21 @@ def log_channel(self) -> typing.Optional[discord.TextChannel]: logger.debug("LOG_CHANNEL_ID was invalid, removed.") self.config.remove("log_channel_id") if self.main_category is not None: - try: - channel = self.main_category.channels[0] - self.config["log_channel_id"] = channel.id - logger.warning("No log channel set, setting #%s to be the log channel.", channel.name) - return channel - except IndexError: - pass + if isinstance(self.main_category, discord.CategoryChannel): + try: + channel = self.main_category.channels[0] + self.config["log_channel_id"] = channel.id + logger.warning("No log channel set, setting #%s to be the log channel.", channel.name) + return channel + except IndexError: + pass + elif isinstance(self.main_category, discord.TextChannel): + self.config["log_channel_id"] = self.main_category.id + logger.warning( + "No log channel set, setting #%s to be the log channel.", self.main_category.name + ) + return self.main_category + logger.warning( "No log channel set, set one with `%ssetup` or `%sconfig set log_channel_id `.", self.prefix, @@ -419,13 +427,13 @@ def using_multiple_server_setup(self) -> bool: return self.modmail_guild != self.guild @property - def main_category(self) -> typing.Optional[discord.CategoryChannel]: + def main_category(self) -> typing.Optional[discord.abc.GuildChannel]: if self.modmail_guild is not None: category_id = self.config["main_category_id"] if category_id is not None: try: - cat = discord.utils.get(self.modmail_guild.categories, id=int(category_id)) - if cat is not None: + cat = discord.utils.get(self.modmail_guild.channels, id=int(category_id)) + if cat is not None and isinstance(cat, (discord.CategoryChannel, discord.TextChannel)): return cat except ValueError: pass @@ -1351,11 +1359,12 @@ async def on_guild_channel_delete(self, channel): if channel.guild != self.modmail_guild: return + if self.main_category == channel: + logger.debug("Main category was deleted.") + self.config.remove("main_category_id") + await self.config.update() + if isinstance(channel, discord.CategoryChannel): - if self.main_category == channel: - logger.debug("Main category was deleted.") - self.config.remove("main_category_id") - await self.config.update() return if not isinstance(channel, discord.TextChannel): diff --git a/cogs/modmail.py b/cogs/modmail.py index e2a0039384..5e964e0c73 100644 --- a/cogs/modmail.py +++ b/cogs/modmail.py @@ -678,6 +678,9 @@ async def unsubscribe(self, ctx, *, user_or_role: Union[discord.Role, User, str. @checks.thread_only() async def nsfw(self, ctx): """Flags a Modmail thread as NSFW (not safe for work).""" + if isinstance(ctx.channel, discord.Thread): + await ctx.send("Unable to set NSFW status for Discord threads.") + return await ctx.channel.edit(nsfw=True) sent_emoji, _ = await self.bot.retrieve_emoji() await self.bot.add_reaction(ctx.message, sent_emoji) @@ -687,6 +690,9 @@ async def nsfw(self, ctx): @checks.thread_only() async def sfw(self, ctx): """Flags a Modmail thread as SFW (safe for work).""" + if isinstance(ctx.channel, discord.Thread): + await ctx.send("Unable to set NSFW status for Discord threads.") + return await ctx.channel.edit(nsfw=False) sent_emoji, _ = await self.bot.retrieve_emoji() await self.bot.add_reaction(ctx.message, sent_emoji) @@ -775,6 +781,9 @@ def format_log_embeds(self, logs, avatar_url): @commands.cooldown(1, 600, BucketType.channel) async def title(self, ctx, *, name: str): """Sets title for a thread""" + if isinstance(ctx.channel, discord.Thread): + await ctx.send("Unable to set titles for Discord threads.") + return await ctx.thread.set_title(name) sent_emoji, _ = await self.bot.retrieve_emoji() await ctx.message.pin() diff --git a/core/thread.py b/core/thread.py index 81dc03f44d..19c3806be9 100644 --- a/core/thread.py +++ b/core/thread.py @@ -150,13 +150,17 @@ async def from_channel(cls, manager: "ThreadManager", channel: discord.TextChann async def get_genesis_message(self) -> discord.Message: if self._genesis_message is None: - async for m in self.channel.history(limit=5, oldest_first=True): - if m.author == self.bot.user: - if m.embeds and m.embeds[0].fields and m.embeds[0].fields[0].name == "Roles": - self._genesis_message = m + self._genesis_message = await self._get_genesis_message(self.channel, self.bot.user) return self._genesis_message + @staticmethod + async def _get_genesis_message(channel, own_user) -> discord.Message | None: + async for m in channel.history(limit=5, oldest_first=True): + if m.author == own_user: + if m.embeds and m.embeds[0].fields and m.embeds[0].fields[0].name == "Roles": + return m + async def setup(self, *, creator=None, category=None, initial_message=None): """Create the thread channel and other io related initialisation tasks""" self.bot.dispatch("thread_initiate", self, creator, category, initial_message) @@ -294,6 +298,11 @@ async def activate_auto_triggers(): activate_auto_triggers(), send_persistent_notes(), ) + if creator is not None: + # now that the genesis message is sent, + # we can cache things. + creator.cache[self.recipient.id] = self + self.bot.dispatch("thread_ready", self, creator, category, initial_message) def _format_info_embed(self, user, log_url, log_count, color): @@ -434,9 +443,11 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None, self.channel.id, { "open": False, - "title": match_title(self.channel.topic), + "title": match_title(self.channel.topic) + if isinstance(self.channel, discord.TextChannel) + else None, "closed_at": str(discord.utils.utcnow()), - "nsfw": self.channel.nsfw, + "nsfw": self.channel.nsfw if isinstance(self.channel, discord.TextChannel) else False, "close_message": message, "closer": { "id": str(closer.id), @@ -466,7 +477,7 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None, else: sneak_peak = "No content" - if self.channel.nsfw: + if isinstance(self.channel, discord.TextChannel) and self.channel.nsfw: _nsfw = "NSFW-" else: _nsfw = "" @@ -1230,39 +1241,39 @@ async def _update_users_genesis(self): await genesis_message.edit(embed=embed) async def add_users(self, users: typing.List[typing.Union[discord.Member, discord.User]]) -> None: - topic = "" - title, _, _ = parse_channel_topic(self.channel.topic) - if title is not None: - topic += f"Title: {title}\n" - - topic += f"User ID: {self._id}" - self._other_recipients += users self._other_recipients = list(set(self._other_recipients)) + if isinstance(self.channel, discord.TextChannel): + topic = "" + title, _, _ = parse_channel_topic(self.channel.topic) + if title is not None: + topic += f"Title: {title}\n" - ids = ",".join(str(i.id) for i in self._other_recipients) + topic += f"User ID: {self._id}" - topic += f"\nOther Recipients: {ids}" + ids = ",".join(str(i.id) for i in self._other_recipients) - await self.channel.edit(topic=topic) + topic += f"\nOther Recipients: {ids}" + + await self.channel.edit(topic=topic) await self._update_users_genesis() async def remove_users(self, users: typing.List[typing.Union[discord.Member, discord.User]]) -> None: - topic = "" - title, user_id, _ = parse_channel_topic(self.channel.topic) - if title is not None: - topic += f"Title: {title}\n" - - topic += f"User ID: {user_id}" - for u in users: self._other_recipients.remove(u) + if isinstance(self.channel, discord.TextChannel): + topic = "" + title, user_id, _ = parse_channel_topic(self.channel.topic) + if title is not None: + topic += f"Title: {title}\n" - if self._other_recipients: - ids = ",".join(str(i.id) for i in self._other_recipients) - topic += f"\nOther Recipients: {ids}" + topic += f"User ID: {user_id}" - await self.channel.edit(topic=topic) + if self._other_recipients: + ids = ",".join(str(i.id) for i in self._other_recipients) + topic += f"\nOther Recipients: {ids}" + + await self.channel.edit(topic=topic) await self._update_users_genesis() @@ -1276,6 +1287,13 @@ def __init__(self, bot): async def populate_cache(self) -> None: for channel in self.bot.modmail_guild.text_channels: await self.find(channel=channel) + for thread in self.bot.modmail_guild.threads: + await self.find(channel=thread) + # handle any threads archived while bot was offline (is this slow? yes. whatever....) + # (maybe this should only iterate until the archived_at timestamp is fine) + if isinstance(self.bot.main_category, discord.TextChannel): + async for thread in self.bot.main_category.archived_threads(): + await self.find(channel=thread) def __len__(self): return len(self.cache) @@ -1290,19 +1308,25 @@ async def find( self, *, recipient: typing.Union[discord.Member, discord.User] = None, - channel: discord.TextChannel = None, + channel: discord.TextChannel | discord.Thread = None, recipient_id: int = None, ) -> typing.Optional[Thread]: """Finds a thread from cache or from discord channel topics.""" - if recipient is None and channel is not None and isinstance(channel, discord.TextChannel): + if ( + recipient is None + and channel is not None + and isinstance(channel, (discord.TextChannel, discord.Thread)) + ): + # check cache *before* potentially awaiting + user_id, cache_thread = next( + ((k, v) for k, v in self.cache.items() if v.channel == channel), (-1, None) + ) + thread = await self._find_from_channel(channel) - if thread is None: - user_id, thread = next( - ((k, v) for k, v in self.cache.items() if v.channel == channel), (-1, None) - ) - if thread is not None: - logger.debug("Found thread with tempered ID.") - await channel.edit(topic=f"User ID: {user_id}") + if thread is None and cache_thread is not None: + logger.debug("Found thread with tampered ID.") + await channel.edit(topic=f"User ID: {user_id}") + thread = cache_thread return thread if recipient: @@ -1357,10 +1381,23 @@ async def _find_from_channel(self, channel): extracts user_id from that. """ - if not channel.topic: - return None + if isinstance(channel, discord.Thread) or not channel.topic: + # actually check for genesis embed :) + msg = await Thread._get_genesis_message(channel, self.bot.user) + if not msg: + return None - _, user_id, other_ids = parse_channel_topic(channel.topic) + embed = msg.embeds[0] + user_id = int((embed.footer.text or "-1").removeprefix("User ID: ").split(" ", 1)[0]) + other_ids = [] + for field in embed.fields: + if field.name == "Other Recipients" and field.value: + other_ids = map( + lambda mention: int(mention.removeprefix("<@").removeprefix("!").removesuffix(">")), + field.value.split(" "), + ) + else: + _, user_id, other_ids = parse_channel_topic(channel.topic) if user_id == -1: return None @@ -1419,8 +1456,6 @@ async def create( thread = Thread(self, recipient) - self.cache[recipient.id] = thread - if (message or not manual_trigger) and self.bot.config["confirm_thread_creation"]: if not manual_trigger: destination = recipient diff --git a/core/utils.py b/core/utils.py index 9f9f572f5a..c9c5b8955e 100644 --- a/core/utils.py +++ b/core/utils.py @@ -457,13 +457,17 @@ async def create_thread_channel(bot, recipient, category, overwrites, *, name=No errors_raised = errors_raised or [] try: - channel = await bot.modmail_guild.create_text_channel( - name=name, - category=category, - overwrites=overwrites, - topic=f"User ID: {recipient.id}", - reason="Creating a thread channel.", - ) + if isinstance(category, discord.TextChannel): + # we ignore `overwrites`... maybe make private threads so it's similar? + channel = await category.create_thread(name=name, reason="Creating a thread channel.", type=discord.ChannelType.public_thread) + else: + channel = await bot.modmail_guild.create_text_channel( + name=name, + category=category, + overwrites=overwrites, + topic=f"User ID: {recipient.id}", + reason="Creating a thread channel.", + ) except discord.HTTPException as e: if (e.text, (category, name)) in errors_raised: # Just raise the error to prevent infinite recursion after retrying