diff --git a/Cargo.lock b/Cargo.lock index 33870708..e1d93ea4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2080,6 +2080,7 @@ dependencies = [ "simln-lib", "simple_logger", "tokio", + "tokio-util", ] [[package]] @@ -2107,6 +2108,7 @@ dependencies = [ "serde_millis", "thiserror", "tokio", + "tokio-util", "tonic 0.8.3", "triggered", ] @@ -2374,6 +2376,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] diff --git a/sim-cli/Cargo.toml b/sim-cli/Cargo.toml index 8730eaca..cde4e9ed 100644 --- a/sim-cli/Cargo.toml +++ b/sim-cli/Cargo.toml @@ -25,6 +25,7 @@ rand = "0.8.5" hex = {version = "0.4.3"} futures = "0.3.30" console-subscriber = { version = "0.4.0", optional = true} +tokio-util = { version = "0.7.13", features = ["rt"] } [features] dev = ["console-subscriber"] diff --git a/sim-cli/src/main.rs b/sim-cli/src/main.rs index 616756fb..1baba0da 100644 --- a/sim-cli/src/main.rs +++ b/sim-cli/src/main.rs @@ -12,6 +12,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::Mutex; +use tokio_util::task::TaskTracker; /// The default directory where the simulation files are stored and where the results will be written to. pub const DEFAULT_DATA_DIR: &str = "."; @@ -209,6 +210,7 @@ async fn main() -> anyhow::Result<()> { None }; + let tasks = TaskTracker::new(); let sim = Simulation::new( SimulationCfg::new( cli.total_time, @@ -219,6 +221,7 @@ async fn main() -> anyhow::Result<()> { ), clients, validated_activities, + tasks, ); let sim2 = sim.clone(); diff --git a/simln-lib/Cargo.toml b/simln-lib/Cargo.toml index fc36ab6e..e1a41c78 100644 --- a/simln-lib/Cargo.toml +++ b/simln-lib/Cargo.toml @@ -32,6 +32,7 @@ serde_millis = "0.1.1" rand_distr = "0.4.3" mockall = "0.12.1" rand_chacha = "0.3.1" +tokio-util = { version = "0.7.13", features = ["rt"] } [dev-dependencies] ntest = "0.9.0" diff --git a/simln-lib/src/lib.rs b/simln-lib/src/lib.rs index 13d2ef78..7cb53ae1 100644 --- a/simln-lib/src/lib.rs +++ b/simln-lib/src/lib.rs @@ -19,8 +19,8 @@ use std::{collections::HashMap, sync::Arc, time::SystemTime}; use thiserror::Error; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::Mutex; -use tokio::task::JoinSet; use tokio::{select, time, time::Duration}; +use tokio_util::task::TaskTracker; use triggered::{Listener, Trigger}; use self::defined_activity::DefinedPaymentActivity; @@ -518,6 +518,9 @@ pub struct Simulation { activity: Vec, /// Results logger that holds the simulation statistics. results: Arc>, + /// Track all tasks spawned for use in the simulation. When used in the `run` method, it will wait for + /// these tasks to complete before returning. + tasks: TaskTracker, /// High level triggers used to manage simulation tasks and shutdown. shutdown_trigger: Trigger, shutdown_listener: Listener, @@ -546,6 +549,7 @@ impl Simulation { cfg: SimulationCfg, nodes: HashMap>>, activity: Vec, + tasks: TaskTracker, ) -> Self { let (shutdown_trigger, shutdown_listener) = triggered::trigger(); Self { @@ -553,6 +557,7 @@ impl Simulation { nodes, activity, results: Arc::new(Mutex::new(PaymentResultLogger::new())), + tasks, shutdown_trigger, shutdown_listener, } @@ -644,7 +649,19 @@ impl Simulation { Ok(()) } + /// run until the simulation completes or we hit an error. + /// Note that it will wait for the tasks in self.tasks to complete + /// before returning. pub async fn run(&self) -> Result<(), SimulationError> { + self.internal_run().await?; + // Close our TaskTracker and wait for any background tasks + // spawned during internal_run to complete. + self.tasks.close(); + self.tasks.wait().await; + Ok(()) + } + + async fn internal_run(&self) -> Result<(), SimulationError> { if let Some(total_time) = self.cfg.total_time { log::info!("Running the simulation for {}s.", total_time.as_secs()); } else { @@ -659,7 +676,6 @@ impl Simulation { self.activity.len(), self.nodes.len() ); - let mut tasks = JoinSet::new(); // Before we start the simulation up, start tasks that will be responsible for gathering simulation data. // The event channels are shared across our functionality: @@ -668,21 +684,15 @@ impl Simulation { // - Event Receiver: used by data reporting to receive events that have been simulated that need to be // tracked and recorded. let (event_sender, event_receiver) = channel(1); - self.run_data_collection(event_receiver, &mut tasks); + self.run_data_collection(event_receiver, &self.tasks); // Get an execution kit per activity that we need to generate and spin up consumers for each source node. let activities = match self.activity_executors().await { Ok(a) => a, Err(e) => { // If we encounter an error while setting up the activity_executors, - // we need to shutdown and wait for tasks to finish. We have started background tasks in the - // run_data_collection function, so we should shut those down before returning. + // we need to shutdown and return. self.shutdown(); - while let Some(res) = tasks.join_next().await { - if let Err(e) = res { - log::error!("Task exited with error: {e}."); - } - } return Err(e); }, }; @@ -692,27 +702,20 @@ impl Simulation { .map(|generator| generator.source_info.pubkey) .collect(), event_sender.clone(), - &mut tasks, + &self.tasks, ); // Next, we'll spin up our actual producers that will be responsible for triggering the configured activity. - // The producers will use their own JoinSet so that the simulation can be shutdown if they all finish. - let mut producer_tasks = JoinSet::new(); + // The producers will use their own TaskTracker so that the simulation can be shutdown if they all finish. + let producer_tasks = TaskTracker::new(); match self - .dispatch_producers(activities, consumer_channels, &mut producer_tasks) + .dispatch_producers(activities, consumer_channels, &producer_tasks) .await { Ok(_) => {}, Err(e) => { - // If we encounter an error in dispatch_producers, we need to shutdown and wait for tasks to finish. - // We have started background tasks in the run_data_collection function, - // so we should shut those down before returning. + // If we encounter an error in dispatch_producers, we need to shutdown and return. self.shutdown(); - while let Some(res) = tasks.join_next().await { - if let Err(e) = res { - log::error!("Task exited with error: {e}."); - } - } return Err(e); }, } @@ -720,12 +723,9 @@ impl Simulation { // Start a task that waits for the producers to finish. // If all producers finish, then there is nothing left to do and the simulation can be shutdown. let producer_trigger = self.shutdown_trigger.clone(); - tasks.spawn(async move { - while let Some(res) = producer_tasks.join_next().await { - if let Err(e) = res { - log::error!("Producer exited with error: {e}."); - } - } + self.tasks.spawn(async move { + producer_tasks.close(); + producer_tasks.wait().await; log::info!("All producers finished. Shutting down."); producer_trigger.trigger() }); @@ -735,7 +735,7 @@ impl Simulation { let t = self.shutdown_trigger.clone(); let l = self.shutdown_listener.clone(); - tasks.spawn(async move { + self.tasks.spawn(async move { if time::timeout(total_time, l).await.is_err() { log::info!( "Simulation run for {}s. Shutting down.", @@ -746,18 +746,7 @@ impl Simulation { }); } - // We always want to wait for all threads to exit, so we wait for all of them to exit and track any errors - // that surface. It's okay if there are multiple and one is overwritten, we just want to know whether we - // exited with an error or not. - let mut success = true; - while let Some(res) = tasks.join_next().await { - if let Err(e) = res { - log::error!("Task exited with error: {e}."); - success = false; - } - } - - success.then_some(()).ok_or(SimulationError::TaskError) + Ok(()) } pub fn shutdown(&self) { @@ -777,7 +766,7 @@ impl Simulation { fn run_data_collection( &self, output_receiver: Receiver, - tasks: &mut JoinSet<()>, + tasks: &TaskTracker, ) { let listener = self.shutdown_listener.clone(); let shutdown = self.shutdown_trigger.clone(); @@ -790,11 +779,17 @@ impl Simulation { // psr: produce simulation results let psr_listener = listener.clone(); let psr_shutdown = shutdown.clone(); + let psr_tasks = tasks.clone(); tasks.spawn(async move { log::debug!("Starting simulation results producer."); - if let Err(e) = - produce_simulation_results(nodes, output_receiver, results_sender, psr_listener) - .await + if let Err(e) = produce_simulation_results( + nodes, + output_receiver, + results_sender, + psr_listener, + &psr_tasks, + ) + .await { psr_shutdown.trigger(); log::error!("Produce simulation results exited with error: {e:?}."); @@ -939,7 +934,7 @@ impl Simulation { &self, consuming_nodes: HashSet, output_sender: Sender, - tasks: &mut JoinSet<()>, + tasks: &TaskTracker, ) -> HashMap> { let mut channels = HashMap::new(); @@ -984,7 +979,7 @@ impl Simulation { &self, executors: Vec, producer_channels: HashMap>, - tasks: &mut JoinSet<()>, + tasks: &TaskTracker, ) -> Result<(), SimulationError> { for executor in executors { let sender = producer_channels.get(&executor.source_info.pubkey).ok_or( @@ -1350,9 +1345,8 @@ async fn produce_simulation_results( mut output_receiver: Receiver, results: Sender<(Payment, PaymentResult)>, listener: Listener, + tasks: &TaskTracker, ) -> Result<(), SimulationError> { - let mut set = tokio::task::JoinSet::new(); - let result = loop { tokio::select! { biased; @@ -1365,7 +1359,7 @@ async fn produce_simulation_results( match simulation_output{ SimulationOutput::SendPaymentSuccess(payment) => { if let Some(source_node) = nodes.get(&payment.source) { - set.spawn(track_payment_result( + tasks.spawn(track_payment_result( source_node.clone(), results.clone(), payment, listener.clone() )); } else { @@ -1396,11 +1390,6 @@ async fn produce_simulation_results( }; log::debug!("Simulation results producer exiting."); - while let Some(res) = set.join_next().await { - if let Err(e) = res { - log::error!("Simulation results producer task exited with error: {e}."); - } - } result } @@ -1476,6 +1465,7 @@ mod tests { use std::sync::Arc; use std::time::Duration; use tokio::sync::Mutex; + use tokio_util::task::TaskTracker; #[test] fn create_seeded_mut_rng() { @@ -1619,6 +1609,7 @@ mod tests { crate::SimulationCfg::new(Some(0), 0, 0.0, None, None), clients, vec![activity_definition], + TaskTracker::new(), ); assert!(simulation.validate_activity().await.is_err()); } diff --git a/simln-lib/src/sim_node.rs b/simln-lib/src/sim_node.rs index 81a8b1a9..6d06b8ad 100644 --- a/simln-lib/src/sim_node.rs +++ b/simln-lib/src/sim_node.rs @@ -9,6 +9,7 @@ use lightning::ln::chan_utils::make_funding_redeemscript; use std::collections::{hash_map::Entry, HashMap}; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; +use tokio_util::task::TaskTracker; use lightning::ln::features::{ChannelFeatures, NodeFeatures}; use lightning::ln::msgs::{ @@ -24,7 +25,6 @@ use thiserror::Error; use tokio::select; use tokio::sync::oneshot::{channel, Receiver, Sender}; use tokio::sync::Mutex; -use tokio::task::JoinSet; use triggered::{Listener, Trigger}; use crate::ShortChannelID; @@ -638,8 +638,9 @@ pub struct SimGraph { /// channels maps the scid of a channel to its current simulation state. channels: Arc>>, - /// track all tasks spawned to process payments in the graph. - tasks: JoinSet<()>, + /// track all tasks spawned to process payments in the graph. Note that handling the shutdown of tasks + /// in this tracker must be done externally. + tasks: TaskTracker, /// trigger shutdown if a critical error occurs. shutdown_trigger: Trigger, @@ -649,6 +650,7 @@ impl SimGraph { /// Creates a graph on which to simulate payments. pub fn new( graph_channels: Vec, + tasks: TaskTracker, shutdown_trigger: Trigger, ) -> Result { let mut nodes: HashMap> = HashMap::new(); @@ -682,24 +684,10 @@ impl SimGraph { Ok(SimGraph { nodes, channels: Arc::new(Mutex::new(channels)), - tasks: JoinSet::new(), + tasks, shutdown_trigger, }) } - - /// Blocks until all tasks created by the simulator have shut down. This function does not trigger shutdown, - /// because it expects erroring-out tasks to handle their own shutdown triggering. - pub async fn wait_for_shutdown(&mut self) { - log::debug!("Waiting for simulated graph to shutdown."); - - while let Some(res) = self.tasks.join_next().await { - if let Err(e) = res { - log::error!("Graph task exited with error: {e}"); - } - } - - log::debug!("Simulated graph shutdown."); - } } /// Produces a map of node public key to lightning node implementation to be used for simulations. @@ -1579,7 +1567,7 @@ mod tests { nodes.push(channels.last().unwrap().node_2.policy.pubkey); let kit = DispatchPaymentTestKit { - graph: SimGraph::new(channels.clone(), shutdown.clone()) + graph: SimGraph::new(channels.clone(), TaskTracker::new(), shutdown.clone()) .expect("could not create test graph"), nodes, routing_graph: populate_network_graph(channels).unwrap(), @@ -1712,7 +1700,8 @@ mod tests { assert_eq!(test_kit.channel_balances().await, expected_balances); test_kit.shutdown.trigger(); - test_kit.graph.wait_for_shutdown().await; + test_kit.graph.tasks.close(); + test_kit.graph.tasks.wait().await; } /// Tests successful dispatch of a multi-hop payment. @@ -1741,7 +1730,8 @@ mod tests { assert_eq!(test_kit.channel_balances().await, expected_balances); test_kit.shutdown.trigger(); - test_kit.graph.wait_for_shutdown().await; + test_kit.graph.tasks.close(); + test_kit.graph.tasks.wait().await; } /// Tests success and failure for single hop payments, which are an edge case in our state machine. @@ -1772,7 +1762,8 @@ mod tests { assert_eq!(test_kit.channel_balances().await, expected_balances); test_kit.shutdown.trigger(); - test_kit.graph.wait_for_shutdown().await; + test_kit.graph.tasks.close(); + test_kit.graph.tasks.wait().await; } /// Tests failing back of multi-hop payments at various failure indexes. @@ -1812,6 +1803,7 @@ mod tests { assert_eq!(test_kit.channel_balances().await, expected_balances); test_kit.shutdown.trigger(); - test_kit.graph.wait_for_shutdown().await; + test_kit.graph.tasks.close(); + test_kit.graph.tasks.wait().await; } }