Skip to content

fix memory buildup from JoinSet #229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions sim-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ rand = "0.8.5"
hex = {version = "0.4.3"}
futures = "0.3.30"
console-subscriber = { version = "0.4.0", optional = true}
tokio-util = { version = "0.7.13", features = ["rt"] }

[features]
dev = ["console-subscriber"]
3 changes: 3 additions & 0 deletions sim-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_util::task::TaskTracker;

/// The default directory where the simulation files are stored and where the results will be written to.
pub const DEFAULT_DATA_DIR: &str = ".";
Expand Down Expand Up @@ -209,6 +210,7 @@ async fn main() -> anyhow::Result<()> {
None
};

let tasks = TaskTracker::new();
let sim = Simulation::new(
SimulationCfg::new(
cli.total_time,
Expand All @@ -219,6 +221,7 @@ async fn main() -> anyhow::Result<()> {
),
clients,
validated_activities,
tasks,
);
let sim2 = sim.clone();

Expand Down
1 change: 1 addition & 0 deletions simln-lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ serde_millis = "0.1.1"
rand_distr = "0.4.3"
mockall = "0.12.1"
rand_chacha = "0.3.1"
tokio-util = { version = "0.7.13", features = ["rt"] }

[dev-dependencies]
ntest = "0.9.0"
101 changes: 46 additions & 55 deletions simln-lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use std::{collections::HashMap, sync::Arc, time::SystemTime};
use thiserror::Error;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio::{select, time, time::Duration};
use tokio_util::task::TaskTracker;
use triggered::{Listener, Trigger};

use self::defined_activity::DefinedPaymentActivity;
Expand Down Expand Up @@ -518,6 +518,9 @@ pub struct Simulation {
activity: Vec<ActivityDefinition>,
/// Results logger that holds the simulation statistics.
results: Arc<Mutex<PaymentResultLogger>>,
/// Track all tasks spawned for use in the simulation. When used in the `run` method, it will wait for
/// these tasks to complete before returning.
tasks: TaskTracker,
/// High level triggers used to manage simulation tasks and shutdown.
shutdown_trigger: Trigger,
shutdown_listener: Listener,
Expand Down Expand Up @@ -546,13 +549,15 @@ impl Simulation {
cfg: SimulationCfg,
nodes: HashMap<PublicKey, Arc<Mutex<dyn LightningNode>>>,
activity: Vec<ActivityDefinition>,
tasks: TaskTracker,
) -> Self {
let (shutdown_trigger, shutdown_listener) = triggered::trigger();
Self {
cfg,
nodes,
activity,
results: Arc::new(Mutex::new(PaymentResultLogger::new())),
tasks,
shutdown_trigger,
shutdown_listener,
}
Expand Down Expand Up @@ -644,7 +649,19 @@ impl Simulation {
Ok(())
}

/// run until the simulation completes or we hit an error.
/// Note that it will wait for the tasks in self.tasks to complete
/// before returning.
pub async fn run(&self) -> Result<(), SimulationError> {
self.internal_run().await?;
// Close our TaskTracker and wait for any background tasks
// spawned during internal_run to complete.
Comment on lines +652 to +658
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: line wrapping on comments

self.tasks.close();
self.tasks.wait().await;
Ok(())
}

async fn internal_run(&self) -> Result<(), SimulationError> {
if let Some(total_time) = self.cfg.total_time {
log::info!("Running the simulation for {}s.", total_time.as_secs());
} else {
Expand All @@ -659,7 +676,6 @@ impl Simulation {
self.activity.len(),
self.nodes.len()
);
let mut tasks = JoinSet::new();

// Before we start the simulation up, start tasks that will be responsible for gathering simulation data.
// The event channels are shared across our functionality:
Expand All @@ -668,21 +684,15 @@ impl Simulation {
// - Event Receiver: used by data reporting to receive events that have been simulated that need to be
// tracked and recorded.
let (event_sender, event_receiver) = channel(1);
self.run_data_collection(event_receiver, &mut tasks);
self.run_data_collection(event_receiver, &self.tasks);

// Get an execution kit per activity that we need to generate and spin up consumers for each source node.
let activities = match self.activity_executors().await {
Ok(a) => a,
Err(e) => {
// If we encounter an error while setting up the activity_executors,
// we need to shutdown and wait for tasks to finish. We have started background tasks in the
// run_data_collection function, so we should shut those down before returning.
// we need to shutdown and return.
self.shutdown();
while let Some(res) = tasks.join_next().await {
if let Err(e) = res {
log::error!("Task exited with error: {e}.");
}
}
return Err(e);
},
};
Expand All @@ -692,40 +702,30 @@ impl Simulation {
.map(|generator| generator.source_info.pubkey)
.collect(),
event_sender.clone(),
&mut tasks,
&self.tasks,
);

// Next, we'll spin up our actual producers that will be responsible for triggering the configured activity.
// The producers will use their own JoinSet so that the simulation can be shutdown if they all finish.
let mut producer_tasks = JoinSet::new();
// The producers will use their own TaskTracker so that the simulation can be shutdown if they all finish.
let producer_tasks = TaskTracker::new();
match self
.dispatch_producers(activities, consumer_channels, &mut producer_tasks)
.dispatch_producers(activities, consumer_channels, &producer_tasks)
.await
{
Ok(_) => {},
Err(e) => {
// If we encounter an error in dispatch_producers, we need to shutdown and wait for tasks to finish.
// We have started background tasks in the run_data_collection function,
// so we should shut those down before returning.
// If we encounter an error in dispatch_producers, we need to shutdown and return.
self.shutdown();
while let Some(res) = tasks.join_next().await {
if let Err(e) = res {
log::error!("Task exited with error: {e}.");
}
}
return Err(e);
},
}

// Start a task that waits for the producers to finish.
// If all producers finish, then there is nothing left to do and the simulation can be shutdown.
let producer_trigger = self.shutdown_trigger.clone();
tasks.spawn(async move {
while let Some(res) = producer_tasks.join_next().await {
if let Err(e) = res {
log::error!("Producer exited with error: {e}.");
}
}
self.tasks.spawn(async move {
producer_tasks.close();
producer_tasks.wait().await;
log::info!("All producers finished. Shutting down.");
producer_trigger.trigger()
});
Expand All @@ -735,7 +735,7 @@ impl Simulation {
let t = self.shutdown_trigger.clone();
let l = self.shutdown_listener.clone();

tasks.spawn(async move {
self.tasks.spawn(async move {
if time::timeout(total_time, l).await.is_err() {
log::info!(
"Simulation run for {}s. Shutting down.",
Expand All @@ -746,18 +746,7 @@ impl Simulation {
});
}

// We always want to wait for all threads to exit, so we wait for all of them to exit and track any errors
// that surface. It's okay if there are multiple and one is overwritten, we just want to know whether we
// exited with an error or not.
let mut success = true;
while let Some(res) = tasks.join_next().await {
if let Err(e) = res {
log::error!("Task exited with error: {e}.");
success = false;
}
}

success.then_some(()).ok_or(SimulationError::TaskError)
Ok(())
}

pub fn shutdown(&self) {
Expand All @@ -777,7 +766,7 @@ impl Simulation {
fn run_data_collection(
&self,
output_receiver: Receiver<SimulationOutput>,
tasks: &mut JoinSet<()>,
tasks: &TaskTracker,
) {
let listener = self.shutdown_listener.clone();
let shutdown = self.shutdown_trigger.clone();
Expand All @@ -790,11 +779,17 @@ impl Simulation {
// psr: produce simulation results
let psr_listener = listener.clone();
let psr_shutdown = shutdown.clone();
let psr_tasks = tasks.clone();
tasks.spawn(async move {
log::debug!("Starting simulation results producer.");
if let Err(e) =
produce_simulation_results(nodes, output_receiver, results_sender, psr_listener)
.await
if let Err(e) = produce_simulation_results(
nodes,
output_receiver,
results_sender,
psr_listener,
&psr_tasks,
)
.await
{
psr_shutdown.trigger();
log::error!("Produce simulation results exited with error: {e:?}.");
Expand Down Expand Up @@ -939,7 +934,7 @@ impl Simulation {
&self,
consuming_nodes: HashSet<PublicKey>,
output_sender: Sender<SimulationOutput>,
tasks: &mut JoinSet<()>,
tasks: &TaskTracker,
) -> HashMap<PublicKey, Sender<SimulationEvent>> {
let mut channels = HashMap::new();

Expand Down Expand Up @@ -984,7 +979,7 @@ impl Simulation {
&self,
executors: Vec<ExecutorKit>,
producer_channels: HashMap<PublicKey, Sender<SimulationEvent>>,
tasks: &mut JoinSet<()>,
tasks: &TaskTracker,
) -> Result<(), SimulationError> {
for executor in executors {
let sender = producer_channels.get(&executor.source_info.pubkey).ok_or(
Expand Down Expand Up @@ -1350,9 +1345,8 @@ async fn produce_simulation_results(
mut output_receiver: Receiver<SimulationOutput>,
results: Sender<(Payment, PaymentResult)>,
listener: Listener,
tasks: &TaskTracker,
) -> Result<(), SimulationError> {
let mut set = tokio::task::JoinSet::new();

let result = loop {
tokio::select! {
biased;
Expand All @@ -1365,7 +1359,7 @@ async fn produce_simulation_results(
match simulation_output{
SimulationOutput::SendPaymentSuccess(payment) => {
if let Some(source_node) = nodes.get(&payment.source) {
set.spawn(track_payment_result(
tasks.spawn(track_payment_result(
source_node.clone(), results.clone(), payment, listener.clone()
));
} else {
Expand Down Expand Up @@ -1396,11 +1390,6 @@ async fn produce_simulation_results(
};

log::debug!("Simulation results producer exiting.");
while let Some(res) = set.join_next().await {
if let Err(e) = res {
log::error!("Simulation results producer task exited with error: {e}.");
}
}

result
}
Expand Down Expand Up @@ -1476,6 +1465,7 @@ mod tests {
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio_util::task::TaskTracker;

#[test]
fn create_seeded_mut_rng() {
Expand Down Expand Up @@ -1619,6 +1609,7 @@ mod tests {
crate::SimulationCfg::new(Some(0), 0, 0.0, None, None),
clients,
vec![activity_definition],
TaskTracker::new(),
);
assert!(simulation.validate_activity().await.is_err());
}
Expand Down
Loading