14
14
# limitations under the License.
15
15
import logging
16
16
17
- from six import iteritems , itervalues
17
+ from six import iteritems
18
18
19
19
from canonicaljson import json
20
20
@@ -72,67 +72,146 @@ def get_devices_by_user(self, user_id):
72
72
73
73
defer .returnValue ({d ["device_id" ]: d for d in devices })
74
74
75
- def get_devices_by_remote (self , destination , from_stream_id ):
75
+ @defer .inlineCallbacks
76
+ def get_devices_by_remote (self , destination , from_stream_id , limit ):
76
77
"""Get stream of updates to send to remote servers
77
78
78
79
Returns:
79
- (int, list[dict]): current stream id and list of updates
80
+ Deferred[tuple[int, list[dict]]]:
81
+ current stream id (ie, the stream id of the last update included in the
82
+ response), and the list of updates
80
83
"""
81
84
now_stream_id = self ._device_list_id_gen .get_current_token ()
82
85
83
86
has_changed = self ._device_list_federation_stream_cache .has_entity_changed (
84
87
destination , int (from_stream_id )
85
88
)
86
89
if not has_changed :
87
- return (now_stream_id , [])
88
-
89
- return self .runInteraction (
90
+ defer .returnValue ((now_stream_id , []))
91
+
92
+ # We retrieve n+1 devices from the list of outbound pokes where n is
93
+ # our outbound device update limit. We then check if the very last
94
+ # device has the same stream_id as the second-to-last device. If so,
95
+ # then we ignore all devices with that stream_id and only send the
96
+ # devices with a lower stream_id.
97
+ #
98
+ # If when culling the list we end up with no devices afterwards, we
99
+ # consider the device update to be too large, and simply skip the
100
+ # stream_id; the rationale being that such a large device list update
101
+ # is likely an error.
102
+ updates = yield self .runInteraction (
90
103
"get_devices_by_remote" ,
91
104
self ._get_devices_by_remote_txn ,
92
105
destination ,
93
106
from_stream_id ,
94
107
now_stream_id ,
108
+ limit + 1 ,
95
109
)
96
110
111
+ # Return an empty list if there are no updates
112
+ if not updates :
113
+ defer .returnValue ((now_stream_id , []))
114
+
115
+ # if we have exceeded the limit, we need to exclude any results with the
116
+ # same stream_id as the last row.
117
+ if len (updates ) > limit :
118
+ stream_id_cutoff = updates [- 1 ][2 ]
119
+ now_stream_id = stream_id_cutoff - 1
120
+ else :
121
+ stream_id_cutoff = None
122
+
123
+ # Perform the equivalent of a GROUP BY
124
+ #
125
+ # Iterate through the updates list and copy non-duplicate
126
+ # (user_id, device_id) entries into a map, with the value being
127
+ # the max stream_id across each set of duplicate entries
128
+ #
129
+ # maps (user_id, device_id) -> stream_id
130
+ # as long as their stream_id does not match that of the last row
131
+ query_map = {}
132
+ for update in updates :
133
+ if stream_id_cutoff is not None and update [2 ] >= stream_id_cutoff :
134
+ # Stop processing updates
135
+ break
136
+
137
+ key = (update [0 ], update [1 ])
138
+ query_map [key ] = max (query_map .get (key , 0 ), update [2 ])
139
+
140
+ # If we didn't find any updates with a stream_id lower than the cutoff, it
141
+ # means that there are more than limit updates all of which have the same
142
+ # steam_id.
143
+
144
+ # That should only happen if a client is spamming the server with new
145
+ # devices, in which case E2E isn't going to work well anyway. We'll just
146
+ # skip that stream_id and return an empty list, and continue with the next
147
+ # stream_id next time.
148
+ if not query_map :
149
+ defer .returnValue ((stream_id_cutoff , []))
150
+
151
+ results = yield self ._get_device_update_edus_by_remote (
152
+ destination ,
153
+ from_stream_id ,
154
+ query_map ,
155
+ )
156
+
157
+ defer .returnValue ((now_stream_id , results ))
158
+
97
159
def _get_devices_by_remote_txn (
98
- self , txn , destination , from_stream_id , now_stream_id
160
+ self , txn , destination , from_stream_id , now_stream_id , limit
99
161
):
162
+ """Return device update information for a given remote destination
163
+
164
+ Args:
165
+ txn (LoggingTransaction): The transaction to execute
166
+ destination (str): The host the device updates are intended for
167
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
168
+ now_stream_id (int): The maximum stream_id to filter updates by, inclusive
169
+ limit (int): Maximum number of device updates to return
170
+
171
+ Returns:
172
+ List: List of device updates
173
+ """
100
174
sql = """
101
- SELECT user_id, device_id, max( stream_id) FROM device_lists_outbound_pokes
175
+ SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
102
176
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
103
- GROUP BY user_id, device_id
104
- LIMIT 20
177
+ ORDER BY stream_id
178
+ LIMIT ?
105
179
"""
106
- txn .execute (sql , (destination , from_stream_id , now_stream_id , False ))
180
+ txn .execute (sql , (destination , from_stream_id , now_stream_id , False , limit ))
107
181
108
- # maps (user_id, device_id) -> stream_id
109
- query_map = {(r [0 ], r [1 ]): r [2 ] for r in txn }
110
- if not query_map :
111
- return (now_stream_id , [])
182
+ return list (txn )
112
183
113
- if len (query_map ) >= 20 :
114
- now_stream_id = max (stream_id for stream_id in itervalues (query_map ))
184
+ @defer .inlineCallbacks
185
+ def _get_device_update_edus_by_remote (
186
+ self , destination , from_stream_id , query_map ,
187
+ ):
188
+ """Returns a list of device update EDUs as well as E2EE keys
115
189
116
- devices = self ._get_e2e_device_keys_txn (
117
- txn ,
190
+ Args:
191
+ destination (str): The host the device updates are intended for
192
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
193
+ query_map (Dict[(str, str): int]): Dictionary mapping
194
+ user_id/device_id to update stream_id
195
+
196
+ Returns:
197
+ List[Dict]: List of objects representing an device update EDU
198
+
199
+ """
200
+ devices = yield self .runInteraction (
201
+ "_get_e2e_device_keys_txn" ,
202
+ self ._get_e2e_device_keys_txn ,
118
203
query_map .keys (),
119
204
include_all_devices = True ,
120
205
include_deleted_devices = True ,
121
206
)
122
207
123
- prev_sent_id_sql = """
124
- SELECT coalesce(max(stream_id), 0) as stream_id
125
- FROM device_lists_outbound_last_success
126
- WHERE destination = ? AND user_id = ? AND stream_id <= ?
127
- """
128
-
129
208
results = []
130
209
for user_id , user_devices in iteritems (devices ):
131
210
# The prev_id for the first row is always the last row before
132
211
# `from_stream_id`
133
- txn . execute ( prev_sent_id_sql , ( destination , user_id , from_stream_id ))
134
- rows = txn . fetchall ()
135
- prev_id = rows [ 0 ][ 0 ]
212
+ prev_id = yield self . _get_last_device_update_for_remote_user (
213
+ destination , user_id , from_stream_id ,
214
+ )
136
215
for device_id , device in iteritems (user_devices ):
137
216
stream_id = query_map [(user_id , device_id )]
138
217
result = {
@@ -156,7 +235,22 @@ def _get_devices_by_remote_txn(
156
235
157
236
results .append (result )
158
237
159
- return (now_stream_id , results )
238
+ defer .returnValue (results )
239
+
240
+ def _get_last_device_update_for_remote_user (
241
+ self , destination , user_id , from_stream_id ,
242
+ ):
243
+ def f (txn ):
244
+ prev_sent_id_sql = """
245
+ SELECT coalesce(max(stream_id), 0) as stream_id
246
+ FROM device_lists_outbound_last_success
247
+ WHERE destination = ? AND user_id = ? AND stream_id <= ?
248
+ """
249
+ txn .execute (prev_sent_id_sql , (destination , user_id , from_stream_id ))
250
+ rows = txn .fetchall ()
251
+ return rows [0 ][0 ]
252
+
253
+ return self .runInteraction ("get_last_device_update_for_remote_user" , f )
160
254
161
255
def mark_as_sent_devices_by_remote (self , destination , stream_id ):
162
256
"""Mark that updates have successfully been sent to the destination.
0 commit comments