|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 | from http import HTTPStatus
|
16 |
| -from typing import Optional, Union |
| 16 | +from typing import Optional, Tuple, Union |
17 | 17 |
|
18 | 18 | from twisted.internet.defer import succeed
|
19 | 19 |
|
20 | 20 | import synapse.rest.admin
|
21 | 21 | from synapse.api.constants import LoginType
|
22 | 22 | from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
23 |
| -from synapse.rest.client import account, auth, devices, login, register |
| 23 | +from synapse.rest.client import account, auth, devices, login, logout, register |
24 | 24 | from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
| 25 | +from synapse.storage.database import LoggingTransaction |
25 | 26 | from synapse.types import JsonDict, UserID
|
26 | 27 |
|
27 | 28 | from tests import unittest
|
@@ -527,6 +528,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
527 | 528 | auth.register_servlets,
|
528 | 529 | account.register_servlets,
|
529 | 530 | login.register_servlets,
|
| 531 | + logout.register_servlets, |
530 | 532 | synapse.rest.admin.register_servlets_for_client_rest_resource,
|
531 | 533 | register.register_servlets,
|
532 | 534 | ]
|
@@ -984,3 +986,90 @@ def test_refresh_token_invalidation(self):
|
984 | 986 | self.assertEqual(
|
985 | 987 | fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
|
986 | 988 | )
|
| 989 | + |
| 990 | + def test_many_token_refresh(self): |
| 991 | + """ |
| 992 | + If a refresh is performed many times during a session, there shouldn't be |
| 993 | + extra 'cruft' built up over time. |
| 994 | +
|
| 995 | + This test was written specifically to troubleshoot a case where logout |
| 996 | + was very slow if a lot of refreshes had been performed for the session. |
| 997 | + """ |
| 998 | + |
| 999 | + def _refresh(refresh_token: str) -> Tuple[str, str]: |
| 1000 | + """ |
| 1001 | + Performs one refresh, returning the next refresh token and access token. |
| 1002 | + """ |
| 1003 | + refresh_response = self.use_refresh_token(refresh_token) |
| 1004 | + self.assertEqual( |
| 1005 | + refresh_response.code, HTTPStatus.OK, refresh_response.result |
| 1006 | + ) |
| 1007 | + return ( |
| 1008 | + refresh_response.json_body["refresh_token"], |
| 1009 | + refresh_response.json_body["access_token"], |
| 1010 | + ) |
| 1011 | + |
| 1012 | + def _table_length(table_name: str) -> int: |
| 1013 | + """ |
| 1014 | + Helper to get the size of a table, in rows. |
| 1015 | + For testing only; trivially vulnerable to SQL injection. |
| 1016 | + """ |
| 1017 | + |
| 1018 | + def _txn(txn: LoggingTransaction) -> int: |
| 1019 | + txn.execute(f"SELECT COUNT(1) FROM {table_name}") |
| 1020 | + row = txn.fetchone() |
| 1021 | + # Query is infallible |
| 1022 | + assert row is not None |
| 1023 | + return row[0] |
| 1024 | + |
| 1025 | + return self.get_success( |
| 1026 | + self.hs.get_datastores().main.db_pool.runInteraction( |
| 1027 | + "_table_length", _txn |
| 1028 | + ) |
| 1029 | + ) |
| 1030 | + |
| 1031 | + # Before we log in, there are no access tokens. |
| 1032 | + self.assertEqual(_table_length("access_tokens"), 0) |
| 1033 | + self.assertEqual(_table_length("refresh_tokens"), 0) |
| 1034 | + |
| 1035 | + body = { |
| 1036 | + "type": "m.login.password", |
| 1037 | + "user": "test", |
| 1038 | + "password": self.user_pass, |
| 1039 | + "refresh_token": True, |
| 1040 | + } |
| 1041 | + login_response = self.make_request( |
| 1042 | + "POST", |
| 1043 | + "/_matrix/client/v3/login", |
| 1044 | + body, |
| 1045 | + ) |
| 1046 | + self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result) |
| 1047 | + |
| 1048 | + access_token = login_response.json_body["access_token"] |
| 1049 | + refresh_token = login_response.json_body["refresh_token"] |
| 1050 | + |
| 1051 | + # Now that we have logged in, there should be one access token and one |
| 1052 | + # refresh token |
| 1053 | + self.assertEqual(_table_length("access_tokens"), 1) |
| 1054 | + self.assertEqual(_table_length("refresh_tokens"), 1) |
| 1055 | + |
| 1056 | + for _ in range(5): |
| 1057 | + refresh_token, access_token = _refresh(refresh_token) |
| 1058 | + |
| 1059 | + # After 5 sequential refreshes, there should only be the latest two |
| 1060 | + # refresh/access token pairs. |
| 1061 | + # (The last one is preserved because it's in use! |
| 1062 | + # The one before that is preserved because it can still be used to |
| 1063 | + # replace the last token pair, in case of e.g. a network interruption.) |
| 1064 | + self.assertEqual(_table_length("access_tokens"), 2) |
| 1065 | + self.assertEqual(_table_length("refresh_tokens"), 2) |
| 1066 | + |
| 1067 | + logout_response = self.make_request( |
| 1068 | + "POST", "/_matrix/client/v3/logout", {}, access_token=access_token |
| 1069 | + ) |
| 1070 | + self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result) |
| 1071 | + |
| 1072 | + # Now that we have logged in, there should be no access token |
| 1073 | + # and no refresh token |
| 1074 | + self.assertEqual(_table_length("access_tokens"), 0) |
| 1075 | + self.assertEqual(_table_length("refresh_tokens"), 0) |
0 commit comments