diff --git a/Cargo.toml b/Cargo.toml index de5c41052..073b6a59f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ authors = [ "Kaitlin Mahar ", ] description = "The official MongoDB driver for Rust" -edition = "2018" +edition = "2021" keywords = ["mongo", "mongodb", "database", "bson", "nosql"] categories = ["asynchronous", "database", "web-programming"] repository = "https://github.com/mongodb/mongo-rust-driver" @@ -163,6 +163,7 @@ function_name = "0.2.1" futures = "0.3" home = "0.5" pretty_assertions = "1.1.0" +serde = { version = "*", features = ["rc"] } serde_json = "1.0.64" semver = "1.0.0" time = "0.3.9" diff --git a/src/bson_util/mod.rs b/src/bson_util/mod.rs index 1bedde6c1..a92969c78 100644 --- a/src/bson_util/mod.rs +++ b/src/bson_util/mod.rs @@ -225,6 +225,16 @@ pub(crate) fn serialize_error_as_string( serializer.serialize_str(&val.to_string()) } +/// Serializes a Result, serializing the error value as a string if present. +pub(crate) fn serialize_result_error_as_string( + val: &Result, + serializer: S, +) -> std::result::Result { + val.as_ref() + .map_err(|e| e.to_string()) + .serialize(serializer) +} + #[cfg(test)] mod test { use crate::bson_util::num_decimal_digits; diff --git a/src/client/mod.rs b/src/client/mod.rs index afb54c599..9324c3f32 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -475,6 +475,11 @@ impl Client { }) .ok() } + + #[cfg(test)] + pub(crate) fn topology(&self) -> &Topology { + &self.inner.topology + } } #[cfg(feature = "csfle")] diff --git a/src/cmap/test/file.rs b/src/cmap/test/file.rs index f4c81f1b8..bb7390ed8 100644 --- a/src/cmap/test/file.rs +++ b/src/cmap/test/file.rs @@ -21,7 +21,7 @@ pub struct TestFile { #[serde(default)] pub ignore: Vec, pub fail_point: Option, - pub run_on: Option>, + pub(crate) run_on: Option>, } #[derive(Debug, Deserialize)] diff --git a/src/sdam/description/server.rs b/src/sdam/description/server.rs index 54df48938..2b296819a 100644 --- a/src/sdam/description/server.rs +++ b/src/sdam/description/server.rs @@ -4,7 +4,9 @@ use serde::{Deserialize, Serialize}; use crate::{ bson::{oid::ObjectId, DateTime}, + bson_util, client::ClusterTime, + error::{ErrorKind, Result}, hello::HelloReply, options::ServerAddress, selection_criteria::TagSet, @@ -106,7 +108,8 @@ pub(crate) struct ServerDescription { // allows us to ensure that only valid states are possible (e.g. preventing that both an error // and a reply are present) while still making it easy to define helper methods on // ServerDescription for information we need from the hello reply by propagating with `?`. - pub(crate) reply: Result, String>, + #[serde(serialize_with = "bson_util::serialize_result_error_as_string")] + pub(crate) reply: Result>, } impl PartialEq for ServerDescription { @@ -122,17 +125,22 @@ impl PartialEq for ServerDescription { self_response == other_response } - (Err(self_err), Err(other_err)) => self_err == other_err, + (Err(self_err), Err(other_err)) => { + match (self_err.kind.as_ref(), other_err.kind.as_ref()) { + ( + ErrorKind::Command(self_command_err), + ErrorKind::Command(other_command_err), + ) => self_command_err.code == other_command_err.code, + _ => self_err.to_string() == other_err.to_string(), + } + } _ => false, } } } impl ServerDescription { - pub(crate) fn new( - mut address: ServerAddress, - hello_reply: Option>, - ) -> Self { + pub(crate) fn new(mut address: ServerAddress, hello_reply: Option>) -> Self { address = ServerAddress::Tcp { host: address.host().to_lowercase(), port: address.port(), @@ -231,7 +239,7 @@ impl ServerDescription { None } - pub(crate) fn set_name(&self) -> Result, String> { + pub(crate) fn set_name(&self) -> Result> { let set_name = self .reply .as_ref() @@ -241,7 +249,7 @@ impl ServerDescription { Ok(set_name) } - pub(crate) fn known_hosts(&self) -> Result, String> { + pub(crate) fn known_hosts(&self) -> Result> { let known_hosts = self .reply .as_ref() @@ -262,7 +270,7 @@ impl ServerDescription { Ok(known_hosts.into_iter().flatten()) } - pub(crate) fn invalid_me(&self) -> Result { + pub(crate) fn invalid_me(&self) -> Result { if let Some(ref reply) = self.reply.as_ref().map_err(Clone::clone)? { if let Some(ref me) = reply.command_response.me { return Ok(&self.address.to_string() != me); @@ -272,7 +280,7 @@ impl ServerDescription { Ok(false) } - pub(crate) fn set_version(&self) -> Result, String> { + pub(crate) fn set_version(&self) -> Result> { let me = self .reply .as_ref() @@ -282,7 +290,7 @@ impl ServerDescription { Ok(me) } - pub(crate) fn election_id(&self) -> Result, String> { + pub(crate) fn election_id(&self) -> Result> { let me = self .reply .as_ref() @@ -293,7 +301,7 @@ impl ServerDescription { } #[cfg(test)] - pub(crate) fn min_wire_version(&self) -> Result, String> { + pub(crate) fn min_wire_version(&self) -> Result> { let me = self .reply .as_ref() @@ -303,7 +311,7 @@ impl ServerDescription { Ok(me) } - pub(crate) fn max_wire_version(&self) -> Result, String> { + pub(crate) fn max_wire_version(&self) -> Result> { let me = self .reply .as_ref() @@ -313,7 +321,7 @@ impl ServerDescription { Ok(me) } - pub(crate) fn last_write_date(&self) -> Result, String> { + pub(crate) fn last_write_date(&self) -> Result> { match self.reply { Ok(None) => Ok(None), Ok(Some(ref reply)) => Ok(reply @@ -325,7 +333,7 @@ impl ServerDescription { } } - pub(crate) fn logical_session_timeout(&self) -> Result, String> { + pub(crate) fn logical_session_timeout(&self) -> Result> { match self.reply { Ok(None) => Ok(None), Ok(Some(ref reply)) => Ok(reply @@ -336,7 +344,7 @@ impl ServerDescription { } } - pub(crate) fn cluster_time(&self) -> Result, String> { + pub(crate) fn cluster_time(&self) -> Result> { match self.reply { Ok(None) => Ok(None), Ok(Some(ref reply)) => Ok(reply.cluster_time.clone()), diff --git a/src/sdam/description/topology/mod.rs b/src/sdam/description/topology/mod.rs index e13b321a8..ad4b773f7 100644 --- a/src/sdam/description/topology/mod.rs +++ b/src/sdam/description/topology/mod.rs @@ -13,6 +13,7 @@ use crate::{ bson::oid::ObjectId, client::ClusterTime, cmap::Command, + error::{Error, Result}, options::{ClientOptions, ServerAddress}, sdam::{ description::server::{ServerDescription, ServerType}, @@ -460,10 +461,7 @@ impl TopologyDescription { /// Update the topology based on the new information about the topology contained by the /// ServerDescription. - pub(crate) fn update( - &mut self, - mut server_description: ServerDescription, - ) -> Result<(), String> { + pub(crate) fn update(&mut self, mut server_description: ServerDescription) -> Result<()> { // Ignore updates from servers not currently in the cluster. if !self.servers.contains_key(&server_description.address) { return Ok(()); @@ -516,10 +514,7 @@ impl TopologyDescription { } /// Update the Unknown topology description based on the server description. - fn update_unknown_topology( - &mut self, - server_description: ServerDescription, - ) -> Result<(), String> { + fn update_unknown_topology(&mut self, server_description: ServerDescription) -> Result<()> { match server_description.server_type { ServerType::Unknown | ServerType::RsGhost => {} ServerType::Standalone => { @@ -535,7 +530,7 @@ impl TopologyDescription { self.update_rs_without_primary_server(server_description)?; } ServerType::LoadBalancer => { - return Err("cannot transition to a load balancer".to_string()) + return Err(Error::internal("cannot transition to a load balancer")) } } @@ -556,7 +551,7 @@ impl TopologyDescription { fn update_replica_set_no_primary_topology( &mut self, server_description: ServerDescription, - ) -> Result<(), String> { + ) -> Result<()> { match server_description.server_type { ServerType::Unknown | ServerType::RsGhost => {} ServerType::Standalone | ServerType::Mongos => { @@ -570,7 +565,7 @@ impl TopologyDescription { self.update_rs_without_primary_server(server_description)?; } ServerType::LoadBalancer => { - return Err("cannot transition to a load balancer".to_string()) + return Err(Error::internal("cannot transition to a load balancer")) } } @@ -581,7 +576,7 @@ impl TopologyDescription { fn update_replica_set_with_primary_topology( &mut self, server_description: ServerDescription, - ) -> Result<(), String> { + ) -> Result<()> { match server_description.server_type { ServerType::Unknown | ServerType::RsGhost => { self.record_primary_state(); @@ -595,7 +590,7 @@ impl TopologyDescription { self.update_rs_with_primary_from_member(server_description)?; } ServerType::LoadBalancer => { - return Err("cannot transition to a load balancer".to_string()) + return Err(Error::internal("cannot transition to a load balancer")); } } @@ -616,7 +611,7 @@ impl TopologyDescription { fn update_rs_without_primary_server( &mut self, server_description: ServerDescription, - ) -> Result<(), String> { + ) -> Result<()> { if self.set_name.is_none() { self.set_name = server_description.set_name()?; } else if self.set_name != server_description.set_name()? { @@ -639,7 +634,7 @@ impl TopologyDescription { fn update_rs_with_primary_from_member( &mut self, server_description: ServerDescription, - ) -> Result<(), String> { + ) -> Result<()> { if self.set_name != server_description.set_name()? { self.servers.remove(&server_description.address); self.record_primary_state(); @@ -661,7 +656,7 @@ impl TopologyDescription { fn update_rs_from_primary_server( &mut self, server_description: ServerDescription, - ) -> Result<(), String> { + ) -> Result<()> { if self.set_name.is_none() { self.set_name = server_description.set_name()?; } else if self.set_name != server_description.set_name()? { @@ -750,13 +745,8 @@ impl TopologyDescription { } /// Create a new ServerDescription for each address and add it to the topology. - fn add_new_servers<'a>( - &mut self, - servers: impl Iterator, - ) -> Result<(), String> { - let servers: Result, String> = servers - .map(|server| ServerAddress::parse(server).map_err(|e| e.to_string())) - .collect(); + fn add_new_servers<'a>(&mut self, servers: impl Iterator) -> Result<()> { + let servers: Result> = servers.map(ServerAddress::parse).collect(); self.add_new_servers_from_addresses(servers?.iter()); Ok(()) @@ -856,16 +846,13 @@ pub(crate) struct TopologyDescriptionDiff<'a> { } fn verify_max_staleness(max_staleness: Option) -> crate::error::Result<()> { - verify_max_staleness_inner(max_staleness) - .map_err(|s| crate::error::ErrorKind::InvalidArgument { message: s }.into()) -} - -fn verify_max_staleness_inner(max_staleness: Option) -> std::result::Result<(), String> { if max_staleness .map(|staleness| staleness > Duration::from_secs(0) && staleness < Duration::from_secs(90)) .unwrap_or(false) { - return Err("max staleness cannot be both positive and below 90 seconds".into()); + return Err(Error::invalid_argument( + "max staleness cannot be both positive and below 90 seconds", + )); } Ok(()) diff --git a/src/sdam/description/topology/server_selection/mod.rs b/src/sdam/description/topology/server_selection/mod.rs index f01b87ae1..0d42f9ed9 100644 --- a/src/sdam/description/topology/server_selection/mod.rs +++ b/src/sdam/description/topology/server_selection/mod.rs @@ -185,6 +185,11 @@ impl TopologyDescription { .filter(move |server| types.contains(&server.server_type)) } + #[cfg(test)] + pub(crate) fn primary(&self) -> Option<&ServerDescription> { + self.servers_with_type(&[ServerType::RsPrimary]).next() + } + fn suitable_servers_in_replica_set<'a>( &self, read_preference: &'a ReadPreference, diff --git a/src/sdam/public.rs b/src/sdam/public.rs index 8c433b508..ab24a7341 100644 --- a/src/sdam/public.rs +++ b/src/sdam/public.rs @@ -5,6 +5,7 @@ use serde::Serialize; pub use crate::sdam::description::{server::ServerType, topology::TopologyType}; use crate::{ bson::DateTime, + error::Error, hello::HelloCommandResponse, options::ServerAddress, sdam::ServerDescription, @@ -100,6 +101,15 @@ impl<'a> ServerInfo<'a> { pub fn tags(&self) -> Option<&TagSet> { self.command_response_getter(|r| r.tags.as_ref()) } + + /// Gets the error that caused the server's state to be transitioned to Unknown, if any. + /// + /// When the driver encounters certain errors during operation execution or server monitoring, + /// it transitions the affected server's state to Unknown, rendering the server unusable for + /// future operations until it is confirmed to be in healthy state again. + pub fn error(&self) -> Option<&Error> { + self.description.reply.as_ref().err() + } } impl<'a> fmt::Debug for ServerInfo<'a> { diff --git a/src/sdam/topology.rs b/src/sdam/topology.rs index 92d2ea0d9..d0d01918c 100644 --- a/src/sdam/topology.rs +++ b/src/sdam/topology.rs @@ -186,10 +186,7 @@ impl Topology { average_round_trip_time: Some(Duration::from_nanos(0)), ..ServerDescription::new(server_address.clone(), None) }; - new_state - .description - .update(new_desc) - .map_err(Error::internal)?; + new_state.description.update(new_desc)?; } worker.process_topology_diff(&old_description, &new_state.description); @@ -480,7 +477,10 @@ impl TopologyWorker { if sd.is_available() { let got_name = sd.set_name(); if latest_state.description.topology_type() == TopologyType::Single - && got_name.as_ref().map(|opt| opt.as_ref()) != Ok(Some(expected_name)) + && !matches!( + got_name.as_ref().map(|opt| opt.as_ref()), + Ok(Some(name)) if name == expected_name + ) { let got_display = match got_name { Ok(Some(s)) => format!("{:?}", s), @@ -490,10 +490,10 @@ impl TopologyWorker { // Mark server as unknown. sd = ServerDescription::new( sd.address, - Some(Err(format!( + Some(Err(Error::invalid_argument(format!( "Connection string replicaSet name {:?} does not match actual name {}", expected_name, got_display, - ))), + )))), ); } } @@ -580,7 +580,7 @@ impl TopologyWorker { /// Mark the server at the given address as Unknown using the provided error as the cause. async fn mark_server_as_unknown(&mut self, address: ServerAddress, error: Error) -> bool { - let description = ServerDescription::new(address, Some(Err(error.to_string()))); + let description = ServerDescription::new(address, Some(Err(error))); self.update_server(description).await } diff --git a/src/test/atlas_planned_maintenance_testing/mod.rs b/src/test/atlas_planned_maintenance_testing/mod.rs index c00b12ec6..5ff7256a0 100644 --- a/src/test/atlas_planned_maintenance_testing/mod.rs +++ b/src/test/atlas_planned_maintenance_testing/mod.rs @@ -22,6 +22,8 @@ use crate::{ use json_models::{Events, Results}; +use super::spec::unified_runner::EntityMap; + #[test] fn get_exe_name() { let mut file = File::create("exe_name.txt").expect("Failed to create file"); @@ -51,7 +53,8 @@ async fn workload_executor() { let mut test_runner = TestRunner::new_with_connection_string(&connection_string).await; let execution_errors = execute_workload(&mut test_runner, workload).await; - write_json(&mut test_runner, execution_errors); + let mut entities = test_runner.entities.write().await; + write_json(&mut entities, execution_errors); } async fn execute_workload(test_runner: &mut TestRunner, workload: Value) -> Vec { @@ -62,7 +65,7 @@ async fn execute_workload(test_runner: &mut TestRunner, workload: Value) -> Vec< log_uncaptured("Running planned maintenance tests"); - if AssertUnwindSafe(test_runner.run_test(test_file, |_| true)) + if AssertUnwindSafe(test_runner.run_test(None, test_file, |_| true)) .catch_unwind() .await .is_err() @@ -81,27 +84,25 @@ async fn execute_workload(test_runner: &mut TestRunner, workload: Value) -> Vec< execution_errors } -fn write_json(test_runner: &mut TestRunner, mut errors: Vec) { +fn write_json(entities: &mut EntityMap, mut errors: Vec) { log_uncaptured("Writing planned maintenance test results to files"); let mut events = Events::new_empty(); - if let Some(Entity::Bson(Bson::Array(mut operation_errors))) = - test_runner.entities.remove("errors") - { + if let Some(Entity::Bson(Bson::Array(mut operation_errors))) = entities.remove("errors") { errors.append(&mut operation_errors); } events.errors = errors; - if let Some(Entity::Bson(Bson::Array(failures))) = test_runner.entities.remove("failures") { + if let Some(Entity::Bson(Bson::Array(failures))) = entities.remove("failures") { events.failures = failures; } let mut results = Results::new_empty(); results.num_errors = events.errors.len().into(); results.num_failures = events.failures.len().into(); - if let Some(Entity::Bson(Bson::Int64(iterations))) = test_runner.entities.remove("iterations") { + if let Some(Entity::Bson(Bson::Int64(iterations))) = entities.remove("iterations") { results.num_iterations = iterations.into(); } - if let Some(Entity::Bson(Bson::Int64(successes))) = test_runner.entities.remove("successes") { + if let Some(Entity::Bson(Bson::Int64(successes))) = entities.remove("successes") { results.num_successes = successes.into(); } @@ -120,7 +121,20 @@ fn write_json(test_runner: &mut TestRunner, mut errors: Vec) { // The events key is expected to be present regardless of whether storeEventsAsEntities was // defined. write!(&mut writer, ",\"events\":[").unwrap(); - test_runner.write_events_list_to_file("events", &mut writer); + let event_list_entity = match entities.get("events") { + Some(entity) => entity.as_event_list().to_owned(), + None => return, + }; + let client = entities + .get(&event_list_entity.client_id) + .unwrap() + .as_client(); + let names: Vec<&str> = event_list_entity + .event_names + .iter() + .map(String::as_ref) + .collect(); + client.write_events_list_to_file(&names, &mut writer); write!(&mut writer, "]}}").unwrap(); let mut results_path = PathBuf::from(&path); diff --git a/src/test/spec/change_streams.rs b/src/test/spec/change_streams.rs index 0836fffea..106bd871e 100644 --- a/src/test/spec/change_streams.rs +++ b/src/test/spec/change_streams.rs @@ -1,13 +1,10 @@ -use crate::test::{run_spec_test, LOCK}; +use crate::test::LOCK; -use super::run_unified_format_test; +use super::{run_spec_test_with_path, run_unified_format_test}; #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] // multi_thread required for FailPoint #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn run() { let _guard = LOCK.run_exclusively().await; - run_spec_test(&["change-streams", "unified"], |f| { - run_unified_format_test(f) - }) - .await; + run_spec_test_with_path(&["change-streams", "unified"], run_unified_format_test).await; } diff --git a/src/test/spec/collection_management.rs b/src/test/spec/collection_management.rs index 2dfc32a51..5df60ecb3 100644 --- a/src/test/spec/collection_management.rs +++ b/src/test/spec/collection_management.rs @@ -1,10 +1,10 @@ -use crate::test::{run_spec_test, LOCK}; +use crate::test::LOCK; -use super::run_unified_format_test; +use super::{run_spec_test_with_path, run_unified_format_test}; #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn run() { let _guard = LOCK.run_exclusively().await; - run_spec_test(&["collection-management"], run_unified_format_test).await; + run_spec_test_with_path(&["collection-management"], run_unified_format_test).await; } diff --git a/src/test/spec/command_monitoring/mod.rs b/src/test/spec/command_monitoring/mod.rs index 67dac90a8..2b1ffbbe0 100644 --- a/src/test/spec/command_monitoring/mod.rs +++ b/src/test/spec/command_monitoring/mod.rs @@ -1,6 +1,6 @@ -use crate::test::{spec::run_spec_test, LOCK}; +use crate::test::LOCK; -use super::{run_unified_format_test_filtered, unified_runner::TestCase}; +use super::{run_spec_test_with_path, run_unified_format_test_filtered, unified_runner::TestCase}; #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] @@ -11,8 +11,8 @@ async fn command_monitoring_unified() { // This test relies on old OP_QUERY behavior that many drivers still use for < 4.4, but we do not use, due to never implementing OP_QUERY. tc.description != "A successful find event with a getmore and the server kills the cursor (<= 4.4)"; - run_spec_test(&["command-monitoring", "unified"], |f| { - run_unified_format_test_filtered(f, pred) + run_spec_test_with_path(&["command-monitoring", "unified"], |path, f| { + run_unified_format_test_filtered(path, f, pred) }) .await; } diff --git a/src/test/spec/crud.rs b/src/test/spec/crud.rs index a77addb9c..d3cfc5294 100644 --- a/src/test/spec/crud.rs +++ b/src/test/spec/crud.rs @@ -1,15 +1,15 @@ use tokio::sync::RwLockWriteGuard; -use crate::test::{run_spec_test, LOCK}; +use crate::test::LOCK; -use super::{run_unified_format_test_filtered, unified_runner::TestCase}; +use super::{run_spec_test_with_path, run_unified_format_test_filtered, unified_runner::TestCase}; #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn run() { let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await; - run_spec_test(&["crud", "unified"], |file| { - run_unified_format_test_filtered(file, test_predicate) + run_spec_test_with_path(&["crud", "unified"], |path, file| { + run_unified_format_test_filtered(path, file, test_predicate) }) .await; } diff --git a/src/test/spec/json/unified-test-format/invalid/entity-thread-additionalProperties.json b/src/test/spec/json/unified-test-format/invalid/entity-thread-additionalProperties.json new file mode 100644 index 000000000..b296719f1 --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/entity-thread-additionalProperties.json @@ -0,0 +1,18 @@ +{ + "description": "entity-thread-additionalProperties", + "schemaVersion": "1.10", + "createEntities": [ + { + "thread": { + "id": "thread0", + "foo": "bar" + } + } + ], + "tests": [ + { + "description": "foo", + "operations": [] + } + ] +} diff --git a/src/test/spec/json/unified-test-format/invalid/entity-thread-additionalProperties.yml b/src/test/spec/json/unified-test-format/invalid/entity-thread-additionalProperties.yml new file mode 100644 index 000000000..b3fb1dc51 --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/entity-thread-additionalProperties.yml @@ -0,0 +1,12 @@ +description: "entity-thread-additionalProperties" + +schemaVersion: "1.10" + +createEntities: + - thread: + id: &thread0 "thread0" + foo: "bar" + +tests: + - description: "foo" + operations: [] diff --git a/src/test/spec/json/unified-test-format/invalid/entity-thread-id-required.json b/src/test/spec/json/unified-test-format/invalid/entity-thread-id-required.json new file mode 100644 index 000000000..3b197e3d6 --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/entity-thread-id-required.json @@ -0,0 +1,15 @@ +{ + "description": "entity-thread-id-required", + "schemaVersion": "1.10", + "createEntities": [ + { + "thread": {} + } + ], + "tests": [ + { + "description": "foo", + "operations": [] + } + ] +} diff --git a/src/test/spec/json/unified-test-format/invalid/entity-thread-id-required.yml b/src/test/spec/json/unified-test-format/invalid/entity-thread-id-required.yml new file mode 100644 index 000000000..b940d4d5c --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/entity-thread-id-required.yml @@ -0,0 +1,10 @@ +description: "entity-thread-id-required" + +schemaVersion: "1.10" + +createEntities: + - thread: {} + +tests: + - description: "foo" + operations: [] diff --git a/src/test/spec/json/unified-test-format/invalid/entity-thread-id-type.json b/src/test/spec/json/unified-test-format/invalid/entity-thread-id-type.json new file mode 100644 index 000000000..8f281ef6f --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/entity-thread-id-type.json @@ -0,0 +1,17 @@ +{ + "description": "entity-thread-id-type", + "schemaVersion": "1.10", + "createEntities": [ + { + "thread": { + "id": 0 + } + } + ], + "tests": [ + { + "description": "foo", + "operations": [] + } + ] +} diff --git a/src/test/spec/json/unified-test-format/invalid/entity-thread-id-type.yml b/src/test/spec/json/unified-test-format/invalid/entity-thread-id-type.yml new file mode 100644 index 000000000..85646ce9c --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/entity-thread-id-type.yml @@ -0,0 +1,11 @@ +description: "entity-thread-id-type" + +schemaVersion: "1.10" + +createEntities: + - thread: + id: 0 + +tests: + - description: "foo" + operations: [] diff --git a/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-additionalProperties.json b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-additionalProperties.json new file mode 100644 index 000000000..1c6ec460b --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-additionalProperties.json @@ -0,0 +1,23 @@ +{ + "description": "expectedSdamEvent-serverDescriptionChangedEvent-additionalProperties", + "schemaVersion": "1.10", + "tests": [ + { + "description": "foo", + "operations": [], + "expectEvents": [ + { + "client": "client0", + "eventType": "sdam", + "events": [ + { + "serverDescriptionChangedEvent": { + "foo": "bar" + } + } + ] + } + ] + } + ] +} diff --git a/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-additionalProperties.yml b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-additionalProperties.yml new file mode 100644 index 000000000..7d9580fe7 --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-additionalProperties.yml @@ -0,0 +1,13 @@ +description: expectedSdamEvent-serverDescriptionChangedEvent-additionalProperties + +schemaVersion: '1.10' + +tests: + - description: foo + operations: [] + expectEvents: + - client: client0 + eventType: sdam + events: + - serverDescriptionChangedEvent: + foo: bar diff --git a/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-additionalProperties.json b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-additionalProperties.json new file mode 100644 index 000000000..58f686739 --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-additionalProperties.json @@ -0,0 +1,25 @@ +{ + "description": "expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-additionalProperties", + "schemaVersion": "1.10", + "tests": [ + { + "description": "foo", + "operations": [], + "expectEvents": [ + { + "client": "client0", + "eventType": "sdam", + "events": [ + { + "serverDescriptionChangedEvent": { + "previousDescription": { + "foo": "bar" + } + } + } + ] + } + ] + } + ] +} diff --git a/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-additionalProperties.yml b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-additionalProperties.yml new file mode 100644 index 000000000..4f5d74422 --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-additionalProperties.yml @@ -0,0 +1,14 @@ +description: expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-additionalProperties + +schemaVersion: '1.10' + +tests: + - description: foo + operations: [] + expectEvents: + - client: client0 + eventType: sdam + events: + - serverDescriptionChangedEvent: + previousDescription: + foo: bar diff --git a/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-enum.json b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-enum.json new file mode 100644 index 000000000..1b4a7e2e7 --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-enum.json @@ -0,0 +1,25 @@ +{ + "description": "expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-enum", + "schemaVersion": "1.10", + "tests": [ + { + "description": "foo", + "operations": [], + "expectEvents": [ + { + "client": "client0", + "eventType": "sdam", + "events": [ + { + "serverDescriptionChangedEvent": { + "previousDescription": { + "type": "not a server type" + } + } + } + ] + } + ] + } + ] +} diff --git a/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-enum.yml b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-enum.yml new file mode 100644 index 000000000..5211cde78 --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-enum.yml @@ -0,0 +1,14 @@ +description: expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-enum + +schemaVersion: '1.10' + +tests: + - description: foo + operations: [] + expectEvents: + - client: client0 + eventType: sdam + events: + - serverDescriptionChangedEvent: + previousDescription: + type: "not a server type" diff --git a/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-type.json b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-type.json new file mode 100644 index 000000000..c7ea9cc9b --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-type.json @@ -0,0 +1,25 @@ +{ + "description": "expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-type", + "schemaVersion": "1.10", + "tests": [ + { + "description": "foo", + "operations": [], + "expectEvents": [ + { + "client": "client0", + "eventType": "sdam", + "events": [ + { + "serverDescriptionChangedEvent": { + "previousDescription": { + "type": 12 + } + } + } + ] + } + ] + } + ] +} diff --git a/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-type.yml b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-type.yml new file mode 100644 index 000000000..3f856bbda --- /dev/null +++ b/src/test/spec/json/unified-test-format/invalid/expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-type.yml @@ -0,0 +1,14 @@ +description: expectedSdamEvent-serverDescriptionChangedEvent-serverDescription-type-type + +schemaVersion: '1.10' + +tests: + - description: foo + operations: [] + expectEvents: + - client: client0 + eventType: sdam + events: + - serverDescriptionChangedEvent: + previousDescription: + type: 12 diff --git a/src/test/spec/load_balancers.rs b/src/test/spec/load_balancers.rs index c79ac70f7..bd102999d 100644 --- a/src/test/spec/load_balancers.rs +++ b/src/test/spec/load_balancers.rs @@ -42,7 +42,7 @@ async fn run() { } } } - run_unified_format_test_filtered(test_file, |tc| { + run_unified_format_test_filtered(path, test_file, |tc| { // TODO RUST-142 unskip this when change streams are implemented. if tc.description == "change streams pin to a connection" { log_uncaptured("skipping due to change streams not being implemented"); diff --git a/src/test/spec/mod.rs b/src/test/spec/mod.rs index e06a9db8e..a150054b7 100644 --- a/src/test/spec/mod.rs +++ b/src/test/spec/mod.rs @@ -29,7 +29,7 @@ use std::{ path::PathBuf, }; -pub use self::{ +pub(crate) use self::{ unified_runner::{ merge_uri_options, run_unified_format_test, @@ -84,11 +84,11 @@ where pub(crate) async fn run_single_test(path: PathBuf, run_test_file: &F) where - F: Fn(T) -> G, + F: Fn(PathBuf, T) -> G, G: Future, T: DeserializeOwned, { - run_single_test_with_path(path, &|_, t| run_test_file(t)).await + run_single_test_with_path(path, run_test_file).await } pub(crate) async fn run_single_test_with_path(path: PathBuf, run_test_file: &F) @@ -97,10 +97,8 @@ where G: Future, T: DeserializeOwned, { - let json: Value = serde_json::from_reader(File::open(path.as_path()).unwrap()).unwrap(); - - // Printing the name of the test file makes it easier to debug deserialization errors. - println!("Running tests from {}", path.display()); + let json: Value = serde_json::from_reader(File::open(path.as_path()).unwrap()) + .unwrap_or_else(|err| panic!("{}: {}", path.display(), err)); run_test_file( path.clone(), diff --git a/src/test/spec/retryable_reads.rs b/src/test/spec/retryable_reads.rs index 5123af458..2cb516c4f 100644 --- a/src/test/spec/retryable_reads.rs +++ b/src/test/spec/retryable_reads.rs @@ -26,7 +26,7 @@ use crate::{ }, }; -use super::{run_unified_format_test, run_v2_test}; +use super::{run_spec_test_with_path, run_unified_format_test, run_v2_test}; #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] #[cfg_attr(feature = "async-std-runtime", async_std::test)] @@ -39,7 +39,7 @@ async fn run_legacy() { #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn run_unified() { let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await; - run_spec_test(&["retryable-reads", "unified"], run_unified_format_test).await; + run_spec_test_with_path(&["retryable-reads", "unified"], run_unified_format_test).await; } /// Test ensures that the connection used in the first attempt of a retry is released back into the diff --git a/src/test/spec/retryable_writes/mod.rs b/src/test/spec/retryable_writes/mod.rs index fe35ab318..a36a35bd7 100644 --- a/src/test/spec/retryable_writes/mod.rs +++ b/src/test/spec/retryable_writes/mod.rs @@ -36,13 +36,13 @@ use crate::{ }, }; -use super::run_unified_format_test; +use super::{run_spec_test_with_path, run_unified_format_test}; #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn run_unified() { let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await; - run_spec_test(&["retryable-writes", "unified"], run_unified_format_test).await; + run_spec_test_with_path(&["retryable-writes", "unified"], run_unified_format_test).await; } #[cfg_attr(feature = "tokio-runtime", tokio::test)] diff --git a/src/test/spec/retryable_writes/test_file.rs b/src/test/spec/retryable_writes/test_file.rs index 92472f691..7350822a9 100644 --- a/src/test/spec/retryable_writes/test_file.rs +++ b/src/test/spec/retryable_writes/test_file.rs @@ -8,46 +8,46 @@ use crate::{ #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct TestFile { - pub run_on: Option>, - pub data: Vec, - pub tests: Vec, +pub(crate) struct TestFile { + pub(crate) run_on: Option>, + pub(crate) data: Vec, + pub(crate) tests: Vec, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct TestCase { - pub description: String, - pub client_options: Option, - pub use_multiple_mongoses: Option, - pub fail_point: Option, - pub operation: Operation, - pub outcome: Outcome, +pub(crate) struct TestCase { + pub(crate) description: String, + pub(crate) client_options: Option, + pub(crate) use_multiple_mongoses: Option, + pub(crate) fail_point: Option, + pub(crate) operation: Operation, + pub(crate) outcome: Outcome, } #[derive(Debug, Deserialize)] -pub struct Outcome { - pub error: Option, - pub result: Option, - pub collection: CollectionOutcome, +pub(crate) struct Outcome { + pub(crate) error: Option, + pub(crate) result: Option, + pub(crate) collection: CollectionOutcome, } #[derive(Debug, Deserialize)] #[serde(untagged)] -pub enum TestResult { +pub(crate) enum TestResult { Labels(Labels), Value(Bson), } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct Labels { - pub error_labels_contain: Option>, - pub error_labels_omit: Option>, +pub(crate) struct Labels { + pub(crate) error_labels_contain: Option>, + pub(crate) error_labels_omit: Option>, } #[derive(Debug, Deserialize)] -pub struct CollectionOutcome { - pub name: Option, - pub data: Vec, +pub(crate) struct CollectionOutcome { + pub(crate) name: Option, + pub(crate) data: Vec, } diff --git a/src/test/spec/sessions.rs b/src/test/spec/sessions.rs index d987c03e6..bfa730d67 100644 --- a/src/test/spec/sessions.rs +++ b/src/test/spec/sessions.rs @@ -4,16 +4,16 @@ use crate::{ bson::doc, error::ErrorKind, options::SessionOptions, - test::{run_spec_test, TestClient, LOCK}, + test::{TestClient, LOCK}, }; -use super::run_unified_format_test; +use super::{run_spec_test_with_path, run_unified_format_test}; #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn run_unified() { let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await; - run_spec_test(&["sessions"], run_unified_format_test).await; + run_spec_test_with_path(&["sessions"], run_unified_format_test).await; } // Sessions prose test 1 diff --git a/src/test/spec/transactions.rs b/src/test/spec/transactions.rs index 0d78e4bda..97541c507 100644 --- a/src/test/spec/transactions.rs +++ b/src/test/spec/transactions.rs @@ -7,7 +7,7 @@ use crate::{ Collection, }; -use super::{run_unified_format_test, run_v2_test}; +use super::{run_spec_test_with_path, run_unified_format_test, run_v2_test}; #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] #[cfg_attr(feature = "async-std-runtime", async_std::test)] @@ -23,7 +23,7 @@ async fn run_unified() { let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await; // TODO RUST-902: Reduce transactionLifetimeLimitSeconds. - run_spec_test(&["transactions", "unified"], run_unified_format_test).await; + run_spec_test_with_path(&["transactions", "unified"], run_unified_format_test).await; } #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] diff --git a/src/test/spec/unified_runner/entity.rs b/src/test/spec/unified_runner/entity.rs index c58d510eb..e98c73283 100644 --- a/src/test/spec/unified_runner/entity.rs +++ b/src/test/spec/unified_runner/entity.rs @@ -3,15 +3,19 @@ use std::{ io::BufWriter, ops::{Deref, DerefMut}, sync::Arc, + time::Duration, }; -use tokio::sync::{oneshot, Mutex}; +use tokio::sync::{mpsc, oneshot, Mutex}; use crate::{ bson::{Bson, Document}, change_stream::ChangeStream, client::{HELLO_COMMAND_NAMES, REDACTED_COMMANDS}, + error::Error, event::command::CommandStartedEvent, + runtime, + sdam::TopologyDescription, test::{ spec::unified_runner::{ExpectedEventType, ObserveEvent}, CommandEvent, @@ -26,9 +30,11 @@ use crate::{ SessionCursor, }; +use super::{observer::EventObserver, test_file::ThreadMessage, Operation}; + #[derive(Debug)] #[allow(clippy::large_enum_variant)] -pub enum Entity { +pub(crate) enum Entity { Client(ClientEntity), Database(Database), Collection(Collection), @@ -36,27 +42,30 @@ pub enum Entity { Cursor(TestCursor), Bson(Bson), EventList(EventList), + Thread(ThreadEntity), + TopologyDescription(TopologyDescription), None, } #[derive(Clone, Debug)] -pub struct ClientEntity { +pub(crate) struct ClientEntity { client: Client, - observer: Arc, + handler: Arc, + pub(crate) observer: Arc>, observe_events: Option>, ignore_command_names: Option>, observe_sensitive_commands: bool, } #[derive(Debug)] -pub struct SessionEntity { - pub lsid: Document, - pub client_session: Option>, +pub(crate) struct SessionEntity { + pub(crate) lsid: Document, + pub(crate) client_session: Option>, } #[allow(clippy::large_enum_variant)] #[derive(Debug)] -pub enum TestCursor { +pub(crate) enum TestCursor { // Due to https://github.com/rust-lang/rust/issues/59245, the `Entity` type is required to be // `Sync`; however, `Cursor` is `!Sync` due to internally storing a `BoxFuture`, which only // has a `Send` bound. Wrapping it in `Mutex` works around this. @@ -83,7 +92,7 @@ impl From for Entity { } impl TestCursor { - pub async fn make_kill_watcher(&mut self) -> oneshot::Receiver<()> { + pub(crate) async fn make_kill_watcher(&mut self) -> oneshot::Receiver<()> { match self { Self::Normal(cursor) => { let (tx, rx) = oneshot::channel(); @@ -106,16 +115,18 @@ impl TestCursor { } impl ClientEntity { - pub fn new( + pub(crate) fn new( client: Client, - observer: Arc, + handler: Arc, observe_events: Option>, ignore_command_names: Option>, observe_sensitive_commands: bool, ) -> Self { + let observer = EventObserver::new(handler.broadcaster().subscribe()); Self { client, - observer, + handler, + observer: Arc::new(Mutex::new(observer)), observe_events, ignore_command_names, observe_sensitive_commands, @@ -125,8 +136,8 @@ impl ClientEntity { /// Gets a list of all of the events of the requested event types that occurred on this client. /// Ignores any event with a name in the ignore list. Also ignores all configureFailPoint /// events. - pub fn get_filtered_events(&self, expected_type: ExpectedEventType) -> Vec { - self.observer.get_filtered_events(expected_type, |event| { + pub(crate) fn get_filtered_events(&self, expected_type: ExpectedEventType) -> Vec { + self.handler.get_filtered_events(expected_type, |event| { if let Event::Command(cev) = event { if !self.allow_command_event(cev) { return false; @@ -172,26 +183,55 @@ impl ClientEntity { } /// Gets all events of type commandStartedEvent, excluding configureFailPoint events. - pub fn get_all_command_started_events(&self) -> Vec { - self.observer.get_all_command_started_events() + pub(crate) fn get_all_command_started_events(&self) -> Vec { + self.handler.get_all_command_started_events() } /// Writes all events with the given name to the given BufWriter. pub fn write_events_list_to_file(&self, names: &[&str], writer: &mut BufWriter) { - self.observer.write_events_list_to_file(names, writer); + self.handler.write_events_list_to_file(names, writer); } /// Gets the count of connections currently checked out. - pub fn connections_checked_out(&self) -> u32 { - self.observer.connections_checked_out() + pub(crate) fn connections_checked_out(&self) -> u32 { + self.handler.connections_checked_out() } /// Synchronize all connection pool worker threads. - pub async fn sync_workers(&self) { + pub(crate) async fn sync_workers(&self) { self.client.sync_workers().await; } } +#[derive(Clone, Debug)] +pub(crate) struct ThreadEntity { + pub(crate) sender: mpsc::UnboundedSender, +} + +impl ThreadEntity { + pub(crate) fn run_operation(&self, op: Arc) { + self.sender + .send(ThreadMessage::ExecuteOperation(op)) + .unwrap(); + } + + pub(crate) async fn wait(&self) -> bool { + let (tx, rx) = oneshot::channel(); + + // if the task panicked, this send will fail + if self.sender.send(ThreadMessage::Stop(tx)).is_err() { + return false; + } + + // return that both the timeout was satisfied and that the task responded to the + // acknowledgment request. + runtime::timeout(Duration::from_secs(10), rx) + .await + .and_then(|r| r.map_err(|_| Error::internal(""))) // flatten tokio error into mongodb::Error + .is_ok() + } +} + impl From for Entity { fn from(database: Database) -> Self { Self::Database(database) @@ -210,6 +250,12 @@ impl From for Entity { } } +impl From for Entity { + fn from(td: TopologyDescription) -> Self { + Self::TopologyDescription(td) + } +} + impl Deref for ClientEntity { type Target = Client; @@ -219,7 +265,7 @@ impl Deref for ClientEntity { } impl SessionEntity { - pub fn new(client_session: ClientSession) -> Self { + pub(crate) fn new(client_session: ClientSession) -> Self { let lsid = client_session.id().clone(); Self { client_session: Some(Box::new(client_session)), @@ -246,56 +292,70 @@ impl DerefMut for SessionEntity { } impl Entity { - pub fn as_client(&self) -> &ClientEntity { + pub(crate) fn as_client(&self) -> &ClientEntity { match self { Self::Client(client) => client, _ => panic!("Expected client entity, got {:?}", &self), } } - pub fn as_database(&self) -> &Database { + pub(crate) fn as_database(&self) -> &Database { match self { Self::Database(database) => database, _ => panic!("Expected database entity, got {:?}", &self), } } - pub fn as_collection(&self) -> &Collection { + pub(crate) fn as_collection(&self) -> &Collection { match self { Self::Collection(collection) => collection, _ => panic!("Expected collection entity, got {:?}", &self), } } - pub fn as_session_entity(&self) -> &SessionEntity { + pub(crate) fn as_session_entity(&self) -> &SessionEntity { match self { Self::Session(client_session) => client_session, _ => panic!("Expected client session entity, got {:?}", &self), } } - pub fn as_mut_session_entity(&mut self) -> &mut SessionEntity { + pub(crate) fn as_mut_session_entity(&mut self) -> &mut SessionEntity { match self { Self::Session(client_session) => client_session, _ => panic!("Expected mutable client session entity, got {:?}", &self), } } - pub fn as_bson(&self) -> &Bson { + pub(crate) fn as_bson(&self) -> &Bson { match self { Self::Bson(bson) => bson, _ => panic!("Expected BSON entity, got {:?}", &self), } } - pub fn as_mut_cursor(&mut self) -> &mut TestCursor { + pub(crate) fn as_mut_cursor(&mut self) -> &mut TestCursor { match self { Self::Cursor(cursor) => cursor, _ => panic!("Expected cursor, got {:?}", &self), } } - pub fn into_cursor(self) -> TestCursor { + pub(crate) fn as_thread(&self) -> &ThreadEntity { + match self { + Self::Thread(thread) => thread, + _ => panic!("Expected thread, got {:?}", self), + } + } + + pub(crate) fn as_topology_description(&self) -> &TopologyDescription { + match self { + Self::TopologyDescription(desc) => desc, + _ => panic!("Expected Topologydescription, got {:?}", self), + } + } + + pub(crate) fn into_cursor(self) -> TestCursor { match self { Self::Cursor(cursor) => cursor, _ => panic!("Expected cursor, got {:?}", &self), diff --git a/src/test/spec/unified_runner/matcher.rs b/src/test/spec/unified_runner/matcher.rs index 4cc1bc1ad..70f07a712 100644 --- a/src/test/spec/unified_runner/matcher.rs +++ b/src/test/spec/unified_runner/matcher.rs @@ -3,12 +3,19 @@ use bson::Document; use crate::{ bson::{doc, spec::ElementType, Bson}, bson_util::get_int, - test::{CmapEvent, CommandEvent, Event}, + event::sdam::ServerDescription, + test::{CmapEvent, CommandEvent, Event, SdamEvent}, }; -use super::{EntityMap, ExpectedCmapEvent, ExpectedCommandEvent, ExpectedEvent}; +use super::{ + test_event::{ExpectedSdamEvent, TestServerDescription}, + EntityMap, + ExpectedCmapEvent, + ExpectedCommandEvent, + ExpectedEvent, +}; -pub fn results_match( +pub(crate) fn results_match( actual: Option<&Bson>, expected: &Bson, returns_root_documents: bool, @@ -17,7 +24,7 @@ pub fn results_match( results_match_inner(actual, expected, returns_root_documents, true, entities) } -pub fn events_match( +pub(crate) fn events_match( actual: &Event, expected: &ExpectedEvent, entities: Option<&EntityMap>, @@ -27,6 +34,7 @@ pub fn events_match( command_events_match(act, exp, entities) } (Event::Cmap(act), ExpectedEvent::Cmap(exp)) => cmap_events_match(act, exp), + (Event::Sdam(act), ExpectedEvent::Sdam(exp)) => sdam_events_match(act, exp), _ => expected_err(actual, expected), } } @@ -154,6 +162,34 @@ fn cmap_events_match(actual: &CmapEvent, expected: &ExpectedCmapEvent) -> Result } } +fn sdam_events_match(actual: &SdamEvent, expected: &ExpectedSdamEvent) -> Result<(), String> { + match (actual, expected) { + ( + SdamEvent::ServerDescriptionChanged(actual), + ExpectedSdamEvent::ServerDescriptionChanged { + previous_description, + new_description, + }, + ) => { + let match_sd = |actual: &ServerDescription, + expected: &TestServerDescription| + -> std::result::Result<(), String> { + match_opt(&actual.server_type(), &expected.server_type)?; + Ok(()) + }; + + if let Some(expected_previous_description) = previous_description { + match_sd(&actual.previous_description, expected_previous_description)?; + } + if let Some(expected_new_description) = new_description { + match_sd(&actual.new_description, expected_new_description)?; + } + Ok(()) + } + _ => expected_err(actual, expected), + } +} + fn results_match_inner( actual: Option<&Bson>, expected: &Bson, diff --git a/src/test/spec/unified_runner/mod.rs b/src/test/spec/unified_runner/mod.rs index aa6b54df8..296caa9f2 100644 --- a/src/test/spec/unified_runner/mod.rs +++ b/src/test/spec/unified_runner/mod.rs @@ -1,9 +1,10 @@ -pub mod entity; -pub mod matcher; -mod operation; -mod test_event; -pub mod test_file; -pub mod test_runner; +pub(crate) mod entity; +pub(crate) mod matcher; +pub(crate) mod observer; +pub(crate) mod operation; +pub(crate) mod test_event; +pub(crate) mod test_file; +pub(crate) mod test_runner; use std::{convert::TryFrom, ffi::OsStr, fs::read_dir, path::PathBuf}; @@ -11,12 +12,12 @@ use futures::future::FutureExt; use semver::Version; use tokio::sync::RwLockWriteGuard; -use crate::test::{log_uncaptured, run_single_test, run_spec_test, LOCK}; +use crate::test::{log_uncaptured, run_single_test, LOCK}; -pub use self::{ +pub(crate) use self::{ entity::{ClientEntity, Entity, SessionEntity, TestCursor}, matcher::{events_match, results_match}, - operation::{Operation, OperationObject}, + operation::Operation, test_event::{ExpectedCmapEvent, ExpectedCommandEvent, ExpectedEvent, ObserveEvent}, test_file::{ merge_uri_options, @@ -31,14 +32,21 @@ pub use self::{ test_runner::{EntityMap, TestRunner}, }; +use super::run_spec_test_with_path; + static MIN_SPEC_VERSION: Version = Version::new(1, 0, 0); -static MAX_SPEC_VERSION: Version = Version::new(1, 7, 0); +static MAX_SPEC_VERSION: Version = Version::new(1, 10, 0); + +fn file_level_log(message: impl AsRef) { + log_uncaptured(format!("\n------------\n{}\n", message.as_ref())); +} -pub async fn run_unified_format_test(test_file: TestFile) { - run_unified_format_test_filtered(test_file, |_| true).await +pub(crate) async fn run_unified_format_test(path: PathBuf, test_file: TestFile) { + run_unified_format_test_filtered(path, test_file, |_| true).await } -pub async fn run_unified_format_test_filtered( +pub(crate) async fn run_unified_format_test_filtered( + path: PathBuf, test_file: TestFile, pred: impl Fn(&TestCase) -> bool, ) { @@ -49,15 +57,15 @@ pub async fn run_unified_format_test_filtered( &test_file.schema_version ); - let mut test_runner = TestRunner::new().await; - test_runner.run_test(test_file, pred).await; + let test_runner = TestRunner::new().await; + test_runner.run_test(path, test_file, pred).await; } #[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn test_examples() { let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await; - run_spec_test( + run_spec_test_with_path( &["unified-test-format", "examples"], run_unified_format_test, ) @@ -97,7 +105,7 @@ async fn valid_fail() { #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn valid_pass() { let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await; - run_spec_test( + run_spec_test_with_path( &["unified-test-format", "valid-pass"], run_unified_format_test, ) @@ -142,12 +150,14 @@ async fn invalid() { .iter() .any(|skip| *skip == test_file_str) { - log_uncaptured(format!("Skipping {}", test_file_str)); + file_level_log(format!("Skipping {}", test_file_str)); continue; } let path = path.join(&test_file_path); let path_display = path.display().to_string(); + file_level_log(format!("Attempting to parse {}", path_display)); + let json: serde_json::Value = serde_json::from_reader(std::fs::File::open(path.as_path()).unwrap()).unwrap(); let result: Result = bson::from_bson( diff --git a/src/test/spec/unified_runner/observer.rs b/src/test/spec/unified_runner/observer.rs new file mode 100644 index 000000000..2224dde88 --- /dev/null +++ b/src/test/spec/unified_runner/observer.rs @@ -0,0 +1,104 @@ +use tokio::sync::{ + broadcast::{ + self, + error::{RecvError, TryRecvError}, + }, + RwLock, +}; + +use std::{sync::Arc, time::Duration}; + +use crate::{ + error::{Error, Result}, + runtime, + test::Event, +}; + +use super::{events_match, EntityMap, ExpectedEvent}; + +// TODO: RUST-1424: consolidate this with `EventHandler` +/// Observer used to cache all the seen events for a given client in a unified test. +/// Used to implement assertEventCount and waitForEvent operations. +#[derive(Debug)] +pub(crate) struct EventObserver { + seen_events: Vec, + receiver: broadcast::Receiver, +} + +impl EventObserver { + pub fn new(receiver: broadcast::Receiver) -> Self { + Self { + seen_events: Vec::new(), + receiver, + } + } + + pub(crate) async fn recv(&mut self) -> Option { + match self.receiver.recv().await { + Ok(e) => { + self.seen_events.push(e.clone()); + Some(e) + } + Err(RecvError::Lagged(_)) => panic!("event receiver lagged"), + Err(RecvError::Closed) => None, + } + } + + fn try_recv(&mut self) -> Option { + match self.receiver.try_recv() { + Ok(e) => { + self.seen_events.push(e.clone()); + Some(e) + } + Err(TryRecvError::Lagged(_)) => panic!("event receiver lagged"), + Err(TryRecvError::Closed | TryRecvError::Empty) => None, + } + } + + pub(crate) async fn matching_event_count( + &mut self, + event: &ExpectedEvent, + entities: Arc>, + ) -> usize { + // first retrieve all the events buffered in the channel + while self.try_recv().is_some() {} + let es = entities.read().await; + // then count + self.seen_events + .iter() + .filter(|e| events_match(e, event, Some(&es)).is_ok()) + .count() + } + + pub async fn wait_for_matching_events( + &mut self, + event: &ExpectedEvent, + count: usize, + entities: Arc>, + ) -> Result<()> { + let mut seen = self.matching_event_count(event, entities.clone()).await; + + if seen >= count { + return Ok(()); + } + + runtime::timeout(Duration::from_secs(10), async { + while let Some(e) = self.recv().await { + let es = entities.read().await; + if events_match(&e, event, Some(&es)).is_ok() { + seen += 1; + if seen == count { + return Ok(()); + } + } + } + Err(Error::internal(format!( + "ran out of events before, only saw {} of {}", + seen, count + ))) + }) + .await??; + + Ok(()) + } +} diff --git a/src/test/spec/unified_runner/operation.rs b/src/test/spec/unified_runner/operation.rs index cf5b036b7..9c4b20e98 100644 --- a/src/test/spec/unified_runner/operation.rs +++ b/src/test/spec/unified_runner/operation.rs @@ -3,7 +3,6 @@ use std::{ convert::TryInto, fmt::Debug, ops::Deref, - panic::{catch_unwind, AssertUnwindSafe}, sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -20,13 +19,22 @@ use serde::{de::Deserializer, Deserialize}; use time::OffsetDateTime; use tokio::sync::Mutex; -use super::{Entity, ExpectError, TestCursor, TestRunner}; +use super::{ + results_match, + Entity, + EntityMap, + ExpectError, + ExpectedEvent, + TestCursor, + TestFileEntity, + TestRunner, +}; use crate::{ bson::{doc, to_bson, Bson, Deserializer as BsonDeserializer, Document}, bson_util, change_stream::options::ChangeStreamOptions, - client::session::{ClientSession, TransactionState}, + client::session::TransactionState, coll::options::Hint, collation::Collation, error::{ErrorKind, Result}, @@ -57,16 +65,18 @@ use crate::{ }, runtime, selection_criteria::ReadPreference, - test::{spec::unified_runner::matcher::results_match, FailPoint}, + test::FailPoint, Collection, Database, IndexModel, + ServerType, + TopologyType, }; -pub trait TestOperation: Debug + Send + Sync { +pub(crate) trait TestOperation: Debug + Send + Sync { fn execute_test_runner_operation<'a>( &'a self, - _test_runner: &'a mut TestRunner, + _test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { todo!() } @@ -74,7 +84,7 @@ pub trait TestOperation: Debug + Send + Sync { fn execute_entity_operation<'a>( &'a self, _id: &'a str, - _test_runner: &'a mut TestRunner, + _test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { Err(ErrorKind::InvalidArgument { @@ -93,16 +103,108 @@ pub trait TestOperation: Debug + Send + Sync { } } +/// To facilitate working with sessions through the lock, this macro pops it out of the entity map, +/// "passes" it to the provided block, and then returns it to the entity map. It does it this way +/// so that we can continue to borrow the entity map in other ways even when we're using a session, +/// which we'd have to borrow mutably from the map. +macro_rules! with_mut_session { + ($test_runner:ident, $id:expr, |$session:ident| $body:expr) => { + async { + let id = $id; + let mut session_owned = match $test_runner.entities.write().await.remove(id).unwrap() { + Entity::Session(session_owned) => session_owned, + o => panic!( + "expected {} to be a session entity, instead was {:?}", + $id, o + ), + }; + let $session = &mut session_owned; + let out = $body.await; + $test_runner + .entities + .write() + .await + .insert(id.to_string(), Entity::Session(session_owned)); + out + } + }; +} + #[derive(Debug)] -pub struct Operation { +pub(crate) struct Operation { operation: Box, - pub name: String, - pub object: OperationObject, - pub expectation: Expectation, + pub(crate) name: String, + pub(crate) object: OperationObject, + pub(crate) expectation: Expectation, +} + +impl Operation { + pub(crate) async fn execute<'a>(&self, test_runner: &TestRunner, description: &str) { + match self.object { + OperationObject::TestRunner => { + self.execute_test_runner_operation(test_runner).await; + } + OperationObject::Entity(ref id) => { + let result = self.execute_entity_operation(id, test_runner).await; + + match &self.expectation { + Expectation::Result { + expected_value, + save_as_entity, + } => { + let opt_entity = result.unwrap_or_else(|e| { + panic!( + "[{}] {} should succeed, but failed with the following error: {}", + description, self.name, e + ) + }); + if expected_value.is_some() || save_as_entity.is_some() { + let entity = opt_entity.unwrap_or_else(|| { + panic!("[{}] {} did not return an entity", description, self.name) + }); + if let Some(expected_bson) = expected_value { + if let Entity::Bson(actual) = &entity { + if let Err(e) = results_match( + Some(actual), + expected_bson, + self.returns_root_documents(), + Some(&*test_runner.entities.read().await), + ) { + panic!( + "[{}] result mismatch, expected = {:#?} actual = \ + {:#?}\nmismatch detail: {}", + description, expected_bson, actual, e + ); + } + } else { + panic!( + "[{}] Incorrect entity type returned from {}, expected \ + BSON", + description, self.name + ); + } + } + if let Some(id) = save_as_entity { + test_runner.insert_entity(id, entity).await; + } + } + } + Expectation::Error(expect_error) => { + let error = result.expect_err(&format!( + "{}: {} should return an error", + description, self.name + )); + expect_error.verify_result(&error, description).unwrap(); + } + Expectation::Ignore => (), + } + } + } + } } #[derive(Debug)] -pub enum OperationObject { +pub(crate) enum OperationObject { TestRunner, Entity(String), } @@ -119,7 +221,7 @@ impl<'de> Deserialize<'de> for OperationObject { } #[derive(Debug)] -pub enum Expectation { +pub(crate) enum Expectation { Result { expected_value: Option, save_as_entity: Option, @@ -139,14 +241,14 @@ impl<'de> Deserialize<'de> for Operation { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] struct OperationDefinition { - pub name: String, - pub object: OperationObject, + pub(crate) name: String, + pub(crate) object: OperationObject, #[serde(default = "default_arguments")] - pub arguments: Bson, - pub expect_error: Option, - pub expect_result: Option, - pub save_result_as_entity: Option, - pub ignore_result_and_error: Option, + pub(crate) arguments: Bson, + pub(crate) expect_error: Option, + pub(crate) expect_result: Option, + pub(crate) save_result_as_entity: Option, + pub(crate) ignore_result_and_error: Option, } fn default_arguments() -> Bson { @@ -225,6 +327,17 @@ impl<'de> Deserialize<'de> for Operation { "createChangeStream" => deserialize_op::(definition.arguments), "rename" => deserialize_op::(definition.arguments), "loop" => deserialize_op::(definition.arguments), + "waitForEvent" => deserialize_op::(definition.arguments), + "assertEventCount" => deserialize_op::(definition.arguments), + "runOnThread" => deserialize_op::(definition.arguments), + "waitForThread" => deserialize_op::(definition.arguments), + "recordTopologyDescription" => { + deserialize_op::(definition.arguments) + } + "assertTopologyType" => deserialize_op::(definition.arguments), + "waitForPrimaryChange" => deserialize_op::(definition.arguments), + "wait" => deserialize_op::(definition.arguments), + "createEntities" => deserialize_op::(definition.arguments), _ => Ok(Box::new(UnimplementedOperation) as Box), } .map_err(|e| serde::de::Error::custom(format!("{}", e)))?; @@ -286,10 +399,10 @@ impl TestOperation for DeleteMany { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id); + let collection = test_runner.get_collection(id).await; let result = collection .delete_many(self.filter.clone(), self.options.clone()) .await?; @@ -316,16 +429,22 @@ impl TestOperation for DeleteOne { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let result = match &self.session { Some(session_id) => { - let session = test_runner.get_mut_session(session_id); - collection - .delete_one_with_session(self.filter.clone(), self.options.clone(), session) - .await? + with_mut_session!(test_runner, session_id, |session| async { + collection + .delete_one_with_session( + self.filter.clone(), + self.options.clone(), + session, + ) + .await + }) + .await? } None => { collection @@ -378,9 +497,9 @@ impl Find { async fn get_cursor<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> Result { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; // `FindOptions` is constructed without the use of `..Default::default()` to enforce at // compile-time that any new fields added there need to be considered here. let comment = if let Some(Bson::String(s)) = &self.comment { @@ -414,10 +533,12 @@ impl Find { }; match &self.session { Some(session_id) => { - let session = test_runner.get_mut_session(session_id); - let cursor = collection - .find_with_session(self.filter.clone(), options, session) - .await?; + let cursor = with_mut_session!(test_runner, session_id, |session| async { + collection + .find_with_session(self.filter.clone(), options, session) + .await + }) + .await?; Ok(TestCursor::Session { cursor, session_id: session_id.clone(), @@ -435,7 +556,7 @@ impl TestOperation for Find { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { let result = match self.get_cursor(id, test_runner).await? { @@ -443,11 +564,10 @@ impl TestOperation for Find { mut cursor, session_id, } => { - let session = test_runner.get_mut_session(&session_id); - cursor - .stream(session) - .try_collect::>() - .await? + with_mut_session!(test_runner, session_id.as_str(), |s| async { + cursor.stream(s).try_collect::>().await + }) + .await? } TestCursor::Normal(cursor) => { let cursor = cursor.into_inner(); @@ -500,7 +620,7 @@ impl TestOperation for CreateFindCursor { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { let find = Find { @@ -553,20 +673,25 @@ impl TestOperation for InsertMany { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let result = match &self.session { Some(session_id) => { - let session = test_runner.get_mut_session(session_id); - collection - .insert_many_with_session( - self.documents.clone(), - self.options.clone(), - session, - ) - .await? + with_mut_session!(test_runner, session_id, |session| { + async move { + collection + .insert_many_with_session( + self.documents.clone(), + self.options.clone(), + session, + ) + .await + } + .boxed() + }) + .await? } None => { collection @@ -602,19 +727,22 @@ impl TestOperation for InsertOne { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let result = match &self.session { Some(session_id) => { - collection - .insert_one_with_session( - self.document.clone(), - self.options.clone(), - test_runner.get_mut_session(session_id), - ) - .await? + with_mut_session!(test_runner, session_id, |session| async { + collection + .insert_one_with_session( + self.document.clone(), + self.options.clone(), + session, + ) + .await + }) + .await? } None => { collection @@ -645,10 +773,10 @@ impl TestOperation for UpdateMany { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id); + let collection = test_runner.get_collection(id).await; let result = collection .update_many( self.filter.clone(), @@ -680,20 +808,23 @@ impl TestOperation for UpdateOne { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let result = match &self.session { Some(session_id) => { - collection - .update_one_with_session( - self.filter.clone(), - self.update.clone(), - self.options.clone(), - test_runner.get_mut_session(session_id), - ) - .await? + with_mut_session!(test_runner, session_id, |session| async { + collection + .update_one_with_session( + self.filter.clone(), + self.update.clone(), + self.options.clone(), + session, + ) + .await + }) + .await? } None => { collection @@ -728,7 +859,7 @@ impl TestOperation for Aggregate { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { let result = match &self.session { @@ -738,41 +869,41 @@ impl TestOperation for Aggregate { Database(Database), Other(String), } - let entity = match test_runner.entities.get(id).unwrap() { + let entity = match test_runner.entities.read().await.get(id).unwrap() { Entity::Collection(c) => AggregateEntity::Collection(c.clone()), Entity::Database(d) => AggregateEntity::Database(d.clone()), other => AggregateEntity::Other(format!("{:?}", other)), }; - let session = test_runner.get_mut_session(session_id); - let mut cursor = match entity { - AggregateEntity::Collection(collection) => { - collection - .aggregate_with_session( + with_mut_session!(test_runner, session_id, |session| async { + let mut cursor = match entity { + AggregateEntity::Collection(collection) => { + collection + .aggregate_with_session( + self.pipeline.clone(), + self.options.clone(), + session, + ) + .await? + } + AggregateEntity::Database(db) => { + db.aggregate_with_session( self.pipeline.clone(), self.options.clone(), session, ) .await? - } - AggregateEntity::Database(db) => { - db.aggregate_with_session( - self.pipeline.clone(), - self.options.clone(), - session, - ) - .await? - } - AggregateEntity::Other(debug) => { - panic!("Cannot execute aggregate on {}", &debug) - } - }; - cursor - .stream(session) - .try_collect::>() - .await? + } + AggregateEntity::Other(debug) => { + panic!("Cannot execute aggregate on {}", &debug) + } + }; + cursor.stream(session).try_collect::>().await + }) + .await? } None => { - let cursor = match test_runner.entities.get(id).unwrap() { + let entities = test_runner.entities.read().await; + let cursor = match entities.get(id).unwrap() { Entity::Collection(collection) => { collection .aggregate(self.pipeline.clone(), self.options.clone()) @@ -811,21 +942,23 @@ impl TestOperation for Distinct { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let result = match &self.session { Some(session_id) => { - let session = test_runner.get_mut_session(session_id); - collection - .distinct_with_session( - &self.field_name, - self.filter.clone(), - self.options.clone(), - session, - ) - .await? + with_mut_session!(test_runner, session_id, |session| async { + collection + .distinct_with_session( + &self.field_name, + self.filter.clone(), + self.options.clone(), + session, + ) + .await + }) + .await? } None => { collection @@ -855,20 +988,22 @@ impl TestOperation for CountDocuments { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let result = match &self.session { Some(session_id) => { - let session = test_runner.get_mut_session(session_id); - collection - .count_documents_with_session( - self.filter.clone(), - self.options.clone(), - session, - ) - .await? + with_mut_session!(test_runner, session_id, |session| async { + collection + .count_documents_with_session( + self.filter.clone(), + self.options.clone(), + session, + ) + .await + }) + .await? } None => { collection @@ -893,10 +1028,10 @@ impl TestOperation for EstimatedDocumentCount { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id); + let collection = test_runner.get_collection(id).await; let result = collection .estimated_document_count(self.options.clone()) .await?; @@ -918,10 +1053,10 @@ impl TestOperation for FindOne { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id); + let collection = test_runner.get_collection(id).await; let result = collection .find_one(self.filter.clone(), self.options.clone()) .await?; @@ -947,20 +1082,22 @@ impl TestOperation for ListDatabases { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let client = test_runner.get_client(id).clone(); + let client = test_runner.get_client(id).await; let result = match &self.session { Some(session_id) => { - let session = test_runner.get_mut_session(session_id); - client - .list_databases_with_session( - self.filter.clone(), - self.options.clone(), - session, - ) - .await? + with_mut_session!(test_runner, session_id, |session| async { + client + .list_databases_with_session( + self.filter.clone(), + self.options.clone(), + session, + ) + .await + }) + .await? } None => { client @@ -986,10 +1123,10 @@ impl TestOperation for ListDatabaseNames { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let client = test_runner.get_client(id); + let client = test_runner.get_client(id).await; let result = client .list_database_names(self.filter.clone(), self.options.clone()) .await?; @@ -1013,21 +1150,23 @@ impl TestOperation for ListCollections { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let db = test_runner.get_database(id).clone(); + let db = test_runner.get_database(id).await; let result = match &self.session { Some(session_id) => { - let session = test_runner.get_mut_session(session_id); - let mut cursor = db - .list_collections_with_session( - self.filter.clone(), - self.options.clone(), - session, - ) - .await?; - cursor.stream(session).try_collect::>().await? + with_mut_session!(test_runner, session_id, |session| async { + let mut cursor = db + .list_collections_with_session( + self.filter.clone(), + self.options.clone(), + session, + ) + .await?; + cursor.stream(session).try_collect::>().await + }) + .await? } None => { let cursor = db @@ -1056,10 +1195,10 @@ impl TestOperation for ListCollectionNames { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let db = test_runner.get_database(id); + let db = test_runner.get_database(id).await; let result = db.list_collection_names(self.filter.clone()).await?; let result: Vec = result.iter().map(|s| Bson::String(s.to_string())).collect(); Ok(Some(Bson::from(result).into())) @@ -1084,10 +1223,10 @@ impl TestOperation for ReplaceOne { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id); + let collection = test_runner.get_collection(id).await; let result = collection .replace_one( self.filter.clone(), @@ -1119,21 +1258,23 @@ impl TestOperation for FindOneAndUpdate { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let result = match &self.session { Some(session_id) => { - let session = test_runner.get_mut_session(session_id); - collection - .find_one_and_update_with_session( - self.filter.clone(), - self.update.clone(), - self.options.clone(), - session, - ) - .await? + with_mut_session!(test_runner, session_id, |session| async { + collection + .find_one_and_update_with_session( + self.filter.clone(), + self.update.clone(), + self.options.clone(), + session, + ) + .await + }) + .await? } None => { collection @@ -1168,10 +1309,10 @@ impl TestOperation for FindOneAndReplace { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id); + let collection = test_runner.get_collection(id).await; let result = collection .find_one_and_replace( self.filter.clone(), @@ -1202,10 +1343,10 @@ impl TestOperation for FindOneAndDelete { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id); + let collection = test_runner.get_collection(id).await; let result = collection .find_one_and_delete(self.filter.clone(), self.options.clone()) .await?; @@ -1226,17 +1367,17 @@ pub(super) struct FailPointCommand { impl TestOperation for FailPointCommand { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - let client = test_runner.get_client(&self.client); + let client = test_runner.get_client(&self.client).await; let guard = self .fail_point .clone() - .enable(client, Some(ReadPreference::Primary.into())) + .enable(&client, Some(ReadPreference::Primary.into())) .await .unwrap(); - test_runner.fail_point_guards.push(guard); + test_runner.fail_point_guards.write().await.push(guard); } .boxed() } @@ -1252,21 +1393,28 @@ pub(super) struct TargetedFailPoint { impl TestOperation for TargetedFailPoint { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - let session = test_runner.get_session(&self.session); - let selection_criteria = session - .transaction - .pinned_mongos() - .cloned() - .unwrap_or_else(|| panic!("ClientSession not pinned")); + let selection_criteria = + with_mut_session!(test_runner, self.session.as_str(), |session| async { + session + .transaction + .pinned_mongos() + .cloned() + .unwrap_or_else(|| panic!("ClientSession not pinned")) + }) + .await; let fail_point_guard = test_runner .internal_client .enable_failpoint(self.fail_point.clone(), Some(selection_criteria)) .await .unwrap(); - test_runner.fail_point_guards.push(fail_point_guard); + test_runner + .fail_point_guards + .write() + .await + .push(fail_point_guard); } .boxed() } @@ -1282,7 +1430,7 @@ pub(super) struct AssertCollectionExists { impl TestOperation for AssertCollectionExists { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { let db = test_runner.internal_client.database(&self.database_name); @@ -1303,7 +1451,7 @@ pub(super) struct AssertCollectionNotExists { impl TestOperation for AssertCollectionNotExists { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { let db = test_runner.internal_client.database(&self.database_name); @@ -1327,19 +1475,22 @@ impl TestOperation for CreateCollection { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let database = test_runner.get_database(id).clone(); + let database = test_runner.get_database(id).await; if let Some(session_id) = &self.session { - database - .create_collection_with_session( - &self.collection, - self.options.clone(), - test_runner.get_mut_session(session_id), - ) - .await?; + with_mut_session!(test_runner, session_id, |session| async { + database + .create_collection_with_session( + &self.collection, + self.options.clone(), + session, + ) + .await + }) + .await?; } else { database .create_collection(&self.collection, self.options.clone()) @@ -1364,19 +1515,19 @@ impl TestOperation for DropCollection { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let database = test_runner.entities.get(id).unwrap().as_database(); + let database = test_runner.get_database(id).await; let collection = database.collection::(&self.collection).clone(); if let Some(session_id) = &self.session { - collection - .drop_with_session( - self.options.clone(), - test_runner.get_mut_session(session_id), - ) - .await?; + with_mut_session!(test_runner, session_id, |session| async { + collection + .drop_with_session(self.options.clone(), session) + .await + }) + .await?; } else { collection.drop(self.options.clone()).await?; } @@ -1404,7 +1555,7 @@ impl TestOperation for RunCommand { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { let mut command = self.command.clone(); @@ -1415,12 +1566,14 @@ impl TestOperation for RunCommand { command.insert("writeConcern", write_concern.clone()); } - let db = test_runner.get_database(id).clone(); + let db = test_runner.get_database(id).await; let result = match &self.session { Some(session_id) => { - let session = test_runner.get_mut_session(session_id); - db.run_command_with_session(command, self.read_preference.clone(), session) - .await? + with_mut_session!(test_runner, session_id, |session| async { + db.run_command_with_session(command, self.read_preference.clone(), session) + .await + }) + .await? } None => { db.run_command(command, self.read_preference.clone()) @@ -1442,11 +1595,13 @@ impl TestOperation for EndSession { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let session = test_runner.get_mut_session(id).client_session.take(); - drop(session); + with_mut_session!(test_runner, id, |session| async { + session.client_session.take(); + }) + .await; runtime::delay_for(Duration::from_secs(1)).await; Ok(None) } @@ -1464,17 +1619,20 @@ pub(super) struct AssertSessionTransactionState { impl TestOperation for AssertSessionTransactionState { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - let session: &ClientSession = test_runner.get_session(&self.session); - let session_state = match &session.transaction.state { - TransactionState::None => "none", - TransactionState::Starting => "starting", - TransactionState::InProgress => "inprogress", - TransactionState::Committed { data_committed: _ } => "committed", - TransactionState::Aborted => "aborted", - }; + let session_state = + with_mut_session!(test_runner, self.session.as_str(), |session| async { + match &session.transaction.state { + TransactionState::None => "none", + TransactionState::Starting => "starting", + TransactionState::InProgress => "inprogress", + TransactionState::Committed { data_committed: _ } => "committed", + TransactionState::Aborted => "aborted", + } + }) + .await; assert_eq!(session_state, self.state); } .boxed() @@ -1490,14 +1648,15 @@ pub(super) struct AssertSessionPinned { impl TestOperation for AssertSessionPinned { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - assert!(test_runner - .get_session(&self.session) - .transaction - .pinned_mongos() - .is_some()); + let is_pinned = + with_mut_session!(test_runner, self.session.as_str(), |session| async { + session.transaction.pinned_mongos().is_some() + }) + .await; + assert!(is_pinned); } .boxed() } @@ -1512,14 +1671,14 @@ pub(super) struct AssertSessionUnpinned { impl TestOperation for AssertSessionUnpinned { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - assert!(test_runner - .get_session(&self.session) - .transaction - .pinned_mongos() - .is_none()); + let is_pinned = with_mut_session!(test_runner, self.session.as_str(), |session| { + async move { session.transaction.pinned_mongos().is_some() } + }) + .await; + assert!(!is_pinned); } .boxed() } @@ -1534,10 +1693,11 @@ pub(super) struct AssertDifferentLsidOnLastTwoCommands { impl TestOperation for AssertDifferentLsidOnLastTwoCommands { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - let client = test_runner.entities.get(&self.client).unwrap().as_client(); + let entities = test_runner.entities.read().await; + let client = entities.get(&self.client).unwrap().as_client(); let events = client.get_all_command_started_events(); let lsid1 = events[events.len() - 1].command.get("lsid").unwrap(); @@ -1557,10 +1717,11 @@ pub(super) struct AssertSameLsidOnLastTwoCommands { impl TestOperation for AssertSameLsidOnLastTwoCommands { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - let client = test_runner.entities.get(&self.client).unwrap().as_client(); + let entities = test_runner.entities.read().await; + let client = entities.get(&self.client).unwrap().as_client(); client.sync_workers().await; let events = client.get_all_command_started_events(); @@ -1581,11 +1742,14 @@ pub(super) struct AssertSessionDirty { impl TestOperation for AssertSessionDirty { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - let session: &ClientSession = test_runner.get_session(&self.session); - assert!(session.is_dirty()); + let dirty = with_mut_session!(test_runner, self.session.as_str(), |session| { + async move { session.is_dirty() }.boxed() + }) + .await; + assert!(dirty); } .boxed() } @@ -1600,11 +1764,14 @@ pub(super) struct AssertSessionNotDirty { impl TestOperation for AssertSessionNotDirty { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - let session: &ClientSession = test_runner.get_session(&self.session); - assert!(!session.is_dirty()); + let dirty = with_mut_session!(test_runner, self.session.as_str(), |session| { + async move { session.is_dirty() } + }) + .await; + assert!(!dirty); } .boxed() } @@ -1618,11 +1785,13 @@ impl TestOperation for StartTransaction { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let session: &mut ClientSession = test_runner.get_mut_session(id); - session.start_transaction(None).await?; + with_mut_session!(test_runner, id, |session| { + async move { session.start_transaction(None).await } + }) + .await?; Ok(None) } .boxed() @@ -1637,11 +1806,13 @@ impl TestOperation for CommitTransaction { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let session: &mut ClientSession = test_runner.get_mut_session(id); - session.commit_transaction().await?; + with_mut_session!(test_runner, id, |session| { + async move { session.commit_transaction().await } + }) + .await?; Ok(None) } .boxed() @@ -1656,11 +1827,13 @@ impl TestOperation for AbortTransaction { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let session: &mut ClientSession = test_runner.get_mut_session(id); - session.abort_transaction().await?; + with_mut_session!(test_runner, id, |session| { + async move { session.abort_transaction().await } + }) + .await?; Ok(None) } .boxed() @@ -1679,7 +1852,7 @@ impl TestOperation for CreateIndex { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { let options = IndexOptions::builder().name(self.name.clone()).build(); @@ -1688,14 +1861,18 @@ impl TestOperation for CreateIndex { .options(options) .build(); - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let name = match self.session { - Some(ref session) => { - let session = test_runner.get_mut_session(session); - collection - .create_index_with_session(index, None, session) - .await? - .index_name + Some(ref session_id) => { + with_mut_session!(test_runner, session_id, |session| { + async move { + collection + .create_index_with_session(index, None, session) + .await + .map(|model| model.index_name) + } + }) + .await? } None => collection.create_index(index, None).await?.index_name, }; @@ -1716,19 +1893,23 @@ impl TestOperation for ListIndexes { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let indexes: Vec = match self.session { Some(ref session) => { - let session = test_runner.get_mut_session(session); - collection - .list_indexes_with_session(self.options.clone(), session) - .await? - .stream(session) - .try_collect() - .await? + with_mut_session!(test_runner, session, |session| { + async { + collection + .list_indexes_with_session(self.options.clone(), session) + .await? + .stream(session) + .try_collect() + .await + } + }) + .await? } None => { collection @@ -1757,14 +1938,16 @@ impl TestOperation for ListIndexNames { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let collection = test_runner.get_collection(id).clone(); + let collection = test_runner.get_collection(id).await; let names = match self.session { Some(ref session) => { - let session = test_runner.get_mut_session(session); - collection.list_index_names_with_session(session).await? + with_mut_session!(test_runner, session.as_str(), |s| { + async move { collection.list_index_names_with_session(s).await } + }) + .await? } None => collection.list_index_names().await?, }; @@ -1785,7 +1968,7 @@ pub(super) struct AssertIndexExists { impl TestOperation for AssertIndexExists { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { let coll = test_runner @@ -1810,7 +1993,7 @@ pub(super) struct AssertIndexNotExists { impl TestOperation for AssertIndexNotExists { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { let coll = test_runner @@ -1835,20 +2018,35 @@ impl TestOperation for IterateUntilDocumentOrError { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { // A `SessionCursor` also requires a `&mut Session`, which would cause conflicting // borrows, so take the cursor from the map and return it after execution instead. - let mut cursor = test_runner.entities.remove(id).unwrap().into_cursor(); + let mut cursor = test_runner + .entities + .write() + .await + .remove(id) + .unwrap() + .into_cursor(); let next = match &mut cursor { TestCursor::Normal(cursor) => { let mut cursor = cursor.lock().await; cursor.next().await } TestCursor::Session { cursor, session_id } => { - let session = test_runner.get_mut_session(session_id); - cursor.next(session).await + cursor + .next( + test_runner + .entities + .write() + .await + .get_mut(session_id) + .unwrap() + .as_mut_session_entity(), + ) + .await } TestCursor::ChangeStream(stream) => { let mut stream = stream.lock().await; @@ -1863,6 +2061,8 @@ impl TestOperation for IterateUntilDocumentOrError { }; test_runner .entities + .write() + .await .insert(id.to_string(), Entity::Cursor(cursor)); next.transpose() .map(|opt| opt.map(|doc| Entity::Bson(Bson::Document(doc)))) @@ -1883,12 +2083,14 @@ impl TestOperation for Close { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let cursor = test_runner.get_mut_find_cursor(id); + let mut entities = test_runner.entities.write().await; + let cursor = entities.get_mut(id).unwrap().as_mut_cursor(); let rx = cursor.make_kill_watcher().await; *cursor = TestCursor::Closed; + drop(entities); let _ = rx.await; Ok(None) } @@ -1906,10 +2108,10 @@ pub(super) struct AssertNumberConnectionsCheckedOut { impl TestOperation for AssertNumberConnectionsCheckedOut { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { - let client = test_runner.get_client(&self.client); + let client = test_runner.get_client(&self.client).await; client.sync_workers().await; assert_eq!(client.connections_checked_out(), self.connections); } @@ -1929,10 +2131,11 @@ impl TestOperation for CreateChangeStream { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let target = test_runner.entities.get(id).unwrap(); + let entities = test_runner.entities.read().await; + let target = entities.get(id).unwrap(); let stream = match target { Entity::Client(ce) => { ce.watch(self.pipeline.clone(), self.options.clone()) @@ -1966,10 +2169,10 @@ impl TestOperation for RenameCollection { fn execute_entity_operation<'a>( &'a self, id: &'a str, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, Result>> { async move { - let target = test_runner.get_collection(id); + let target = test_runner.get_collection(id).await; let ns = target.namespace(); let mut to_ns = ns.clone(); to_ns.coll = self.to.clone(); @@ -1986,19 +2189,19 @@ impl TestOperation for RenameCollection { } macro_rules! report_error { - ($loop:expr, $error:expr, $test_runner:expr) => {{ + ($loop:expr, $error:expr, $entities:expr) => {{ let error = format!("{:?}", $error); report_error_or_failure!( $loop.store_errors_as_entity, $loop.store_failures_as_entity, error, - $test_runner + $entities ); }}; } macro_rules! report_failure { - ($loop:expr, $name:expr, $actual:expr, $expected:expr, $test_runner:expr) => {{ + ($loop:expr, $name:expr, $actual:expr, $expected:expr, $entities:expr) => {{ let error = format!( "{} error: got {:?}, expected {:?}", $name, $actual, $expected @@ -2007,13 +2210,13 @@ macro_rules! report_failure { $loop.store_failures_as_entity, $loop.store_errors_as_entity, error, - $test_runner + $entities ); }}; } macro_rules! report_error_or_failure { - ($first_option:expr, $second_option:expr, $error:expr, $test_runner:expr) => {{ + ($first_option:expr, $second_option:expr, $error:expr, $entities:expr) => {{ let id = if let Some(ref id) = $first_option { id } else if let Some(ref id) = $second_option { @@ -2025,7 +2228,7 @@ macro_rules! report_error_or_failure { ); }; - match $test_runner.entities.get_mut(id) { + match $entities.get_mut(id) { Some(Entity::Bson(Bson::Array(array))) => { let doc = doc! { "error": $error, @@ -2054,24 +2257,24 @@ pub(super) struct Loop { impl TestOperation for Loop { fn execute_test_runner_operation<'a>( &'a self, - test_runner: &'a mut TestRunner, + test_runner: &'a TestRunner, ) -> BoxFuture<'a, ()> { async move { if let Some(id) = &self.store_errors_as_entity { let errors = Bson::Array(vec![]); - test_runner.insert_entity(id, errors.into()); + test_runner.insert_entity(id, errors).await; } if let Some(id) = &self.store_failures_as_entity { let failures = Bson::Array(vec![]); - test_runner.insert_entity(id, failures.into()); + test_runner.insert_entity(id, failures).await; } if let Some(id) = &self.store_successes_as_entity { let successes = Bson::Int64(0); - test_runner.insert_entity(id, successes.into()); + test_runner.insert_entity(id, successes).await; } if let Some(id) = &self.store_iterations_as_entity { let iterations = Bson::Int64(0); - test_runner.insert_entity(id, iterations.into()); + test_runner.insert_entity(id, iterations).await; } let continue_looping = Arc::new(AtomicBool::new(true)); @@ -2091,6 +2294,8 @@ impl TestOperation for Loop { operation.execute_entity_operation(id, test_runner).await } }; + + let mut entities = test_runner.entities.write().await; match (result, &operation.expectation) { ( Ok(entity), @@ -2109,7 +2314,7 @@ impl TestOperation for Loop { &operation.name, entity, expected_value, - test_runner + &mut entities ); } }; @@ -2117,25 +2322,25 @@ impl TestOperation for Loop { actual_value, expected_value, operation.returns_root_documents(), - Some(&test_runner.entities), + Some(&entities), ) .is_ok() { - self.report_success(test_runner); + self.report_success(&mut entities); } else { report_failure!( self, &operation.name, actual_value, expected_value, - test_runner + &mut entities ); } } else { - self.report_success(test_runner); + self.report_success(&mut entities); } if let (Some(entity), Some(id)) = (entity, save_as_entity) { - test_runner.insert_entity(id, entity); + entities.insert(id.to_string(), entity); } } (Ok(result), Expectation::Error(ref expected_error)) => { @@ -2144,32 +2349,30 @@ impl TestOperation for Loop { &operation.name, result, expected_error, - test_runner + &mut entities ); } (Ok(_), Expectation::Ignore) => { - self.report_success(test_runner); + self.report_success(&mut entities); } (Err(error), Expectation::Error(ref expected_error)) => { - match catch_unwind(AssertUnwindSafe(|| { - expected_error.verify_result(&error); - })) { - Ok(_) => self.report_success(test_runner), - Err(_) => report_failure!( - self, - &operation.name, - error, - expected_error, - test_runner + match expected_error.verify_result(&error, operation.name.as_str()) { + Ok(_) => self.report_success(&mut entities), + Err(e) => report_error_or_failure!( + self.store_failures_as_entity, + self.store_errors_as_entity, + e, + &mut entities ), } } (Err(error), Expectation::Result { .. } | Expectation::Ignore) => { - report_error!(self, error, test_runner); + report_error!(self, error, &mut entities); } } } - self.report_iteration(test_runner); + let mut entities = test_runner.entities.write().await; + self.report_iteration(&mut entities); } } .boxed() @@ -2177,23 +2380,239 @@ impl TestOperation for Loop { } impl Loop { - fn report_iteration(&self, test_runner: &mut TestRunner) { - Self::increment_count(self.store_iterations_as_entity.as_ref(), test_runner); + fn report_iteration(&self, entities: &mut EntityMap) { + Self::increment_count(self.store_iterations_as_entity.as_ref(), entities) } - fn report_success(&self, test_runner: &mut TestRunner) { - Self::increment_count(self.store_successes_as_entity.as_ref(), test_runner); + fn report_success(&self, test_runner: &mut EntityMap) { + Self::increment_count(self.store_successes_as_entity.as_ref(), test_runner) } - fn increment_count(id: Option<&String>, test_runner: &mut TestRunner) { + fn increment_count(id: Option<&String>, entities: &mut EntityMap) { if let Some(id) = id { - match test_runner.entities.get_mut(id) { + match entities.get_mut(id) { Some(Entity::Bson(Bson::Int64(count))) => *count += 1, _ => panic!("Test runner should contain a Bson::Int64 entity for {}", id), } } } } +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(super) struct RunOnThread { + thread: String, + operation: Arc, +} + +impl TestOperation for RunOnThread { + fn execute_test_runner_operation<'a>( + &'a self, + test_runner: &'a TestRunner, + ) -> BoxFuture<'a, ()> { + async { + let thread = test_runner.get_thread(self.thread.as_str()).await; + thread.run_operation(self.operation.clone()); + } + .boxed() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(super) struct WaitForThread { + thread: String, +} + +impl TestOperation for WaitForThread { + fn execute_test_runner_operation<'a>( + &'a self, + test_runner: &'a TestRunner, + ) -> BoxFuture<'a, ()> { + async { + let thread = test_runner.get_thread(self.thread.as_str()).await; + assert!( + thread.wait().await, + "thread {:?} did not exit successfully", + self.thread + ); + } + .boxed() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(super) struct AssertEventCount { + client: String, + event: ExpectedEvent, + count: usize, +} + +impl TestOperation for AssertEventCount { + fn execute_test_runner_operation<'a>( + &'a self, + test_runner: &'a TestRunner, + ) -> BoxFuture<'a, ()> { + async { + let client = test_runner.get_client(self.client.as_str()).await; + let entities = test_runner.entities.clone(); + let actual_count = client + .observer + .lock() + .await + .matching_event_count(&self.event, entities) + .await; + assert_eq!(actual_count, self.count); + } + .boxed() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(super) struct WaitForEvent { + client: String, + event: ExpectedEvent, + count: usize, +} + +impl TestOperation for WaitForEvent { + fn execute_test_runner_operation<'a>( + &'a self, + test_runner: &'a TestRunner, + ) -> BoxFuture<'a, ()> { + async { + let client = test_runner.get_client(self.client.as_str()).await; + let entities = test_runner.entities.clone(); + client + .observer + .lock() + .await + .wait_for_matching_events(&self.event, self.count, entities) + .await + .unwrap(); + } + .boxed() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(super) struct RecordTopologyDescription { + id: String, + client: String, +} + +impl TestOperation for RecordTopologyDescription { + fn execute_test_runner_operation<'a>( + &'a self, + test_runner: &'a TestRunner, + ) -> BoxFuture<'a, ()> { + async { + let client = test_runner.get_client(&self.client).await; + let description = client.topology_description(); + test_runner.insert_entity(&self.id, description).await; + } + .boxed() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(super) struct AssertTopologyType { + topology_description: String, + topology_type: TopologyType, +} + +impl TestOperation for AssertTopologyType { + fn execute_test_runner_operation<'a>( + &'a self, + test_runner: &'a TestRunner, + ) -> BoxFuture<'a, ()> { + async { + let td = test_runner + .get_topology_description(&self.topology_description) + .await; + assert_eq!(td.topology_type, self.topology_type); + } + .boxed() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(super) struct WaitForPrimaryChange { + client: String, + prior_topology_description: String, + #[serde(rename = "timeoutMS")] + timeout_ms: Option, +} + +impl TestOperation for WaitForPrimaryChange { + fn execute_test_runner_operation<'a>( + &'a self, + test_runner: &'a TestRunner, + ) -> BoxFuture<'a, ()> { + async move { + let client = test_runner.get_client(&self.client).await; + let td = test_runner + .get_topology_description(&self.prior_topology_description) + .await; + let old_primary = td.servers_with_type(&[ServerType::RsPrimary]).next(); + let timeout = Duration::from_millis(self.timeout_ms.unwrap_or(10_000)); + + runtime::timeout(timeout, async { + let mut watcher = client.topology().watch(); + + loop { + let latest = watcher.observe_latest(); + if let Some(primary) = latest.description.primary() { + if Some(primary) != old_primary { + return; + } + } + watcher.wait_for_update(Duration::MAX).await; + } + }) + .await + .unwrap(); + } + .boxed() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(super) struct Wait { + ms: u64, +} + +impl TestOperation for Wait { + fn execute_test_runner_operation<'a>( + &'a self, + _test_runner: &'a TestRunner, + ) -> BoxFuture<'a, ()> { + runtime::delay_for(Duration::from_millis(self.ms)).boxed() + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(super) struct CreateEntities { + entities: Vec, +} + +impl TestOperation for CreateEntities { + fn execute_test_runner_operation<'a>( + &'a self, + test_runner: &'a TestRunner, + ) -> BoxFuture<'a, ()> { + test_runner + .populate_entity_map(&self.entities[..], "createEntities operation") + .boxed() + } +} + #[derive(Debug, Deserialize)] pub(super) struct UnimplementedOperation; diff --git a/src/test/spec/unified_runner/test_event.rs b/src/test/spec/unified_runner/test_event.rs index b1216eec1..933a8bab4 100644 --- a/src/test/spec/unified_runner/test_event.rs +++ b/src/test/spec/unified_runner/test_event.rs @@ -2,20 +2,21 @@ use crate::{ bson::Document, event::cmap::{ConnectionCheckoutFailedReason, ConnectionClosedReason}, test::{CmapEvent, CommandEvent, Event}, + ServerType, }; use serde::Deserialize; -#[derive(Debug, Deserialize, PartialEq)] +#[derive(Debug, Deserialize)] #[serde(untagged, deny_unknown_fields, rename_all = "camelCase")] -pub enum ExpectedEvent { +pub(crate) enum ExpectedEvent { Cmap(ExpectedCmapEvent), Command(ExpectedCommandEvent), - Sdam, + Sdam(Box), } -#[derive(Debug, Deserialize, PartialEq)] +#[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] -pub enum ExpectedCommandEvent { +pub(crate) enum ExpectedCommandEvent { #[serde(rename = "commandStartedEvent", rename_all = "camelCase")] Started { command_name: Option, @@ -39,9 +40,9 @@ pub enum ExpectedCommandEvent { }, } -#[derive(Debug, Deserialize, PartialEq)] +#[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] -pub enum ExpectedCmapEvent { +pub(crate) enum ExpectedCmapEvent { #[serde(rename = "poolCreatedEvent")] PoolCreated {}, #[serde(rename = "poolReadyEvent")] @@ -70,8 +71,25 @@ pub enum ExpectedCmapEvent { ConnectionCheckedIn {}, } -#[derive(Copy, Clone, PartialEq, Eq, Debug, Deserialize)] -pub enum ObserveEvent { +#[derive(Debug, Deserialize)] +pub(crate) enum ExpectedSdamEvent { + #[serde(rename = "serverDescriptionChangedEvent", rename_all = "camelCase")] + ServerDescriptionChanged { + #[allow(unused)] + previous_description: Option, + new_description: Option, + }, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct TestServerDescription { + #[serde(rename = "type")] + pub(crate) server_type: Option, +} + +#[derive(Copy, Clone, Debug, Deserialize)] +pub(crate) enum ObserveEvent { #[serde(rename = "commandStartedEvent")] CommandStarted, #[serde(rename = "commandSucceededEvent")] @@ -100,10 +118,12 @@ pub enum ObserveEvent { ConnectionCheckedOut, #[serde(rename = "connectionCheckedInEvent")] ConnectionCheckedIn, + #[serde(rename = "serverDescriptionChangedEvent")] + ServerDescriptionChanged, } impl ObserveEvent { - pub fn matches(&self, event: &Event) -> bool { + pub(crate) fn matches(&self, event: &Event) -> bool { #[allow(clippy::match_like_matches_macro)] match (self, event) { (Self::CommandStarted, Event::Command(CommandEvent::Started(_))) => true, diff --git a/src/test/spec/unified_runner/test_file.rs b/src/test/spec/unified_runner/test_file.rs index a16dcfa29..c1b3bd438 100644 --- a/src/test/spec/unified_runner/test_file.rs +++ b/src/test/spec/unified_runner/test_file.rs @@ -1,7 +1,8 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use semver::{Version, VersionReq}; use serde::{Deserialize, Deserializer}; +use tokio::sync::oneshot; use super::{results_match, ExpectedEvent, ObserveEvent, Operation}; @@ -25,15 +26,14 @@ use crate::{ #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct TestFile { - pub description: String, +pub(crate) struct TestFile { + pub(crate) description: String, #[serde(deserialize_with = "deserialize_schema_version")] - pub schema_version: Version, - pub run_on_requirements: Option>, - pub allow_multiple_mongoses: Option, - pub create_entities: Option>, - pub initial_data: Option>, - pub tests: Vec, + pub(crate) schema_version: Version, + pub(crate) run_on_requirements: Option>, + pub(crate) create_entities: Option>, + pub(crate) initial_data: Option>, + pub(crate) tests: Vec, // We don't need to use this field, but it needs to be included during deserialization so that // we can use the deny_unknown_fields tag. #[serde(rename = "_yamlAnchors")] @@ -58,7 +58,7 @@ where #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct RunOnRequirement { +pub(crate) struct RunOnRequirement { min_server_version: Option, max_server_version: Option, topologies: Option>, @@ -69,7 +69,7 @@ pub struct RunOnRequirement { #[derive(Debug, Deserialize, PartialEq)] #[serde(rename_all = "lowercase", deny_unknown_fields)] -pub enum Topology { +pub(crate) enum Topology { Single, ReplicaSet, Sharded, @@ -80,7 +80,7 @@ pub enum Topology { } impl RunOnRequirement { - pub async fn can_run_on(&self, client: &TestClient) -> bool { + pub(crate) async fn can_run_on(&self, client: &TestClient) -> bool { if let Some(ref min_version) = self.min_server_version { let req = VersionReq::parse(&format!(">= {}", &min_version)).unwrap(); if !req.matches(&client.server_version) { @@ -126,12 +126,13 @@ impl RunOnRequirement { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub enum TestFileEntity { +pub(crate) enum TestFileEntity { Client(Client), Database(Database), Collection(Collection), Session(Session), Bucket(Bucket), + Thread(Thread), } #[derive(Debug, Deserialize)] @@ -143,20 +144,20 @@ pub struct StoreEventsAsEntity { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct Client { - pub id: String, - pub uri_options: Option, - pub use_multiple_mongoses: Option, - pub observe_events: Option>, - pub ignore_command_monitoring_events: Option>, +pub(crate) struct Client { + pub(crate) id: String, + pub(crate) uri_options: Option, + pub(crate) use_multiple_mongoses: Option, + pub(crate) observe_events: Option>, + pub(crate) ignore_command_monitoring_events: Option>, #[serde(default)] - pub observe_sensitive_commands: Option, + pub(crate) observe_sensitive_commands: Option, #[serde(default, deserialize_with = "deserialize_server_api_test_format")] - pub server_api: Option, - pub store_events_as_entities: Option>, + pub(crate) server_api: Option, + pub(crate) store_events_as_entities: Option>, } -pub fn deserialize_server_api_test_format<'de, D>( +pub(crate) fn deserialize_server_api_test_format<'de, D>( deserializer: D, ) -> std::result::Result, D::Error> where @@ -178,7 +179,7 @@ where })) } -pub fn merge_uri_options(given_uri: &str, uri_options: Option<&Document>) -> String { +pub(crate) fn merge_uri_options(given_uri: &str, uri_options: Option<&Document>) -> String { let uri_options = match uri_options { Some(opts) => opts, None => return given_uri.to_string(), @@ -222,56 +223,64 @@ pub fn merge_uri_options(given_uri: &str, uri_options: Option<&Document>) -> Str #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct Database { - pub id: String, - pub client: String, - pub database_name: String, - pub database_options: Option, +pub(crate) struct Database { + pub(crate) id: String, + pub(crate) client: String, + pub(crate) database_name: String, + pub(crate) database_options: Option, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct Collection { - pub id: String, - pub database: String, - pub collection_name: String, - pub collection_options: Option, +pub(crate) struct Collection { + pub(crate) id: String, + pub(crate) database: String, + pub(crate) collection_name: String, + pub(crate) collection_options: Option, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct Session { - pub id: String, - pub client: String, - pub session_options: Option, +pub(crate) struct Session { + pub(crate) id: String, + pub(crate) client: String, + pub(crate) session_options: Option, } +// TODO: RUST-527 remove the unused annotation #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct Bucket { - pub id: String, - pub database: String, - pub bucket_options: Option, +#[allow(unused)] +pub(crate) struct Bucket { + pub(crate) id: String, + pub(crate) database: String, + pub(crate) bucket_options: Option, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct Stream { - pub id: String, - pub hex_bytes: String, +pub(crate) struct Thread { + pub(crate) id: String, +} + +/// Messages used for communicating with test runner "threads". +#[derive(Debug)] +pub(crate) enum ThreadMessage { + ExecuteOperation(Arc), + Stop(oneshot::Sender<()>), } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct CollectionOrDatabaseOptions { - pub read_concern: Option, +pub(crate) struct CollectionOrDatabaseOptions { + pub(crate) read_concern: Option, #[serde(rename = "readPreference")] - pub selection_criteria: Option, - pub write_concern: Option, + pub(crate) selection_criteria: Option, + pub(crate) write_concern: Option, } impl CollectionOrDatabaseOptions { - pub fn as_database_options(&self) -> DatabaseOptions { + pub(crate) fn as_database_options(&self) -> DatabaseOptions { DatabaseOptions { read_concern: self.read_concern.clone(), selection_criteria: self.selection_criteria.clone(), @@ -279,7 +288,7 @@ impl CollectionOrDatabaseOptions { } } - pub fn as_collection_options(&self) -> CollectionOptions { + pub(crate) fn as_collection_options(&self) -> CollectionOptions { CollectionOptions { read_concern: self.read_concern.clone(), selection_criteria: self.selection_criteria.clone(), @@ -290,35 +299,35 @@ impl CollectionOrDatabaseOptions { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct CollectionData { - pub collection_name: String, - pub database_name: String, - pub documents: Vec, +pub(crate) struct CollectionData { + pub(crate) collection_name: String, + pub(crate) database_name: String, + pub(crate) documents: Vec, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct TestCase { - pub description: String, - pub run_on_requirements: Option>, - pub skip_reason: Option, - pub operations: Vec, - pub expect_events: Option>, - pub outcome: Option>, +pub(crate) struct TestCase { + pub(crate) description: String, + pub(crate) run_on_requirements: Option>, + pub(crate) skip_reason: Option, + pub(crate) operations: Vec, + pub(crate) expect_events: Option>, + pub(crate) outcome: Option>, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct ExpectedEvents { - pub client: String, - pub events: Vec, - pub event_type: Option, - pub ignore_extra_events: Option, +pub(crate) struct ExpectedEvents { + pub(crate) client: String, + pub(crate) events: Vec, + pub(crate) event_type: Option, + pub(crate) ignore_extra_events: Option, } #[derive(Debug, Copy, Clone, PartialEq, Eq, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub enum ExpectedEventType { +pub(crate) enum ExpectedEventType { Command, Cmap, // TODO RUST-1055 Remove this when connection usage is serialized. @@ -328,64 +337,117 @@ pub enum ExpectedEventType { #[derive(Debug, Copy, Clone, PartialEq, Eq, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub enum EventMatch { +pub(crate) enum EventMatch { Exact, Prefix, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct ExpectError { - pub is_error: Option, - pub is_client_error: Option, - pub error_contains: Option, - pub error_code: Option, - pub error_code_name: Option, - pub error_labels_contain: Option>, - pub error_labels_omit: Option>, - pub expect_result: Option, +pub(crate) struct ExpectError { + #[allow(unused)] + pub(crate) is_error: Option, + pub(crate) is_client_error: Option, + pub(crate) error_contains: Option, + pub(crate) error_code: Option, + pub(crate) error_code_name: Option, + pub(crate) error_labels_contain: Option>, + pub(crate) error_labels_omit: Option>, + pub(crate) expect_result: Option, } impl ExpectError { - pub fn verify_result(&self, error: &Error) { + pub(crate) fn verify_result( + &self, + error: &Error, + description: impl AsRef, + ) -> std::result::Result<(), String> { + let description = description.as_ref(); + if let Some(is_client_error) = self.is_client_error { - assert_eq!(is_client_error, !error.is_server_error()); + if is_client_error != !error.is_server_error() { + return Err(format!( + "{}: expected client error but got {:?}", + description, error + )); + } } if let Some(error_contains) = &self.error_contains { match &error.message() { - Some(msg) => assert!(msg.contains(error_contains)), - None => panic!("{} should include message field", error), + Some(msg) if msg.contains(error_contains) => (), + _ => { + return Err(format!( + "{}: \"{}\" should include message field", + description, error + )) + } } } if let Some(error_code) = self.error_code { match &error.code() { - Some(code) => assert_eq!( - *code, error_code, - "error {:?} did not match expected error code {}", - error, error_code - ), - None => panic!("{} should include code", error), + Some(code) => { + if code != &error_code { + return Err(format!( + "{}: error code {} ({:?}) did not match expected error code {}", + description, + code, + error.code_name(), + error_code + )); + } + } + None => { + return Err(format!( + "{}: {:?} was expected to include code {} but had no code", + description, error, error_code + )) + } } } - if let Some(error_code_name) = &self.error_code_name { - match &error.code_name() { - Some(name) => assert_eq!(&error_code_name, name), - None => panic!("{} should include code name", error), + + if let Some(expected_code_name) = &self.error_code_name { + match error.code_name() { + Some(name) => { + if name != expected_code_name { + return Err(format!( + "{}: error code name \"{}\" did not match expected error code name \ + \"{}\"", + description, name, expected_code_name, + )); + } + } + None => { + return Err(format!( + "{}: {:?} was expected to include code name \"{}\" but had no code name", + description, error, expected_code_name + )) + } } } if let Some(error_labels_contain) = &self.error_labels_contain { for label in error_labels_contain { - assert!(error.labels().contains(label)); + if !error.contains_label(label) { + return Err(format!( + "{}: expected {:?} to contain label \"{}\"", + description, error, label + )); + } } } if let Some(error_labels_omit) = &self.error_labels_omit { for label in error_labels_omit { - assert!(!error.labels().contains(label)); + if error.contains_label(label) { + return Err(format!( + "{}: expected {:?} to omit label \"{}\"", + description, error, label + )); + } } } if self.expect_result.is_some() { // TODO RUST-260: match against partial results } + Ok(()) } } diff --git a/src/test/spec/unified_runner/test_runner.rs b/src/test/spec/unified_runner/test_runner.rs index 7122658c8..0d8acd644 100644 --- a/src/test/spec/unified_runner/test_runner.rs +++ b/src/test/spec/unified_runner/test_runner.rs @@ -1,7 +1,9 @@ -use std::{collections::HashMap, fs::File, io::BufWriter, sync::Arc, time::Duration}; +use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; use futures::TryStreamExt; +use tokio::sync::{mpsc, RwLock}; + use crate::{ bson::{doc, Document}, client::options::ClientOptions, @@ -15,12 +17,12 @@ use crate::{ SelectionCriteria, }, runtime, + sdam::TopologyDescription, test::{ log_uncaptured, spec::unified_runner::{ entity::EventList, - matcher::{events_match, results_match}, - operation::{Expectation, OperationObject}, + matcher::events_match, test_file::{ExpectedEventType, TestCase, TestFile}, }, update_options_for_testing, @@ -40,12 +42,14 @@ use crate::{ }; use super::{ + entity::ThreadEntity, + file_level_log, merge_uri_options, + test_file::ThreadMessage, ClientEntity, CollectionData, Entity, SessionEntity, - TestCursor, TestFileEntity, }; @@ -60,36 +64,48 @@ const SKIPPED_OPERATIONS: &[&str] = &[ "watch", ]; -pub type EntityMap = HashMap; +pub(crate) type EntityMap = HashMap; -pub struct TestRunner { - pub internal_client: TestClient, - pub entities: EntityMap, - pub fail_point_guards: Vec, +#[derive(Clone)] +pub(crate) struct TestRunner { + pub(crate) internal_client: TestClient, + pub(crate) entities: Arc>, + pub(crate) fail_point_guards: Arc>>, } impl TestRunner { - pub async fn new() -> Self { + pub(crate) async fn new() -> Self { Self { internal_client: TestClient::new().await, - entities: HashMap::new(), - fail_point_guards: Vec::new(), + entities: Default::default(), + fail_point_guards: Default::default(), } } - pub async fn new_with_connection_string(connection_string: &str) -> Self { + pub(crate) async fn new_with_connection_string(connection_string: &str) -> Self { #[cfg(all(not(feature = "sync"), not(feature = "tokio-sync")))] let options = ClientOptions::parse(connection_string).await.unwrap(); #[cfg(any(feature = "sync", feature = "tokio-sync"))] let options = ClientOptions::parse(connection_string).unwrap(); Self { internal_client: TestClient::with_options(Some(options)).await, - entities: HashMap::new(), - fail_point_guards: Vec::new(), + entities: Arc::new(RwLock::new(EntityMap::new())), + fail_point_guards: Arc::new(RwLock::new(Vec::new())), } } - pub async fn run_test(&mut self, test_file: TestFile, pred: impl Fn(&TestCase) -> bool) { + pub(crate) async fn run_test( + &self, + path: impl Into>, + test_file: TestFile, + pred: impl Fn(&TestCase) -> bool, + ) { + let path = path.into(); + let file_title = path + .as_ref() + .map(|p| p.display().to_string()) + .unwrap_or_else(|| test_file.description.clone()); + if let Some(requirements) = test_file.run_on_requirements { let mut can_run_on = false; for requirement in requirements { @@ -98,20 +114,23 @@ impl TestRunner { } } if !can_run_on { - log_uncaptured(format!( - "Skipping file {}; client topology not compatible with test", - test_file.description + file_level_log(format!( + "Skipping file {}: client topology not compatible with test", + file_title )); return; } } - log_uncaptured(format!("Running file {}", test_file.description)); + log_uncaptured(format!( + "\n------------\nRunning tests from {}\n", + file_title + )); for test_case in test_file.tests { if let Some(skip_reason) = test_case.skip_reason { log_uncaptured(format!( - "Skipping test case {}: {}", + "Skipping test case {:?}: {}", &test_case.description, skip_reason )); continue; @@ -124,7 +143,7 @@ impl TestRunner { .map(|op| op.name.as_str()) { log_uncaptured(format!( - "Skipping test case {}: unsupported operation {}", + "Skipping test case {:?}: unsupported operation {}", &test_case.description, op )); continue; @@ -132,7 +151,7 @@ impl TestRunner { if !pred(&test_case) { log_uncaptured(format!( - "Skipping test case {}: predicate failed", + "Skipping test case {:?}: predicate failed", test_case.description )); continue; @@ -147,14 +166,14 @@ impl TestRunner { } if !can_run_on { log_uncaptured(format!( - "Skipping test case {}: client topology not compatible with test", + "Skipping test case {:?}: client topology not compatible with test", &test_case.description )); continue; } } - log_uncaptured(format!("Running {}", &test_case.description)); + log_uncaptured(format!("Executing {:?}", &test_case.description)); if let Some(ref initial_data) = test_file.initial_data { for data in initial_data { @@ -162,77 +181,15 @@ impl TestRunner { } } + self.entities.write().await.clear(); if let Some(ref create_entities) = test_file.create_entities { - self.populate_entity_map(create_entities).await; + self.populate_entity_map(create_entities, &test_case.description) + .await; } for operation in test_case.operations { self.sync_workers().await; - match operation.object { - OperationObject::TestRunner => { - operation.execute_test_runner_operation(self).await; - } - OperationObject::Entity(ref id) => { - let result = operation.execute_entity_operation(id, self).await; - - match &operation.expectation { - Expectation::Result { - expected_value, - save_as_entity, - } => { - let desc = &test_case.description; - let opt_entity = result.unwrap_or_else(|e| { - panic!( - "[{}] {} should succeed, but failed with the following \ - error: {}", - desc, operation.name, e - ) - }); - if expected_value.is_some() || save_as_entity.is_some() { - let entity = opt_entity.unwrap_or_else(|| { - panic!( - "[{}] {} did not return an entity", - desc, operation.name - ) - }); - if let Some(expected_bson) = expected_value { - if let Entity::Bson(actual) = &entity { - if let Err(e) = results_match( - Some(actual), - expected_bson, - operation.returns_root_documents(), - Some(&self.entities), - ) { - panic!( - "[{}] result mismatch, expected = {:#?} \ - actual = {:#?}\nmismatch detail: {}", - desc, expected_bson, actual, e - ); - } - } else { - panic!( - "[{}] Incorrect entity type returned from {}, \ - expected BSON", - desc, operation.name - ); - } - } - if let Some(id) = save_as_entity { - self.insert_entity(id, entity); - } - } - } - Expectation::Error(expect_error) => { - let error = result.expect_err(&format!( - "{}: {} should return an error", - test_case.description, operation.name - )); - expect_error.verify_result(&error); - } - Expectation::Ignore => (), - } - } - } + operation.execute(self, &test_case.description).await; // This test (in src/test/spec/json/sessions/server-support.json) runs two // operations with implicit sessions in sequence and then checks to see if they // used the same lsid. We delay for one second to ensure that the @@ -245,7 +202,8 @@ impl TestRunner { if let Some(ref events) = test_case.expect_events { for expected in events { - let entity = self.entities.get(&expected.client).unwrap(); + let entities = self.entities.read().await; + let entity = entities.get(&expected.client).unwrap(); let client = entity.as_client(); client.sync_workers().await; let event_type = expected.event_type.unwrap_or(ExpectedEventType::Command); @@ -273,7 +231,7 @@ impl TestRunner { } for (actual, expected) in actual_events.iter().zip(expected_events) { - if let Err(e) = events_match(actual, expected, Some(&self.entities)) { + if let Err(e) = events_match(actual, expected, Some(&entities)) { panic!( "event mismatch: expected = {:#?}, actual = {:#?}\nall \ expected:\n{:#?}\nall actual:\n{:#?}\nmismatch detail: {}", @@ -284,7 +242,7 @@ impl TestRunner { } } - self.fail_point_guards.clear(); + self.fail_point_guards.write().await.clear(); if let Some(ref outcome) = test_case.outcome { for expected_data in outcome { @@ -315,12 +273,10 @@ impl TestRunner { assert_eq!(expected_data.documents, actual_data); } } - - println!("{} succeeded", &test_case.description); } } - pub async fn insert_initial_data(&self, data: &CollectionData) { + pub(crate) async fn insert_initial_data(&self, data: &CollectionData) { let write_concern = WriteConcern::builder().w(Acknowledgment::Majority).build(); if !data.documents.is_empty() { @@ -352,9 +308,11 @@ impl TestRunner { } } - pub async fn populate_entity_map(&mut self, create_entities: &[TestFileEntity]) { - self.entities.clear(); - + pub(crate) async fn populate_entity_map( + &self, + create_entities: &[TestFileEntity], + description: impl AsRef, + ) { for entity in create_entities { let (id, entity) = match entity { TestFileEntity::Client(client) => { @@ -364,7 +322,8 @@ impl TestRunner { client_id: client.id.clone(), event_names: store_events_as_entity.events.clone(), }; - self.insert_entity(&store_events_as_entity.id, event_list.into()); + self.insert_entity(&store_events_as_entity.id, event_list) + .await; } } @@ -374,7 +333,6 @@ impl TestRunner { let observe_sensitive_commands = client.observe_sensitive_commands.unwrap_or(false); let server_api = client.server_api.clone().or_else(|| SERVER_API.clone()); - let observer = Arc::new(EventHandler::new()); let given_uri = if CLIENT_OPTIONS.get().await.load_balanced.unwrap_or(false) { // for serverless testing, ignore use_multiple_mongoses. @@ -391,11 +349,23 @@ impl TestRunner { &DEFAULT_URI }; let uri = merge_uri_options(given_uri, client.uri_options.as_ref()); - let mut options = ClientOptions::parse_uri(&uri, None).await.unwrap(); + let mut options = + ClientOptions::parse_uri(&uri, None) + .await + .unwrap_or_else(|e| { + panic!( + "[{}] invalid client URI: {}, error: {}", + description.as_ref(), + uri, + e + ) + }); update_options_for_testing(&mut options); - options.command_event_handler = Some(observer.clone()); - options.cmap_event_handler = Some(observer.clone()); - options.sdam_event_handler = Some(observer.clone()); + let handler = Arc::new(EventHandler::new()); + options.command_event_handler = Some(handler.clone()); + options.cmap_event_handler = Some(handler.clone()); + options.sdam_event_handler = Some(handler.clone()); + options.server_api = server_api; if let Some(use_multiple_mongoses) = client.use_multiple_mongoses { @@ -403,7 +373,8 @@ impl TestRunner { if use_multiple_mongoses { assert!( options.hosts.len() > 1, - "Test requires multiple mongos hosts" + "[{}]: Test requires multiple mongos hosts", + description.as_ref() ); } else { options.hosts.drain(1..); @@ -417,7 +388,7 @@ impl TestRunner { id, Entity::Client(ClientEntity::new( client, - observer, + handler, observe_events, ignore_command_names, observe_sensitive_commands, @@ -426,7 +397,7 @@ impl TestRunner { } TestFileEntity::Database(database) => { let id = database.id.clone(); - let client = self.entities.get(&database.client).unwrap().as_client(); + let client = self.get_client(&database.client).await; let database = if let Some(ref options) = database.database_options { let options = options.as_database_options(); client.database_with_options(&database.database_name, options) @@ -437,11 +408,7 @@ impl TestRunner { } TestFileEntity::Collection(collection) => { let id = collection.id.clone(); - let database = self - .entities - .get(&collection.database) - .unwrap() - .as_database(); + let database = self.get_database(&collection.database).await; let collection = if let Some(ref options) = collection.collection_options { let options = options.as_collection_options(); database.collection_with_options(&collection.collection_name, options) @@ -452,7 +419,7 @@ impl TestRunner { } TestFileEntity::Session(session) => { let id = session.id.clone(); - let client = self.get_client(&session.client); + let client = self.get_client(&session.client).await; let client_session = client .start_session(session.session_options.clone()) .await @@ -462,62 +429,108 @@ impl TestRunner { TestFileEntity::Bucket(_) => { panic!("GridFS not implemented"); } + TestFileEntity::Thread(thread) => { + let (sender, mut receiver) = mpsc::unbounded_channel::(); + let runner = self.clone(); + let d = description.as_ref().to_string(); + runtime::execute(async move { + while let Some(msg) = receiver.recv().await { + match msg { + ThreadMessage::ExecuteOperation(op) => { + op.execute(&runner, d.as_str()).await; + } + ThreadMessage::Stop(sender) => { + // This returns an error if the waitForThread operation stopped + // listening (e.g. due to timeout). The waitForThread operation + // will handle reporting that error, so we can ignore it here. + let _ = sender.send(()); + break; + } + } + } + }); + (thread.id.clone(), Entity::Thread(ThreadEntity { sender })) + } }; - self.insert_entity(&id, entity); + self.insert_entity(&id, entity).await; } } - pub fn insert_entity(&mut self, id: &str, entity: Entity) { - if self.entities.insert(id.to_string(), entity).is_some() { - panic!("Entity with id {} already present in entity map", id); + pub(crate) async fn insert_entity(&self, id: impl AsRef, entity: impl Into) { + if self + .entities + .write() + .await + .insert(id.as_ref().to_string(), entity.into()) + .is_some() + { + panic!( + "Entity with id {} already present in entity map", + id.as_ref() + ); } } - pub async fn sync_workers(&self) { + pub(crate) async fn sync_workers(&self) { self.internal_client.sync_workers().await; - for entity in self.entities.values() { + let entities = self.entities.read().await; + for entity in entities.values() { if let Entity::Client(client) = entity { client.sync_workers().await; } } } - pub fn get_client(&self, id: &str) -> &ClientEntity { - self.entities.get(id).unwrap().as_client() - } - - pub fn get_database(&self, id: &str) -> &Database { - self.entities.get(id).unwrap().as_database() - } - - pub fn get_collection(&self, id: &str) -> &Collection { - self.entities.get(id).unwrap().as_collection() + pub(crate) async fn get_client(&self, id: &str) -> ClientEntity { + self.entities + .read() + .await + .get(id) + .unwrap() + .as_client() + .clone() } - pub fn get_session(&self, id: &str) -> &SessionEntity { - self.entities.get(id).unwrap().as_session_entity() + pub(crate) async fn get_database(&self, id: &str) -> Database { + self.entities + .read() + .await + .get(id) + .unwrap() + .as_database() + .clone() } - pub fn get_mut_session(&mut self, id: &str) -> &mut SessionEntity { - self.entities.get_mut(id).unwrap().as_mut_session_entity() + pub(crate) async fn get_collection(&self, id: &str) -> Collection { + self.entities + .read() + .await + .get(id) + .unwrap() + .as_collection() + .clone() } - pub fn get_mut_find_cursor(&mut self, id: &str) -> &mut TestCursor { - self.entities.get_mut(id).unwrap().as_mut_cursor() + pub(crate) async fn get_thread(&self, id: &str) -> ThreadEntity { + self.entities + .read() + .await + .get(id) + .unwrap() + .as_thread() + .clone() } - pub fn write_events_list_to_file(&self, id: &str, writer: &mut BufWriter) { - let event_list_entity = match self.entities.get(id) { - Some(entity) => entity.as_event_list(), - None => return, - }; - let client = self.get_client(&event_list_entity.client_id); - let names: Vec<&str> = event_list_entity - .event_names - .iter() - .map(String::as_ref) - .collect(); - - client.write_events_list_to_file(&names, writer); + pub(crate) async fn get_topology_description( + &self, + id: impl AsRef, + ) -> TopologyDescription { + self.entities + .read() + .await + .get(id.as_ref()) + .unwrap() + .as_topology_description() + .clone() } } diff --git a/src/test/spec/v2_runner/mod.rs b/src/test/spec/v2_runner/mod.rs index f872e6759..47c61a53c 100644 --- a/src/test/spec/v2_runner/mod.rs +++ b/src/test/spec/v2_runner/mod.rs @@ -1,6 +1,6 @@ -pub mod operation; -pub mod test_event; -pub mod test_file; +pub(crate) mod operation; +pub(crate) mod test_event; +pub(crate) mod test_file; use std::{ops::Deref, sync::Arc, time::Duration}; @@ -38,7 +38,7 @@ const SKIPPED_OPERATIONS: &[&str] = &[ "mapReduce", ]; -pub async fn run_v2_test(test_file: TestFile) { +pub(crate) async fn run_v2_test(test_file: TestFile) { let internal_client = TestClient::new().await; if let Some(requirements) = test_file.run_on { diff --git a/src/test/spec/v2_runner/operation.rs b/src/test/spec/v2_runner/operation.rs index e816389e0..f6f02b250 100644 --- a/src/test/spec/v2_runner/operation.rs +++ b/src/test/spec/v2_runner/operation.rs @@ -42,7 +42,7 @@ use crate::{ IndexModel, }; -pub trait TestOperation: Debug { +pub(crate) trait TestOperation: Debug { fn execute_on_collection<'a>( &'a self, _collection: &'a Collection, @@ -75,20 +75,20 @@ pub trait TestOperation: Debug { } #[derive(Debug)] -pub struct Operation { +pub(crate) struct Operation { operation: Box, - pub name: String, - pub object: Option, - pub collection_options: Option, - pub database_options: Option, - pub error: Option, - pub result: Option, - pub session: Option, + pub(crate) name: String, + pub(crate) object: Option, + pub(crate) collection_options: Option, + pub(crate) database_options: Option, + pub(crate) error: Option, + pub(crate) result: Option, + pub(crate) session: Option, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub enum OperationObject { +pub(crate) enum OperationObject { Database, Collection, Client, @@ -101,19 +101,19 @@ pub enum OperationObject { #[derive(Debug, Deserialize)] #[serde(untagged)] -pub enum OperationResult { +pub(crate) enum OperationResult { Error(OperationError), Success(Bson), } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct OperationError { - pub error_contains: Option, - pub error_code_name: Option, - pub error_code: Option, - pub error_labels_contain: Option>, - pub error_labels_omit: Option>, +pub(crate) struct OperationError { + pub(crate) error_contains: Option, + pub(crate) error_code_name: Option, + pub(crate) error_code: Option, + pub(crate) error_labels_contain: Option>, + pub(crate) error_labels_omit: Option>, } impl<'de> Deserialize<'de> for Operation { @@ -121,18 +121,18 @@ impl<'de> Deserialize<'de> for Operation { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] struct OperationDefinition { - pub name: String, - pub object: Option, - pub collection_options: Option, - pub database_options: Option, + pub(crate) name: String, + pub(crate) object: Option, + pub(crate) collection_options: Option, + pub(crate) database_options: Option, #[serde(default = "default_arguments")] - pub arguments: Document, - pub error: Option, - pub result: Option, + pub(crate) arguments: Document, + pub(crate) error: Option, + pub(crate) result: Option, // We don't need to use this field, but it needs to be included during deserialization // so that we can use the deny_unknown_fields tag. #[serde(rename = "command_name")] - pub _command_name: Option, + pub(crate) _command_name: Option, } fn default_arguments() -> Document { diff --git a/src/test/spec/v2_runner/test_file.rs b/src/test/spec/v2_runner/test_file.rs index 33fc86dec..8b1831ca1 100644 --- a/src/test/spec/v2_runner/test_file.rs +++ b/src/test/spec/v2_runner/test_file.rs @@ -15,27 +15,28 @@ use super::{operation::Operation, test_event::CommandStartedEvent}; #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub struct TestFile { +pub(crate) struct TestFile { #[serde(rename = "runOn")] - pub run_on: Option>, - pub database_name: Option, - pub collection_name: Option, - pub bucket_name: Option, - pub data: Option, - pub tests: Vec, + pub(crate) run_on: Option>, + pub(crate) database_name: Option, + pub(crate) collection_name: Option, + #[allow(unused)] + pub(crate) bucket_name: Option, + pub(crate) data: Option, + pub(crate) tests: Vec, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase", deny_unknown_fields)] -pub struct RunOn { - pub min_server_version: Option, - pub max_server_version: Option, - pub topology: Option>, +pub(crate) struct RunOn { + pub(crate) min_server_version: Option, + pub(crate) max_server_version: Option, + pub(crate) topology: Option>, pub(crate) serverless: Option, } impl RunOn { - pub fn can_run_on(&self, client: &TestClient) -> bool { + pub(crate) fn can_run_on(&self, client: &TestClient) -> bool { if let Some(ref min_version) = self.min_server_version { let req = VersionReq::parse(&format!(">= {}", &min_version)).unwrap(); if !req.matches(&client.server_version) { @@ -64,29 +65,29 @@ impl RunOn { #[derive(Debug, Deserialize)] #[serde(untagged)] -pub enum TestData { +pub(crate) enum TestData { Single(Vec), Many(HashMap>), } #[derive(Deserialize)] #[serde(rename_all = "camelCase")] -pub struct Test { - pub description: String, - pub skip_reason: Option, - pub use_multiple_mongoses: Option, +pub(crate) struct Test { + pub(crate) description: String, + pub(crate) skip_reason: Option, + pub(crate) use_multiple_mongoses: Option, #[serde( default, deserialize_with = "deserialize_uri_options_to_uri_string_option", rename = "clientOptions" )] - pub client_uri: Option, - pub fail_point: Option, - pub session_options: Option>, - pub operations: Vec, + pub(crate) client_uri: Option, + pub(crate) fail_point: Option, + pub(crate) session_options: Option>, + pub(crate) operations: Vec, #[serde(default, deserialize_with = "deserialize_command_started_events")] - pub expectations: Option>, - pub outcome: Option, + pub(crate) expectations: Option>, + pub(crate) outcome: Option, } fn deserialize_uri_options_to_uri_string_option<'de, D>( @@ -100,12 +101,12 @@ where } #[derive(Debug, Deserialize)] -pub struct Outcome { - pub collection: CollectionOutcome, +pub(crate) struct Outcome { + pub(crate) collection: CollectionOutcome, } impl Outcome { - pub async fn matches_actual( + pub(crate) async fn matches_actual( self, db_name: String, coll_name: String, @@ -133,9 +134,9 @@ impl Outcome { } #[derive(Debug, Deserialize)] -pub struct CollectionOutcome { - pub name: Option, - pub data: Vec, +pub(crate) struct CollectionOutcome { + pub(crate) name: Option, + pub(crate) data: Vec, } fn deserialize_command_started_events<'de, D>( diff --git a/src/test/spec/versioned_api.rs b/src/test/spec/versioned_api.rs index 3cfb1c721..8c2fd0c18 100644 --- a/src/test/spec/versioned_api.rs +++ b/src/test/spec/versioned_api.rs @@ -1,12 +1,12 @@ use tokio::sync::RwLockWriteGuard; -use crate::test::{run_spec_test, LOCK}; +use crate::test::LOCK; -use super::run_unified_format_test; +use super::{run_spec_test_with_path, run_unified_format_test}; #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] async fn run() { let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await; - run_spec_test(&["versioned-api"], run_unified_format_test).await; + run_spec_test_with_path(&["versioned-api"], run_unified_format_test).await; } diff --git a/src/test/util/event.rs b/src/test/util/event.rs index 92380581b..f8b2eb39c 100644 --- a/src/test/util/event.rs +++ b/src/test/util/event.rs @@ -55,8 +55,8 @@ use crate::{ test::{spec::ExpectedEventType, LOCK}, }; -pub type EventQueue = Arc>>; -pub type CmapEvent = crate::cmap::test::event::Event; +pub(crate) type EventQueue = Arc>>; +pub(crate) type CmapEvent = crate::cmap::test::event::Event; fn add_event_to_queue(event_queue: &EventQueue, event: T) { event_queue @@ -67,14 +67,14 @@ fn add_event_to_queue(event_queue: &EventQueue, event: T) { #[derive(Clone, Debug, From)] #[allow(clippy::large_enum_variant)] -pub enum Event { +pub(crate) enum Event { Cmap(CmapEvent), Command(CommandEvent), Sdam(SdamEvent), } impl Event { - pub fn unwrap_sdam_event(self) -> SdamEvent { + pub(crate) fn unwrap_sdam_event(self) -> SdamEvent { if let Event::Sdam(e) = self { e } else { @@ -85,7 +85,7 @@ impl Event { #[derive(Clone, Debug, Serialize)] #[serde(untagged)] -pub enum SdamEvent { +pub(crate) enum SdamEvent { ServerDescriptionChanged(Box), ServerOpening(ServerOpeningEvent), ServerClosed(ServerClosedEvent), @@ -116,7 +116,7 @@ impl SdamEvent { #[derive(Clone, Debug, Serialize)] #[allow(clippy::large_enum_variant)] #[serde(untagged)] -pub enum CommandEvent { +pub(crate) enum CommandEvent { Started(CommandStartedEvent), Succeeded(CommandSucceededEvent), Failed(CommandFailedEvent), @@ -131,7 +131,7 @@ impl CommandEvent { } } - pub fn command_name(&self) -> &str { + pub(crate) fn command_name(&self) -> &str { match self { CommandEvent::Started(event) => event.command_name.as_str(), CommandEvent::Failed(event) => event.command_name.as_str(), @@ -147,14 +147,14 @@ impl CommandEvent { } } - pub fn as_command_started(&self) -> Option<&CommandStartedEvent> { + pub(crate) fn as_command_started(&self) -> Option<&CommandStartedEvent> { match self { CommandEvent::Started(e) => Some(e), _ => None, } } - pub fn as_command_succeeded(&self) -> Option<&CommandSucceededEvent> { + pub(crate) fn as_command_succeeded(&self) -> Option<&CommandSucceededEvent> { match self { CommandEvent::Succeeded(e) => Some(e), _ => None, @@ -163,7 +163,7 @@ impl CommandEvent { } #[derive(Clone, Debug)] -pub struct EventHandler { +pub(crate) struct EventHandler { command_events: EventQueue, sdam_events: EventQueue, cmap_events: EventQueue, @@ -172,7 +172,7 @@ pub struct EventHandler { } impl EventHandler { - pub fn new() -> Self { + pub(crate) fn new() -> Self { let (event_broadcaster, _) = tokio::sync::broadcast::channel(10_000); Self { command_events: Default::default(), @@ -189,15 +189,22 @@ impl EventHandler { self.event_broadcaster.send(event.into()); } - pub fn subscribe(&self) -> EventSubscriber { + pub(crate) fn subscribe(&self) -> EventSubscriber { EventSubscriber { _handler: self, receiver: self.event_broadcaster.subscribe(), } } + pub(crate) fn broadcaster(&self) -> &tokio::sync::broadcast::Sender { + &self.event_broadcaster + } + /// Gets all of the command started events for the specified command names. - pub fn get_command_started_events(&self, command_names: &[&str]) -> Vec { + pub(crate) fn get_command_started_events( + &self, + command_names: &[&str], + ) -> Vec { let events = self.command_events.read().unwrap(); events .iter() @@ -215,7 +222,7 @@ impl EventHandler { } /// Gets all of the command started events, excluding configureFailPoint events. - pub fn get_all_command_started_events(&self) -> Vec { + pub(crate) fn get_all_command_started_events(&self) -> Vec { let events = self.command_events.read().unwrap(); events .iter() @@ -228,7 +235,11 @@ impl EventHandler { .collect() } - pub fn get_filtered_events(&self, event_type: ExpectedEventType, filter: F) -> Vec + pub(crate) fn get_filtered_events( + &self, + event_type: ExpectedEventType, + filter: F, + ) -> Vec where F: Fn(&Event) -> bool, { @@ -259,7 +270,7 @@ impl EventHandler { } } - pub fn write_events_list_to_file(&self, names: &[&str], writer: &mut BufWriter) { + pub(crate) fn write_events_list_to_file(&self, names: &[&str], writer: &mut BufWriter) { let mut add_comma = false; let mut write_json = |mut event: Document, name: &str, time: &OffsetDateTime| { event.insert("name", name); @@ -296,11 +307,11 @@ impl EventHandler { } } - pub fn connections_checked_out(&self) -> u32 { + pub(crate) fn connections_checked_out(&self) -> u32 { *self.connections_checked_out.lock().unwrap() } - pub fn clear_cached_events(&self) { + pub(crate) fn clear_cached_events(&self) { self.command_events.write().unwrap().clear(); self.cmap_events.write().unwrap().clear(); self.sdam_events.write().unwrap().clear(); @@ -454,7 +465,7 @@ impl CommandEventHandler for EventHandler { } #[derive(Debug)] -pub struct EventSubscriber<'a> { +pub(crate) struct EventSubscriber<'a> { /// A reference to the handler this subscriber is receiving events from. /// Stored here to ensure this subscriber cannot outlive the handler that is generating its /// events. @@ -463,7 +474,11 @@ pub struct EventSubscriber<'a> { } impl EventSubscriber<'_> { - pub async fn wait_for_event(&mut self, timeout: Duration, mut filter: F) -> Option + pub(crate) async fn wait_for_event( + &mut self, + timeout: Duration, + mut filter: F, + ) -> Option where F: FnMut(&Event) -> bool, { @@ -485,7 +500,7 @@ impl EventSubscriber<'_> { .flatten() } - pub async fn collect_events(&mut self, timeout: Duration, mut filter: F) -> Vec + pub(crate) async fn collect_events(&mut self, timeout: Duration, mut filter: F) -> Vec where F: FnMut(&Event) -> bool, { @@ -498,7 +513,7 @@ impl EventSubscriber<'_> { } #[derive(Clone, Debug)] -pub struct EventClient { +pub(crate) struct EventClient { client: TestClient, pub(crate) handler: Arc, } @@ -518,7 +533,7 @@ impl std::ops::DerefMut for EventClient { } impl EventClient { - pub async fn new() -> Self { + pub(crate) async fn new() -> Self { EventClient::with_options(None).await } @@ -535,11 +550,11 @@ impl EventClient { Self { client, handler } } - pub async fn with_options(options: impl Into>) -> Self { + pub(crate) async fn with_options(options: impl Into>) -> Self { Self::with_options_and_handler(options, None).await } - pub async fn with_additional_options( + pub(crate) async fn with_additional_options( options: impl Into>, heartbeat_freq: Option, use_multiple_mongoses: Option, @@ -558,7 +573,7 @@ impl EventClient { /// events before and between them. /// /// Panics if the command failed or could not be found in the events. - pub fn get_successful_command_execution( + pub(crate) fn get_successful_command_execution( &self, command_name: &str, ) -> (CommandStartedEvent, CommandSucceededEvent) { @@ -595,16 +610,19 @@ impl EventClient { } /// Gets all of the command started events for the specified command names. - pub fn get_command_started_events(&self, command_names: &[&str]) -> Vec { + pub(crate) fn get_command_started_events( + &self, + command_names: &[&str], + ) -> Vec { self.handler.get_command_started_events(command_names) } /// Gets all command started events, excluding configureFailPoint events. - pub fn get_all_command_started_events(&self) -> Vec { + pub(crate) fn get_all_command_started_events(&self) -> Vec { self.handler.get_all_command_started_events() } - pub fn get_command_events(&self, command_names: &[&str]) -> Vec { + pub(crate) fn get_command_events(&self, command_names: &[&str]) -> Vec { self.handler .command_events .write() @@ -615,7 +633,7 @@ impl EventClient { .collect() } - pub fn count_pool_cleared_events(&self) -> usize { + pub(crate) fn count_pool_cleared_events(&self) -> usize { let mut out = 0; for (event, _) in self.handler.cmap_events.read().unwrap().iter() { if matches!(event, CmapEvent::PoolCleared(_)) { @@ -626,11 +644,11 @@ impl EventClient { } #[allow(dead_code)] - pub fn subscribe_to_events(&self) -> EventSubscriber<'_> { + pub(crate) fn subscribe_to_events(&self) -> EventSubscriber<'_> { self.handler.subscribe() } - pub fn clear_cached_events(&self) { + pub(crate) fn clear_cached_events(&self) { self.handler.clear_cached_events() } } diff --git a/src/test/util/mod.rs b/src/test/util/mod.rs index 81ca39868..4dcbb7811 100644 --- a/src/test/util/mod.rs +++ b/src/test/util/mod.rs @@ -3,7 +3,7 @@ mod failpoint; mod lock; mod matchable; -pub use self::{ +pub(crate) use self::{ event::{CmapEvent, CommandEvent, Event, EventClient, EventHandler, SdamEvent}, failpoint::{FailCommandOptions, FailPoint, FailPointGuard, FailPointMode}, lock::TestLock, @@ -38,12 +38,12 @@ use crate::{ }; #[derive(Clone, Debug)] -pub struct TestClient { +pub(crate) struct TestClient { client: Client, - pub options: ClientOptions, + pub(crate) options: ClientOptions, pub(crate) server_info: HelloCommandResponse, - pub server_version: Version, - pub server_parameters: Document, + pub(crate) server_version: Version, + pub(crate) server_parameters: Document, } impl std::ops::Deref for TestClient { @@ -55,15 +55,15 @@ impl std::ops::Deref for TestClient { } impl TestClient { - pub async fn new() -> Self { + pub(crate) async fn new() -> Self { Self::with_options(None).await } - pub async fn with_options(options: Option) -> Self { + pub(crate) async fn with_options(options: Option) -> Self { Self::with_handler(None, options).await } - pub async fn with_handler( + pub(crate) async fn with_handler( event_handler: Option>, options: impl Into>, ) -> Self { @@ -123,12 +123,12 @@ impl TestClient { } } - pub async fn with_additional_options(options: Option) -> Self { + pub(crate) async fn with_additional_options(options: Option) -> Self { let options = Self::options_for_multiple_mongoses(options, false).await; Self::with_options(Some(options)).await } - pub async fn create_user( + pub(crate) async fn create_user( &self, user: &str, pwd: impl Into>, @@ -152,7 +152,7 @@ impl TestClient { Ok(()) } - pub async fn drop_and_create_user( + pub(crate) async fn drop_and_create_user( &self, user: &str, pwd: impl Into>, @@ -175,17 +175,25 @@ impl TestClient { self.create_user(user, pwd, roles, mechanisms, db).await } - pub fn get_coll(&self, db_name: &str, coll_name: &str) -> Collection { + pub(crate) fn get_coll(&self, db_name: &str, coll_name: &str) -> Collection { self.database(db_name).collection(coll_name) } - pub async fn init_db_and_coll(&self, db_name: &str, coll_name: &str) -> Collection { + pub(crate) async fn init_db_and_coll( + &self, + db_name: &str, + coll_name: &str, + ) -> Collection { let coll = self.get_coll(db_name, coll_name); drop_collection(&coll).await; coll } - pub async fn init_db_and_typed_coll(&self, db_name: &str, coll_name: &str) -> Collection + pub(crate) async fn init_db_and_typed_coll( + &self, + db_name: &str, + coll_name: &str, + ) -> Collection where T: Serialize + DeserializeOwned + Unpin + Debug, { @@ -194,7 +202,7 @@ impl TestClient { coll } - pub fn get_coll_with_options( + pub(crate) fn get_coll_with_options( &self, db_name: &str, coll_name: &str, @@ -204,7 +212,7 @@ impl TestClient { .collection_with_options(coll_name, options) } - pub async fn init_db_and_coll_with_options( + pub(crate) async fn init_db_and_coll_with_options( &self, db_name: &str, coll_name: &str, @@ -215,7 +223,7 @@ impl TestClient { coll } - pub async fn create_fresh_collection( + pub(crate) async fn create_fresh_collection( &self, db_name: &str, coll_name: &str, @@ -230,7 +238,7 @@ impl TestClient { self.get_coll(db_name, coll_name) } - pub fn supports_fail_command(&self) -> bool { + pub(crate) fn supports_fail_command(&self) -> bool { let version = if self.is_sharded() { VersionReq::parse(">= 4.1.5").unwrap() } else { @@ -239,7 +247,7 @@ impl TestClient { version.matches(&self.server_version) } - pub fn supports_block_connection(&self) -> bool { + pub(crate) fn supports_block_connection(&self) -> bool { let version = VersionReq::parse(">= 4.2.9").unwrap(); version.matches(&self.server_version) } @@ -248,7 +256,7 @@ impl TestClient { /// only when it uses a specified appName. /// /// See SERVER-49336 for more info. - pub fn supports_fail_command_appname_initial_handshake(&self) -> bool { + pub(crate) fn supports_fail_command_appname_initial_handshake(&self) -> bool { let requirements = [ VersionReq::parse(">= 4.2.15, < 4.3.0").unwrap(), VersionReq::parse(">= 4.4.7, < 4.5.0").unwrap(), @@ -259,12 +267,12 @@ impl TestClient { .any(|req| req.matches(&self.server_version)) } - pub fn supports_transactions(&self) -> bool { + pub(crate) fn supports_transactions(&self) -> bool { self.is_replica_set() && self.server_version_gte(4, 0) || self.is_sharded() && self.server_version_gte(4, 2) } - pub async fn enable_failpoint( + pub(crate) async fn enable_failpoint( &self, fp: FailPoint, criteria: impl Into>, @@ -272,53 +280,53 @@ impl TestClient { fp.enable(self, criteria).await } - pub fn auth_enabled(&self) -> bool { + pub(crate) fn auth_enabled(&self) -> bool { self.options.credential.is_some() } - pub fn is_standalone(&self) -> bool { + pub(crate) fn is_standalone(&self) -> bool { self.base_topology() == Topology::Single } - pub fn is_replica_set(&self) -> bool { + pub(crate) fn is_replica_set(&self) -> bool { self.base_topology() == Topology::ReplicaSet } - pub fn is_sharded(&self) -> bool { + pub(crate) fn is_sharded(&self) -> bool { self.base_topology() == Topology::Sharded } - pub fn is_load_balanced(&self) -> bool { + pub(crate) fn is_load_balanced(&self) -> bool { self.base_topology() == Topology::LoadBalanced } - pub fn server_version_eq(&self, major: u64, minor: u64) -> bool { + pub(crate) fn server_version_eq(&self, major: u64, minor: u64) -> bool { self.server_version.major == major && self.server_version.minor == minor } #[allow(dead_code)] - pub fn server_version_gt(&self, major: u64, minor: u64) -> bool { + pub(crate) fn server_version_gt(&self, major: u64, minor: u64) -> bool { self.server_version.major > major || (self.server_version.major == major && self.server_version.minor > minor) } - pub fn server_version_gte(&self, major: u64, minor: u64) -> bool { + pub(crate) fn server_version_gte(&self, major: u64, minor: u64) -> bool { self.server_version.major > major || (self.server_version.major == major && self.server_version.minor >= minor) } - pub fn server_version_lt(&self, major: u64, minor: u64) -> bool { + pub(crate) fn server_version_lt(&self, major: u64, minor: u64) -> bool { self.server_version.major < major || (self.server_version.major == major && self.server_version.minor < minor) } #[allow(dead_code)] - pub fn server_version_lte(&self, major: u64, minor: u64) -> bool { + pub(crate) fn server_version_lte(&self, major: u64, minor: u64) -> bool { self.server_version.major < major || (self.server_version.major == major && self.server_version.minor <= minor) } - pub async fn drop_collection(&self, db_name: &str, coll_name: &str) { + pub(crate) async fn drop_collection(&self, db_name: &str, coll_name: &str) { let coll = self.get_coll(db_name, coll_name); drop_collection(&coll).await; } @@ -338,7 +346,7 @@ impl TestClient { Topology::Single } - pub async fn topology(&self) -> Topology { + pub(crate) async fn topology(&self) -> Topology { let bt = self.base_topology(); if let Topology::Sharded = bt { let shard_info = self @@ -358,7 +366,7 @@ impl TestClient { bt } - pub fn topology_string(&self) -> String { + pub(crate) fn topology_string(&self) -> String { match self.base_topology() { Topology::LoadBalanced => "load-balanced", Topology::Sharded | Topology::ShardedReplicaSet => "sharded", @@ -368,7 +376,7 @@ impl TestClient { .to_string() } - pub async fn options_for_multiple_mongoses( + pub(crate) async fn options_for_multiple_mongoses( options: Option, use_multiple_mongoses: bool, ) -> ClientOptions { @@ -408,7 +416,7 @@ impl TestClient { } } -pub async fn drop_collection(coll: &Collection) +pub(crate) async fn drop_collection(coll: &Collection) where T: Serialize + DeserializeOwned + Unpin + Debug, { @@ -425,7 +433,7 @@ struct BuildInfo { version: String, } -pub fn get_default_name(description: &str) -> String { +pub(crate) fn get_default_name(description: &str) -> String { let mut db_name = description .replace('$', "%") .replace(' ', "_") @@ -436,7 +444,7 @@ pub fn get_default_name(description: &str) -> String { } /// Log a message on stderr that won't be captured by `cargo test`. Panics if the write fails. -pub fn log_uncaptured>(text: S) { +pub(crate) fn log_uncaptured>(text: S) { use std::io::Write; let mut stderr = std::io::stderr();