Skip to content

Commit e12cf80

Browse files
fix: allow for recursive block-on calls
Fixes #798,#795,#760
1 parent 631105b commit e12cf80

File tree

3 files changed

+75
-5
lines changed

3 files changed

+75
-5
lines changed

Diff for: Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ futures-timer = { version = "3.0.2", optional = true }
7575
surf = { version = "1.0.3", optional = true }
7676

7777
[target.'cfg(not(target_os = "unknown"))'.dependencies]
78-
smol = { version = "0.1.10", optional = true }
78+
smol = { version = "0.1.11", optional = true }
7979

8080
[target.'cfg(target_arch = "wasm32")'.dependencies]
8181
futures-timer = { version = "3.0.2", optional = true, features = ["wasm-bindgen"] }
@@ -103,3 +103,4 @@ required-features = ["unstable"]
103103
[[example]]
104104
name = "surf-web"
105105
required-features = ["surf"]
106+

Diff for: src/task/builder.rs

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::cell::Cell;
12
use std::future::Future;
23
use std::pin::Pin;
34
use std::sync::Arc;
@@ -150,8 +151,31 @@ impl Builder {
150151
parent_task_id: TaskLocalsWrapper::get_current(|t| t.id().0).unwrap_or(0),
151152
});
152153

154+
thread_local! {
155+
/// Tracks the number of nested block_on calls.
156+
static NUM_NESTED_BLOCKING: Cell<usize> = Cell::new(0);
157+
}
158+
153159
// Run the future as a task.
154-
unsafe { TaskLocalsWrapper::set_current(&wrapped.tag, || smol::run(wrapped)) }
160+
NUM_NESTED_BLOCKING.with(|num_nested_blocking| {
161+
let count = num_nested_blocking.get();
162+
let should_run = count == 0;
163+
// increase the count
164+
num_nested_blocking.replace(count + 1);
165+
166+
unsafe {
167+
TaskLocalsWrapper::set_current(&wrapped.tag, || {
168+
let res = if should_run {
169+
// The first call should use run.
170+
smol::run(wrapped)
171+
} else {
172+
smol::block_on(wrapped)
173+
};
174+
num_nested_blocking.replace(num_nested_blocking.get() - 1);
175+
res
176+
})
177+
}
178+
})
155179
}
156180
}
157181

Diff for: tests/block_on.rs

+48-3
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,63 @@
11
#![cfg(not(target_os = "unknown"))]
22

3-
use async_std::task;
3+
use async_std::{future::ready, task::block_on};
44

55
#[test]
66
fn smoke() {
7-
let res = task::block_on(async { 1 + 2 });
7+
let res = block_on(async { 1 + 2 });
88
assert_eq!(res, 3);
99
}
1010

1111
#[test]
1212
#[should_panic = "boom"]
1313
fn panic() {
14-
task::block_on(async {
14+
block_on(async {
1515
// This panic should get propagated into the parent thread.
1616
panic!("boom");
1717
});
1818
}
19+
20+
#[cfg(feature = "unstable")]
21+
#[test]
22+
fn nested_block_on_local() {
23+
use async_std::task::spawn_local;
24+
25+
let x = block_on(async {
26+
let a = block_on(async { block_on(async { ready(3).await }) });
27+
let b = spawn_local(async { block_on(async { ready(2).await }) }).await;
28+
let c = block_on(async { block_on(async { ready(1).await }) });
29+
a + b + c
30+
});
31+
32+
assert_eq!(x, 3 + 2 + 1);
33+
34+
let y = block_on(async {
35+
let a = block_on(async { block_on(async { ready(3).await }) });
36+
let b = spawn_local(async { block_on(async { ready(2).await }) }).await;
37+
let c = block_on(async { block_on(async { ready(1).await }) });
38+
a + b + c
39+
});
40+
41+
assert_eq!(y, 3 + 2 + 1);
42+
}
43+
44+
#[test]
45+
fn nested_block_on() {
46+
let x = block_on(async {
47+
let a = block_on(async { block_on(async { ready(3).await }) });
48+
let b = block_on(async { block_on(async { ready(2).await }) });
49+
let c = block_on(async { block_on(async { ready(1).await }) });
50+
a + b + c
51+
});
52+
53+
assert_eq!(x, 3 + 2 + 1);
54+
55+
let y = block_on(async {
56+
let a = block_on(async { block_on(async { ready(3).await }) });
57+
let b = block_on(async { block_on(async { ready(2).await }) });
58+
let c = block_on(async { block_on(async { ready(1).await }) });
59+
a + b + c
60+
});
61+
62+
assert_eq!(y, 3 + 2 + 1);
63+
}

0 commit comments

Comments
 (0)