18
18
from typing import Any , Dict , List , Optional , Type , Union
19
19
from unittest .mock import Mock
20
20
21
+ from twisted .test .proto_helpers import MemoryReactor
22
+
21
23
import synapse
22
24
from synapse .api .constants import LoginType
23
25
from synapse .api .errors import Codes
24
26
from synapse .handlers .account import AccountHandler
25
27
from synapse .module_api import ModuleApi
26
28
from synapse .rest .client import account , devices , login , logout , register
29
+ from synapse .server import HomeServer
27
30
from synapse .types import JsonDict , UserID
31
+ from synapse .util import Clock
28
32
29
33
from tests import unittest
30
34
from tests .server import FakeChannel
@@ -162,10 +166,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
162
166
CALLBACK_USERNAME = "get_username_for_registration"
163
167
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
164
168
165
- def setUp (self ) -> None :
169
+ def prepare (
170
+ self , reactor : MemoryReactor , clock : Clock , homeserver : HomeServer
171
+ ) -> None :
166
172
# we use a global mock device, so make sure we are starting with a clean slate
167
173
mock_password_provider .reset_mock ()
168
- super ().setUp ()
174
+
175
+ # The mock password provider doesn't register the users, so ensure they
176
+ # are registered first.
177
+ self .register_user ("u" , "not-the-tested-password" )
178
+ self .register_user ("user" , "not-the-tested-password" )
169
179
170
180
@override_config (legacy_providers_config (LegacyPasswordOnlyAuthProvider ))
171
181
def test_password_only_auth_progiver_login_legacy (self ) -> None :
@@ -185,33 +195,19 @@ def password_only_auth_provider_login_test_body(self) -> None:
185
195
mock_password_provider .reset_mock ()
186
196
187
197
# login with mxid should work too
188
- channel = self ._send_password_login ("@u:bz " , "p" )
198
+ channel = self ._send_password_login ("@u:test " , "p" )
189
199
self .assertEqual (channel .code , HTTPStatus .OK , channel .result )
190
- self .assertEqual ("@u:bz " , channel .json_body ["user_id" ])
191
- mock_password_provider .check_password .assert_called_once_with ("@u:bz " , "p" )
200
+ self .assertEqual ("@u:test " , channel .json_body ["user_id" ])
201
+ mock_password_provider .check_password .assert_called_once_with ("@u:test " , "p" )
192
202
mock_password_provider .reset_mock ()
193
203
194
- # try a weird username / pass. Honestly it's unclear what we *expect* to happen
195
- # in these cases, but at least we can guard against the API changing
196
- # unexpectedly
197
- channel = self ._send_password_login (" USER🙂NAME " , " pASS\U0001F622 word " )
198
- self .assertEqual (channel .code , HTTPStatus .OK , channel .result )
199
- self .assertEqual ("@ USER🙂NAME :test" , channel .json_body ["user_id" ])
200
- mock_password_provider .check_password .assert_called_once_with (
201
- "@ USER🙂NAME :test" , " pASS😢word "
202
- )
203
-
204
204
@override_config (legacy_providers_config (LegacyPasswordOnlyAuthProvider ))
205
205
def test_password_only_auth_provider_ui_auth_legacy (self ) -> None :
206
206
self .password_only_auth_provider_ui_auth_test_body ()
207
207
208
208
def password_only_auth_provider_ui_auth_test_body (self ) -> None :
209
209
"""UI Auth should delegate correctly to the password provider"""
210
210
211
- # create the user, otherwise access doesn't work
212
- module_api = self .hs .get_module_api ()
213
- self .get_success (module_api .register_user ("u" ))
214
-
215
211
# log in twice, to get two devices
216
212
mock_password_provider .check_password .return_value = make_awaitable (True )
217
213
tok1 = self .login ("u" , "p" )
@@ -401,29 +397,16 @@ def custom_auth_provider_login_test_body(self) -> None:
401
397
mock_password_provider .check_auth .assert_not_called ()
402
398
403
399
mock_password_provider .check_auth .return_value = make_awaitable (
404
- ("@user:bz " , None )
400
+ ("@user:test " , None )
405
401
)
406
402
channel = self ._send_login ("test.login_type" , "u" , test_field = "y" )
407
403
self .assertEqual (channel .code , HTTPStatus .OK , channel .result )
408
- self .assertEqual ("@user:bz " , channel .json_body ["user_id" ])
404
+ self .assertEqual ("@user:test " , channel .json_body ["user_id" ])
409
405
mock_password_provider .check_auth .assert_called_once_with (
410
406
"u" , "test.login_type" , {"test_field" : "y" }
411
407
)
412
408
mock_password_provider .reset_mock ()
413
409
414
- # try a weird username. Again, it's unclear what we *expect* to happen
415
- # in these cases, but at least we can guard against the API changing
416
- # unexpectedly
417
- mock_password_provider .check_auth .return_value = make_awaitable (
418
- ("@ MALFORMED! :bz" , None )
419
- )
420
- channel = self ._send_login ("test.login_type" , " USER🙂NAME " , test_field = " abc " )
421
- self .assertEqual (channel .code , HTTPStatus .OK , channel .result )
422
- self .assertEqual ("@ MALFORMED! :bz" , channel .json_body ["user_id" ])
423
- mock_password_provider .check_auth .assert_called_once_with (
424
- " USER🙂NAME " , "test.login_type" , {"test_field" : " abc " }
425
- )
426
-
427
410
@override_config (legacy_providers_config (LegacyCustomAuthProvider ))
428
411
def test_custom_auth_provider_ui_auth_legacy (self ) -> None :
429
412
self .custom_auth_provider_ui_auth_test_body ()
@@ -465,7 +448,7 @@ def custom_auth_provider_ui_auth_test_body(self) -> None:
465
448
466
449
# right params, but authing as the wrong user
467
450
mock_password_provider .check_auth .return_value = make_awaitable (
468
- ("@user:bz " , None )
451
+ ("@user:test " , None )
469
452
)
470
453
body ["auth" ]["test_field" ] = "foo"
471
454
channel = self ._delete_device (tok1 , "dev2" , body )
@@ -498,11 +481,11 @@ def custom_auth_provider_callback_test_body(self) -> None:
498
481
callback = Mock (return_value = make_awaitable (None ))
499
482
500
483
mock_password_provider .check_auth .return_value = make_awaitable (
501
- ("@user:bz " , callback )
484
+ ("@user:test " , callback )
502
485
)
503
486
channel = self ._send_login ("test.login_type" , "u" , test_field = "y" )
504
487
self .assertEqual (channel .code , HTTPStatus .OK , channel .result )
505
- self .assertEqual ("@user:bz " , channel .json_body ["user_id" ])
488
+ self .assertEqual ("@user:test " , channel .json_body ["user_id" ])
506
489
mock_password_provider .check_auth .assert_called_once_with (
507
490
"u" , "test.login_type" , {"test_field" : "y" }
508
491
)
@@ -512,7 +495,7 @@ def custom_auth_provider_callback_test_body(self) -> None:
512
495
call_args , call_kwargs = callback .call_args
513
496
# should be one positional arg
514
497
self .assertEqual (len (call_args ), 1 )
515
- self .assertEqual (call_args [0 ]["user_id" ], "@user:bz " )
498
+ self .assertEqual (call_args [0 ]["user_id" ], "@user:test " )
516
499
for p in ["user_id" , "access_token" , "device_id" , "home_server" ]:
517
500
self .assertIn (p , call_args [0 ])
518
501
0 commit comments