|
| 1 | +//! A global process-aborting timeout system, mainly intended for testing. |
| 2 | +
|
| 3 | +use std::cmp::Reverse; |
| 4 | +use std::collections::BinaryHeap; |
| 5 | +use std::sync::LazyLock; |
| 6 | +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; |
| 7 | +use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, channel}; |
| 8 | +use std::time::Duration; |
| 9 | + |
| 10 | +use polars::prelude::{InitHashMaps, PlHashSet}; |
| 11 | +use polars_utils::priority::Priority; |
| 12 | + |
| 13 | +static TIMEOUT_REQUEST_HANDLER: LazyLock<Sender<TimeoutRequest>> = LazyLock::new(|| { |
| 14 | + let (send, recv) = channel(); |
| 15 | + std::thread::Builder::new() |
| 16 | + .name("polars-timeout".to_string()) |
| 17 | + .spawn(move || timeout_thread(recv)) |
| 18 | + .unwrap(); |
| 19 | + send |
| 20 | +}); |
| 21 | + |
| 22 | +enum TimeoutRequest { |
| 23 | + Start(Duration, u64), |
| 24 | + Cancel(u64), |
| 25 | +} |
| 26 | + |
| 27 | +pub fn get_timeout() -> Option<Duration> { |
| 28 | + static TIMEOUT_DISABLED: AtomicBool = AtomicBool::new(false); |
| 29 | + |
| 30 | + // Fast path so we don't have to keep checking environment variables. Make |
| 31 | + // sure that if you want to use POLARS_TIMEOUT_MS it is set before the first |
| 32 | + // polars call. |
| 33 | + if TIMEOUT_DISABLED.load(Ordering::Relaxed) { |
| 34 | + return None; |
| 35 | + } |
| 36 | + |
| 37 | + let Ok(timeout) = std::env::var("POLARS_TIMEOUT_MS") else { |
| 38 | + TIMEOUT_DISABLED.store(true, Ordering::Relaxed); |
| 39 | + return None; |
| 40 | + }; |
| 41 | + |
| 42 | + match timeout.parse() { |
| 43 | + Ok(ms) => Some(Duration::from_millis(ms)), |
| 44 | + Err(e) => { |
| 45 | + eprintln!("failed to parse POLARS_TIMEOUT_MS: {e:?}"); |
| 46 | + None |
| 47 | + }, |
| 48 | + } |
| 49 | +} |
| 50 | + |
| 51 | +fn timeout_thread(recv: Receiver<TimeoutRequest>) { |
| 52 | + let mut active_timeouts: PlHashSet<u64> = PlHashSet::new(); |
| 53 | + let mut shortest_timeout: BinaryHeap<Priority<Reverse<Duration>, u64>> = BinaryHeap::new(); |
| 54 | + loop { |
| 55 | + // Remove cancelled requests. |
| 56 | + while let Some(Priority(_, id)) = shortest_timeout.peek() { |
| 57 | + if active_timeouts.contains(id) { |
| 58 | + break; |
| 59 | + } |
| 60 | + shortest_timeout.pop(); |
| 61 | + } |
| 62 | + |
| 63 | + let request = if let Some(Priority(timeout, _)) = shortest_timeout.peek() { |
| 64 | + match recv.recv_timeout(timeout.0) { |
| 65 | + Err(RecvTimeoutError::Timeout) => { |
| 66 | + eprintln!("exiting the process, POLARS_TIMEOUT_MS exceeded"); |
| 67 | + std::thread::sleep(Duration::from_secs_f64(1.0)); |
| 68 | + std::process::exit(1); |
| 69 | + }, |
| 70 | + r => r.unwrap(), |
| 71 | + } |
| 72 | + } else { |
| 73 | + recv.recv().unwrap() |
| 74 | + }; |
| 75 | + |
| 76 | + match request { |
| 77 | + TimeoutRequest::Start(duration, id) => { |
| 78 | + shortest_timeout.push(Priority(Reverse(duration), id)); |
| 79 | + active_timeouts.insert(id); |
| 80 | + }, |
| 81 | + TimeoutRequest::Cancel(id) => { |
| 82 | + active_timeouts.remove(&id); |
| 83 | + }, |
| 84 | + } |
| 85 | + } |
| 86 | +} |
| 87 | + |
| 88 | +pub fn schedule_polars_timeout() -> Option<u64> { |
| 89 | + static TIMEOUT_ID: AtomicU64 = AtomicU64::new(0); |
| 90 | + |
| 91 | + let timeout = get_timeout()?; |
| 92 | + let id = TIMEOUT_ID.fetch_add(1, Ordering::Relaxed); |
| 93 | + TIMEOUT_REQUEST_HANDLER |
| 94 | + .send(TimeoutRequest::Start(timeout, id)) |
| 95 | + .unwrap(); |
| 96 | + Some(id) |
| 97 | +} |
| 98 | + |
| 99 | +pub fn cancel_polars_timeout(opt_id: Option<u64>) { |
| 100 | + if let Some(id) = opt_id { |
| 101 | + TIMEOUT_REQUEST_HANDLER |
| 102 | + .send(TimeoutRequest::Cancel(id)) |
| 103 | + .unwrap(); |
| 104 | + } |
| 105 | +} |
0 commit comments