@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
37
37
async def get_relations_for_event (
38
38
self ,
39
39
event_id : str ,
40
+ room_id : str ,
40
41
relation_type : Optional [str ] = None ,
41
42
event_type : Optional [str ] = None ,
42
43
aggregation_key : Optional [str ] = None ,
@@ -49,6 +50,7 @@ async def get_relations_for_event(
49
50
50
51
Args:
51
52
event_id: Fetch events that relate to this event ID.
53
+ room_id: The room the event belongs to.
52
54
relation_type: Only fetch events with this relation type, if given.
53
55
event_type: Only fetch events with this event type, if given.
54
56
aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ async def get_relations_for_event(
63
65
the form `{"event_id": "..."}`.
64
66
"""
65
67
66
- where_clause = ["relates_to_id = ?" ]
67
- where_args : List [Union [str , int ]] = [event_id ]
68
+ where_clause = ["relates_to_id = ?" , "room_id = ?" ]
69
+ where_args : List [Union [str , int ]] = [event_id , room_id ]
68
70
69
71
if relation_type is not None :
70
72
where_clause .append ("relation_type = ?" )
@@ -199,6 +201,7 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool:
199
201
async def get_aggregation_groups_for_event (
200
202
self ,
201
203
event_id : str ,
204
+ room_id : str ,
202
205
event_type : Optional [str ] = None ,
203
206
limit : int = 5 ,
204
207
direction : str = "b" ,
@@ -213,6 +216,7 @@ async def get_aggregation_groups_for_event(
213
216
214
217
Args:
215
218
event_id: Fetch events that relate to this event ID.
219
+ room_id: The room the event belongs to.
216
220
event_type: Only fetch events with this event type, if given.
217
221
limit: Only fetch the `limit` groups.
218
222
direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +229,12 @@ async def get_aggregation_groups_for_event(
225
229
`type`, `key` and `count` fields.
226
230
"""
227
231
228
- where_clause = ["relates_to_id = ?" , "relation_type = ?" ]
229
- where_args : List [Union [str , int ]] = [event_id , RelationTypes .ANNOTATION ]
232
+ where_clause = ["relates_to_id = ?" , "room_id = ?" , "relation_type = ?" ]
233
+ where_args : List [Union [str , int ]] = [
234
+ event_id ,
235
+ room_id ,
236
+ RelationTypes .ANNOTATION ,
237
+ ]
230
238
231
239
if event_type :
232
240
where_clause .append ("type = ?" )
@@ -288,14 +296,17 @@ def _get_aggregation_groups_for_event_txn(
288
296
)
289
297
290
298
@cached ()
291
- async def get_applicable_edit (self , event_id : str ) -> Optional [EventBase ]:
299
+ async def get_applicable_edit (
300
+ self , event_id : str , room_id : str
301
+ ) -> Optional [EventBase ]:
292
302
"""Get the most recent edit (if any) that has happened for the given
293
303
event.
294
304
295
305
Correctly handles checking whether edits were allowed to happen.
296
306
297
307
Args:
298
308
event_id: The original event ID
309
+ room_id: The original event's room ID
299
310
300
311
Returns:
301
312
The most recent edit, if any.
@@ -317,13 +328,14 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
317
328
WHERE
318
329
relates_to_id = ?
319
330
AND relation_type = ?
331
+ AND edit.room_id = ?
320
332
AND edit.type = 'm.room.message'
321
333
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
322
334
LIMIT 1
323
335
"""
324
336
325
337
def _get_applicable_edit_txn (txn : LoggingTransaction ) -> Optional [str ]:
326
- txn .execute (sql , (event_id , RelationTypes .REPLACE ))
338
+ txn .execute (sql , (event_id , RelationTypes .REPLACE , room_id ))
327
339
row = txn .fetchone ()
328
340
if row :
329
341
return row [0 ]
@@ -340,13 +352,14 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
340
352
341
353
@cached ()
342
354
async def get_thread_summary (
343
- self , event_id : str
355
+ self , event_id : str , room_id : str
344
356
) -> Tuple [int , Optional [EventBase ]]:
345
357
"""Get the number of threaded replies, the senders of those replies, and
346
358
the latest reply (if any) for the given event.
347
359
348
360
Args:
349
- event_id: The original event ID
361
+ event_id: Summarize the thread related to this event ID.
362
+ room_id: The room the event belongs to.
350
363
351
364
Returns:
352
365
The number of items in the thread and the most recent response, if any.
@@ -363,12 +376,13 @@ def _get_thread_summary_txn(
363
376
INNER JOIN events USING (event_id)
364
377
WHERE
365
378
relates_to_id = ?
379
+ AND room_id = ?
366
380
AND relation_type = ?
367
381
ORDER BY topological_ordering DESC, stream_ordering DESC
368
382
LIMIT 1
369
383
"""
370
384
371
- txn .execute (sql , (event_id , RelationTypes .THREAD ))
385
+ txn .execute (sql , (event_id , room_id , RelationTypes .THREAD ))
372
386
row = txn .fetchone ()
373
387
if row is None :
374
388
return 0 , None
@@ -378,11 +392,13 @@ def _get_thread_summary_txn(
378
392
sql = """
379
393
SELECT COALESCE(COUNT(event_id), 0)
380
394
FROM event_relations
395
+ INNER JOIN events USING (event_id)
381
396
WHERE
382
397
relates_to_id = ?
398
+ AND room_id = ?
383
399
AND relation_type = ?
384
400
"""
385
- txn .execute (sql , (event_id , RelationTypes .THREAD ))
401
+ txn .execute (sql , (event_id , room_id , RelationTypes .THREAD ))
386
402
count = txn .fetchone ()[0 ] # type: ignore[index]
387
403
388
404
return count , latest_event_id
0 commit comments