Skip to content

Commit 1366b91

Browse files
authored
PYTHON-5394 - Add native async support for OIDC (#2352)
1 parent 2759379 commit 1366b91

File tree

10 files changed

+1232
-33
lines changed

10 files changed

+1232
-33
lines changed

pymongo/asynchronous/auth_oidc.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""MONGODB-OIDC Authentication helpers."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import threading
1920
import time
2021
from dataclasses import dataclass, field
@@ -36,6 +37,7 @@
3637
)
3738
from pymongo.errors import ConfigurationError, OperationFailure
3839
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
40+
from pymongo.lock import Lock, _async_create_lock
3941

4042
if TYPE_CHECKING:
4143
from pymongo.asynchronous.pool import AsyncConnection
@@ -81,7 +83,11 @@ class _OIDCAuthenticator:
8183
access_token: Optional[str] = field(default=None)
8284
idp_info: Optional[OIDCIdPInfo] = field(default=None)
8385
token_gen_id: int = field(default=0)
84-
lock: threading.Lock = field(default_factory=threading.Lock)
86+
if not _IS_SYNC:
87+
lock: Lock = field(default_factory=_async_create_lock) # type: ignore[assignment]
88+
else:
89+
lock: threading.Lock = field(default_factory=_async_create_lock) # type: ignore[assignment, no-redef]
90+
8591
last_call_time: float = field(default=0)
8692

8793
async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]:
@@ -164,7 +170,7 @@ async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[s
164170
# Attempt to authenticate with a JwtStepRequest.
165171
return await self._sasl_continue_jwt(conn, start_resp)
166172

167-
def _get_access_token(self) -> Optional[str]:
173+
async def _get_access_token(self) -> Optional[str]:
168174
properties = self.properties
169175
cb: Union[None, OIDCCallback]
170176
resp: OIDCCallbackResult
@@ -186,7 +192,7 @@ def _get_access_token(self) -> Optional[str]:
186192
return None
187193

188194
if not prev_token and cb is not None:
189-
with self.lock:
195+
async with self.lock: # type: ignore[attr-defined]
190196
# See if the token was changed while we were waiting for the
191197
# lock.
192198
new_token = self.access_token
@@ -196,7 +202,7 @@ def _get_access_token(self) -> Optional[str]:
196202
# Ensure that we are waiting a min time between callback invocations.
197203
delta = time.time() - self.last_call_time
198204
if delta < TIME_BETWEEN_CALLS_SECONDS:
199-
time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
205+
await asyncio.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
200206
self.last_call_time = time.time()
201207

202208
if is_human:
@@ -211,7 +217,10 @@ def _get_access_token(self) -> Optional[str]:
211217
idp_info=self.idp_info,
212218
username=self.properties.username,
213219
)
214-
resp = cb.fetch(context)
220+
if not _IS_SYNC:
221+
resp = await asyncio.get_running_loop().run_in_executor(None, cb.fetch, context) # type: ignore[assignment]
222+
else:
223+
resp = cb.fetch(context)
215224
if not isinstance(resp, OIDCCallbackResult):
216225
raise ValueError(
217226
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"
@@ -253,13 +262,13 @@ async def _sasl_continue_jwt(
253262
start_payload: dict = bson.decode(start_resp["payload"])
254263
if "issuer" in start_payload:
255264
self.idp_info = OIDCIdPInfo(**start_payload)
256-
access_token = self._get_access_token()
265+
access_token = await self._get_access_token()
257266
conn.oidc_token_gen_id = self.token_gen_id
258267
cmd = self._get_continue_command({"jwt": access_token}, start_resp)
259268
return await self._run_command(conn, cmd)
260269

261270
async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]:
262-
access_token = self._get_access_token()
271+
access_token = await self._get_access_token()
263272
conn.oidc_token_gen_id = self.token_gen_id
264273
cmd = self._get_start_command({"jwt": access_token})
265274
return await self._run_command(conn, cmd)

pymongo/asynchronous/cursor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,6 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
11301130
except BaseException:
11311131
await self.close()
11321132
raise
1133-
11341133
self._address = response.address
11351134
if isinstance(response, PinnedResponse):
11361135
if not self._sock_mgr:

pymongo/asynchronous/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
6464
await conn.authenticate(reauthenticate=True)
6565
else:
6666
raise
67-
return func(*args, **kwargs)
67+
return await func(*args, **kwargs)
6868
raise
6969

7070
return cast(F, inner)

pymongo/synchronous/auth_oidc.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""MONGODB-OIDC Authentication helpers."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import threading
1920
import time
2021
from dataclasses import dataclass, field
@@ -36,6 +37,7 @@
3637
)
3738
from pymongo.errors import ConfigurationError, OperationFailure
3839
from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE
40+
from pymongo.lock import Lock, _create_lock
3941

4042
if TYPE_CHECKING:
4143
from pymongo.auth_shared import MongoCredential
@@ -81,7 +83,11 @@ class _OIDCAuthenticator:
8183
access_token: Optional[str] = field(default=None)
8284
idp_info: Optional[OIDCIdPInfo] = field(default=None)
8385
token_gen_id: int = field(default=0)
84-
lock: threading.Lock = field(default_factory=threading.Lock)
86+
if not _IS_SYNC:
87+
lock: Lock = field(default_factory=_create_lock) # type: ignore[assignment]
88+
else:
89+
lock: threading.Lock = field(default_factory=_create_lock) # type: ignore[assignment, no-redef]
90+
8591
last_call_time: float = field(default=0)
8692

8793
def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
@@ -186,7 +192,7 @@ def _get_access_token(self) -> Optional[str]:
186192
return None
187193

188194
if not prev_token and cb is not None:
189-
with self.lock:
195+
with self.lock: # type: ignore[attr-defined]
190196
# See if the token was changed while we were waiting for the
191197
# lock.
192198
new_token = self.access_token
@@ -211,7 +217,10 @@ def _get_access_token(self) -> Optional[str]:
211217
idp_info=self.idp_info,
212218
username=self.properties.username,
213219
)
214-
resp = cb.fetch(context)
220+
if not _IS_SYNC:
221+
resp = asyncio.get_running_loop().run_in_executor(None, cb.fetch, context) # type: ignore[assignment]
222+
else:
223+
resp = cb.fetch(context)
215224
if not isinstance(resp, OIDCCallbackResult):
216225
raise ValueError(
217226
f"Callback result must be of type OIDCCallbackResult, not {type(resp)}"

pymongo/synchronous/cursor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,6 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
11281128
except BaseException:
11291129
self.close()
11301130
raise
1131-
11321131
self._address = response.address
11331132
if isinstance(response, PinnedResponse):
11341133
if not self._sock_mgr:

0 commit comments

Comments
 (0)