57
57
from synapse .storage .databases .main .end_to_end_keys import EndToEndKeyWorkerStore
58
58
from synapse .storage .databases .main .roommember import RoomMemberWorkerStore
59
59
from synapse .storage .types import Cursor
60
- from synapse .storage .util .id_generators import (
61
- AbstractStreamIdGenerator ,
62
- StreamIdGenerator ,
63
- )
60
+ from synapse .storage .util .id_generators import MultiWriterIdGenerator
64
61
from synapse .types import (
65
62
JsonDict ,
66
63
JsonMapping ,
@@ -99,19 +96,21 @@ def __init__(
99
96
100
97
# In the worker store this is an ID tracker which we overwrite in the non-worker
101
98
# class below that is used on the main process.
102
- self ._device_list_id_gen = StreamIdGenerator (
103
- db_conn ,
104
- hs .get_replication_notifier (),
105
- "device_lists_stream" ,
106
- "stream_id" ,
107
- extra_tables = [
108
- ("user_signature_stream" , "stream_id" ),
109
- ("device_lists_outbound_pokes" , "stream_id" ),
110
- ("device_lists_changes_in_room" , "stream_id" ),
111
- ("device_lists_remote_pending" , "stream_id" ),
112
- ("device_lists_changes_converted_stream_position" , "stream_id" ),
99
+ self ._device_list_id_gen = MultiWriterIdGenerator (
100
+ db_conn = db_conn ,
101
+ db = database ,
102
+ notifier = hs .get_replication_notifier (),
103
+ stream_name = "device_lists_stream" ,
104
+ instance_name = self ._instance_name ,
105
+ tables = [
106
+ ("device_lists_stream" , "instance_name" , "stream_id" ),
107
+ ("user_signature_stream" , "instance_name" , "stream_id" ),
108
+ ("device_lists_outbound_pokes" , "instance_name" , "stream_id" ),
109
+ ("device_lists_changes_in_room" , "instance_name" , "stream_id" ),
110
+ ("device_lists_remote_pending" , "instance_name" , "stream_id" ),
113
111
],
114
- is_writer = hs .config .worker .worker_app is None ,
112
+ sequence_name = "device_lists_sequence" ,
113
+ writers = ["master" ],
115
114
)
116
115
117
116
device_list_max = self ._device_list_id_gen .get_current_token ()
@@ -762,6 +761,7 @@ def _add_user_signature_change_txn(
762
761
"stream_id" : stream_id ,
763
762
"from_user_id" : from_user_id ,
764
763
"user_ids" : json_encoder .encode (user_ids ),
764
+ "instance_name" : self ._instance_name ,
765
765
},
766
766
)
767
767
@@ -1582,6 +1582,8 @@ def __init__(
1582
1582
):
1583
1583
super ().__init__ (database , db_conn , hs )
1584
1584
1585
+ self ._instance_name = hs .get_instance_name ()
1586
+
1585
1587
self .db_pool .updates .register_background_index_update (
1586
1588
"device_lists_stream_idx" ,
1587
1589
index_name = "device_lists_stream_user_id" ,
@@ -1694,6 +1696,7 @@ def _txn(txn: LoggingTransaction) -> int:
1694
1696
"device_lists_outbound_pokes" ,
1695
1697
{
1696
1698
"stream_id" : stream_id ,
1699
+ "instance_name" : self ._instance_name ,
1697
1700
"destination" : destination ,
1698
1701
"user_id" : user_id ,
1699
1702
"device_id" : device_id ,
@@ -1730,10 +1733,6 @@ def _txn(txn: LoggingTransaction) -> int:
1730
1733
1731
1734
1732
1735
class DeviceStore (DeviceWorkerStore , DeviceBackgroundUpdateStore ):
1733
- # Because we have write access, this will be a StreamIdGenerator
1734
- # (see DeviceWorkerStore.__init__)
1735
- _device_list_id_gen : AbstractStreamIdGenerator
1736
-
1737
1736
def __init__ (
1738
1737
self ,
1739
1738
database : DatabasePool ,
@@ -2092,9 +2091,9 @@ def _add_device_change_to_stream_txn(
2092
2091
self .db_pool .simple_insert_many_txn (
2093
2092
txn ,
2094
2093
table = "device_lists_stream" ,
2095
- keys = ("stream_id" , "user_id" , "device_id" ),
2094
+ keys = ("instance_name" , " stream_id" , "user_id" , "device_id" ),
2096
2095
values = [
2097
- (stream_id , user_id , device_id )
2096
+ (self . _instance_name , stream_id , user_id , device_id )
2098
2097
for stream_id , device_id in zip (stream_ids , device_ids )
2099
2098
],
2100
2099
)
@@ -2124,6 +2123,7 @@ def _add_device_outbound_poke_to_stream_txn(
2124
2123
values = [
2125
2124
(
2126
2125
destination ,
2126
+ self ._instance_name ,
2127
2127
next (stream_id_iterator ),
2128
2128
user_id ,
2129
2129
device_id ,
@@ -2139,6 +2139,7 @@ def _add_device_outbound_poke_to_stream_txn(
2139
2139
table = "device_lists_outbound_pokes" ,
2140
2140
keys = (
2141
2141
"destination" ,
2142
+ "instance_name" ,
2142
2143
"stream_id" ,
2143
2144
"user_id" ,
2144
2145
"device_id" ,
@@ -2157,7 +2158,7 @@ def _add_device_outbound_poke_to_stream_txn(
2157
2158
device_id ,
2158
2159
{
2159
2160
stream_id : destination
2160
- for (destination , stream_id , _ , _ , _ , _ , _ ) in values
2161
+ for (destination , _ , stream_id , _ , _ , _ , _ , _ ) in values
2161
2162
},
2162
2163
)
2163
2164
@@ -2210,6 +2211,7 @@ def _add_device_outbound_room_poke_txn(
2210
2211
"device_id" ,
2211
2212
"room_id" ,
2212
2213
"stream_id" ,
2214
+ "instance_name" ,
2213
2215
"converted_to_destinations" ,
2214
2216
"opentracing_context" ,
2215
2217
),
@@ -2219,6 +2221,7 @@ def _add_device_outbound_room_poke_txn(
2219
2221
device_id ,
2220
2222
room_id ,
2221
2223
stream_id ,
2224
+ self ._instance_name ,
2222
2225
# We only need to calculate outbound pokes for local users
2223
2226
not self .hs .is_mine_id (user_id ),
2224
2227
encoded_context ,
@@ -2338,7 +2341,10 @@ async def add_remote_device_list_to_pending(
2338
2341
"user_id" : user_id ,
2339
2342
"device_id" : device_id ,
2340
2343
},
2341
- values = {"stream_id" : stream_id },
2344
+ values = {
2345
+ "stream_id" : stream_id ,
2346
+ "instance_name" : self ._instance_name ,
2347
+ },
2342
2348
desc = "add_remote_device_list_to_pending" ,
2343
2349
)
2344
2350
0 commit comments