Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 2d1d7b7

Browse files
anoadragon453richvdh
authored andcommitted
Prevent multiple device list updates from breaking a batch send (#5156)
fixes #5153
1 parent a118650 commit 2d1d7b7

File tree

4 files changed

+196
-31
lines changed

4 files changed

+196
-31
lines changed

changelog.d/5156.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Prevent federation device list updates breaking when processing multiple updates at once.

synapse/federation/sender/per_destination_queue.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,10 @@ def _pop_pending_edus(self, limit):
349349
@defer.inlineCallbacks
350350
def _get_new_device_messages(self, limit):
351351
last_device_list = self._last_device_list_stream_id
352-
# Will return at most 20 entries
352+
353+
# Retrieve list of new device updates to send to the destination
353354
now_stream_id, results = yield self._store.get_devices_by_remote(
354-
self._destination, last_device_list
355+
self._destination, last_device_list, limit=limit,
355356
)
356357
edus = [
357358
Edu(

synapse/storage/devices.py

+123-29
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import logging
1616

17-
from six import iteritems, itervalues
17+
from six import iteritems
1818

1919
from canonicaljson import json
2020

@@ -72,67 +72,146 @@ def get_devices_by_user(self, user_id):
7272

7373
defer.returnValue({d["device_id"]: d for d in devices})
7474

75-
def get_devices_by_remote(self, destination, from_stream_id):
75+
@defer.inlineCallbacks
76+
def get_devices_by_remote(self, destination, from_stream_id, limit):
7677
"""Get stream of updates to send to remote servers
7778
7879
Returns:
79-
(int, list[dict]): current stream id and list of updates
80+
Deferred[tuple[int, list[dict]]]:
81+
current stream id (ie, the stream id of the last update included in the
82+
response), and the list of updates
8083
"""
8184
now_stream_id = self._device_list_id_gen.get_current_token()
8285

8386
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
8487
destination, int(from_stream_id)
8588
)
8689
if not has_changed:
87-
return (now_stream_id, [])
88-
89-
return self.runInteraction(
90+
defer.returnValue((now_stream_id, []))
91+
92+
# We retrieve n+1 devices from the list of outbound pokes where n is
93+
# our outbound device update limit. We then check if the very last
94+
# device has the same stream_id as the second-to-last device. If so,
95+
# then we ignore all devices with that stream_id and only send the
96+
# devices with a lower stream_id.
97+
#
98+
# If when culling the list we end up with no devices afterwards, we
99+
# consider the device update to be too large, and simply skip the
100+
# stream_id; the rationale being that such a large device list update
101+
# is likely an error.
102+
updates = yield self.runInteraction(
90103
"get_devices_by_remote",
91104
self._get_devices_by_remote_txn,
92105
destination,
93106
from_stream_id,
94107
now_stream_id,
108+
limit + 1,
95109
)
96110

111+
# Return an empty list if there are no updates
112+
if not updates:
113+
defer.returnValue((now_stream_id, []))
114+
115+
# if we have exceeded the limit, we need to exclude any results with the
116+
# same stream_id as the last row.
117+
if len(updates) > limit:
118+
stream_id_cutoff = updates[-1][2]
119+
now_stream_id = stream_id_cutoff - 1
120+
else:
121+
stream_id_cutoff = None
122+
123+
# Perform the equivalent of a GROUP BY
124+
#
125+
# Iterate through the updates list and copy non-duplicate
126+
# (user_id, device_id) entries into a map, with the value being
127+
# the max stream_id across each set of duplicate entries
128+
#
129+
# maps (user_id, device_id) -> stream_id
130+
# as long as their stream_id does not match that of the last row
131+
query_map = {}
132+
for update in updates:
133+
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
134+
# Stop processing updates
135+
break
136+
137+
key = (update[0], update[1])
138+
query_map[key] = max(query_map.get(key, 0), update[2])
139+
140+
# If we didn't find any updates with a stream_id lower than the cutoff, it
141+
# means that there are more than limit updates all of which have the same
142+
# steam_id.
143+
144+
# That should only happen if a client is spamming the server with new
145+
# devices, in which case E2E isn't going to work well anyway. We'll just
146+
# skip that stream_id and return an empty list, and continue with the next
147+
# stream_id next time.
148+
if not query_map:
149+
defer.returnValue((stream_id_cutoff, []))
150+
151+
results = yield self._get_device_update_edus_by_remote(
152+
destination,
153+
from_stream_id,
154+
query_map,
155+
)
156+
157+
defer.returnValue((now_stream_id, results))
158+
97159
def _get_devices_by_remote_txn(
98-
self, txn, destination, from_stream_id, now_stream_id
160+
self, txn, destination, from_stream_id, now_stream_id, limit
99161
):
162+
"""Return device update information for a given remote destination
163+
164+
Args:
165+
txn (LoggingTransaction): The transaction to execute
166+
destination (str): The host the device updates are intended for
167+
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
168+
now_stream_id (int): The maximum stream_id to filter updates by, inclusive
169+
limit (int): Maximum number of device updates to return
170+
171+
Returns:
172+
List: List of device updates
173+
"""
100174
sql = """
101-
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
175+
SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
102176
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
103-
GROUP BY user_id, device_id
104-
LIMIT 20
177+
ORDER BY stream_id
178+
LIMIT ?
105179
"""
106-
txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
180+
txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))
107181

108-
# maps (user_id, device_id) -> stream_id
109-
query_map = {(r[0], r[1]): r[2] for r in txn}
110-
if not query_map:
111-
return (now_stream_id, [])
182+
return list(txn)
112183

113-
if len(query_map) >= 20:
114-
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
184+
@defer.inlineCallbacks
185+
def _get_device_update_edus_by_remote(
186+
self, destination, from_stream_id, query_map,
187+
):
188+
"""Returns a list of device update EDUs as well as E2EE keys
115189
116-
devices = self._get_e2e_device_keys_txn(
117-
txn,
190+
Args:
191+
destination (str): The host the device updates are intended for
192+
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
193+
query_map (Dict[(str, str): int]): Dictionary mapping
194+
user_id/device_id to update stream_id
195+
196+
Returns:
197+
List[Dict]: List of objects representing an device update EDU
198+
199+
"""
200+
devices = yield self.runInteraction(
201+
"_get_e2e_device_keys_txn",
202+
self._get_e2e_device_keys_txn,
118203
query_map.keys(),
119204
include_all_devices=True,
120205
include_deleted_devices=True,
121206
)
122207

123-
prev_sent_id_sql = """
124-
SELECT coalesce(max(stream_id), 0) as stream_id
125-
FROM device_lists_outbound_last_success
126-
WHERE destination = ? AND user_id = ? AND stream_id <= ?
127-
"""
128-
129208
results = []
130209
for user_id, user_devices in iteritems(devices):
131210
# The prev_id for the first row is always the last row before
132211
# `from_stream_id`
133-
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
134-
rows = txn.fetchall()
135-
prev_id = rows[0][0]
212+
prev_id = yield self._get_last_device_update_for_remote_user(
213+
destination, user_id, from_stream_id,
214+
)
136215
for device_id, device in iteritems(user_devices):
137216
stream_id = query_map[(user_id, device_id)]
138217
result = {
@@ -156,7 +235,22 @@ def _get_devices_by_remote_txn(
156235

157236
results.append(result)
158237

159-
return (now_stream_id, results)
238+
defer.returnValue(results)
239+
240+
def _get_last_device_update_for_remote_user(
241+
self, destination, user_id, from_stream_id,
242+
):
243+
def f(txn):
244+
prev_sent_id_sql = """
245+
SELECT coalesce(max(stream_id), 0) as stream_id
246+
FROM device_lists_outbound_last_success
247+
WHERE destination = ? AND user_id = ? AND stream_id <= ?
248+
"""
249+
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
250+
rows = txn.fetchall()
251+
return rows[0][0]
252+
253+
return self.runInteraction("get_last_device_update_for_remote_user", f)
160254

161255
def mark_as_sent_devices_by_remote(self, destination, stream_id):
162256
"""Mark that updates have successfully been sent to the destination.

tests/storage/test_devices.py

+69
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,75 @@ def test_get_devices_by_user(self):
7171
res["device2"],
7272
)
7373

74+
@defer.inlineCallbacks
75+
def test_get_devices_by_remote(self):
76+
device_ids = ["device_id1", "device_id2"]
77+
78+
# Add two device updates with a single stream_id
79+
yield self.store.add_device_change_to_streams(
80+
"user_id", device_ids, ["somehost"],
81+
)
82+
83+
# Get all device updates ever meant for this remote
84+
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
85+
"somehost", -1, limit=100,
86+
)
87+
88+
# Check original device_ids are contained within these updates
89+
self._check_devices_in_updates(device_ids, device_updates)
90+
91+
@defer.inlineCallbacks
92+
def test_get_devices_by_remote_limited(self):
93+
# Test breaking the update limit in 1, 101, and 1 device_id segments
94+
95+
# first add one device
96+
device_ids1 = ["device_id0"]
97+
yield self.store.add_device_change_to_streams(
98+
"user_id", device_ids1, ["someotherhost"],
99+
)
100+
101+
# then add 101
102+
device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
103+
yield self.store.add_device_change_to_streams(
104+
"user_id", device_ids2, ["someotherhost"],
105+
)
106+
107+
# then one more
108+
device_ids3 = ["newdevice"]
109+
yield self.store.add_device_change_to_streams(
110+
"user_id", device_ids3, ["someotherhost"],
111+
)
112+
113+
#
114+
# now read them back.
115+
#
116+
117+
# first we should get a single update
118+
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
119+
"someotherhost", -1, limit=100,
120+
)
121+
self._check_devices_in_updates(device_ids1, device_updates)
122+
123+
# Then we should get an empty list back as the 101 devices broke the limit
124+
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
125+
"someotherhost", now_stream_id, limit=100,
126+
)
127+
self.assertEqual(len(device_updates), 0)
128+
129+
# The 101 devices should've been cleared, so we should now just get one device
130+
# update
131+
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
132+
"someotherhost", now_stream_id, limit=100,
133+
)
134+
self._check_devices_in_updates(device_ids3, device_updates)
135+
136+
def _check_devices_in_updates(self, expected_device_ids, device_updates):
137+
"""Check that an specific device ids exist in a list of device update EDUs"""
138+
self.assertEqual(len(device_updates), len(expected_device_ids))
139+
140+
received_device_ids = {update["device_id"] for update in device_updates}
141+
self.assertEqual(received_device_ids, set(expected_device_ids))
142+
74143
@defer.inlineCallbacks
75144
def test_update_device(self):
76145
yield self.store.store_device("user_id", "device_id", "display_name 1")

0 commit comments

Comments
 (0)