@@ -150,13 +150,17 @@ async def from_channel(cls, manager: "ThreadManager", channel: discord.TextChann
150
150
151
151
async def get_genesis_message (self ) -> discord .Message :
152
152
if self ._genesis_message is None :
153
- async for m in self .channel .history (limit = 5 , oldest_first = True ):
154
- if m .author == self .bot .user :
155
- if m .embeds and m .embeds [0 ].fields and m .embeds [0 ].fields [0 ].name == "Roles" :
156
- self ._genesis_message = m
153
+ self ._genesis_message = await self ._get_genesis_message (self .channel , self .bot .user )
157
154
158
155
return self ._genesis_message
159
156
157
+ @staticmethod
158
+ async def _get_genesis_message (channel , own_user ) -> discord .Message | None :
159
+ async for m in channel .history (limit = 5 , oldest_first = True ):
160
+ if m .author == own_user :
161
+ if m .embeds and m .embeds [0 ].fields and m .embeds [0 ].fields [0 ].name == "Roles" :
162
+ return m
163
+
160
164
async def setup (self , * , creator = None , category = None , initial_message = None ):
161
165
"""Create the thread channel and other io related initialisation tasks"""
162
166
self .bot .dispatch ("thread_initiate" , self , creator , category , initial_message )
@@ -434,9 +438,11 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,
434
438
self .channel .id ,
435
439
{
436
440
"open" : False ,
437
- "title" : match_title (self .channel .topic ),
441
+ "title" : match_title (self .channel .topic )
442
+ if isinstance (self .channel , discord .TextChannel )
443
+ else None ,
438
444
"closed_at" : str (discord .utils .utcnow ()),
439
- "nsfw" : self .channel .nsfw ,
445
+ "nsfw" : self .channel .nsfw if isinstance ( self . channel , discord . TextChannel ) else False ,
440
446
"close_message" : message ,
441
447
"closer" : {
442
448
"id" : str (closer .id ),
@@ -466,7 +472,7 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,
466
472
else :
467
473
sneak_peak = "No content"
468
474
469
- if self .channel .nsfw :
475
+ if isinstance ( self . channel , discord . TextChannel ) and self .channel .nsfw :
470
476
_nsfw = "NSFW-"
471
477
else :
472
478
_nsfw = ""
@@ -1230,39 +1236,39 @@ async def _update_users_genesis(self):
1230
1236
await genesis_message .edit (embed = embed )
1231
1237
1232
1238
async def add_users (self , users : typing .List [typing .Union [discord .Member , discord .User ]]) -> None :
1233
- topic = ""
1234
- title , _ , _ = parse_channel_topic (self .channel .topic )
1235
- if title is not None :
1236
- topic += f"Title: { title } \n "
1237
-
1238
- topic += f"User ID: { self ._id } "
1239
-
1240
1239
self ._other_recipients += users
1241
1240
self ._other_recipients = list (set (self ._other_recipients ))
1241
+ if isinstance (self .channel , discord .TextChannel ):
1242
+ topic = ""
1243
+ title , _ , _ = parse_channel_topic (self .channel .topic )
1244
+ if title is not None :
1245
+ topic += f"Title: { title } \n "
1242
1246
1243
- ids = "," . join ( str ( i . id ) for i in self ._other_recipients )
1247
+ topic += f"User ID: { self ._id } "
1244
1248
1245
- topic += f" \n Other Recipients: { ids } "
1249
+ ids = "," . join ( str ( i . id ) for i in self . _other_recipients )
1246
1250
1247
- await self .channel .edit (topic = topic )
1251
+ topic += f"\n Other Recipients: { ids } "
1252
+
1253
+ await self .channel .edit (topic = topic )
1248
1254
await self ._update_users_genesis ()
1249
1255
1250
1256
async def remove_users (self , users : typing .List [typing .Union [discord .Member , discord .User ]]) -> None :
1251
- topic = ""
1252
- title , user_id , _ = parse_channel_topic (self .channel .topic )
1253
- if title is not None :
1254
- topic += f"Title: { title } \n "
1255
-
1256
- topic += f"User ID: { user_id } "
1257
-
1258
1257
for u in users :
1259
1258
self ._other_recipients .remove (u )
1259
+ if isinstance (self .channel , discord .TextChannel ):
1260
+ topic = ""
1261
+ title , user_id , _ = parse_channel_topic (self .channel .topic )
1262
+ if title is not None :
1263
+ topic += f"Title: { title } \n "
1260
1264
1261
- if self ._other_recipients :
1262
- ids = "," .join (str (i .id ) for i in self ._other_recipients )
1263
- topic += f"\n Other Recipients: { ids } "
1265
+ topic += f"User ID: { user_id } "
1264
1266
1265
- await self .channel .edit (topic = topic )
1267
+ if self ._other_recipients :
1268
+ ids = "," .join (str (i .id ) for i in self ._other_recipients )
1269
+ topic += f"\n Other Recipients: { ids } "
1270
+
1271
+ await self .channel .edit (topic = topic )
1266
1272
await self ._update_users_genesis ()
1267
1273
1268
1274
@@ -1276,6 +1282,13 @@ def __init__(self, bot):
1276
1282
async def populate_cache (self ) -> None :
1277
1283
for channel in self .bot .modmail_guild .text_channels :
1278
1284
await self .find (channel = channel )
1285
+ for thread in self .bot .modmail_guild .threads :
1286
+ await self .find (channel = thread )
1287
+ # handle any threads archived while bot was offline (is this slow? yes. whatever....)
1288
+ # (maybe this should only iterate until the archived_at timestamp is fine)
1289
+ if isinstance (self .bot .main_category , discord .TextChannel ):
1290
+ async for thread in self .bot .main_category .archived_threads ():
1291
+ await self .find (channel = thread )
1279
1292
1280
1293
def __len__ (self ):
1281
1294
return len (self .cache )
@@ -1290,11 +1303,15 @@ async def find(
1290
1303
self ,
1291
1304
* ,
1292
1305
recipient : typing .Union [discord .Member , discord .User ] = None ,
1293
- channel : discord .TextChannel = None ,
1306
+ channel : discord .TextChannel | discord . Thread = None ,
1294
1307
recipient_id : int = None ,
1295
1308
) -> typing .Optional [Thread ]:
1296
1309
"""Finds a thread from cache or from discord channel topics."""
1297
- if recipient is None and channel is not None and isinstance (channel , discord .TextChannel ):
1310
+ if (
1311
+ recipient is None
1312
+ and channel is not None
1313
+ and isinstance (channel , (discord .TextChannel , discord .Thread ))
1314
+ ):
1298
1315
thread = await self ._find_from_channel (channel )
1299
1316
if thread is None :
1300
1317
user_id , thread = next (
@@ -1357,10 +1374,23 @@ async def _find_from_channel(self, channel):
1357
1374
extracts user_id from that.
1358
1375
"""
1359
1376
1360
- if not channel .topic :
1361
- return None
1377
+ if isinstance (channel , discord .Thread ) or not channel .topic :
1378
+ # actually check for genesis embed :)
1379
+ msg = await Thread ._get_genesis_message (channel , self .bot .user )
1380
+ if not msg :
1381
+ return None
1362
1382
1363
- _ , user_id , other_ids = parse_channel_topic (channel .topic )
1383
+ embed = msg .embeds [0 ]
1384
+ user_id = int ((embed .footer .text or "-1" ).removeprefix ("User ID: " ).split (" " , 1 )[0 ])
1385
+ other_ids = []
1386
+ for field in embed .fields :
1387
+ if field .name == "Other Recipients" and field .value :
1388
+ other_ids = map (
1389
+ lambda mention : int (mention .removeprefix ("<@" ).removeprefix ("!" ).removesuffix (">" )),
1390
+ field .value .split (" " ),
1391
+ )
1392
+ else :
1393
+ _ , user_id , other_ids = parse_channel_topic (channel .topic )
1364
1394
1365
1395
if user_id == - 1 :
1366
1396
return None
0 commit comments