20
20
#
21
21
22
22
import logging
23
+ from typing import Tuple , Optional , Callable , Awaitable
24
+
23
25
import requests
24
- import json
25
26
import time
27
+ import synapse
28
+ from synapse import module_api
26
29
27
30
logger = logging .getLogger (__name__ )
28
31
29
32
30
33
class RestAuthProvider (object ):
31
34
32
- def __init__ (self , config , account_handler ):
33
- self .account_handler = account_handler
35
+ def __init__ (self , config : dict , api : module_api ):
36
+ self .account_handler = api
34
37
35
38
if not config .endpoint :
36
39
raise RuntimeError ('Missing endpoint config' )
@@ -42,6 +45,36 @@ def __init__(self, config, account_handler):
42
45
logger .info ('Endpoint: %s' , self .endpoint )
43
46
logger .info ('Enforce lowercase username during registration: %s' , self .regLower )
44
47
48
+ # register an auth callback handler
49
+ # see https://matrix-org.github.io/synapse/latest/modules/password_auth_provider_callbacks.html
50
+ api .register_password_auth_provider_callbacks (
51
+ auth_checkers = {
52
+ ("m.login.password" , ("password" ,)): self .check_m_login_password
53
+ }
54
+ )
55
+
56
+ async def check_m_login_password (self , username : str ,
57
+ login_type : str ,
58
+ login_dict : "synapse.module_api.JsonDict" ) -> Optional [
59
+ Tuple [
60
+ str ,
61
+ Optional [Callable [["synapse.module_api.LoginResponse" ], Awaitable [None ]]],
62
+ ]
63
+ ]:
64
+ if login_type != "m.login.password" :
65
+ return None
66
+
67
+ # get the complete MXID
68
+ mxid = self .account_handler .get_qualified_user_id (username )
69
+
70
+ # check if the password is valid with the old function
71
+ password_valid = await self .check_password (mxid , login_dict .get ("password" ))
72
+
73
+ if password_valid :
74
+ return mxid , None
75
+ else :
76
+ return None
77
+
45
78
async def check_password (self , user_id , password ):
46
79
logger .info ("Got password check for " + user_id )
47
80
data = {'user' : {'id' : user_id , 'password' : password }}
0 commit comments