diff --git a/test/asynchronous/test_auth_oidc.py b/test/asynchronous/test_auth_oidc.py index 8c06c4a21d..f450c75df7 100644 --- a/test/asynchronous/test_auth_oidc.py +++ b/test/asynchronous/test_auth_oidc.py @@ -1085,6 +1085,25 @@ async def test_4_4_speculative_authentication_should_be_ignored_on_reauthenticat # Assert there were `SaslStart` commands executed. assert any(event.command_name.lower() == "saslstart" for event in listener.started_events) + async def test_4_5_reauthentication_succeeds_when_a_session_is_involved(self): + # Create an OIDC configured client. + client = await self.create_client() + + # Set a fail point for `find` commands of the form: + async with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Start a new session. + async with client.start_session() as session: + # In the started session perform a `find` operation that succeeds. + await client.test.test.find_one({}, session=session) + + # Assert that the callback was called 2 times (once during the connection handshake, and again during reauthentication). + self.assertEqual(self.request_called, 2) + async def test_5_1_azure_with_no_username(self): if ENVIRON != "azure": raise unittest.SkipTest("Test is only supported on Azure") diff --git a/test/test_auth_oidc.py b/test/test_auth_oidc.py index e7a8dce957..33a1e55fe2 100644 --- a/test/test_auth_oidc.py +++ b/test/test_auth_oidc.py @@ -1083,6 +1083,25 @@ def test_4_4_speculative_authentication_should_be_ignored_on_reauthentication(se # Assert there were `SaslStart` commands executed. assert any(event.command_name.lower() == "saslstart" for event in listener.started_events) + def test_4_5_reauthentication_succeeds_when_a_session_is_involved(self): + # Create an OIDC configured client. + client = self.create_client() + + # Set a fail point for `find` commands of the form: + with self.fail_point( + { + "mode": {"times": 1}, + "data": {"failCommands": ["find"], "errorCode": 391}, + } + ): + # Start a new session. + with client.start_session() as session: + # In the started session perform a `find` operation that succeeds. + client.test.test.find_one({}, session=session) + + # Assert that the callback was called 2 times (once during the connection handshake, and again during reauthentication). + self.assertEqual(self.request_called, 2) + def test_5_1_azure_with_no_username(self): if ENVIRON != "azure": raise unittest.SkipTest("Test is only supported on Azure")