Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 235d291

Browse files
authored
Fix slow performance of /logout in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. (#12056)
1 parent 6a1bad5 commit 235d291

File tree

4 files changed

+136
-4
lines changed

4 files changed

+136
-4
lines changed

changelog.d/12056.bugfix

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens.

synapse/storage/databases/main/registration.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,8 @@ def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
16811681
user_id=row[1],
16821682
device_id=row[2],
16831683
next_token_id=row[3],
1684-
has_next_refresh_token_been_refreshed=row[4],
1684+
# SQLite returns 0 or 1 for false/true, so convert to a bool.
1685+
has_next_refresh_token_been_refreshed=bool(row[4]),
16851686
# This column is nullable, ensure it's a boolean
16861687
has_next_access_token_been_used=(row[5] or False),
16871688
expiry_ts=row[6],
@@ -1697,12 +1698,15 @@ async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None
16971698
Set the successor of a refresh token, removing the existing successor
16981699
if any.
16991700
1701+
This also deletes the predecessor refresh and access tokens,
1702+
since they cannot be valid anymore.
1703+
17001704
Args:
17011705
token_id: ID of the refresh token to update.
17021706
next_token_id: ID of its successor.
17031707
"""
17041708

1705-
def _replace_refresh_token_txn(txn) -> None:
1709+
def _replace_refresh_token_txn(txn: LoggingTransaction) -> None:
17061710
# First check if there was an existing refresh token
17071711
old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
17081712
txn,
@@ -1728,6 +1732,16 @@ def _replace_refresh_token_txn(txn) -> None:
17281732
{"id": old_next_token_id},
17291733
)
17301734

1735+
# Delete the previous refresh token, since we only want to keep the
1736+
# last 2 refresh tokens in the database.
1737+
# (The predecessor of the latest refresh token is still useful in
1738+
# case the refresh was interrupted and the client re-uses the old
1739+
# one.)
1740+
# This cascades to delete the associated access token.
1741+
self.db_pool.simple_delete_txn(
1742+
txn, "refresh_tokens", {"next_token_id": token_id}
1743+
)
1744+
17311745
await self.db_pool.runInteraction(
17321746
"replace_refresh_token", _replace_refresh_token_txn
17331747
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/* Copyright 2022 The Matrix.org Foundation C.I.C
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
-- next_token_id is a foreign key reference, so previously required a table scan
17+
-- when a row in the referenced table was deleted.
18+
-- As it was self-referential and cascaded deletes, this led to O(t*n) time to
19+
-- delete a row, where t: number of rows in the table and n: number of rows in
20+
-- the ancestral 'chain' of access tokens.
21+
--
22+
-- This index is partial since we only require it for rows which reference
23+
-- another.
24+
-- Performance was tested to be the same regardless of whether the index was
25+
-- full or partial, but a partial index can be smaller.
26+
CREATE INDEX refresh_tokens_next_token_id
27+
ON refresh_tokens(next_token_id)
28+
WHERE next_token_id IS NOT NULL;

tests/rest/client/test_auth.py

+91-2
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from http import HTTPStatus
16-
from typing import Optional, Union
16+
from typing import Optional, Tuple, Union
1717

1818
from twisted.internet.defer import succeed
1919

2020
import synapse.rest.admin
2121
from synapse.api.constants import LoginType
2222
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
23-
from synapse.rest.client import account, auth, devices, login, register
23+
from synapse.rest.client import account, auth, devices, login, logout, register
2424
from synapse.rest.synapse.client import build_synapse_client_resource_tree
25+
from synapse.storage.database import LoggingTransaction
2526
from synapse.types import JsonDict, UserID
2627

2728
from tests import unittest
@@ -527,6 +528,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
527528
auth.register_servlets,
528529
account.register_servlets,
529530
login.register_servlets,
531+
logout.register_servlets,
530532
synapse.rest.admin.register_servlets_for_client_rest_resource,
531533
register.register_servlets,
532534
]
@@ -984,3 +986,90 @@ def test_refresh_token_invalidation(self):
984986
self.assertEqual(
985987
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
986988
)
989+
990+
def test_many_token_refresh(self):
991+
"""
992+
If a refresh is performed many times during a session, there shouldn't be
993+
extra 'cruft' built up over time.
994+
995+
This test was written specifically to troubleshoot a case where logout
996+
was very slow if a lot of refreshes had been performed for the session.
997+
"""
998+
999+
def _refresh(refresh_token: str) -> Tuple[str, str]:
1000+
"""
1001+
Performs one refresh, returning the next refresh token and access token.
1002+
"""
1003+
refresh_response = self.use_refresh_token(refresh_token)
1004+
self.assertEqual(
1005+
refresh_response.code, HTTPStatus.OK, refresh_response.result
1006+
)
1007+
return (
1008+
refresh_response.json_body["refresh_token"],
1009+
refresh_response.json_body["access_token"],
1010+
)
1011+
1012+
def _table_length(table_name: str) -> int:
1013+
"""
1014+
Helper to get the size of a table, in rows.
1015+
For testing only; trivially vulnerable to SQL injection.
1016+
"""
1017+
1018+
def _txn(txn: LoggingTransaction) -> int:
1019+
txn.execute(f"SELECT COUNT(1) FROM {table_name}")
1020+
row = txn.fetchone()
1021+
# Query is infallible
1022+
assert row is not None
1023+
return row[0]
1024+
1025+
return self.get_success(
1026+
self.hs.get_datastores().main.db_pool.runInteraction(
1027+
"_table_length", _txn
1028+
)
1029+
)
1030+
1031+
# Before we log in, there are no access tokens.
1032+
self.assertEqual(_table_length("access_tokens"), 0)
1033+
self.assertEqual(_table_length("refresh_tokens"), 0)
1034+
1035+
body = {
1036+
"type": "m.login.password",
1037+
"user": "test",
1038+
"password": self.user_pass,
1039+
"refresh_token": True,
1040+
}
1041+
login_response = self.make_request(
1042+
"POST",
1043+
"/_matrix/client/v3/login",
1044+
body,
1045+
)
1046+
self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
1047+
1048+
access_token = login_response.json_body["access_token"]
1049+
refresh_token = login_response.json_body["refresh_token"]
1050+
1051+
# Now that we have logged in, there should be one access token and one
1052+
# refresh token
1053+
self.assertEqual(_table_length("access_tokens"), 1)
1054+
self.assertEqual(_table_length("refresh_tokens"), 1)
1055+
1056+
for _ in range(5):
1057+
refresh_token, access_token = _refresh(refresh_token)
1058+
1059+
# After 5 sequential refreshes, there should only be the latest two
1060+
# refresh/access token pairs.
1061+
# (The last one is preserved because it's in use!
1062+
# The one before that is preserved because it can still be used to
1063+
# replace the last token pair, in case of e.g. a network interruption.)
1064+
self.assertEqual(_table_length("access_tokens"), 2)
1065+
self.assertEqual(_table_length("refresh_tokens"), 2)
1066+
1067+
logout_response = self.make_request(
1068+
"POST", "/_matrix/client/v3/logout", {}, access_token=access_token
1069+
)
1070+
self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result)
1071+
1072+
# Now that we have logged in, there should be no access token
1073+
# and no refresh token
1074+
self.assertEqual(_table_length("access_tokens"), 0)
1075+
self.assertEqual(_table_length("refresh_tokens"), 0)

0 commit comments

Comments
 (0)