From 58ca6f0b00bb22caafb0f4b727c74bfbdd1be0c6 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 18 Feb 2025 09:41:17 -0800 Subject: [PATCH 1/3] PYTHON-5113 - Refactor test utils for async --- test/asynchronous/test_async_cancellation.py | 2 +- test/asynchronous/test_auth.py | 2 +- test/asynchronous/test_bulk.py | 2 +- test/asynchronous/test_change_stream.py | 2 +- test/asynchronous/test_client.py | 14 +- test/asynchronous/test_client_bulk_write.py | 2 +- test/asynchronous/test_collation.py | 2 +- test/asynchronous/test_collection.py | 2 +- test/asynchronous/test_comment.py | 2 +- test/asynchronous/test_concurrency.py | 2 +- .../test_connection_monitoring.py | 2 +- ...nnections_survive_primary_stepdown_spec.py | 2 +- test/asynchronous/test_cursor.py | 2 +- test/asynchronous/test_data_lake.py | 2 +- test/asynchronous/test_database.py | 2 +- test/asynchronous/test_dns.py | 2 +- test/asynchronous/test_encryption.py | 2 +- test/asynchronous/test_examples.py | 2 +- test/asynchronous/test_grid_file.py | 2 +- test/asynchronous/test_gridfs.py | 2 +- test/asynchronous/test_gridfs_bucket.py | 2 +- .../asynchronous/test_heartbeat_monitoring.py | 2 +- test/asynchronous/test_index_management.py | 2 +- test/asynchronous/test_load_balancer.py | 2 +- test/asynchronous/test_max_staleness.py | 2 +- .../test_mongos_load_balancing.py | 2 +- test/asynchronous/test_monitoring.py | 2 +- test/asynchronous/test_pooling.py | 2 +- test/asynchronous/test_read_concern.py | 2 +- test/asynchronous/test_read_preferences.py | 2 +- .../test_read_write_concern_spec.py | 2 +- test/asynchronous/test_retryable_reads.py | 2 +- test/asynchronous/test_retryable_writes.py | 2 +- .../asynchronous/test_sdam_monitoring_spec.py | 2 +- test/asynchronous/test_server_selection.py | 14 +- .../test_server_selection_in_window.py | 2 +- test/asynchronous/test_session.py | 2 +- test/asynchronous/test_srv_polling.py | 4 +- test/asynchronous/test_ssl.py | 2 +- test/asynchronous/test_streaming_protocol.py | 2 +- test/asynchronous/test_transactions.py | 2 +- .../test_versioned_api_integration.py | 2 +- test/asynchronous/unified_format.py | 2 +- test/asynchronous/utils.py | 276 +++++ test/asynchronous/utils_selection_tests.py | 4 +- test/asynchronous/utils_spec_runner.py | 2 +- test/auth_oidc/test_auth_oidc.py | 2 +- test/test_auth.py | 2 +- test/test_bulk.py | 2 +- test/test_change_stream.py | 2 +- test/test_client.py | 10 +- test/test_client_bulk_write.py | 2 +- test/test_collation.py | 2 +- test/test_collection.py | 2 +- test/test_comment.py | 2 +- test/test_connection_monitoring.py | 2 +- ...nnections_survive_primary_stepdown_spec.py | 2 +- test/test_cursor.py | 2 +- test/test_data_lake.py | 2 +- test/test_database.py | 2 +- test/test_discovery_and_monitoring.py | 2 +- test/test_dns.py | 2 +- test/test_encryption.py | 2 +- test/test_examples.py | 2 +- test/test_fork.py | 2 +- test/test_grid_file.py | 2 +- test/test_gridfs.py | 2 +- test/test_gridfs_bucket.py | 2 +- test/test_heartbeat_monitoring.py | 2 +- test/test_index_management.py | 2 +- test/test_load_balancer.py | 2 +- test/test_max_staleness.py | 2 +- test/test_mongos_load_balancing.py | 2 +- test/test_monitor.py | 2 +- test/test_monitoring.py | 2 +- test/test_objectid.py | 2 +- test/test_pooling.py | 2 +- test/test_read_concern.py | 2 +- test/test_read_preferences.py | 2 +- test/test_read_write_concern_spec.py | 2 +- test/test_replica_set_reconfig.py | 2 +- test/test_retryable_reads.py | 2 +- test/test_retryable_writes.py | 2 +- test/test_sdam_monitoring_spec.py | 2 +- test/test_server_selection.py | 14 +- test/test_server_selection_in_window.py | 4 +- test/test_session.py | 2 +- test/test_ssl.py | 2 +- test/test_streaming_protocol.py | 2 +- test/test_threads.py | 2 +- test/test_topology.py | 2 +- test/test_transactions.py | 2 +- test/test_versioned_api_integration.py | 2 +- test/unified_format.py | 2 +- test/unified_format_shared.py | 2 +- test/utils.py | 1005 ++--------------- test/utils_selection_tests.py | 4 +- test/utils_shared.py | 661 +++++++++++ test/utils_spec_runner.py | 2 +- tools/synchro.py | 2 + 100 files changed, 1162 insertions(+), 1026 deletions(-) create mode 100644 test/asynchronous/utils.py create mode 100644 test/utils_shared.py diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index b73c7a8084..d071fae317 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -17,7 +17,7 @@ import asyncio import sys -from test.utils import async_get_pool, delay, one +from test.utils_shared import async_get_pool, delay, one sys.path[0:0] = [""] diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index 7172152d69..904674db16 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -30,7 +30,7 @@ async_client_context, unittest, ) -from test.utils import AllowListEventListener, delay, ignore_deprecations +from test.utils_shared import AllowListEventListener, delay, ignore_deprecations import pytest diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 86568b666b..5573c3987f 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, remove_all_users, unittest -from test.utils import async_wait_until +from test.utils_shared import async_wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 08da00cc1e..4025c13730 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -36,7 +36,7 @@ unittest, ) from test.asynchronous.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 744a170be2..07b4cdee53 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -60,14 +60,16 @@ unittest, ) from test.asynchronous.pymongo_mocks import AsyncMockClient -from test.test_binary import BinaryData -from test.utils import ( - NTHREADS, - CMAPListener, - FunctionCallRecorder, +from test.asynchronous.utils import ( + AsyncFunctionCallRecorder, async_get_pool, async_wait_until, asyncAssertRaisesExactly, +) +from test.test_binary import BinaryData +from test.utils_shared import ( + NTHREADS, + CMAPListener, delay, gevent_monkey_patched, is_greenthread_patched, @@ -511,7 +513,7 @@ async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. from pymongo.srv_resolver import _resolve - patched_resolver = FunctionCallRecorder(_resolve) + patched_resolver = AsyncFunctionCallRecorder(_resolve) pymongo.srv_resolver._resolve = patched_resolver def reset_resolver(): diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 282009f554..f8b9465b09 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -25,7 +25,7 @@ async_client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) from unittest.mock import patch diff --git a/test/asynchronous/test_collation.py b/test/asynchronous/test_collation.py index d7fd85b168..05e548c79e 100644 --- a/test/asynchronous/test_collation.py +++ b/test/asynchronous/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import EventListener, OvertCommandListener +from test.utils_shared import EventListener, OvertCommandListener from typing import Any from pymongo.asynchronous.helpers import anext diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index beb58012a8..6e64c2d668 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -33,7 +33,7 @@ AsyncUnitTest, async_client_context, ) -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, OvertCommandListener, diff --git a/test/asynchronous/test_comment.py b/test/asynchronous/test_comment.py index be3626a8b8..d3ddaf2b65 100644 --- a/test/asynchronous/test_comment.py +++ b/test/asynchronous/test_comment.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from asyncio import iscoroutinefunction from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.dbref import DBRef from pymongo.asynchronous.command_cursor import AsyncCommandCursor diff --git a/test/asynchronous/test_concurrency.py b/test/asynchronous/test_concurrency.py index 1683b8413b..193ecf05c8 100644 --- a/test/asynchronous/test_concurrency.py +++ b/test/asynchronous/test_concurrency.py @@ -18,7 +18,7 @@ import asyncio import time from test.asynchronous import AsyncIntegrationTest, async_client_context -from test.utils import delay +from test.utils_shared import delay _IS_SYNC = False diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py index a68b2a90cb..a547e0bcd4 100644 --- a/test/asynchronous/test_connection_monitoring.py +++ b/test/asynchronous/test_connection_monitoring.py @@ -26,7 +26,7 @@ from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest from test.asynchronous.pymongo_mocks import DummyMonitor from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerTask -from test.utils import ( +from test.utils_shared import ( CMAPListener, async_client_context, async_get_pool, diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index 7c11742a90..8754125f23 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -25,7 +25,7 @@ unittest, ) from test.asynchronous.helpers import async_repl_set_step_down -from test.utils import ( +from test.utils_shared import ( CMAPListener, async_ensure_all_connected, ) diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index d843ffb4aa..90d5e7801e 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, diff --git a/test/asynchronous/test_data_lake.py b/test/asynchronous/test_data_lake.py index 0b259fb0d0..33b9d33d76 100644 --- a/test/asynchronous/test_data_lake.py +++ b/test/asynchronous/test_data_lake.py @@ -25,7 +25,7 @@ from test.asynchronous import AsyncIntegrationTest, AsyncUnitTest, async_client_context, unittest from test.asynchronous.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 55a8cc3ab2..d5ef582a67 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -26,7 +26,7 @@ from test import unittest from test.asynchronous import AsyncIntegrationTest, async_client_context from test.test_custom_types import DECIMAL_CODECOPTS -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, async_wait_until, diff --git a/test/asynchronous/test_dns.py b/test/asynchronous/test_dns.py index e24e0fb5ce..a622062fec 100644 --- a/test/asynchronous/test_dns.py +++ b/test/asynchronous/test_dns.py @@ -29,7 +29,7 @@ async_client_context, unittest, ) -from test.utils import async_wait_until +from test.utils_shared import async_wait_until from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 335aa9d81c..000d98a111 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -64,7 +64,7 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, OvertCommandListener, TopologyEventListener, diff --git a/test/asynchronous/test_examples.py b/test/asynchronous/test_examples.py index 7fea9d41af..1312f1e215 100644 --- a/test/asynchronous/test_examples.py +++ b/test/asynchronous/test_examples.py @@ -26,7 +26,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import async_wait_until +from test.utils_shared import async_wait_until import pymongo from pymongo.asynchronous.helpers import anext diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index affdacde91..3f864367de 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.objectid import ObjectId from gridfs.asynchronous.grid_file import ( diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py index b1c1e754ff..9582f9eca0 100644 --- a/test/asynchronous/test_gridfs.py +++ b/test/asynchronous/test_gridfs.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import async_joinall, one +from test.utils_shared import async_joinall, one import gridfs from bson.binary import Binary diff --git a/test/asynchronous/test_gridfs_bucket.py b/test/asynchronous/test_gridfs_bucket.py index 5d1cf5beff..009bc7c620 100644 --- a/test/asynchronous/test_gridfs_bucket.py +++ b/test/asynchronous/test_gridfs_bucket.py @@ -29,7 +29,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import async_joinall, joinall, one +from test.utils_shared import async_joinall, joinall, one import gridfs from bson.binary import Binary diff --git a/test/asynchronous/test_heartbeat_monitoring.py b/test/asynchronous/test_heartbeat_monitoring.py index ff595a8144..92e63460dc 100644 --- a/test/asynchronous/test_heartbeat_monitoring.py +++ b/test/asynchronous/test_heartbeat_monitoring.py @@ -20,7 +20,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest -from test.utils import AsyncMockPool, HeartbeatEventListener, async_wait_until +from test.utils_shared import AsyncMockPool, HeartbeatEventListener, async_wait_until from pymongo.asynchronous.monitor import Monitor from pymongo.errors import ConnectionFailure diff --git a/test/asynchronous/test_index_management.py b/test/asynchronous/test_index_management.py index 2920c48b2f..73da83456f 100644 --- a/test/asynchronous/test_index_management.py +++ b/test/asynchronous/test_index_management.py @@ -29,7 +29,7 @@ from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest from test.asynchronous.unified_format import generate_test_classes -from test.utils import AllowListEventListener, OvertCommandListener +from test.utils_shared import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure from pymongo.operations import SearchIndexModel diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py index fd50841c87..78afa914ed 100644 --- a/test/asynchronous/test_load_balancer.py +++ b/test/asynchronous/test_load_balancer.py @@ -30,7 +30,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( async_get_pool, async_wait_until, create_async_event, diff --git a/test/asynchronous/test_max_staleness.py b/test/asynchronous/test_max_staleness.py index 7dbf17021f..747d7870fc 100644 --- a/test/asynchronous/test_max_staleness.py +++ b/test/asynchronous/test_max_staleness.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncPyMongoTestCase, async_client_context, unittest -from test.utils_selection_tests import create_selection_tests +from test.utils_shared_selection_tests import create_selection_tests from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py index 0bc6a405f4..97170aa9e0 100644 --- a/test/asynchronous/test_mongos_load_balancing.py +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -26,7 +26,7 @@ from test.asynchronous import AsyncMockClientTest, async_client_context, connected, unittest from test.asynchronous.pymongo_mocks import AsyncMockClient -from test.utils import async_wait_until +from test.utils_shared import async_wait_until from pymongo.errors import AutoReconnect, InvalidOperation from pymongo.server_selectors import writable_server_selector diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index eaad60beac..a7d56a8cf7 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -29,7 +29,7 @@ sanitize_cmd, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, async_wait_until, diff --git a/test/asynchronous/test_pooling.py b/test/asynchronous/test_pooling.py index 09b8fb0853..8f787b5328 100644 --- a/test/asynchronous/test_pooling.py +++ b/test/asynchronous/test_pooling.py @@ -33,7 +33,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.helpers import ConcurrentRunner -from test.utils import async_get_pool, async_joinall, delay +from test.utils_shared import async_get_pool, async_joinall, delay from pymongo.asynchronous.pool import Pool, PoolOptions from pymongo.socket_checker import SocketChecker diff --git a/test/asynchronous/test_read_concern.py b/test/asynchronous/test_read_concern.py index fbc07a5c36..8659bf80b2 100644 --- a/test/asynchronous/test_read_concern.py +++ b/test/asynchronous/test_read_concern.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.son import SON from pymongo.errors import OperationFailure diff --git a/test/asynchronous/test_read_preferences.py b/test/asynchronous/test_read_preferences.py index 077bc21eaf..5bea174058 100644 --- a/test/asynchronous/test_read_preferences.py +++ b/test/asynchronous/test_read_preferences.py @@ -33,7 +33,7 @@ connected, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, async_wait_until, one, diff --git a/test/asynchronous/test_read_write_concern_spec.py b/test/asynchronous/test_read_write_concern_spec.py index 3fb13ba194..86f79fd28d 100644 --- a/test/asynchronous/test_read_write_concern_spec.py +++ b/test/asynchronous/test_read_write_concern_spec.py @@ -25,7 +25,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.unified_format import generate_test_classes -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from pymongo import DESCENDING from pymongo.asynchronous.mongo_client import AsyncMongoClient diff --git a/test/asynchronous/test_retryable_reads.py b/test/asynchronous/test_retryable_reads.py index bde7a9f2ee..a0e46ff65a 100644 --- a/test/asynchronous/test_retryable_reads.py +++ b/test/asynchronous/test_retryable_reads.py @@ -31,7 +31,7 @@ client_knobs, unittest, ) -from test.utils import ( +from test.utils_shared import ( CMAPListener, OvertCommandListener, async_set_fail_point, diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py index 738ce04192..bbcf8d5ed2 100644 --- a/test/asynchronous/test_retryable_writes.py +++ b/test/asynchronous/test_retryable_writes.py @@ -30,7 +30,7 @@ unittest, ) from test.asynchronous.helpers import client_knobs -from test.utils import ( +from test.utils_shared import ( CMAPListener, DeprecationFilter, EventListener, diff --git a/test/asynchronous/test_sdam_monitoring_spec.py b/test/asynchronous/test_sdam_monitoring_spec.py index 8b0ec63cfe..71ec6c6b46 100644 --- a/test/asynchronous/test_sdam_monitoring_spec.py +++ b/test/asynchronous/test_sdam_monitoring_spec.py @@ -25,7 +25,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs, unittest -from test.utils import ( +from test.utils_shared import ( ServerAndTopologyEventListener, async_wait_until, server_name_to_type, diff --git a/test/asynchronous/test_server_selection.py b/test/asynchronous/test_server_selection.py index f0451841cd..64f7784569 100644 --- a/test/asynchronous/test_server_selection.py +++ b/test/asynchronous/test_server_selection.py @@ -31,17 +31,17 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.utils import AsyncFunctionCallRecorder, async_wait_until from test.asynchronous.utils_selection_tests import ( create_selection_tests, - get_addresses, get_topology_settings_dict, +) +from test.utils_selection_tests_shared import ( + get_addresses, make_server_description, ) -from test.utils import ( - EventListener, - FunctionCallRecorder, +from test.utils_shared import ( OvertCommandListener, - async_wait_until, ) _IS_SYNC = False @@ -122,7 +122,7 @@ async def test_invalid_server_selector(self): @async_client_context.require_replica_set async def test_selector_called(self): - selector = FunctionCallRecorder(lambda x: x) + selector = AsyncFunctionCallRecorder(lambda x: x) # Client setup. mongo_client = await self.async_rs_or_single_client(server_selector=selector) @@ -175,7 +175,7 @@ async def test_latency_threshold_application(self): @async_client_context.require_replica_set async def test_server_selector_bypassed(self): - selector = FunctionCallRecorder(lambda x: x) + selector = AsyncFunctionCallRecorder(lambda x: x) scenario_def = { "topology_description": { diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index e2ae92a27c..8151472719 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -23,7 +23,7 @@ from test.asynchronous.helpers import ConcurrentRunner from test.asynchronous.utils_selection_tests import create_topology from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator -from test.utils import ( +from test.utils_shared import ( CMAPListener, OvertCommandListener, async_get_pool, diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 03d1032b5b..9633124aa6 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -36,7 +36,7 @@ async_client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, async_wait_until, diff --git a/test/asynchronous/test_srv_polling.py b/test/asynchronous/test_srv_polling.py index 763c80e665..85e39f4658 100644 --- a/test/asynchronous/test_srv_polling.py +++ b/test/asynchronous/test_srv_polling.py @@ -23,7 +23,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncPyMongoTestCase, client_knobs, unittest -from test.utils import FunctionCallRecorder, async_wait_until +from test.asynchronous.utils import AsyncFunctionCallRecorder, async_wait_until import pymongo from pymongo import common @@ -69,7 +69,7 @@ def mock_get_hosts_and_min_ttl(resolver, *args): patch_func: Any if self.count_resolver_calls: - patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl) + patch_func = AsyncFunctionCallRecorder(mock_get_hosts_and_min_ttl) else: patch_func = mock_get_hosts_and_min_ttl diff --git a/test/asynchronous/test_ssl.py b/test/asynchronous/test_ssl.py index d50bb220b1..d920b77ac2 100644 --- a/test/asynchronous/test_ssl.py +++ b/test/asynchronous/test_ssl.py @@ -32,7 +32,7 @@ remove_all_users, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, cat_files, diff --git a/test/asynchronous/test_streaming_protocol.py b/test/asynchronous/test_streaming_protocol.py index fd890d29fb..1206e7b2fa 100644 --- a/test/asynchronous/test_streaming_protocol.py +++ b/test/asynchronous/test_streaming_protocol.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import ( +from test.utils_shared import ( HeartbeatEventListener, ServerEventListener, async_wait_until, diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index d11d0a9776..59a42d7291 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, async_wait_until, ) diff --git a/test/asynchronous/test_versioned_api_integration.py b/test/asynchronous/test_versioned_api_integration.py index 7e9a79da90..46e62d5c14 100644 --- a/test/asynchronous/test_versioned_api_integration.py +++ b/test/asynchronous/test_versioned_api_integration.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from pymongo.server_api import ServerApi diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 149aad9786..baca5a9902 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -49,7 +49,7 @@ parse_collection_or_database_options, with_metaclass, ) -from test.utils import ( +from test.utils_shared import ( async_get_pool, async_wait_until, camel_to_snake, diff --git a/test/asynchronous/utils.py b/test/asynchronous/utils.py new file mode 100644 index 0000000000..f87589bf63 --- /dev/null +++ b/test/asynchronous/utils.py @@ -0,0 +1,276 @@ +# Copyright 2012-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing pymongo that require synchronization.""" +from __future__ import annotations + +import asyncio +import contextlib +import copy +import functools +import os +import random +import re +import shutil +import sys +import threading +import time +import unittest +import warnings +from asyncio import iscoroutinefunction +from collections import abc, defaultdict +from functools import partial +from test import client_context, db_pwd, db_user +from test.asynchronous import async_client_context +from typing import Any, List + +from bson import json_util +from bson.objectid import ObjectId +from bson.son import SON +from pymongo import AsyncMongoClient, monitoring, operations, read_preferences +from pymongo._asyncio_task import create_task +from pymongo.cursor_shared import CursorType +from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.hello import HelloCompat +from pymongo.helpers_shared import _SENSITIVE_COMMANDS +from pymongo.lock import _async_create_lock, _create_lock +from pymongo.monitoring import ( + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, +) +from pymongo.operations import _Op +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.server_selectors import any_server_selector, writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.mongo_client import MongoClient +from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration +from pymongo.uri_parser import parse_uri +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +async def async_get_pool(client): + """Get the standalone, primary, or mongos pool.""" + topology = await client._get_topology() + server = await topology._select_server(writable_server_selector, _Op.TEST) + return server.pool + + +async def async_get_pools(client): + """Get all pools.""" + return [ + server.pool + for server in await (await client._get_topology()).select_servers( + any_server_selector, _Op.TEST + ) + ] + + +async def async_wait_until(predicate, success_description, timeout=10): + """Wait up to 10 seconds (by default) for predicate to be true. + + E.g.: + + wait_until(lambda: client.primary == ('a', 1), + 'connect to the primary') + + If the lambda-expression isn't true after 10 seconds, we raise + AssertionError("Didn't ever connect to the primary"). + + Returns the predicate's first true value. + """ + start = time.time() + interval = min(float(timeout) / 100, 0.1) + while True: + if iscoroutinefunction(predicate): + retval = await predicate() + else: + retval = predicate() + if retval: + return retval + + if time.time() - start > timeout: + raise AssertionError("Didn't ever %s" % success_description) + + await asyncio.sleep(interval) + + +async def async_is_mongos(client): + res = await client.admin.command(HelloCompat.LEGACY_CMD) + return res.get("msg", "") == "isdbgrid" + + +async def async_ensure_all_connected(client: AsyncMongoClient) -> None: + """Ensure that the client's connection pool has socket connections to all + members of a replica set. Raises ConfigurationError when called with a + non-replica set client. + + Depending on the use-case, the caller may need to clear any event listeners + that are configured on the client. + """ + hello: dict = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" not in hello: + raise ConfigurationError("cluster is not a replica set") + + target_host_list = set(hello["hosts"] + hello.get("passives", [])) + connected_host_list = {hello["me"]} + + # Run hello until we have connected to each host at least once. + async def discover(): + i = 0 + while i < 100 and connected_host_list != target_host_list: + hello: dict = await client.admin.command( + HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY + ) + connected_host_list.update([hello["me"]]) + i += 1 + return connected_host_list + + try: + + async def predicate(): + return target_host_list == await discover() + + await async_wait_until(predicate, "connected to all hosts") + except AssertionError as exc: + raise AssertionError( + f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" + ) + + +async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs): + """ + Unlike the standard assertRaises, this checks that a function raises a + specific class of exception, and not a subclass. E.g., check that + MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. + """ + try: + await fn(*args, **kwargs) + except Exception as e: + assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" + else: + raise AssertionError("%s not raised" % cls) + + +async def async_set_fail_point(client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await client.admin.command(cmd) + + +async def async_joinall(tasks): + """Join threads with a 5-minute timeout, assert joins succeeded""" + if _IS_SYNC: + for t in tasks: + t.join(300) + assert not t.is_alive(), "Thread %s hung" % t + else: + await asyncio.wait([t.task for t in tasks if t is not None], timeout=300) + + +class AsyncMockConnection: + def __init__(self): + self.cancel_context = _CancellationContext() + self.more_to_come = False + self.id = random.randint(0, 100) + + def close_conn(self, reason): + pass + + def __aenter__(self): + return self + + def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +class AsyncMockPool: + def __init__(self, address, options, handshake=True, client_id=None): + self.gen = _PoolGeneration() + self._lock = _async_create_lock() + self.opts = options + self.operation_count = 0 + self.conns = [] + + def stale_generation(self, gen, service_id): + return self.gen.stale(gen, service_id) + + @contextlib.asynccontextmanager + async def checkout(self, handler=None): + yield AsyncMockConnection() + + async def checkin(self, *args, **kwargs): + pass + + async def _reset(self, service_id=None): + async with self._lock: + self.gen.inc(service_id) + + async def ready(self): + pass + + async def reset(self, service_id=None, interrupt_connections=False): + await self._reset() + + async def reset_without_pause(self): + await self._reset() + + async def close(self): + await self._reset() + + async def update_is_writable(self, is_writable): + pass + + async def remove_stale_sockets(self, *args, **kwargs): + pass + + +class AsyncFunctionCallRecorder: + """Utility class to wrap a callable and record its invocations.""" + + def __init__(self, function): + self._function = function + self._call_list = [] + + async def __call__(self, *args, **kwargs): + self._call_list.append((args, kwargs)) + if iscoroutinefunction(self._function): + return await self._function(*args, **kwargs) + else: + return self._function(*args, **kwargs) + + def reset(self): + """Wipes the call list.""" + self._call_list = [] + + def call_list(self): + """Returns a copy of the call list.""" + return self._call_list[:] + + @property + def call_count(self): + """Returns the number of times the function has been called.""" + return len(self._call_list) diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py index 71e287569a..6a63bfe6e7 100644 --- a/test/asynchronous/utils_selection_tests.py +++ b/test/asynchronous/utils_selection_tests.py @@ -24,8 +24,8 @@ from test import unittest from test.pymongo_mocks import DummyMonitor -from test.utils import AsyncMockPool, parse_read_preference -from test.utils_selection_tests_shared import ( +from test.utils_shared import AsyncMockPool, parse_read_preference +from test.utils_shared_selection_tests_shared import ( get_addresses, get_topology_type_name, make_server_description, diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 11d88850fc..6cc9a625b2 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -23,7 +23,7 @@ from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs from test.asynchronous.helpers import ConcurrentRunner -from test.utils import ( +from test.utils_shared import ( CMAPListener, CompareType, EventListener, diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 7a78f3d2f6..a5334d79bd 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test.unified_format import generate_test_classes -from test.utils import EventListener, OvertCommandListener +from test.utils_shared import EventListener, OvertCommandListener from bson import SON from pymongo import MongoClient diff --git a/test/test_auth.py b/test/test_auth.py index 345d16121b..27f6743fae 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -30,7 +30,7 @@ client_context, unittest, ) -from test.utils import AllowListEventListener, delay, ignore_deprecations +from test.utils_shared import AllowListEventListener, delay, ignore_deprecations import pytest diff --git a/test/test_bulk.py b/test/test_bulk.py index 6a72bddfc0..8a863cc49b 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, remove_all_users, unittest -from test.utils import wait_until +from test.utils_shared import wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 4ed21f55cf..e50f4667f6 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -36,7 +36,7 @@ unittest, ) from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, diff --git a/test/test_client.py b/test/test_client.py index cdc7691c28..3a1aa4d1b4 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -61,17 +61,19 @@ from test.pymongo_mocks import MockClient from test.test_binary import BinaryData from test.utils import ( - NTHREADS, - CMAPListener, FunctionCallRecorder, assertRaisesExactly, - delay, get_pool, + wait_until, +) +from test.utils_shared import ( + NTHREADS, + CMAPListener, + delay, gevent_monkey_patched, is_greenthread_patched, lazy_client_trial, one, - wait_until, ) import bson diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index f8d92668ea..b00b2c1b03 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -25,7 +25,7 @@ client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) from unittest.mock import patch diff --git a/test/test_collation.py b/test/test_collation.py index 06436f0638..5425551dc6 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test import IntegrationTest, client_context, unittest -from test.utils import EventListener, OvertCommandListener +from test.utils_shared import EventListener, OvertCommandListener from typing import Any from pymongo.collation import ( diff --git a/test/test_collection.py b/test/test_collection.py index 8a862646eb..264f642921 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -33,7 +33,7 @@ client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, EventListener, OvertCommandListener, diff --git a/test/test_comment.py b/test/test_comment.py index 9f9bf98640..b6c17c14fe 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from asyncio import iscoroutinefunction from test import IntegrationTest, client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.dbref import DBRef from pymongo.operations import IndexModel diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 810d440932..4e89c69de5 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -25,7 +25,7 @@ from test import IntegrationTest, client_knobs, unittest from test.pymongo_mocks import DummyMonitor -from test.utils import ( +from test.utils_shared import ( CMAPListener, camel_to_snake, client_context, diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 9cac633301..bb2a1f4ac6 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -25,7 +25,7 @@ unittest, ) from test.helpers import repl_set_step_down -from test.utils import ( +from test.utils_shared import ( CMAPListener, ensure_all_connected, ) diff --git a/test/test_cursor.py b/test/test_cursor.py index 84e431f8cb..a9cbe99942 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, EventListener, OvertCommandListener, diff --git a/test/test_data_lake.py b/test/test_data_lake.py index 797ef85000..0228d5f2c6 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -25,7 +25,7 @@ from test import IntegrationTest, UnitTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, ) diff --git a/test/test_database.py b/test/test_database.py index aad9089bd8..f4df118ce8 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -25,7 +25,7 @@ from test import IntegrationTest, client_context, unittest from test.test_custom_types import DECIMAL_CODECOPTS -from test.utils import ( +from test.utils_shared import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, wait_until, diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index ce7a52f1a0..c5e6e058aa 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -25,7 +25,7 @@ from test import IntegrationTest, PyMongoTestCase, unittest from test.pymongo_mocks import DummyMonitor from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( CMAPListener, HeartbeatEventListener, HeartbeatEventsListListener, diff --git a/test/test_dns.py b/test/test_dns.py index 6f4736fd5e..71326ae49e 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -29,7 +29,7 @@ client_context, unittest, ) -from test.utils import wait_until +from test.utils_shared import wait_until from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError diff --git a/test/test_encryption.py b/test/test_encryption.py index 9224310144..6efb167442 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -63,7 +63,7 @@ ) from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( AllowListEventListener, OvertCommandListener, TopologyEventListener, diff --git a/test/test_examples.py b/test/test_examples.py index 9bcc276248..ef06a77b9a 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -26,7 +26,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import wait_until +from test.utils_shared import wait_until import pymongo from pymongo.errors import ConnectionFailure, OperationFailure diff --git a/test/test_fork.py b/test/test_fork.py index 1a89159435..fe88d778d2 100644 --- a/test/test_fork.py +++ b/test/test_fork.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest -from test.utils import is_greenthread_patched +from test.utils_shared import is_greenthread_patched from bson.objectid import ObjectId diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 6534bc11bf..0baeb5ae19 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.objectid import ObjectId from gridfs.errors import NoFile diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 47e38141b2..c8b9a7ee1b 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one +from test.utils_shared import joinall, one import gridfs from bson.binary import Binary diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index e7486cb237..5cb19c557a 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -29,7 +29,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one +from test.utils_shared import joinall, one import gridfs from bson.binary import Binary diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 0523d0ba4d..13516088cd 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -20,7 +20,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_knobs, unittest -from test.utils import HeartbeatEventListener, MockPool, wait_until +from test.utils_shared import HeartbeatEventListener, MockPool, wait_until from pymongo.errors import ConnectionFailure from pymongo.hello import Hello, HelloCompat diff --git a/test/test_index_management.py b/test/test_index_management.py index 5135e43f1f..a5c02b232f 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -29,7 +29,7 @@ from test import IntegrationTest, PyMongoTestCase, unittest from test.unified_format import generate_test_classes -from test.utils import AllowListEventListener, OvertCommandListener +from test.utils_shared import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure from pymongo.operations import SearchIndexModel diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 7db19b46b5..41bbe0ea41 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -30,7 +30,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ( +from test.utils_shared import ( create_event, get_pool, wait_until, diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 56e047fd4b..05943e3d64 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test import PyMongoTestCase, client_context, unittest -from test.utils_selection_tests import create_selection_tests +from test.utils_shared_selection_tests import create_selection_tests from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index ca2f3cfd1e..8c31854343 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -26,7 +26,7 @@ from test import MockClientTest, client_context, connected, unittest from test.pymongo_mocks import MockClient -from test.utils import wait_until +from test.utils_shared import wait_until from pymongo.errors import AutoReconnect, InvalidOperation from pymongo.server_selectors import writable_server_selector diff --git a/test/test_monitor.py b/test/test_monitor.py index a704f3d8cb..837464ba36 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, connected, unittest -from test.utils import ( +from test.utils_shared import ( ServerAndTopologyEventListener, wait_until, ) diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 670558c0a0..ae3e50db77 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -29,7 +29,7 @@ sanitize_cmd, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, wait_until, diff --git a/test/test_objectid.py b/test/test_objectid.py index 26670832f6..d7db7229ea 100644 --- a/test/test_objectid.py +++ b/test/test_objectid.py @@ -23,7 +23,7 @@ sys.path[0:0] = [""] from test import SkipTest, unittest -from test.utils import oid_generated_on_process +from test.utils_shared import oid_generated_on_process from bson.errors import InvalidId from bson.objectid import _MAX_COUNTER_VALUE, ObjectId diff --git a/test/test_pooling.py b/test/test_pooling.py index 5d23b85f23..81f8caf31d 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -33,7 +33,7 @@ from test import IntegrationTest, client_context, unittest from test.helpers import ConcurrentRunner -from test.utils import delay, get_pool, joinall +from test.utils_shared import delay, get_pool, joinall from pymongo.socket_checker import SocketChecker from pymongo.synchronous.pool import Pool, PoolOptions diff --git a/test/test_read_concern.py b/test/test_read_concern.py index 8ec9865eaa..62b2491475 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from bson.son import SON from pymongo.errors import OperationFailure diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 0d38f3f00d..e754c896ad 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -33,7 +33,7 @@ connected, unittest, ) -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, one, wait_until, diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 8543991f72..383dc70902 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -25,7 +25,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from pymongo import DESCENDING from pymongo.errors import ( diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index 4c23d71b69..3371543f27 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -21,7 +21,7 @@ from test import MockClientTest, client_context, client_knobs, unittest from test.pymongo_mocks import MockClient -from test.utils import wait_until +from test.utils_shared import wait_until from pymongo import ReadPreference from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 9c3f6b170f..c131b225ca 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -31,7 +31,7 @@ client_knobs, unittest, ) -from test.utils import ( +from test.utils_shared import ( CMAPListener, OvertCommandListener, set_fail_point, diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 07bd1db0ba..854f6c2a90 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -30,7 +30,7 @@ unittest, ) from test.helpers import client_knobs -from test.utils import ( +from test.utils_shared import ( CMAPListener, DeprecationFilter, EventListener, diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 6a53c062cc..2167e561cf 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -25,7 +25,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, client_knobs, unittest -from test.utils import ( +from test.utils_shared import ( ServerAndTopologyEventListener, server_name_to_type, wait_until, diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 3e7f9a8671..2e8d9f4f87 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -31,18 +31,18 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( - EventListener, - FunctionCallRecorder, - OvertCommandListener, - wait_until, -) +from test.utils import FunctionCallRecorder, wait_until from test.utils_selection_tests import ( create_selection_tests, - get_addresses, get_topology_settings_dict, +) +from test.utils_selection_tests_shared import ( + get_addresses, make_server_description, ) +from test.utils_shared import ( + OvertCommandListener, +) _IS_SYNC = True diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 7ccd4b529e..92d8ad89b4 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -21,13 +21,13 @@ from pathlib import Path from test import IntegrationTest, client_context, unittest from test.helpers import ConcurrentRunner -from test.utils import ( +from test.utils_selection_tests import create_topology +from test.utils_shared import ( CMAPListener, OvertCommandListener, get_pool, wait_until, ) -from test.utils_selection_tests import create_topology from test.utils_spec_runner import SpecTestCreator from pymongo.common import clean_node diff --git a/test/test_session.py b/test/test_session.py index 175a282495..7040d87f55 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -36,7 +36,7 @@ client_context, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, wait_until, diff --git a/test/test_ssl.py b/test/test_ssl.py index 7d6c3f7cd1..a66fe21be5 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -32,7 +32,7 @@ remove_all_users, unittest, ) -from test.utils import ( +from test.utils_shared import ( EventListener, OvertCommandListener, cat_files, diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 894e89e208..acf7610c94 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.utils_shared import ( HeartbeatEventListener, ServerEventListener, wait_until, diff --git a/test/test_threads.py b/test/test_threads.py index 3e469e28fe..7c128720e1 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -17,7 +17,7 @@ import threading from test import IntegrationTest, client_context, unittest -from test.utils import joinall +from test.utils_shared import joinall @client_context.require_connection diff --git a/test/test_topology.py b/test/test_topology.py index 86aa87c2cc..0840e42ba8 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -23,7 +23,7 @@ from test import client_knobs, unittest from test.pymongo_mocks import DummyMonitor -from test.utils import MockPool, wait_until +from test.utils_shared import MockPool, wait_until from bson.objectid import ObjectId from pymongo import common diff --git a/test/test_transactions.py b/test/test_transactions.py index 949b88e60b..f1357badd3 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import ( +from test.utils_shared import ( OvertCommandListener, wait_until, ) diff --git a/test/test_versioned_api_integration.py b/test/test_versioned_api_integration.py index 502198576a..0066ecd977 100644 --- a/test/test_versioned_api_integration.py +++ b/test/test_versioned_api_integration.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import OvertCommandListener +from test.utils_shared import OvertCommandListener from pymongo.server_api import ServerApi diff --git a/test/unified_format.py b/test/unified_format.py index b2e6ae1e83..8894aff94b 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -48,7 +48,7 @@ parse_collection_or_database_options, with_metaclass, ) -from test.utils import ( +from test.utils_shared import ( camel_to_snake, camel_to_snake_args, get_pool, diff --git a/test/unified_format_shared.py b/test/unified_format_shared.py index 0c685366f4..009c5c7e28 100644 --- a/test/unified_format_shared.py +++ b/test/unified_format_shared.py @@ -35,7 +35,7 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.utils import CMAPListener, camel_to_snake, parse_collection_options +from test.utils_shared import CMAPListener, camel_to_snake, parse_collection_options from typing import Any, Union from bson import ( diff --git a/test/utils.py b/test/utils.py index e089b3fc2f..2d8693a77e 100644 --- a/test/utils.py +++ b/test/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for testing pymongo""" +"""Utilities for testing pymongo that require synchronization.""" from __future__ import annotations import asyncio @@ -32,19 +32,18 @@ from collections import abc, defaultdict from functools import partial from test import client_context, db_pwd, db_user -from test.asynchronous import async_client_context from typing import Any, List from bson import json_util from bson.objectid import ObjectId from bson.son import SON -from pymongo import AsyncMongoClient, monitoring, operations, read_preferences +from pymongo import MongoClient, monitoring, operations, read_preferences from pymongo._asyncio_task import create_task from pymongo.cursor_shared import CursorType from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat from pymongo.helpers_shared import _SENSITIVE_COMMANDS -from pymongo.lock import _async_create_lock, _create_lock +from pymongo.lock import _create_lock from pymongo.monitoring import ( ConnectionCheckedInEvent, ConnectionCheckedOutEvent, @@ -69,264 +68,124 @@ from pymongo.uri_parser import parse_uri from pymongo.write_concern import WriteConcern -IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) +_IS_SYNC = True -class BaseListener: - def __init__(self): - self.events = [] - - def reset(self): - self.events = [] - - def add_event(self, event): - self.events.append(event) - - def event_count(self, event_type): - return len(self.events_by_type(event_type)) - - def events_by_type(self, event_type): - """Return the matching events by event class. - - event_type can be a single class or a tuple of classes. - """ - return self.matching(lambda e: isinstance(e, event_type)) - - def matching(self, matcher): - """Return the matching events.""" - return [event for event in self.events[:] if matcher(event)] - - def wait_for_event(self, event, count): - """Wait for a number of events to be published, or fail.""" - wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") - - async def async_wait_for_event(self, event, count): - """Wait for a number of events to be published, or fail.""" - await async_wait_until( - lambda: self.event_count(event) >= count, f"find {count} {event} event(s)" - ) - - -class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): - def connection_created(self, event): - assert isinstance(event, ConnectionCreatedEvent) - self.add_event(event) - - def connection_ready(self, event): - assert isinstance(event, ConnectionReadyEvent) - self.add_event(event) - - def connection_closed(self, event): - assert isinstance(event, ConnectionClosedEvent) - self.add_event(event) - - def connection_check_out_started(self, event): - assert isinstance(event, ConnectionCheckOutStartedEvent) - self.add_event(event) - - def connection_check_out_failed(self, event): - assert isinstance(event, ConnectionCheckOutFailedEvent) - self.add_event(event) - - def connection_checked_out(self, event): - assert isinstance(event, ConnectionCheckedOutEvent) - self.add_event(event) - - def connection_checked_in(self, event): - assert isinstance(event, ConnectionCheckedInEvent) - self.add_event(event) - - def pool_created(self, event): - assert isinstance(event, PoolCreatedEvent) - self.add_event(event) - - def pool_ready(self, event): - assert isinstance(event, PoolReadyEvent) - self.add_event(event) - - def pool_cleared(self, event): - assert isinstance(event, PoolClearedEvent) - self.add_event(event) - - def pool_closed(self, event): - assert isinstance(event, PoolClosedEvent) - self.add_event(event) - - -class EventListener(BaseListener, monitoring.CommandListener): - def __init__(self): - super().__init__() - self.results = defaultdict(list) - - @property - def started_events(self) -> List[monitoring.CommandStartedEvent]: - return self.results["started"] - - @property - def succeeded_events(self) -> List[monitoring.CommandSucceededEvent]: - return self.results["succeeded"] - - @property - def failed_events(self) -> List[monitoring.CommandFailedEvent]: - return self.results["failed"] - - def started(self, event: monitoring.CommandStartedEvent) -> None: - self.started_events.append(event) - self.add_event(event) - - def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: - self.succeeded_events.append(event) - self.add_event(event) - - def failed(self, event: monitoring.CommandFailedEvent) -> None: - self.failed_events.append(event) - self.add_event(event) - - def started_command_names(self) -> List[str]: - """Return list of command names started.""" - return [event.command_name for event in self.started_events] - - def reset(self) -> None: - """Reset the state of this listener.""" - self.results.clear() - super().reset() - - -class TopologyEventListener(monitoring.TopologyListener): - def __init__(self): - self.results = defaultdict(list) - - def closed(self, event): - self.results["closed"].append(event) - - def description_changed(self, event): - self.results["description_changed"].append(event) - - def opened(self, event): - self.results["opened"].append(event) - - def reset(self): - """Reset the state of this listener.""" - self.results.clear() - - -class AllowListEventListener(EventListener): - def __init__(self, *commands): - self.commands = set(commands) - super().__init__() - - def started(self, event): - if event.command_name in self.commands: - super().started(event) - - def succeeded(self, event): - if event.command_name in self.commands: - super().succeeded(event) - - def failed(self, event): - if event.command_name in self.commands: - super().failed(event) - - -class OvertCommandListener(EventListener): - """A CommandListener that ignores sensitive commands.""" - - ignore_list_collections = False - - def started(self, event): - if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super().started(event) - - def succeeded(self, event): - if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super().succeeded(event) - - def failed(self, event): - if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super().failed(event) - - -class _ServerEventListener: - """Listens to all events.""" +def get_pool(client): + """Get the standalone, primary, or mongos pool.""" + topology = client._get_topology() + server = topology._select_server(writable_server_selector, _Op.TEST) + return server.pool - def __init__(self): - self.results = [] - def opened(self, event): - self.results.append(event) +def get_pools(client): + """Get all pools.""" + return [ + server.pool + for server in (client._get_topology()).select_servers(any_server_selector, _Op.TEST) + ] - def description_changed(self, event): - self.results.append(event) - def closed(self, event): - self.results.append(event) +def wait_until(predicate, success_description, timeout=10): + """Wait up to 10 seconds (by default) for predicate to be true. - def matching(self, matcher): - """Return the matching events.""" - results = self.results[:] - return [event for event in results if matcher(event)] + E.g.: - def reset(self): - self.results = [] + wait_until(lambda: client.primary == ('a', 1), + 'connect to the primary') + If the lambda-expression isn't true after 10 seconds, we raise + AssertionError("Didn't ever connect to the primary"). -class ServerEventListener(_ServerEventListener, monitoring.ServerListener): - """Listens to Server events.""" + Returns the predicate's first true value. + """ + start = time.time() + interval = min(float(timeout) / 100, 0.1) + while True: + if iscoroutinefunction(predicate): + retval = predicate() + else: + retval = predicate() + if retval: + return retval + if time.time() - start > timeout: + raise AssertionError("Didn't ever %s" % success_description) -class ServerAndTopologyEventListener( # type: ignore[misc] - ServerEventListener, monitoring.TopologyListener -): - """Listens to Server and Topology events.""" + time.sleep(interval) -class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener): - """Listens to only server heartbeat events.""" +def is_mongos(client): + res = client.admin.command(HelloCompat.LEGACY_CMD) + return res.get("msg", "") == "isdbgrid" - def started(self, event): - self.add_event(event) - def succeeded(self, event): - self.add_event(event) +def ensure_all_connected(client: MongoClient) -> None: + """Ensure that the client's connection pool has socket connections to all + members of a replica set. Raises ConfigurationError when called with a + non-replica set client. - def failed(self, event): - self.add_event(event) + Depending on the use-case, the caller may need to clear any event listeners + that are configured on the client. + """ + hello: dict = client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" not in hello: + raise ConfigurationError("cluster is not a replica set") + target_host_list = set(hello["hosts"] + hello.get("passives", [])) + connected_host_list = {hello["me"]} -class HeartbeatEventsListListener(HeartbeatEventListener): - """Listens to only server heartbeat events and publishes them to a provided list.""" + # Run hello until we have connected to each host at least once. + def discover(): + i = 0 + while i < 100 and connected_host_list != target_host_list: + hello: dict = client.admin.command( + HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY + ) + connected_host_list.update([hello["me"]]) + i += 1 + return connected_host_list - def __init__(self, events): - super().__init__() - self.event_list = events + try: - def started(self, event): - self.add_event(event) - self.event_list.append("serverHeartbeatStartedEvent") + def predicate(): + return target_host_list == discover() - def succeeded(self, event): - self.add_event(event) - self.event_list.append("serverHeartbeatSucceededEvent") + wait_until(predicate, "connected to all hosts") + except AssertionError as exc: + raise AssertionError( + f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" + ) - def failed(self, event): - self.add_event(event) - self.event_list.append("serverHeartbeatFailedEvent") +def assertRaisesExactly(cls, fn, *args, **kwargs): + """ + Unlike the standard assertRaises, this checks that a function raises a + specific class of exception, and not a subclass. E.g., check that + MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. + """ + try: + fn(*args, **kwargs) + except Exception as e: + assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" + else: + raise AssertionError("%s not raised" % cls) -class AsyncMockConnection: - def __init__(self): - self.cancel_context = _CancellationContext() - self.more_to_come = False - self.id = random.randint(0, 100) - def close_conn(self, reason): - pass +def set_fail_point(client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + client.admin.command(cmd) - def __aenter__(self): - return self - def __aexit__(self, exc_type, exc_val, exc_tb): - pass +def joinall(tasks): + """Join threads with a 5-minute timeout, assert joins succeeded""" + if _IS_SYNC: + for t in tasks: + t.join(300) + assert not t.is_alive(), "Thread %s hung" % t + else: + asyncio.wait([t.task for t in tasks if t is not None], timeout=300) class MockConnection: @@ -345,47 +204,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass -class AsyncMockPool: - def __init__(self, address, options, handshake=True, client_id=None): - self.gen = _PoolGeneration() - self._lock = _async_create_lock() - self.opts = options - self.operation_count = 0 - self.conns = [] - - def stale_generation(self, gen, service_id): - return self.gen.stale(gen, service_id) - - @contextlib.asynccontextmanager - async def checkout(self, handler=None): - yield AsyncMockConnection() - - async def checkin(self, *args, **kwargs): - pass - - async def _reset(self, service_id=None): - async with self._lock: - self.gen.inc(service_id) - - async def ready(self): - pass - - async def reset(self, service_id=None, interrupt_connections=False): - await self._reset() - - async def reset_without_pause(self): - await self._reset() - - async def close(self): - await self._reset() - - async def update_is_writable(self, is_writable): - pass - - async def remove_stale_sockets(self, *args, **kwargs): - pass - - class MockPool: def __init__(self, address, options, handshake=True, client_id=None): self.gen = _PoolGeneration() @@ -397,8 +215,9 @@ def __init__(self, address, options, handshake=True, client_id=None): def stale_generation(self, gen, service_id): return self.gen.stale(gen, service_id) + @contextlib.contextmanager def checkout(self, handler=None): - return MockConnection() + yield MockConnection() def checkin(self, *args, **kwargs): pass @@ -426,39 +245,6 @@ def remove_stale_sockets(self, *args, **kwargs): pass -class ScenarioDict(dict): - """Dict that returns {} for any unknown key, recursively.""" - - def __init__(self, data): - def convert(v): - if isinstance(v, abc.Mapping): - return ScenarioDict(v) - if isinstance(v, (str, bytes)): - return v - if isinstance(v, abc.Sequence): - return [convert(item) for item in v] - return v - - dict.__init__(self, [(k, convert(v)) for k, v in data.items()]) - - def __getitem__(self, item): - try: - return dict.__getitem__(self, item) - except KeyError: - # Unlike a defaultdict, don't set the key, just return a dict. - return ScenarioDict({}) - - -class CompareType: - """Class that compares equal to any object of the given type(s).""" - - def __init__(self, types): - self.types = types - - def __eq__(self, other): - return isinstance(other, self.types) - - class FunctionCallRecorder: """Utility class to wrap a callable and record its invocations.""" @@ -468,7 +254,10 @@ def __init__(self, function): def __call__(self, *args, **kwargs): self._call_list.append((args, kwargs)) - return self._function(*args, **kwargs) + if iscoroutinefunction(self._function): + return self._function(*args, **kwargs) + else: + return self._function(*args, **kwargs) def reset(self): """Wipes the call list.""" @@ -482,599 +271,3 @@ def call_list(self): def call_count(self): """Returns the number of times the function has been called.""" return len(self._call_list) - - -def ensure_all_connected(client: MongoClient) -> None: - """Ensure that the client's connection pool has socket connections to all - members of a replica set. Raises ConfigurationError when called with a - non-replica set client. - - Depending on the use-case, the caller may need to clear any event listeners - that are configured on the client. - """ - hello: dict = client.admin.command(HelloCompat.LEGACY_CMD) - if "setName" not in hello: - raise ConfigurationError("cluster is not a replica set") - - target_host_list = set(hello["hosts"] + hello.get("passives", [])) - connected_host_list = {hello["me"]} - - # Run hello until we have connected to each host at least once. - def discover(): - i = 0 - while i < 100 and connected_host_list != target_host_list: - hello: dict = client.admin.command( - HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY - ) - connected_host_list.update([hello["me"]]) - i += 1 - return connected_host_list - - try: - wait_until(lambda: target_host_list == discover(), "connected to all hosts") - except AssertionError as exc: - raise AssertionError( - f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" - ) - - -async def async_ensure_all_connected(client: AsyncMongoClient) -> None: - """Ensure that the client's connection pool has socket connections to all - members of a replica set. Raises ConfigurationError when called with a - non-replica set client. - - Depending on the use-case, the caller may need to clear any event listeners - that are configured on the client. - """ - hello: dict = await client.admin.command(HelloCompat.LEGACY_CMD) - if "setName" not in hello: - raise ConfigurationError("cluster is not a replica set") - - target_host_list = set(hello["hosts"] + hello.get("passives", [])) - connected_host_list = {hello["me"]} - - # Run hello until we have connected to each host at least once. - async def discover(): - i = 0 - while i < 100 and connected_host_list != target_host_list: - hello: dict = await client.admin.command( - HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY - ) - connected_host_list.update([hello["me"]]) - i += 1 - return connected_host_list - - try: - - async def predicate(): - return target_host_list == await discover() - - await async_wait_until(predicate, "connected to all hosts") - except AssertionError as exc: - raise AssertionError( - f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" - ) - - -def one(s): - """Get one element of a set""" - return next(iter(s)) - - -def oid_generated_on_process(oid): - """Makes a determination as to whether the given ObjectId was generated - by the current process, based on the 5-byte random number in the ObjectId. - """ - return ObjectId._random() == oid.binary[4:9] - - -def delay(sec): - return """function() { sleep(%f * 1000); return true; }""" % sec - - -def get_command_line(client): - command_line = client.admin.command("getCmdLineOpts") - assert command_line["ok"] == 1, "getCmdLineOpts() failed" - return command_line - - -def camel_to_snake(camel): - # Regex to convert CamelCase to snake_case. - snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() - - -def camel_to_upper_camel(camel): - return camel[0].upper() + camel[1:] - - -def camel_to_snake_args(arguments): - for arg_name in list(arguments): - c2s = camel_to_snake(arg_name) - arguments[c2s] = arguments.pop(arg_name) - return arguments - - -def snake_to_camel(snake): - # Regex to convert snake_case to lowerCamelCase. - return re.sub(r"_([a-z])", lambda m: m.group(1).upper(), snake) - - -def parse_collection_options(opts): - if "readPreference" in opts: - opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) - - if "writeConcern" in opts: - opts["write_concern"] = WriteConcern(**dict(opts.pop("writeConcern"))) - - if "readConcern" in opts: - opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) - - if "timeoutMS" in opts: - opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 - return opts - - -def server_started_with_option(client, cmdline_opt, config_opt): - """Check if the server was started with a particular option. - - :Parameters: - - `cmdline_opt`: The command line option (i.e. --nojournal) - - `config_opt`: The config file option (i.e. nojournal) - """ - command_line = get_command_line(client) - if "parsed" in command_line: - parsed = command_line["parsed"] - if config_opt in parsed: - return parsed[config_opt] - argv = command_line["argv"] - return cmdline_opt in argv - - -def server_started_with_auth(client): - try: - command_line = get_command_line(client) - except OperationFailure as e: - assert e.details is not None - msg = e.details.get("errmsg", "") - if e.code == 13 or "unauthorized" in msg or "login" in msg: - # Unauthorized. - return True - raise - - # MongoDB >= 2.0 - if "parsed" in command_line: - parsed = command_line["parsed"] - # MongoDB >= 2.6 - if "security" in parsed: - security = parsed["security"] - # >= rc3 - if "authorization" in security: - return security["authorization"] == "enabled" - # < rc3 - return security.get("auth", False) or bool(security.get("keyFile")) - return parsed.get("auth", False) or bool(parsed.get("keyFile")) - # Legacy - argv = command_line["argv"] - return "--auth" in argv or "--keyFile" in argv - - -def joinall(threads): - """Join threads with a 5-minute timeout, assert joins succeeded""" - for t in threads: - t.join(300) - assert not t.is_alive(), "Thread %s hung" % t - - -async def async_joinall(tasks): - """Join threads with a 5-minute timeout, assert joins succeeded""" - await asyncio.wait([t.task for t in tasks if t is not None], timeout=300) - - -def wait_until(predicate, success_description, timeout=10): - """Wait up to 10 seconds (by default) for predicate to be true. - - E.g.: - - wait_until(lambda: client.primary == ('a', 1), - 'connect to the primary') - - If the lambda-expression isn't true after 10 seconds, we raise - AssertionError("Didn't ever connect to the primary"). - - Returns the predicate's first true value. - """ - start = time.time() - interval = min(float(timeout) / 100, 0.1) - while True: - retval = predicate() - if retval: - return retval - - if time.time() - start > timeout: - raise AssertionError("Didn't ever %s" % success_description) - - time.sleep(interval) - - -async def async_wait_until(predicate, success_description, timeout=10): - """Wait up to 10 seconds (by default) for predicate to be true. - - E.g.: - - wait_until(lambda: client.primary == ('a', 1), - 'connect to the primary') - - If the lambda-expression isn't true after 10 seconds, we raise - AssertionError("Didn't ever connect to the primary"). - - Returns the predicate's first true value. - """ - start = time.time() - interval = min(float(timeout) / 100, 0.1) - while True: - if iscoroutinefunction(predicate): - retval = await predicate() - else: - retval = predicate() - if retval: - return retval - - if time.time() - start > timeout: - raise AssertionError("Didn't ever %s" % success_description) - - await asyncio.sleep(interval) - - -def is_mongos(client): - res = client.admin.command(HelloCompat.LEGACY_CMD) - return res.get("msg", "") == "isdbgrid" - - -async def async_is_mongos(client): - res = await client.admin.command(HelloCompat.LEGACY_CMD) - return res.get("msg", "") == "isdbgrid" - - -def assertRaisesExactly(cls, fn, *args, **kwargs): - """ - Unlike the standard assertRaises, this checks that a function raises a - specific class of exception, and not a subclass. E.g., check that - MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. - """ - try: - fn(*args, **kwargs) - except Exception as e: - assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" - else: - raise AssertionError("%s not raised" % cls) - - -async def asyncAssertRaisesExactly(cls, fn, *args, **kwargs): - """ - Unlike the standard assertRaises, this checks that a function raises a - specific class of exception, and not a subclass. E.g., check that - MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. - """ - try: - await fn(*args, **kwargs) - except Exception as e: - assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" - else: - raise AssertionError("%s not raised" % cls) - - -@contextlib.contextmanager -def _ignore_deprecations(): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - yield - - -def ignore_deprecations(wrapped=None): - """A context manager or a decorator.""" - if wrapped: - if iscoroutinefunction(wrapped): - - @functools.wraps(wrapped) - async def wrapper(*args, **kwargs): - with _ignore_deprecations(): - return await wrapped(*args, **kwargs) - else: - - @functools.wraps(wrapped) - def wrapper(*args, **kwargs): - with _ignore_deprecations(): - return wrapped(*args, **kwargs) - - return wrapper - - else: - return _ignore_deprecations() - - -class DeprecationFilter: - def __init__(self, action="ignore"): - """Start filtering deprecations.""" - self.warn_context = warnings.catch_warnings() - self.warn_context.__enter__() - warnings.simplefilter(action, DeprecationWarning) - - def stop(self): - """Stop filtering deprecations.""" - self.warn_context.__exit__() # type: ignore - self.warn_context = None # type: ignore - - -def get_pool(client): - """Get the standalone, primary, or mongos pool.""" - topology = client._get_topology() - server = topology._select_server(writable_server_selector, _Op.TEST) - return server.pool - - -async def async_get_pool(client): - """Get the standalone, primary, or mongos pool.""" - topology = await client._get_topology() - server = await topology._select_server(writable_server_selector, _Op.TEST) - return server.pool - - -def get_pools(client): - """Get all pools.""" - return [ - server.pool - for server in client._get_topology().select_servers(any_server_selector, _Op.TEST) - ] - - -async def async_get_pools(client): - """Get all pools.""" - return [ - server.pool - for server in await (await client._get_topology()).select_servers( - any_server_selector, _Op.TEST - ) - ] - - -# Constants for run_threads and lazy_client_trial. -NTRIALS = 5 -NTHREADS = 10 - - -def run_threads(collection, target): - """Run a target function in many threads. - - target is a function taking a Collection and an integer. - """ - threads = [] - for i in range(NTHREADS): - bound_target = partial(target, collection, i) - threads.append(threading.Thread(target=bound_target)) - - for t in threads: - t.start() - - for t in threads: - t.join(60) - assert not t.is_alive() - - -@contextlib.contextmanager -def frequent_thread_switches(): - """Make concurrency bugs more likely to manifest.""" - interval = sys.getswitchinterval() - sys.setswitchinterval(1e-6) - - try: - yield - finally: - sys.setswitchinterval(interval) - - -def lazy_client_trial(reset, target, test, get_client): - """Test concurrent operations on a lazily-connecting client. - - `reset` takes a collection and resets it for the next trial. - - `target` takes a lazily-connecting collection and an index from - 0 to NTHREADS, and performs some operation, e.g. an insert. - - `test` takes the lazily-connecting collection and asserts a - post-condition to prove `target` succeeded. - """ - collection = client_context.client.pymongo_test.test - - with frequent_thread_switches(): - for _i in range(NTRIALS): - reset(collection) - lazy_client = get_client() - lazy_collection = lazy_client.pymongo_test.test - run_threads(lazy_collection, target) - test(lazy_collection) - - -def gevent_monkey_patched(): - """Check if gevent's monkey patching is active.""" - try: - import socket - - import gevent.socket # type:ignore[import] - - return socket.socket is gevent.socket.socket - except ImportError: - return False - - -def eventlet_monkey_patched(): - """Check if eventlet's monkey patching is active.""" - import threading - - return threading.current_thread.__module__ == "eventlet.green.threading" - - -def is_greenthread_patched(): - return gevent_monkey_patched() or eventlet_monkey_patched() - - -def parse_read_preference(pref): - # Make first letter lowercase to match read_pref's modes. - mode_string = pref.get("mode", "primary") - mode_string = mode_string[:1].lower() + mode_string[1:] - mode = read_preferences.read_pref_mode_from_name(mode_string) - max_staleness = pref.get("maxStalenessSeconds", -1) - tag_sets = pref.get("tagSets") or pref.get("tag_sets") - return read_preferences.make_read_preference( - mode, tag_sets=tag_sets, max_staleness=max_staleness - ) - - -def server_name_to_type(name): - """Convert a ServerType name to the corresponding value. For SDAM tests.""" - # Special case, some tests in the spec include the PossiblePrimary - # type, but only single-threaded drivers need that type. We call - # possible primaries Unknown. - if name == "PossiblePrimary": - return SERVER_TYPE.Unknown - return getattr(SERVER_TYPE, name) - - -def cat_files(dest, *sources): - """Cat multiple files into dest.""" - with open(dest, "wb") as fdst: - for src in sources: - with open(src, "rb") as fsrc: - shutil.copyfileobj(fsrc, fdst) - - -@contextlib.contextmanager -def assertion_context(msg): - """A context manager that adds info to an assertion failure.""" - try: - yield - except AssertionError as exc: - raise AssertionError(f"{msg}: {exc}") - - -def parse_spec_options(opts): - if "readPreference" in opts: - opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) - - if "writeConcern" in opts: - w_opts = opts.pop("writeConcern") - if "journal" in w_opts: - w_opts["j"] = w_opts.pop("journal") - if "wtimeoutMS" in w_opts: - w_opts["wtimeout"] = w_opts.pop("wtimeoutMS") - opts["write_concern"] = WriteConcern(**dict(w_opts)) - - if "readConcern" in opts: - opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) - - if "timeoutMS" in opts: - assert isinstance(opts["timeoutMS"], int) - opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 - - if "maxTimeMS" in opts: - opts["max_time_ms"] = opts.pop("maxTimeMS") - - if "maxCommitTimeMS" in opts: - opts["max_commit_time_ms"] = opts.pop("maxCommitTimeMS") - - return dict(opts) - - -def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callback): - for arg_name in list(arguments): - c2s = camel_to_snake(arg_name) - # Named "key" instead not fieldName. - if arg_name == "fieldName": - arguments["key"] = arguments.pop(arg_name) - # Aggregate uses "batchSize", while find uses batch_size. - elif (arg_name == "batchSize" or arg_name == "allowDiskUse") and opname == "aggregate": - continue - elif arg_name == "timeoutMode": - raise unittest.SkipTest("PyMongo does not support timeoutMode") - # Requires boolean returnDocument. - elif arg_name == "returnDocument": - arguments[c2s] = getattr(ReturnDocument, arguments.pop(arg_name).upper()) - elif "bulk_write" in opname and (c2s == "requests" or c2s == "models"): - # Parse each request into a bulk write model. - requests = [] - for request in arguments[c2s]: - if "name" in request: - # CRUD v2 format - bulk_model = camel_to_upper_camel(request["name"]) - bulk_class = getattr(operations, bulk_model) - bulk_arguments = camel_to_snake_args(request["arguments"]) - else: - # Unified test format - bulk_model, spec = next(iter(request.items())) - bulk_class = getattr(operations, camel_to_upper_camel(bulk_model)) - bulk_arguments = camel_to_snake_args(spec) - requests.append(bulk_class(**dict(bulk_arguments))) - arguments[c2s] = requests - elif arg_name == "session": - arguments["session"] = entity_map[arguments["session"]] - elif opname == "open_download_stream" and arg_name == "id": - arguments["file_id"] = arguments.pop(arg_name) - elif opname not in ("find", "find_one") and c2s == "max_time_ms": - # find is the only method that accepts snake_case max_time_ms. - # All other methods take kwargs which must use the server's - # camelCase maxTimeMS. See PYTHON-1855. - arguments["maxTimeMS"] = arguments.pop("max_time_ms") - elif opname == "with_transaction" and arg_name == "callback": - if "operations" in arguments[arg_name]: - # CRUD v2 format - callback_ops = arguments[arg_name]["operations"] - else: - # Unified test format - callback_ops = arguments[arg_name] - arguments["callback"] = lambda _: with_txn_callback(copy.deepcopy(callback_ops)) - elif opname == "drop_collection" and arg_name == "collection": - arguments["name_or_collection"] = arguments.pop(arg_name) - elif opname == "create_collection": - if arg_name == "collection": - arguments["name"] = arguments.pop(arg_name) - arguments["check_exists"] = False - # Any other arguments to create_collection are passed through - # **kwargs. - elif opname == "create_index" and arg_name == "keys": - arguments["keys"] = list(arguments.pop(arg_name).items()) - elif opname == "drop_index" and arg_name == "name": - arguments["index_or_name"] = arguments.pop(arg_name) - elif opname == "rename" and arg_name == "to": - arguments["new_name"] = arguments.pop(arg_name) - elif opname == "rename" and arg_name == "dropTarget": - arguments["dropTarget"] = arguments.pop(arg_name) - elif arg_name == "cursorType": - cursor_type = arguments.pop(arg_name) - if cursor_type == "tailable": - arguments["cursor_type"] = CursorType.TAILABLE - elif cursor_type == "tailableAwait": - arguments["cursor_type"] = CursorType.TAILABLE - else: - raise AssertionError(f"Unsupported cursorType: {cursor_type}") - else: - arguments[c2s] = arguments.pop(arg_name) - - -def set_fail_point(client, command_args): - cmd = SON([("configureFailPoint", "failCommand")]) - cmd.update(command_args) - client.admin.command(cmd) - - -async def async_set_fail_point(client, command_args): - cmd = SON([("configureFailPoint", "failCommand")]) - cmd.update(command_args) - await client.admin.command(cmd) - - -def create_async_event(): - return asyncio.Event() - - -def create_event(): - return threading.Event() diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 9667ea701b..ee6fcfae0c 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -24,8 +24,8 @@ from test import unittest from test.pymongo_mocks import DummyMonitor -from test.utils import MockPool, parse_read_preference -from test.utils_selection_tests_shared import ( +from test.utils_shared import MockPool, parse_read_preference +from test.utils_shared_selection_tests_shared import ( get_addresses, get_topology_type_name, make_server_description, diff --git a/test/utils_shared.py b/test/utils_shared.py new file mode 100644 index 0000000000..8f282484e6 --- /dev/null +++ b/test/utils_shared.py @@ -0,0 +1,661 @@ +# Copyright 2012-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utilities for testing pymongo""" +from __future__ import annotations + +import asyncio +import contextlib +import copy +import functools +import random +import re +import shutil +import sys +import threading +import unittest +import warnings +from asyncio import iscoroutinefunction +from collections import abc, defaultdict +from functools import partial +from test import client_context +from test.asynchronous.utils import async_wait_until +from test.utils import wait_until +from typing import List + +from bson.objectid import ObjectId +from pymongo import monitoring, operations, read_preferences +from pymongo.cursor_shared import CursorType +from pymongo.errors import OperationFailure +from pymongo.helpers_shared import _SENSITIVE_COMMANDS +from pymongo.lock import _async_create_lock, _create_lock +from pymongo.monitoring import ( + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, +) +from pymongo.read_concern import ReadConcern +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration +from pymongo.write_concern import WriteConcern + +IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) + + +class BaseListener: + def __init__(self): + self.events = [] + + def reset(self): + self.events = [] + + def add_event(self, event): + self.events.append(event) + + def event_count(self, event_type): + return len(self.events_by_type(event_type)) + + def events_by_type(self, event_type): + """Return the matching events by event class. + + event_type can be a single class or a tuple of classes. + """ + return self.matching(lambda e: isinstance(e, event_type)) + + def matching(self, matcher): + """Return the matching events.""" + return [event for event in self.events[:] if matcher(event)] + + def wait_for_event(self, event, count): + """Wait for a number of events to be published, or fail.""" + wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") + + async def async_wait_for_event(self, event, count): + """Wait for a number of events to be published, or fail.""" + await async_wait_until( + lambda: self.event_count(event) >= count, f"find {count} {event} event(s)" + ) + + +class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): + def connection_created(self, event): + assert isinstance(event, ConnectionCreatedEvent) + self.add_event(event) + + def connection_ready(self, event): + assert isinstance(event, ConnectionReadyEvent) + self.add_event(event) + + def connection_closed(self, event): + assert isinstance(event, ConnectionClosedEvent) + self.add_event(event) + + def connection_check_out_started(self, event): + assert isinstance(event, ConnectionCheckOutStartedEvent) + self.add_event(event) + + def connection_check_out_failed(self, event): + assert isinstance(event, ConnectionCheckOutFailedEvent) + self.add_event(event) + + def connection_checked_out(self, event): + assert isinstance(event, ConnectionCheckedOutEvent) + self.add_event(event) + + def connection_checked_in(self, event): + assert isinstance(event, ConnectionCheckedInEvent) + self.add_event(event) + + def pool_created(self, event): + assert isinstance(event, PoolCreatedEvent) + self.add_event(event) + + def pool_ready(self, event): + assert isinstance(event, PoolReadyEvent) + self.add_event(event) + + def pool_cleared(self, event): + assert isinstance(event, PoolClearedEvent) + self.add_event(event) + + def pool_closed(self, event): + assert isinstance(event, PoolClosedEvent) + self.add_event(event) + + +class EventListener(BaseListener, monitoring.CommandListener): + def __init__(self): + super().__init__() + self.results = defaultdict(list) + + @property + def started_events(self) -> List[monitoring.CommandStartedEvent]: + return self.results["started"] + + @property + def succeeded_events(self) -> List[monitoring.CommandSucceededEvent]: + return self.results["succeeded"] + + @property + def failed_events(self) -> List[monitoring.CommandFailedEvent]: + return self.results["failed"] + + def started(self, event: monitoring.CommandStartedEvent) -> None: + self.started_events.append(event) + self.add_event(event) + + def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: + self.succeeded_events.append(event) + self.add_event(event) + + def failed(self, event: monitoring.CommandFailedEvent) -> None: + self.failed_events.append(event) + self.add_event(event) + + def started_command_names(self) -> List[str]: + """Return list of command names started.""" + return [event.command_name for event in self.started_events] + + def reset(self) -> None: + """Reset the state of this listener.""" + self.results.clear() + super().reset() + + +class TopologyEventListener(monitoring.TopologyListener): + def __init__(self): + self.results = defaultdict(list) + + def closed(self, event): + self.results["closed"].append(event) + + def description_changed(self, event): + self.results["description_changed"].append(event) + + def opened(self, event): + self.results["opened"].append(event) + + def reset(self): + """Reset the state of this listener.""" + self.results.clear() + + +class AllowListEventListener(EventListener): + def __init__(self, *commands): + self.commands = set(commands) + super().__init__() + + def started(self, event): + if event.command_name in self.commands: + super().started(event) + + def succeeded(self, event): + if event.command_name in self.commands: + super().succeeded(event) + + def failed(self, event): + if event.command_name in self.commands: + super().failed(event) + + +class OvertCommandListener(EventListener): + """A CommandListener that ignores sensitive commands.""" + + ignore_list_collections = False + + def started(self, event): + if event.command_name.lower() not in _SENSITIVE_COMMANDS: + super().started(event) + + def succeeded(self, event): + if event.command_name.lower() not in _SENSITIVE_COMMANDS: + super().succeeded(event) + + def failed(self, event): + if event.command_name.lower() not in _SENSITIVE_COMMANDS: + super().failed(event) + + +class _ServerEventListener: + """Listens to all events.""" + + def __init__(self): + self.results = [] + + def opened(self, event): + self.results.append(event) + + def description_changed(self, event): + self.results.append(event) + + def closed(self, event): + self.results.append(event) + + def matching(self, matcher): + """Return the matching events.""" + results = self.results[:] + return [event for event in results if matcher(event)] + + def reset(self): + self.results = [] + + +class ServerEventListener(_ServerEventListener, monitoring.ServerListener): + """Listens to Server events.""" + + +class ServerAndTopologyEventListener( # type: ignore[misc] + ServerEventListener, monitoring.TopologyListener +): + """Listens to Server and Topology events.""" + + +class HeartbeatEventListener(BaseListener, monitoring.ServerHeartbeatListener): + """Listens to only server heartbeat events.""" + + def started(self, event): + self.add_event(event) + + def succeeded(self, event): + self.add_event(event) + + def failed(self, event): + self.add_event(event) + + +class HeartbeatEventsListListener(HeartbeatEventListener): + """Listens to only server heartbeat events and publishes them to a provided list.""" + + def __init__(self, events): + super().__init__() + self.event_list = events + + def started(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatStartedEvent") + + def succeeded(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatSucceededEvent") + + def failed(self, event): + self.add_event(event) + self.event_list.append("serverHeartbeatFailedEvent") + + +class ScenarioDict(dict): + """Dict that returns {} for any unknown key, recursively.""" + + def __init__(self, data): + def convert(v): + if isinstance(v, abc.Mapping): + return ScenarioDict(v) + if isinstance(v, (str, bytes)): + return v + if isinstance(v, abc.Sequence): + return [convert(item) for item in v] + return v + + dict.__init__(self, [(k, convert(v)) for k, v in data.items()]) + + def __getitem__(self, item): + try: + return dict.__getitem__(self, item) + except KeyError: + # Unlike a defaultdict, don't set the key, just return a dict. + return ScenarioDict({}) + + +class CompareType: + """Class that compares equal to any object of the given type(s).""" + + def __init__(self, types): + self.types = types + + def __eq__(self, other): + return isinstance(other, self.types) + + +def one(s): + """Get one element of a set""" + return next(iter(s)) + + +def oid_generated_on_process(oid): + """Makes a determination as to whether the given ObjectId was generated + by the current process, based on the 5-byte random number in the ObjectId. + """ + return ObjectId._random() == oid.binary[4:9] + + +def delay(sec): + return """function() { sleep(%f * 1000); return true; }""" % sec + + +def camel_to_snake(camel): + # Regex to convert CamelCase to snake_case. + snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() + + +def camel_to_upper_camel(camel): + return camel[0].upper() + camel[1:] + + +def camel_to_snake_args(arguments): + for arg_name in list(arguments): + c2s = camel_to_snake(arg_name) + arguments[c2s] = arguments.pop(arg_name) + return arguments + + +def snake_to_camel(snake): + # Regex to convert snake_case to lowerCamelCase. + return re.sub(r"_([a-z])", lambda m: m.group(1).upper(), snake) + + +def parse_collection_options(opts): + if "readPreference" in opts: + opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) + + if "writeConcern" in opts: + opts["write_concern"] = WriteConcern(**dict(opts.pop("writeConcern"))) + + if "readConcern" in opts: + opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) + + if "timeoutMS" in opts: + opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 + return opts + + +@contextlib.contextmanager +def _ignore_deprecations(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + yield + + +def ignore_deprecations(wrapped=None): + """A context manager or a decorator.""" + if wrapped: + if iscoroutinefunction(wrapped): + + @functools.wraps(wrapped) + async def wrapper(*args, **kwargs): + with _ignore_deprecations(): + return await wrapped(*args, **kwargs) + else: + + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + with _ignore_deprecations(): + return wrapped(*args, **kwargs) + + return wrapper + + else: + return _ignore_deprecations() + + +class DeprecationFilter: + def __init__(self, action="ignore"): + """Start filtering deprecations.""" + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() + warnings.simplefilter(action, DeprecationWarning) + + def stop(self): + """Stop filtering deprecations.""" + self.warn_context.__exit__() # type: ignore + self.warn_context = None # type: ignore + + +# Constants for run_threads and lazy_client_trial. +NTRIALS = 5 +NTHREADS = 10 + + +def run_threads(collection, target): + """Run a target function in many threads. + + target is a function taking a Collection and an integer. + """ + threads = [] + for i in range(NTHREADS): + bound_target = partial(target, collection, i) + threads.append(threading.Thread(target=bound_target)) + + for t in threads: + t.start() + + for t in threads: + t.join(60) + assert not t.is_alive() + + +@contextlib.contextmanager +def frequent_thread_switches(): + """Make concurrency bugs more likely to manifest.""" + interval = sys.getswitchinterval() + sys.setswitchinterval(1e-6) + + try: + yield + finally: + sys.setswitchinterval(interval) + + +def lazy_client_trial(reset, target, test, get_client): + """Test concurrent operations on a lazily-connecting client. + + `reset` takes a collection and resets it for the next trial. + + `target` takes a lazily-connecting collection and an index from + 0 to NTHREADS, and performs some operation, e.g. an insert. + + `test` takes the lazily-connecting collection and asserts a + post-condition to prove `target` succeeded. + """ + collection = client_context.client.pymongo_test.test + + with frequent_thread_switches(): + for _i in range(NTRIALS): + reset(collection) + lazy_client = get_client() + lazy_collection = lazy_client.pymongo_test.test + run_threads(lazy_collection, target) + test(lazy_collection) + + +def gevent_monkey_patched(): + """Check if gevent's monkey patching is active.""" + try: + import socket + + import gevent.socket # type:ignore[import] + + return socket.socket is gevent.socket.socket + except ImportError: + return False + + +def eventlet_monkey_patched(): + """Check if eventlet's monkey patching is active.""" + import threading + + return threading.current_thread.__module__ == "eventlet.green.threading" + + +def is_greenthread_patched(): + return gevent_monkey_patched() or eventlet_monkey_patched() + + +def parse_read_preference(pref): + # Make first letter lowercase to match read_pref's modes. + mode_string = pref.get("mode", "primary") + mode_string = mode_string[:1].lower() + mode_string[1:] + mode = read_preferences.read_pref_mode_from_name(mode_string) + max_staleness = pref.get("maxStalenessSeconds", -1) + tag_sets = pref.get("tagSets") or pref.get("tag_sets") + return read_preferences.make_read_preference( + mode, tag_sets=tag_sets, max_staleness=max_staleness + ) + + +def server_name_to_type(name): + """Convert a ServerType name to the corresponding value. For SDAM tests.""" + # Special case, some tests in the spec include the PossiblePrimary + # type, but only single-threaded drivers need that type. We call + # possible primaries Unknown. + if name == "PossiblePrimary": + return SERVER_TYPE.Unknown + return getattr(SERVER_TYPE, name) + + +def cat_files(dest, *sources): + """Cat multiple files into dest.""" + with open(dest, "wb") as fdst: + for src in sources: + with open(src, "rb") as fsrc: + shutil.copyfileobj(fsrc, fdst) + + +@contextlib.contextmanager +def assertion_context(msg): + """A context manager that adds info to an assertion failure.""" + try: + yield + except AssertionError as exc: + raise AssertionError(f"{msg}: {exc}") + + +def parse_spec_options(opts): + if "readPreference" in opts: + opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) + + if "writeConcern" in opts: + w_opts = opts.pop("writeConcern") + if "journal" in w_opts: + w_opts["j"] = w_opts.pop("journal") + if "wtimeoutMS" in w_opts: + w_opts["wtimeout"] = w_opts.pop("wtimeoutMS") + opts["write_concern"] = WriteConcern(**dict(w_opts)) + + if "readConcern" in opts: + opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) + + if "timeoutMS" in opts: + assert isinstance(opts["timeoutMS"], int) + opts["timeout"] = int(opts.pop("timeoutMS")) / 1000.0 + + if "maxTimeMS" in opts: + opts["max_time_ms"] = opts.pop("maxTimeMS") + + if "maxCommitTimeMS" in opts: + opts["max_commit_time_ms"] = opts.pop("maxCommitTimeMS") + + return dict(opts) + + +def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callback): + for arg_name in list(arguments): + c2s = camel_to_snake(arg_name) + # Named "key" instead not fieldName. + if arg_name == "fieldName": + arguments["key"] = arguments.pop(arg_name) + # Aggregate uses "batchSize", while find uses batch_size. + elif (arg_name == "batchSize" or arg_name == "allowDiskUse") and opname == "aggregate": + continue + elif arg_name == "timeoutMode": + raise unittest.SkipTest("PyMongo does not support timeoutMode") + # Requires boolean returnDocument. + elif arg_name == "returnDocument": + arguments[c2s] = getattr(ReturnDocument, arguments.pop(arg_name).upper()) + elif "bulk_write" in opname and (c2s == "requests" or c2s == "models"): + # Parse each request into a bulk write model. + requests = [] + for request in arguments[c2s]: + if "name" in request: + # CRUD v2 format + bulk_model = camel_to_upper_camel(request["name"]) + bulk_class = getattr(operations, bulk_model) + bulk_arguments = camel_to_snake_args(request["arguments"]) + else: + # Unified test format + bulk_model, spec = next(iter(request.items())) + bulk_class = getattr(operations, camel_to_upper_camel(bulk_model)) + bulk_arguments = camel_to_snake_args(spec) + requests.append(bulk_class(**dict(bulk_arguments))) + arguments[c2s] = requests + elif arg_name == "session": + arguments["session"] = entity_map[arguments["session"]] + elif opname == "open_download_stream" and arg_name == "id": + arguments["file_id"] = arguments.pop(arg_name) + elif opname not in ("find", "find_one") and c2s == "max_time_ms": + # find is the only method that accepts snake_case max_time_ms. + # All other methods take kwargs which must use the server's + # camelCase maxTimeMS. See PYTHON-1855. + arguments["maxTimeMS"] = arguments.pop("max_time_ms") + elif opname == "with_transaction" and arg_name == "callback": + if "operations" in arguments[arg_name]: + # CRUD v2 format + callback_ops = arguments[arg_name]["operations"] + else: + # Unified test format + callback_ops = arguments[arg_name] + arguments["callback"] = lambda _: with_txn_callback(copy.deepcopy(callback_ops)) + elif opname == "drop_collection" and arg_name == "collection": + arguments["name_or_collection"] = arguments.pop(arg_name) + elif opname == "create_collection": + if arg_name == "collection": + arguments["name"] = arguments.pop(arg_name) + arguments["check_exists"] = False + # Any other arguments to create_collection are passed through + # **kwargs. + elif opname == "create_index" and arg_name == "keys": + arguments["keys"] = list(arguments.pop(arg_name).items()) + elif opname == "drop_index" and arg_name == "name": + arguments["index_or_name"] = arguments.pop(arg_name) + elif opname == "rename" and arg_name == "to": + arguments["new_name"] = arguments.pop(arg_name) + elif opname == "rename" and arg_name == "dropTarget": + arguments["dropTarget"] = arguments.pop(arg_name) + elif arg_name == "cursorType": + cursor_type = arguments.pop(arg_name) + if cursor_type == "tailable": + arguments["cursor_type"] = CursorType.TAILABLE + elif cursor_type == "tailableAwait": + arguments["cursor_type"] = CursorType.TAILABLE + else: + raise AssertionError(f"Unsupported cursorType: {cursor_type}") + else: + arguments[c2s] = arguments.pop(arg_name) + + +def create_async_event(): + return asyncio.Event() + + +def create_event(): + return threading.Event() diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 98949431d0..11395f8159 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -23,7 +23,7 @@ from collections import abc from test import IntegrationTest, client_context, client_knobs from test.helpers import ConcurrentRunner -from test.utils import ( +from test.utils_shared import ( CMAPListener, CompareType, EventListener, diff --git a/tools/synchro.py b/tools/synchro.py index 519ebb102b..ee925ab790 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -125,6 +125,7 @@ "StopAsyncIteration": "StopIteration", "create_async_event": "create_event", "async_joinall": "joinall", + "AsyncFunctionCallRecorder": "FunctionCallRecorder", } docstring_replacements: dict[tuple[str, str], str] = { @@ -254,6 +255,7 @@ def async_only_test(f: str) -> bool: "test_versioned_api_integration.py", "unified_format.py", "utils_selection_tests.py", + "utils.py", ] From 4726c3c6a882b0036862c9d7bfe43cf49f111442 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 18 Feb 2025 10:16:29 -0800 Subject: [PATCH 2/3] Fix failures --- test/asynchronous/test_async_cancellation.py | 3 +- test/asynchronous/test_client.py | 4 +- test/asynchronous/test_collection.py | 3 +- .../test_connection_monitoring.py | 6 +- ...nnections_survive_primary_stepdown_spec.py | 2 +- test/asynchronous/test_gridfs.py | 3 +- test/asynchronous/test_gridfs_bucket.py | 3 +- .../asynchronous/test_heartbeat_monitoring.py | 3 +- test/asynchronous/test_load_balancer.py | 2 +- test/asynchronous/test_max_staleness.py | 2 +- test/asynchronous/test_pooling.py | 3 +- test/asynchronous/test_retryable_reads.py | 2 +- test/asynchronous/test_retryable_writes.py | 2 +- test/asynchronous/test_server_selection.py | 7 +- .../test_server_selection_in_window.py | 1 - test/asynchronous/test_srv_polling.py | 5 +- test/asynchronous/unified_format.py | 2 +- test/asynchronous/utils.py | 74 +------------------ test/asynchronous/utils_selection_tests.py | 5 +- test/test_client.py | 2 +- test/test_collection.py | 3 +- test/test_connection_monitoring.py | 6 +- ...nnections_survive_primary_stepdown_spec.py | 2 +- test/test_discovery_and_monitoring.py | 2 +- test/test_gridfs.py | 3 +- test/test_gridfs_bucket.py | 3 +- test/test_heartbeat_monitoring.py | 3 +- test/test_load_balancer.py | 2 +- test/test_max_staleness.py | 2 +- test/test_pooling.py | 3 +- test/test_retryable_reads.py | 2 +- test/test_retryable_writes.py | 2 +- test/test_server_selection.py | 3 +- test/test_server_selection_in_window.py | 1 - test/test_srv_polling.py | 3 +- test/test_threads.py | 2 +- test/test_topology.py | 3 +- test/unified_format.py | 2 +- test/utils.py | 69 +---------------- test/utils_selection_tests.py | 5 +- test/utils_shared.py | 28 +++++++ tools/synchro.py | 1 - 42 files changed, 94 insertions(+), 190 deletions(-) diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index d071fae317..83a67a65d6 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -17,7 +17,8 @@ import asyncio import sys -from test.utils_shared import async_get_pool, delay, one +from test.asynchronous.utils import async_get_pool +from test.utils_shared import delay, one sys.path[0:0] = [""] diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 07b4cdee53..3398a6f092 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -61,7 +61,6 @@ ) from test.asynchronous.pymongo_mocks import AsyncMockClient from test.asynchronous.utils import ( - AsyncFunctionCallRecorder, async_get_pool, async_wait_until, asyncAssertRaisesExactly, @@ -70,6 +69,7 @@ from test.utils_shared import ( NTHREADS, CMAPListener, + FunctionCallRecorder, delay, gevent_monkey_patched, is_greenthread_patched, @@ -513,7 +513,7 @@ async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. from pymongo.srv_resolver import _resolve - patched_resolver = AsyncFunctionCallRecorder(_resolve) + patched_resolver = FunctionCallRecorder(_resolve) pymongo.srv_resolver._resolve = patched_resolver def reset_resolver(): diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 6e64c2d668..00ed020d88 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -21,6 +21,7 @@ import sys from codecs import utf_8_decode from collections import defaultdict +from test.asynchronous.utils import async_get_pool, async_is_mongos from typing import Any, Iterable, no_type_check from pymongo.asynchronous.database import AsyncDatabase @@ -37,8 +38,6 @@ IMPOSSIBLE_WRITE_CONCERN, EventListener, OvertCommandListener, - async_get_pool, - async_is_mongos, async_wait_until, ) diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py index a547e0bcd4..687c615e1b 100644 --- a/test/asynchronous/test_connection_monitoring.py +++ b/test/asynchronous/test_connection_monitoring.py @@ -20,17 +20,15 @@ import sys import time from pathlib import Path +from test.asynchronous.utils import async_get_pool, async_get_pools sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest +from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs, unittest from test.asynchronous.pymongo_mocks import DummyMonitor from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerTask from test.utils_shared import ( CMAPListener, - async_client_context, - async_get_pool, - async_get_pools, async_wait_until, camel_to_snake, ) diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index 8754125f23..92c750c4fe 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -16,6 +16,7 @@ from __future__ import annotations import sys +from test.asynchronous.utils import async_ensure_all_connected sys.path[0:0] = [""] @@ -27,7 +28,6 @@ from test.asynchronous.helpers import async_repl_set_step_down from test.utils_shared import ( CMAPListener, - async_ensure_all_connected, ) from bson import SON diff --git a/test/asynchronous/test_gridfs.py b/test/asynchronous/test_gridfs.py index 9582f9eca0..f886601f36 100644 --- a/test/asynchronous/test_gridfs.py +++ b/test/asynchronous/test_gridfs.py @@ -28,7 +28,8 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils_shared import async_joinall, one +from test.asynchronous.utils import async_joinall +from test.utils_shared import one import gridfs from bson.binary import Binary diff --git a/test/asynchronous/test_gridfs_bucket.py b/test/asynchronous/test_gridfs_bucket.py index 009bc7c620..29877ee9c4 100644 --- a/test/asynchronous/test_gridfs_bucket.py +++ b/test/asynchronous/test_gridfs_bucket.py @@ -29,7 +29,8 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.utils_shared import async_joinall, joinall, one +from test.asynchronous.utils import async_joinall +from test.utils_shared import one import gridfs from bson.binary import Binary diff --git a/test/asynchronous/test_heartbeat_monitoring.py b/test/asynchronous/test_heartbeat_monitoring.py index 92e63460dc..aa8a205021 100644 --- a/test/asynchronous/test_heartbeat_monitoring.py +++ b/test/asynchronous/test_heartbeat_monitoring.py @@ -16,11 +16,12 @@ from __future__ import annotations import sys +from test.asynchronous.utils import AsyncMockPool sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest -from test.utils_shared import AsyncMockPool, HeartbeatEventListener, async_wait_until +from test.utils_shared import HeartbeatEventListener, async_wait_until from pymongo.asynchronous.monitor import Monitor from pymongo.errors import ConnectionFailure diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py index 78afa914ed..127fdfd24d 100644 --- a/test/asynchronous/test_load_balancer.py +++ b/test/asynchronous/test_load_balancer.py @@ -23,6 +23,7 @@ import threading from asyncio import Event from test.asynchronous.helpers import ConcurrentRunner, ExceptionCatchingTask +from test.asynchronous.utils import async_get_pool import pytest @@ -31,7 +32,6 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.unified_format import generate_test_classes from test.utils_shared import ( - async_get_pool, async_wait_until, create_async_event, ) diff --git a/test/asynchronous/test_max_staleness.py b/test/asynchronous/test_max_staleness.py index 747d7870fc..b6e15f9158 100644 --- a/test/asynchronous/test_max_staleness.py +++ b/test/asynchronous/test_max_staleness.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncPyMongoTestCase, async_client_context, unittest -from test.utils_shared_selection_tests import create_selection_tests +from test.asynchronous.utils_selection_tests import create_selection_tests from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector diff --git a/test/asynchronous/test_pooling.py b/test/asynchronous/test_pooling.py index 8f787b5328..a79251ca8f 100644 --- a/test/asynchronous/test_pooling.py +++ b/test/asynchronous/test_pooling.py @@ -21,6 +21,7 @@ import socket import sys import time +from test.asynchronous.utils import async_get_pool, async_joinall from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON @@ -33,7 +34,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.helpers import ConcurrentRunner -from test.utils_shared import async_get_pool, async_joinall, delay +from test.utils_shared import delay from pymongo.asynchronous.pool import Pool, PoolOptions from pymongo.socket_checker import SocketChecker diff --git a/test/asynchronous/test_retryable_reads.py b/test/asynchronous/test_retryable_reads.py index a0e46ff65a..10d9e738b4 100644 --- a/test/asynchronous/test_retryable_reads.py +++ b/test/asynchronous/test_retryable_reads.py @@ -19,6 +19,7 @@ import pprint import sys import threading +from test.asynchronous.utils import async_set_fail_point from pymongo.errors import AutoReconnect @@ -34,7 +35,6 @@ from test.utils_shared import ( CMAPListener, OvertCommandListener, - async_set_fail_point, ) from pymongo.monitoring import ( diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py index bbcf8d5ed2..2f6cb2b575 100644 --- a/test/asynchronous/test_retryable_writes.py +++ b/test/asynchronous/test_retryable_writes.py @@ -20,6 +20,7 @@ import pprint import sys import threading +from test.asynchronous.utils import async_set_fail_point sys.path[0:0] = [""] @@ -35,7 +36,6 @@ DeprecationFilter, EventListener, OvertCommandListener, - async_set_fail_point, ) from test.version import Version diff --git a/test/asynchronous/test_server_selection.py b/test/asynchronous/test_server_selection.py index 64f7784569..f98a05ee91 100644 --- a/test/asynchronous/test_server_selection.py +++ b/test/asynchronous/test_server_selection.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.asynchronous.utils import AsyncFunctionCallRecorder, async_wait_until +from test.asynchronous.utils import async_wait_until from test.asynchronous.utils_selection_tests import ( create_selection_tests, get_topology_settings_dict, @@ -41,6 +41,7 @@ make_server_description, ) from test.utils_shared import ( + FunctionCallRecorder, OvertCommandListener, ) @@ -122,7 +123,7 @@ async def test_invalid_server_selector(self): @async_client_context.require_replica_set async def test_selector_called(self): - selector = AsyncFunctionCallRecorder(lambda x: x) + selector = FunctionCallRecorder(lambda x: x) # Client setup. mongo_client = await self.async_rs_or_single_client(server_selector=selector) @@ -175,7 +176,7 @@ async def test_latency_threshold_application(self): @async_client_context.require_replica_set async def test_server_selector_bypassed(self): - selector = AsyncFunctionCallRecorder(lambda x: x) + selector = FunctionCallRecorder(lambda x: x) scenario_def = { "topology_description": { diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index 8151472719..3fe448d4dd 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -26,7 +26,6 @@ from test.utils_shared import ( CMAPListener, OvertCommandListener, - async_get_pool, async_wait_until, ) diff --git a/test/asynchronous/test_srv_polling.py b/test/asynchronous/test_srv_polling.py index 85e39f4658..bf7807eb97 100644 --- a/test/asynchronous/test_srv_polling.py +++ b/test/asynchronous/test_srv_polling.py @@ -18,12 +18,13 @@ import asyncio import sys import time +from test.utils_shared import FunctionCallRecorder from typing import Any sys.path[0:0] = [""] from test.asynchronous import AsyncPyMongoTestCase, client_knobs, unittest -from test.asynchronous.utils import AsyncFunctionCallRecorder, async_wait_until +from test.asynchronous.utils import async_wait_until import pymongo from pymongo import common @@ -69,7 +70,7 @@ def mock_get_hosts_and_min_ttl(resolver, *args): patch_func: Any if self.count_resolver_calls: - patch_func = AsyncFunctionCallRecorder(mock_get_hosts_and_min_ttl) + patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl) else: patch_func = mock_get_hosts_and_min_ttl diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index baca5a9902..cb26b08930 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -35,6 +35,7 @@ client_knobs, unittest, ) +from test.asynchronous.utils import async_get_pool from test.asynchronous.utils_spec_runner import SpecRunnerTask from test.unified_format_shared import ( KMS_TLS_OPTS, @@ -50,7 +51,6 @@ with_metaclass, ) from test.utils_shared import ( - async_get_pool, async_wait_until, camel_to_snake, camel_to_snake_args, diff --git a/test/asynchronous/utils.py b/test/asynchronous/utils.py index f87589bf63..d1615fc0fa 100644 --- a/test/asynchronous/utils.py +++ b/test/asynchronous/utils.py @@ -17,57 +17,19 @@ import asyncio import contextlib -import copy -import functools -import os import random -import re -import shutil -import sys -import threading import time -import unittest -import warnings from asyncio import iscoroutinefunction -from collections import abc, defaultdict -from functools import partial -from test import client_context, db_pwd, db_user -from test.asynchronous import async_client_context -from typing import Any, List - -from bson import json_util -from bson.objectid import ObjectId + from bson.son import SON -from pymongo import AsyncMongoClient, monitoring, operations, read_preferences -from pymongo._asyncio_task import create_task -from pymongo.cursor_shared import CursorType -from pymongo.errors import ConfigurationError, OperationFailure +from pymongo import AsyncMongoClient +from pymongo.errors import ConfigurationError from pymongo.hello import HelloCompat -from pymongo.helpers_shared import _SENSITIVE_COMMANDS -from pymongo.lock import _async_create_lock, _create_lock -from pymongo.monitoring import ( - ConnectionCheckedInEvent, - ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - ConnectionCheckOutStartedEvent, - ConnectionClosedEvent, - ConnectionCreatedEvent, - ConnectionReadyEvent, - PoolClearedEvent, - PoolClosedEvent, - PoolCreatedEvent, - PoolReadyEvent, -) +from pymongo.lock import _async_create_lock from pymongo.operations import _Op -from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import any_server_selector, writable_server_selector -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration -from pymongo.uri_parser import parse_uri -from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -246,31 +208,3 @@ async def update_is_writable(self, is_writable): async def remove_stale_sockets(self, *args, **kwargs): pass - - -class AsyncFunctionCallRecorder: - """Utility class to wrap a callable and record its invocations.""" - - def __init__(self, function): - self._function = function - self._call_list = [] - - async def __call__(self, *args, **kwargs): - self._call_list.append((args, kwargs)) - if iscoroutinefunction(self._function): - return await self._function(*args, **kwargs) - else: - return self._function(*args, **kwargs) - - def reset(self): - """Wipes the call list.""" - self._call_list = [] - - def call_list(self): - """Returns a copy of the call list.""" - return self._call_list[:] - - @property - def call_count(self): - """Returns the number of times the function has been called.""" - return len(self._call_list) diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py index 6a63bfe6e7..d6b92fadb4 100644 --- a/test/asynchronous/utils_selection_tests.py +++ b/test/asynchronous/utils_selection_tests.py @@ -19,17 +19,18 @@ import os import sys from test.asynchronous import AsyncPyMongoTestCase +from test.asynchronous.utils import AsyncMockPool sys.path[0:0] = [""] from test import unittest from test.pymongo_mocks import DummyMonitor -from test.utils_shared import AsyncMockPool, parse_read_preference -from test.utils_shared_selection_tests_shared import ( +from test.utils_selection_tests_shared import ( get_addresses, get_topology_type_name, make_server_description, ) +from test.utils_shared import parse_read_preference from bson import json_util from pymongo.asynchronous.settings import TopologySettings diff --git a/test/test_client.py b/test/test_client.py index 3a1aa4d1b4..6dbb16b626 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -61,7 +61,6 @@ from test.pymongo_mocks import MockClient from test.test_binary import BinaryData from test.utils import ( - FunctionCallRecorder, assertRaisesExactly, get_pool, wait_until, @@ -69,6 +68,7 @@ from test.utils_shared import ( NTHREADS, CMAPListener, + FunctionCallRecorder, delay, gevent_monkey_patched, is_greenthread_patched, diff --git a/test/test_collection.py b/test/test_collection.py index 264f642921..75c11383d0 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -21,6 +21,7 @@ import sys from codecs import utf_8_decode from collections import defaultdict +from test.utils import get_pool, is_mongos from typing import Any, Iterable, no_type_check from pymongo.synchronous.database import Database @@ -37,8 +38,6 @@ IMPOSSIBLE_WRITE_CONCERN, EventListener, OvertCommandListener, - get_pool, - is_mongos, wait_until, ) diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 4e89c69de5..e97a21163e 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -20,17 +20,15 @@ import sys import time from pathlib import Path +from test.utils import get_pool, get_pools sys.path[0:0] = [""] -from test import IntegrationTest, client_knobs, unittest +from test import IntegrationTest, client_context, client_knobs, unittest from test.pymongo_mocks import DummyMonitor from test.utils_shared import ( CMAPListener, camel_to_snake, - client_context, - get_pool, - get_pools, wait_until, ) from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index bb2a1f4ac6..d923a477b5 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -16,6 +16,7 @@ from __future__ import annotations import sys +from test.utils import ensure_all_connected sys.path[0:0] = [""] @@ -27,7 +28,6 @@ from test.helpers import repl_set_step_down from test.utils_shared import ( CMAPListener, - ensure_all_connected, ) from bson import SON diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index c5e6e058aa..477c6a906a 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -19,6 +19,7 @@ import socketserver import sys import threading +from test.utils import get_pool sys.path[0:0] = [""] @@ -31,7 +32,6 @@ HeartbeatEventsListListener, assertion_context, client_context, - get_pool, server_name_to_type, wait_until, ) diff --git a/test/test_gridfs.py b/test/test_gridfs.py index c8b9a7ee1b..75342ee437 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -28,7 +28,8 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils_shared import joinall, one +from test.utils import joinall +from test.utils_shared import one import gridfs from bson.binary import Binary diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 5cb19c557a..d68c9f6ba2 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -29,7 +29,8 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils_shared import joinall, one +from test.utils import joinall +from test.utils_shared import one import gridfs from bson.binary import Binary diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 13516088cd..7864caf6e1 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -16,11 +16,12 @@ from __future__ import annotations import sys +from test.utils import MockPool sys.path[0:0] = [""] from test import IntegrationTest, client_knobs, unittest -from test.utils_shared import HeartbeatEventListener, MockPool, wait_until +from test.utils_shared import HeartbeatEventListener, wait_until from pymongo.errors import ConnectionFailure from pymongo.hello import Hello, HelloCompat diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 41bbe0ea41..d7f1d596cc 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -23,6 +23,7 @@ import threading from asyncio import Event from test.helpers import ConcurrentRunner, ExceptionCatchingTask +from test.utils import get_pool import pytest @@ -32,7 +33,6 @@ from test.unified_format import generate_test_classes from test.utils_shared import ( create_event, - get_pool, wait_until, ) diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 05943e3d64..56e047fd4b 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test import PyMongoTestCase, client_context, unittest -from test.utils_shared_selection_tests import create_selection_tests +from test.utils_selection_tests import create_selection_tests from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector diff --git a/test/test_pooling.py b/test/test_pooling.py index 81f8caf31d..7cb6ab5945 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -21,6 +21,7 @@ import socket import sys import time +from test.utils import get_pool, joinall from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON @@ -33,7 +34,7 @@ from test import IntegrationTest, client_context, unittest from test.helpers import ConcurrentRunner -from test.utils_shared import delay, get_pool, joinall +from test.utils_shared import delay from pymongo.socket_checker import SocketChecker from pymongo.synchronous.pool import Pool, PoolOptions diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index c131b225ca..7ae4c41e70 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -19,6 +19,7 @@ import pprint import sys import threading +from test.utils import set_fail_point from pymongo.errors import AutoReconnect @@ -34,7 +35,6 @@ from test.utils_shared import ( CMAPListener, OvertCommandListener, - set_fail_point, ) from pymongo.monitoring import ( diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 854f6c2a90..b099820a45 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -20,6 +20,7 @@ import pprint import sys import threading +from test.utils import set_fail_point sys.path[0:0] = [""] @@ -35,7 +36,6 @@ DeprecationFilter, EventListener, OvertCommandListener, - set_fail_point, ) from test.version import Version diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 2e8d9f4f87..aec8e2e47a 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import FunctionCallRecorder, wait_until +from test.utils import wait_until from test.utils_selection_tests import ( create_selection_tests, get_topology_settings_dict, @@ -41,6 +41,7 @@ make_server_description, ) from test.utils_shared import ( + FunctionCallRecorder, OvertCommandListener, ) diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 92d8ad89b4..4aad34050c 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -25,7 +25,6 @@ from test.utils_shared import ( CMAPListener, OvertCommandListener, - get_pool, wait_until, ) from test.utils_spec_runner import SpecTestCreator diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 86fad6d90e..6812465074 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -18,12 +18,13 @@ import asyncio import sys import time +from test.utils_shared import FunctionCallRecorder from typing import Any sys.path[0:0] = [""] from test import PyMongoTestCase, client_knobs, unittest -from test.utils import FunctionCallRecorder, wait_until +from test.utils import wait_until import pymongo from pymongo import common diff --git a/test/test_threads.py b/test/test_threads.py index 7c128720e1..3e469e28fe 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -17,7 +17,7 @@ import threading from test import IntegrationTest, client_context, unittest -from test.utils_shared import joinall +from test.utils import joinall @client_context.require_connection diff --git a/test/test_topology.py b/test/test_topology.py index 0840e42ba8..22e94739ee 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -23,7 +23,8 @@ from test import client_knobs, unittest from test.pymongo_mocks import DummyMonitor -from test.utils_shared import MockPool, wait_until +from test.utils import MockPool +from test.utils_shared import wait_until from bson.objectid import ObjectId from pymongo import common diff --git a/test/unified_format.py b/test/unified_format.py index 8894aff94b..34ee723c06 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -48,10 +48,10 @@ parse_collection_or_database_options, with_metaclass, ) +from test.utils import get_pool from test.utils_shared import ( camel_to_snake, camel_to_snake_args, - get_pool, parse_spec_options, prepare_spec_arguments, snake_to_camel, diff --git a/test/utils.py b/test/utils.py index 2d8693a77e..9e74023d30 100644 --- a/test/utils.py +++ b/test/utils.py @@ -17,56 +17,19 @@ import asyncio import contextlib -import copy -import functools -import os import random -import re -import shutil -import sys -import threading import time -import unittest -import warnings from asyncio import iscoroutinefunction -from collections import abc, defaultdict -from functools import partial -from test import client_context, db_pwd, db_user -from typing import Any, List -from bson import json_util -from bson.objectid import ObjectId from bson.son import SON -from pymongo import MongoClient, monitoring, operations, read_preferences -from pymongo._asyncio_task import create_task -from pymongo.cursor_shared import CursorType -from pymongo.errors import ConfigurationError, OperationFailure +from pymongo import MongoClient +from pymongo.errors import ConfigurationError from pymongo.hello import HelloCompat -from pymongo.helpers_shared import _SENSITIVE_COMMANDS from pymongo.lock import _create_lock -from pymongo.monitoring import ( - ConnectionCheckedInEvent, - ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - ConnectionCheckOutStartedEvent, - ConnectionClosedEvent, - ConnectionCreatedEvent, - ConnectionReadyEvent, - PoolClearedEvent, - PoolClosedEvent, - PoolCreatedEvent, - PoolReadyEvent, -) from pymongo.operations import _Op -from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import any_server_selector, writable_server_selector -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration -from pymongo.uri_parser import parse_uri -from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -243,31 +206,3 @@ def update_is_writable(self, is_writable): def remove_stale_sockets(self, *args, **kwargs): pass - - -class FunctionCallRecorder: - """Utility class to wrap a callable and record its invocations.""" - - def __init__(self, function): - self._function = function - self._call_list = [] - - def __call__(self, *args, **kwargs): - self._call_list.append((args, kwargs)) - if iscoroutinefunction(self._function): - return self._function(*args, **kwargs) - else: - return self._function(*args, **kwargs) - - def reset(self): - """Wipes the call list.""" - self._call_list = [] - - def call_list(self): - """Returns a copy of the call list.""" - return self._call_list[:] - - @property - def call_count(self): - """Returns the number of times the function has been called.""" - return len(self._call_list) diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index ee6fcfae0c..2772f06070 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -19,17 +19,18 @@ import os import sys from test import PyMongoTestCase +from test.utils import MockPool sys.path[0:0] = [""] from test import unittest from test.pymongo_mocks import DummyMonitor -from test.utils_shared import MockPool, parse_read_preference -from test.utils_shared_selection_tests_shared import ( +from test.utils_selection_tests_shared import ( get_addresses, get_topology_type_name, make_server_description, ) +from test.utils_shared import parse_read_preference from bson import json_util from pymongo.common import HEARTBEAT_FREQUENCY diff --git a/test/utils_shared.py b/test/utils_shared.py index 8f282484e6..1fa9608029 100644 --- a/test/utils_shared.py +++ b/test/utils_shared.py @@ -336,6 +336,34 @@ def __eq__(self, other): return isinstance(other, self.types) +class FunctionCallRecorder: + """Utility class to wrap a callable and record its invocations.""" + + def __init__(self, function): + self._function = function + self._call_list = [] + + def __call__(self, *args, **kwargs): + self._call_list.append((args, kwargs)) + if iscoroutinefunction(self._function): + return self._function(*args, **kwargs) + else: + return self._function(*args, **kwargs) + + def reset(self): + """Wipes the call list.""" + self._call_list = [] + + def call_list(self): + """Returns a copy of the call list.""" + return self._call_list[:] + + @property + def call_count(self): + """Returns the number of times the function has been called.""" + return len(self._call_list) + + def one(s): """Get one element of a set""" return next(iter(s)) diff --git a/tools/synchro.py b/tools/synchro.py index ee925ab790..d995a3e510 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -125,7 +125,6 @@ "StopAsyncIteration": "StopIteration", "create_async_event": "create_event", "async_joinall": "joinall", - "AsyncFunctionCallRecorder": "FunctionCallRecorder", } docstring_replacements: dict[tuple[str, str], str] = { From 2dae369159058e10e710eaa68642a76121a242f1 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 11 Mar 2025 13:34:36 -0400 Subject: [PATCH 3/3] Merge resolution --- .../test_discovery_and_monitoring.py | 19 +++++++++++++------ test/asynchronous/test_monitor.py | 4 ++-- test/asynchronous/test_session.py | 1 - test/asynchronous/utils.py | 1 + test/test_discovery_and_monitoring.py | 17 ++++++++++++----- test/test_monitor.py | 4 +--- test/test_session.py | 4 +--- test/utils.py | 19 +------------------ test/utils_shared.py | 16 ++++++++++++++++ 9 files changed, 47 insertions(+), 38 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index c3c2bb1a6c..b3de2c5a4d 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -26,25 +26,32 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, AsyncUnitTest, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + AsyncUnitTest, + async_client_context, + unittest, +) from test.asynchronous.pymongo_mocks import DummyMonitor from test.asynchronous.unified_format import generate_test_classes -from test.utils import ( +from test.asynchronous.utils import ( + async_get_pool, +) +from test.utils_shared import ( CMAPListener, HeartbeatEventListener, HeartbeatEventsListListener, assertion_context, async_barrier_wait, - async_client_context, async_create_barrier, - async_get_pool, async_wait_until, server_name_to_type, ) from unittest.mock import patch from bson import Timestamp, json_util -from pymongo import AsyncMongoClient, common, monitoring +from pymongo import common, monitoring from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.errors import ( @@ -291,7 +298,7 @@ async def test_ignore_stale_connection_errors(self): if not _IS_SYNC and sys.version_info < (3, 11): self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") N_TASKS = 5 - barrier = async_create_barrier(N_TASKS, timeout=30) + barrier = async_create_barrier(N_TASKS) client = await self.async_rs_or_single_client(minPoolSize=N_TASKS) # Wait for initial discovery. diff --git a/test/asynchronous/test_monitor.py b/test/asynchronous/test_monitor.py index 2705fbda3b..195f6f9fac 100644 --- a/test/asynchronous/test_monitor.py +++ b/test/asynchronous/test_monitor.py @@ -25,10 +25,10 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, connected, unittest -from test.utils import ( - ServerAndTopologyEventListener, +from test.asynchronous.utils import ( async_wait_until, ) +from test.utils_shared import ServerAndTopologyEventListener from pymongo.periodic_executor import _EXECUTORS diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 15facbeab3..4431cbcb16 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -30,7 +30,6 @@ from test.asynchronous import ( AsyncIntegrationTest, - AsyncPyMongoTestCase, AsyncUnitTest, SkipTest, async_client_context, diff --git a/test/asynchronous/utils.py b/test/asynchronous/utils.py index d1615fc0fa..4b68595397 100644 --- a/test/asynchronous/utils.py +++ b/test/asynchronous/utils.py @@ -18,6 +18,7 @@ import asyncio import contextlib import random +import threading # Used in the synchronized version of this file import time from asyncio import iscoroutinefunction diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index fa58460ae5..00021310c9 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -20,23 +20,30 @@ import socketserver import sys import threading -from test.utils import get_pool from asyncio import StreamReader, StreamWriter from pathlib import Path from test.helpers import ConcurrentRunner sys.path[0:0] = [""] -from test import IntegrationTest, PyMongoTestCase, UnitTest, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + UnitTest, + client_context, + unittest, +) from test.pymongo_mocks import DummyMonitor from test.unified_format import generate_test_classes +from test.utils import ( + get_pool, +) from test.utils_shared import ( CMAPListener, HeartbeatEventListener, HeartbeatEventsListListener, assertion_context, barrier_wait, - client_context, create_barrier, server_name_to_type, wait_until, @@ -44,7 +51,7 @@ from unittest.mock import patch from bson import Timestamp, json_util -from pymongo import MongoClient, common, monitoring +from pymongo import common, monitoring from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -291,7 +298,7 @@ def test_ignore_stale_connection_errors(self): if not _IS_SYNC and sys.version_info < (3, 11): self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") N_TASKS = 5 - barrier = create_barrier(N_TASKS, timeout=30) + barrier = create_barrier(N_TASKS) client = self.rs_or_single_client(minPoolSize=N_TASKS) # Wait for initial discovery. diff --git a/test/test_monitor.py b/test/test_monitor.py index efc500c3ef..25620a99e8 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -24,13 +24,11 @@ sys.path[0:0] = [""] -from test import IntegrationTest, connected, unittest -from test.utils_shared import ( from test import IntegrationTest, client_context, connected, unittest from test.utils import ( - ServerAndTopologyEventListener, wait_until, ) +from test.utils_shared import ServerAndTopologyEventListener from pymongo.periodic_executor import _EXECUTORS diff --git a/test/test_session.py b/test/test_session.py index fe233ae519..905539a1f8 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -30,15 +30,13 @@ from test import ( IntegrationTest, - PyMongoTestCase, SkipTest, UnitTest, client_context, unittest, ) -from test.utils_shared import ( from test.helpers import client_knobs -from test.utils import ( +from test.utils_shared import ( EventListener, HeartbeatEventListener, OvertCommandListener, diff --git a/test/utils.py b/test/utils.py index cd44b36b71..1459a8fba7 100644 --- a/test/utils.py +++ b/test/utils.py @@ -18,6 +18,7 @@ import asyncio import contextlib import random +import threading # Used in the synchronized version of this file import time from asyncio import iscoroutinefunction @@ -206,21 +207,3 @@ def update_is_writable(self, is_writable): def remove_stale_sockets(self, *args, **kwargs): pass -def create_event(): - return threading.Event() - - -def async_create_barrier(N_TASKS, timeout: float | None = None): - return asyncio.Barrier(N_TASKS) - - -def create_barrier(N_TASKS, timeout: float | None = None): - return threading.Barrier(N_TASKS, timeout=timeout) - - -async def async_barrier_wait(barrier, timeout: float | None = None): - await asyncio.wait_for(barrier.wait(), timeout=timeout) - - -def barrier_wait(barrier, timeout: float | None = None): - barrier.wait() diff --git a/test/utils_shared.py b/test/utils_shared.py index 1fa9608029..2c52445968 100644 --- a/test/utils_shared.py +++ b/test/utils_shared.py @@ -687,3 +687,19 @@ def create_async_event(): def create_event(): return threading.Event() + + +def async_create_barrier(n_tasks: int): + return asyncio.Barrier(n_tasks) + + +def create_barrier(n_tasks: int, timeout: float | None = None): + return threading.Barrier(n_tasks, timeout=timeout) + + +async def async_barrier_wait(barrier, timeout: float | None = None): + await asyncio.wait_for(barrier.wait(), timeout=timeout) + + +def barrier_wait(barrier, timeout: float | None = None): + barrier.wait(timeout=timeout)