15
15
"""MONGODB-OIDC Authentication helpers."""
16
16
from __future__ import annotations
17
17
18
+ import asyncio
18
19
import threading
19
20
import time
20
21
from dataclasses import dataclass , field
36
37
)
37
38
from pymongo .errors import ConfigurationError , OperationFailure
38
39
from pymongo .helpers_shared import _AUTHENTICATION_FAILURE_CODE
40
+ from pymongo .lock import Lock , _async_create_lock
39
41
40
42
if TYPE_CHECKING :
41
43
from pymongo .asynchronous .pool import AsyncConnection
@@ -81,7 +83,11 @@ class _OIDCAuthenticator:
81
83
access_token : Optional [str ] = field (default = None )
82
84
idp_info : Optional [OIDCIdPInfo ] = field (default = None )
83
85
token_gen_id : int = field (default = 0 )
84
- lock : threading .Lock = field (default_factory = threading .Lock )
86
+ if not _IS_SYNC :
87
+ lock : Lock = field (default_factory = _async_create_lock ) # type: ignore[assignment]
88
+ else :
89
+ lock : threading .Lock = field (default_factory = _async_create_lock ) # type: ignore[assignment, no-redef]
90
+
85
91
last_call_time : float = field (default = 0 )
86
92
87
93
async def reauthenticate (self , conn : AsyncConnection ) -> Optional [Mapping [str , Any ]]:
@@ -164,7 +170,7 @@ async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[s
164
170
# Attempt to authenticate with a JwtStepRequest.
165
171
return await self ._sasl_continue_jwt (conn , start_resp )
166
172
167
- def _get_access_token (self ) -> Optional [str ]:
173
+ async def _get_access_token (self ) -> Optional [str ]:
168
174
properties = self .properties
169
175
cb : Union [None , OIDCCallback ]
170
176
resp : OIDCCallbackResult
@@ -186,7 +192,7 @@ def _get_access_token(self) -> Optional[str]:
186
192
return None
187
193
188
194
if not prev_token and cb is not None :
189
- with self .lock :
195
+ async with self .lock : # type: ignore[attr-defined]
190
196
# See if the token was changed while we were waiting for the
191
197
# lock.
192
198
new_token = self .access_token
@@ -196,7 +202,7 @@ def _get_access_token(self) -> Optional[str]:
196
202
# Ensure that we are waiting a min time between callback invocations.
197
203
delta = time .time () - self .last_call_time
198
204
if delta < TIME_BETWEEN_CALLS_SECONDS :
199
- time .sleep (TIME_BETWEEN_CALLS_SECONDS - delta )
205
+ await asyncio .sleep (TIME_BETWEEN_CALLS_SECONDS - delta )
200
206
self .last_call_time = time .time ()
201
207
202
208
if is_human :
@@ -211,7 +217,10 @@ def _get_access_token(self) -> Optional[str]:
211
217
idp_info = self .idp_info ,
212
218
username = self .properties .username ,
213
219
)
214
- resp = cb .fetch (context )
220
+ if not _IS_SYNC :
221
+ resp = await asyncio .get_running_loop ().run_in_executor (None , cb .fetch , context ) # type: ignore[assignment]
222
+ else :
223
+ resp = cb .fetch (context )
215
224
if not isinstance (resp , OIDCCallbackResult ):
216
225
raise ValueError (
217
226
f"Callback result must be of type OIDCCallbackResult, not { type (resp )} "
@@ -253,13 +262,13 @@ async def _sasl_continue_jwt(
253
262
start_payload : dict = bson .decode (start_resp ["payload" ])
254
263
if "issuer" in start_payload :
255
264
self .idp_info = OIDCIdPInfo (** start_payload )
256
- access_token = self ._get_access_token ()
265
+ access_token = await self ._get_access_token ()
257
266
conn .oidc_token_gen_id = self .token_gen_id
258
267
cmd = self ._get_continue_command ({"jwt" : access_token }, start_resp )
259
268
return await self ._run_command (conn , cmd )
260
269
261
270
async def _sasl_start_jwt (self , conn : AsyncConnection ) -> Mapping [str , Any ]:
262
- access_token = self ._get_access_token ()
271
+ access_token = await self ._get_access_token ()
263
272
conn .oidc_token_gen_id = self .token_gen_id
264
273
cmd = self ._get_start_command ({"jwt" : access_token })
265
274
return await self ._run_command (conn , cmd )
0 commit comments