Skip to content

Commit 0874037

Browse files
committed
Allow using Discord threads
1 parent 6c820bf commit 0874037

File tree

5 files changed

+109
-54
lines changed

5 files changed

+109
-54
lines changed

Diff for: CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ however, insignificant breaking changes do not guarantee a major version bump, s
88

99
# v4.1.1
1010

11+
### Breaking
12+
- Modmail threads are now potentially Discord threads
13+
1114
### Fixed
1215
- `?msglink` now supports threads with multiple recipients. ([PR #3341](https://github.com/modmail-dev/Modmail/pull/3341))
1316
- Fixed persistent notes not working due to discord.py internal change. ([PR #3324](https://github.com/modmail-dev/Modmail/pull/3324))

Diff for: bot.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -305,13 +305,21 @@ def log_channel(self) -> typing.Optional[discord.TextChannel]:
305305
logger.debug("LOG_CHANNEL_ID was invalid, removed.")
306306
self.config.remove("log_channel_id")
307307
if self.main_category is not None:
308-
try:
309-
channel = self.main_category.channels[0]
310-
self.config["log_channel_id"] = channel.id
311-
logger.warning("No log channel set, setting #%s to be the log channel.", channel.name)
312-
return channel
313-
except IndexError:
314-
pass
308+
if isinstance(self.main_category, discord.CategoryChannel):
309+
try:
310+
channel = self.main_category.channels[0]
311+
self.config["log_channel_id"] = channel.id
312+
logger.warning("No log channel set, setting #%s to be the log channel.", channel.name)
313+
return channel
314+
except IndexError:
315+
pass
316+
elif isinstance(self.main_category, discord.TextChannel):
317+
self.config["log_channel_id"] = self.main_category.id
318+
logger.warning(
319+
"No log channel set, setting #%s to be the log channel.", self.main_category.name
320+
)
321+
return self.main_category
322+
315323
logger.warning(
316324
"No log channel set, set one with `%ssetup` or `%sconfig set log_channel_id <id>`.",
317325
self.prefix,
@@ -419,13 +427,13 @@ def using_multiple_server_setup(self) -> bool:
419427
return self.modmail_guild != self.guild
420428

421429
@property
422-
def main_category(self) -> typing.Optional[discord.CategoryChannel]:
430+
def main_category(self) -> typing.Optional[discord.abc.GuildChannel]:
423431
if self.modmail_guild is not None:
424432
category_id = self.config["main_category_id"]
425433
if category_id is not None:
426434
try:
427-
cat = discord.utils.get(self.modmail_guild.categories, id=int(category_id))
428-
if cat is not None:
435+
cat = discord.utils.get(self.modmail_guild.channels, id=int(category_id))
436+
if cat is not None and isinstance(cat, (discord.CategoryChannel, discord.TextChannel)):
429437
return cat
430438
except ValueError:
431439
pass
@@ -1351,11 +1359,12 @@ async def on_guild_channel_delete(self, channel):
13511359
if channel.guild != self.modmail_guild:
13521360
return
13531361

1362+
if self.main_category == channel:
1363+
logger.debug("Main category was deleted.")
1364+
self.config.remove("main_category_id")
1365+
await self.config.update()
1366+
13541367
if isinstance(channel, discord.CategoryChannel):
1355-
if self.main_category == channel:
1356-
logger.debug("Main category was deleted.")
1357-
self.config.remove("main_category_id")
1358-
await self.config.update()
13591368
return
13601369

13611370
if not isinstance(channel, discord.TextChannel):

Diff for: cogs/modmail.py

+9
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,9 @@ async def unsubscribe(self, ctx, *, user_or_role: Union[discord.Role, User, str.
678678
@checks.thread_only()
679679
async def nsfw(self, ctx):
680680
"""Flags a Modmail thread as NSFW (not safe for work)."""
681+
if isinstance(ctx.channel, discord.Thread):
682+
await ctx.send("Unable to set NSFW status for Discord threads.")
683+
return
681684
await ctx.channel.edit(nsfw=True)
682685
sent_emoji, _ = await self.bot.retrieve_emoji()
683686
await self.bot.add_reaction(ctx.message, sent_emoji)
@@ -687,6 +690,9 @@ async def nsfw(self, ctx):
687690
@checks.thread_only()
688691
async def sfw(self, ctx):
689692
"""Flags a Modmail thread as SFW (safe for work)."""
693+
if isinstance(ctx.channel, discord.Thread):
694+
await ctx.send("Unable to set NSFW status for Discord threads.")
695+
return
690696
await ctx.channel.edit(nsfw=False)
691697
sent_emoji, _ = await self.bot.retrieve_emoji()
692698
await self.bot.add_reaction(ctx.message, sent_emoji)
@@ -775,6 +781,9 @@ def format_log_embeds(self, logs, avatar_url):
775781
@commands.cooldown(1, 600, BucketType.channel)
776782
async def title(self, ctx, *, name: str):
777783
"""Sets title for a thread"""
784+
if isinstance(ctx.channel, discord.Thread):
785+
await ctx.send("Unable to set titles for Discord threads.")
786+
return
778787
await ctx.thread.set_title(name)
779788
sent_emoji, _ = await self.bot.retrieve_emoji()
780789
await ctx.message.pin()

Diff for: core/thread.py

+63-33
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,17 @@ async def from_channel(cls, manager: "ThreadManager", channel: discord.TextChann
150150

151151
async def get_genesis_message(self) -> discord.Message:
152152
if self._genesis_message is None:
153-
async for m in self.channel.history(limit=5, oldest_first=True):
154-
if m.author == self.bot.user:
155-
if m.embeds and m.embeds[0].fields and m.embeds[0].fields[0].name == "Roles":
156-
self._genesis_message = m
153+
self._genesis_message = await self._get_genesis_message(self.channel, self.bot.user)
157154

158155
return self._genesis_message
159156

157+
@staticmethod
158+
async def _get_genesis_message(channel, own_user) -> discord.Message | None:
159+
async for m in channel.history(limit=5, oldest_first=True):
160+
if m.author == own_user:
161+
if m.embeds and m.embeds[0].fields and m.embeds[0].fields[0].name == "Roles":
162+
return m
163+
160164
async def setup(self, *, creator=None, category=None, initial_message=None):
161165
"""Create the thread channel and other io related initialisation tasks"""
162166
self.bot.dispatch("thread_initiate", self, creator, category, initial_message)
@@ -434,9 +438,11 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,
434438
self.channel.id,
435439
{
436440
"open": False,
437-
"title": match_title(self.channel.topic),
441+
"title": match_title(self.channel.topic)
442+
if isinstance(self.channel, discord.TextChannel)
443+
else None,
438444
"closed_at": str(discord.utils.utcnow()),
439-
"nsfw": self.channel.nsfw,
445+
"nsfw": self.channel.nsfw if isinstance(self.channel, discord.TextChannel) else False,
440446
"close_message": message,
441447
"closer": {
442448
"id": str(closer.id),
@@ -466,7 +472,7 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,
466472
else:
467473
sneak_peak = "No content"
468474

469-
if self.channel.nsfw:
475+
if isinstance(self.channel, discord.TextChannel) and self.channel.nsfw:
470476
_nsfw = "NSFW-"
471477
else:
472478
_nsfw = ""
@@ -1230,39 +1236,39 @@ async def _update_users_genesis(self):
12301236
await genesis_message.edit(embed=embed)
12311237

12321238
async def add_users(self, users: typing.List[typing.Union[discord.Member, discord.User]]) -> None:
1233-
topic = ""
1234-
title, _, _ = parse_channel_topic(self.channel.topic)
1235-
if title is not None:
1236-
topic += f"Title: {title}\n"
1237-
1238-
topic += f"User ID: {self._id}"
1239-
12401239
self._other_recipients += users
12411240
self._other_recipients = list(set(self._other_recipients))
1241+
if isinstance(self.channel, discord.TextChannel):
1242+
topic = ""
1243+
title, _, _ = parse_channel_topic(self.channel.topic)
1244+
if title is not None:
1245+
topic += f"Title: {title}\n"
12421246

1243-
ids = ",".join(str(i.id) for i in self._other_recipients)
1247+
topic += f"User ID: {self._id}"
12441248

1245-
topic += f"\nOther Recipients: {ids}"
1249+
ids = ",".join(str(i.id) for i in self._other_recipients)
12461250

1247-
await self.channel.edit(topic=topic)
1251+
topic += f"\nOther Recipients: {ids}"
1252+
1253+
await self.channel.edit(topic=topic)
12481254
await self._update_users_genesis()
12491255

12501256
async def remove_users(self, users: typing.List[typing.Union[discord.Member, discord.User]]) -> None:
1251-
topic = ""
1252-
title, user_id, _ = parse_channel_topic(self.channel.topic)
1253-
if title is not None:
1254-
topic += f"Title: {title}\n"
1255-
1256-
topic += f"User ID: {user_id}"
1257-
12581257
for u in users:
12591258
self._other_recipients.remove(u)
1259+
if isinstance(self.channel, discord.TextChannel):
1260+
topic = ""
1261+
title, user_id, _ = parse_channel_topic(self.channel.topic)
1262+
if title is not None:
1263+
topic += f"Title: {title}\n"
12601264

1261-
if self._other_recipients:
1262-
ids = ",".join(str(i.id) for i in self._other_recipients)
1263-
topic += f"\nOther Recipients: {ids}"
1265+
topic += f"User ID: {user_id}"
12641266

1265-
await self.channel.edit(topic=topic)
1267+
if self._other_recipients:
1268+
ids = ",".join(str(i.id) for i in self._other_recipients)
1269+
topic += f"\nOther Recipients: {ids}"
1270+
1271+
await self.channel.edit(topic=topic)
12661272
await self._update_users_genesis()
12671273

12681274

@@ -1276,6 +1282,13 @@ def __init__(self, bot):
12761282
async def populate_cache(self) -> None:
12771283
for channel in self.bot.modmail_guild.text_channels:
12781284
await self.find(channel=channel)
1285+
for thread in self.bot.modmail_guild.threads:
1286+
await self.find(channel=thread)
1287+
# handle any threads archived while bot was offline (is this slow? yes. whatever....)
1288+
# (maybe this should only iterate until the archived_at timestamp is fine)
1289+
if isinstance(self.bot.main_category, discord.TextChannel):
1290+
async for thread in self.bot.main_category.archived_threads():
1291+
await self.find(channel=thread)
12791292

12801293
def __len__(self):
12811294
return len(self.cache)
@@ -1290,11 +1303,15 @@ async def find(
12901303
self,
12911304
*,
12921305
recipient: typing.Union[discord.Member, discord.User] = None,
1293-
channel: discord.TextChannel = None,
1306+
channel: discord.TextChannel | discord.Thread = None,
12941307
recipient_id: int = None,
12951308
) -> typing.Optional[Thread]:
12961309
"""Finds a thread from cache or from discord channel topics."""
1297-
if recipient is None and channel is not None and isinstance(channel, discord.TextChannel):
1310+
if (
1311+
recipient is None
1312+
and channel is not None
1313+
and isinstance(channel, (discord.TextChannel, discord.Thread))
1314+
):
12981315
thread = await self._find_from_channel(channel)
12991316
if thread is None:
13001317
user_id, thread = next(
@@ -1357,10 +1374,23 @@ async def _find_from_channel(self, channel):
13571374
extracts user_id from that.
13581375
"""
13591376

1360-
if not channel.topic:
1361-
return None
1377+
if isinstance(channel, discord.Thread) or not channel.topic:
1378+
# actually check for genesis embed :)
1379+
msg = await Thread._get_genesis_message(channel, self.bot.user)
1380+
if not msg:
1381+
return None
13621382

1363-
_, user_id, other_ids = parse_channel_topic(channel.topic)
1383+
embed = msg.embeds[0]
1384+
user_id = int((embed.footer.text or "-1").removeprefix("User ID: ").split(" ", 1)[0])
1385+
other_ids = []
1386+
for field in embed.fields:
1387+
if field.name == "Other Recipients" and field.value:
1388+
other_ids = map(
1389+
lambda mention: int(mention.removeprefix("<@").removeprefix("!").removesuffix(">")),
1390+
field.value.split(" "),
1391+
)
1392+
else:
1393+
_, user_id, other_ids = parse_channel_topic(channel.topic)
13641394

13651395
if user_id == -1:
13661396
return None

Diff for: core/utils.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,17 @@ async def create_thread_channel(bot, recipient, category, overwrites, *, name=No
457457
errors_raised = errors_raised or []
458458

459459
try:
460-
channel = await bot.modmail_guild.create_text_channel(
461-
name=name,
462-
category=category,
463-
overwrites=overwrites,
464-
topic=f"User ID: {recipient.id}",
465-
reason="Creating a thread channel.",
466-
)
460+
if isinstance(category, discord.TextChannel):
461+
# we ignore `overwrites`... maybe make private threads so it's similar?
462+
channel = await category.create_thread(name=name, reason="Creating a thread channel.")
463+
else:
464+
channel = await bot.modmail_guild.create_text_channel(
465+
name=name,
466+
category=category,
467+
overwrites=overwrites,
468+
topic=f"User ID: {recipient.id}",
469+
reason="Creating a thread channel.",
470+
)
467471
except discord.HTTPException as e:
468472
if (e.text, (category, name)) in errors_raised:
469473
# Just raise the error to prevent infinite recursion after retrying

0 commit comments

Comments
 (0)