Skip to content

Commit d15a5fe

Browse files
committed
Change StreamExt::scan to pass state to closure by value
1 parent c359ebf commit d15a5fe

File tree

5 files changed

+84
-47
lines changed

5 files changed

+84
-47
lines changed

futures-util/src/sink/unfold.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ pin_project! {
1010
/// Sink for the [`unfold`] function.
1111
#[derive(Debug)]
1212
#[must_use = "sinks do nothing unless polled"]
13-
pub struct Unfold<T, F, R> {
13+
pub struct Unfold<T, F, Fut> {
1414
function: F,
1515
#[pin]
16-
state: UnfoldState<T, R>,
16+
state: UnfoldState<T, Fut>,
1717
}
1818
}
1919

@@ -36,18 +36,18 @@ pin_project! {
3636
/// unfold.send(5).await?;
3737
/// # Ok::<(), std::convert::Infallible>(()) }).unwrap();
3838
/// ```
39-
pub fn unfold<T, F, R, Item, E>(init: T, function: F) -> Unfold<T, F, R>
39+
pub fn unfold<T, F, Fut, Item, E>(init: T, function: F) -> Unfold<T, F, Fut>
4040
where
41-
F: FnMut(T, Item) -> R,
42-
R: Future<Output = Result<T, E>>,
41+
F: FnMut(T, Item) -> Fut,
42+
Fut: Future<Output = Result<T, E>>,
4343
{
4444
assert_sink::<Item, E, _>(Unfold { function, state: UnfoldState::Value { value: init } })
4545
}
4646

47-
impl<T, F, R, Item, E> Sink<Item> for Unfold<T, F, R>
47+
impl<T, F, Fut, Item, E> Sink<Item> for Unfold<T, F, Fut>
4848
where
49-
F: FnMut(T, Item) -> R,
50-
R: Future<Output = Result<T, E>>,
49+
F: FnMut(T, Item) -> Fut,
50+
Fut: Future<Output = Result<T, E>>,
5151
{
5252
type Error = E;
5353

futures-util/src/stream/stream/mod.rs

+13-6
Original file line numberDiff line numberDiff line change
@@ -715,27 +715,34 @@ pub trait StreamExt: Stream {
715715
/// of the stream until provided closure returns `None`. Once `None` is
716716
/// returned, stream will be terminated.
717717
///
718+
/// Unlike [`Iterator::scan`], the closure takes the state by value instead of
719+
/// mutable reference to avoid [the limitation of the async
720+
/// block](https://github.com/rust-lang/futures-rs/issues/2171).
721+
///
718722
/// # Examples
719723
///
720724
/// ```
721725
/// # futures::executor::block_on(async {
722-
/// use futures::future;
723726
/// use futures::stream::{self, StreamExt};
724727
///
725728
/// let stream = stream::iter(1..=10);
726729
///
727-
/// let stream = stream.scan(0, |state, x| {
728-
/// *state += x;
729-
/// future::ready(if *state < 10 { Some(x) } else { None })
730+
/// let stream = stream.scan(0, |mut state, x| async move {
731+
/// state += x;
732+
/// if state < 10 {
733+
/// Some((state, x))
734+
/// } else {
735+
/// None
736+
/// }
730737
/// });
731738
///
732739
/// assert_eq!(vec![1, 2, 3], stream.collect::<Vec<_>>().await);
733740
/// # });
734741
/// ```
735742
fn scan<S, B, Fut, F>(self, initial_state: S, f: F) -> Scan<Self, S, Fut, F>
736743
where
737-
F: FnMut(&mut S, Self::Item) -> Fut,
738-
Fut: Future<Output = Option<B>>,
744+
F: FnMut(S, Self::Item) -> Fut,
745+
Fut: Future<Output = Option<(S, B)>>,
739746
Self: Sized,
740747
{
741748
assert_stream::<B, _>(Scan::new(self, initial_state, f))

futures-util/src/stream/stream/scan.rs

+25-26
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::unfold_state::UnfoldState;
12
use core::fmt;
23
use core::pin::Pin;
34
use futures_core::future::Future;
@@ -8,20 +9,15 @@ use futures_core::task::{Context, Poll};
89
use futures_sink::Sink;
910
use pin_project_lite::pin_project;
1011

11-
struct StateFn<S, F> {
12-
state: S,
13-
f: F,
14-
}
15-
1612
pin_project! {
1713
/// Stream for the [`scan`](super::StreamExt::scan) method.
1814
#[must_use = "streams do nothing unless polled"]
1915
pub struct Scan<St: Stream, S, Fut, F> {
2016
#[pin]
2117
stream: St,
22-
state_f: Option<StateFn<S, F>>,
18+
f: F,
2319
#[pin]
24-
future: Option<Fut>,
20+
state: UnfoldState<S, Fut>,
2521
}
2622
}
2723

@@ -35,8 +31,7 @@ where
3531
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3632
f.debug_struct("Scan")
3733
.field("stream", &self.stream)
38-
.field("state", &self.state_f.as_ref().map(|s| &s.state))
39-
.field("future", &self.future)
34+
.field("state", &self.state)
4035
.field("done_taking", &self.is_done_taking())
4136
.finish()
4237
}
@@ -45,18 +40,18 @@ where
4540
impl<St: Stream, S, Fut, F> Scan<St, S, Fut, F> {
4641
/// Checks if internal state is `None`.
4742
fn is_done_taking(&self) -> bool {
48-
self.state_f.is_none()
43+
self.state.is_empty()
4944
}
5045
}
5146

5247
impl<B, St, S, Fut, F> Scan<St, S, Fut, F>
5348
where
5449
St: Stream,
55-
F: FnMut(&mut S, St::Item) -> Fut,
56-
Fut: Future<Output = Option<B>>,
50+
F: FnMut(S, St::Item) -> Fut,
51+
Fut: Future<Output = Option<(S, B)>>,
5752
{
5853
pub(super) fn new(stream: St, initial_state: S, f: F) -> Self {
59-
Self { stream, state_f: Some(StateFn { state: initial_state, f }), future: None }
54+
Self { stream, f, state: UnfoldState::Value { value: initial_state } }
6055
}
6156

6257
delegate_access_inner!(stream, St, ());
@@ -65,8 +60,8 @@ where
6560
impl<B, St, S, Fut, F> Stream for Scan<St, S, Fut, F>
6661
where
6762
St: Stream,
68-
F: FnMut(&mut S, St::Item) -> Fut,
69-
Fut: Future<Output = Option<B>>,
63+
F: FnMut(S, St::Item) -> Fut,
64+
Fut: Future<Output = Option<(S, B)>>,
7065
{
7166
type Item = B;
7267

@@ -78,18 +73,22 @@ where
7873
let mut this = self.project();
7974

8075
Poll::Ready(loop {
81-
if let Some(fut) = this.future.as_mut().as_pin_mut() {
76+
if let Some(fut) = this.state.as_mut().project_future() {
8277
let item = ready!(fut.poll(cx));
83-
this.future.set(None);
8478

85-
if item.is_none() {
86-
*this.state_f = None;
79+
match item {
80+
None => {
81+
this.state.set(UnfoldState::Empty);
82+
break None;
83+
}
84+
Some((state, item)) => {
85+
this.state.set(UnfoldState::Value { value: state });
86+
break Some(item);
87+
}
8788
}
88-
89-
break item;
9089
} else if let Some(item) = ready!(this.stream.as_mut().poll_next(cx)) {
91-
let state_f = this.state_f.as_mut().unwrap();
92-
this.future.set(Some((state_f.f)(&mut state_f.state, item)))
90+
let state = this.state.as_mut().take_value().unwrap();
91+
this.state.set(UnfoldState::Future { future: (this.f)(state, item) })
9392
} else {
9493
break None;
9594
}
@@ -108,11 +107,11 @@ where
108107
impl<B, St, S, Fut, F> FusedStream for Scan<St, S, Fut, F>
109108
where
110109
St: FusedStream,
111-
F: FnMut(&mut S, St::Item) -> Fut,
112-
Fut: Future<Output = Option<B>>,
110+
F: FnMut(S, St::Item) -> Fut,
111+
Fut: Future<Output = Option<(S, B)>>,
113112
{
114113
fn is_terminated(&self) -> bool {
115-
self.is_done_taking() || self.future.is_none() && self.stream.is_terminated()
114+
self.is_done_taking() || !self.state.is_future() && self.stream.is_terminated()
116115
}
117116
}
118117

futures-util/src/unfold_state.rs

+18-4
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,34 @@ pin_project! {
77
#[project = UnfoldStateProj]
88
#[project_replace = UnfoldStateProjReplace]
99
#[derive(Debug)]
10-
pub(crate) enum UnfoldState<T, R> {
10+
pub(crate) enum UnfoldState<T, Fut> {
1111
Value {
1212
value: T,
1313
},
1414
Future {
1515
#[pin]
16-
future: R,
16+
future: Fut,
1717
},
1818
Empty,
1919
}
2020
}
2121

22-
impl<T, R> UnfoldState<T, R> {
23-
pub(crate) fn project_future(self: Pin<&mut Self>) -> Option<Pin<&mut R>> {
22+
impl<T, Fut> UnfoldState<T, Fut> {
23+
pub(crate) fn is_empty(&self) -> bool {
24+
match self {
25+
Self::Empty => true,
26+
_ => false,
27+
}
28+
}
29+
30+
pub(crate) fn is_future(&self) -> bool {
31+
match self {
32+
Self::Future { .. } => true,
33+
_ => false,
34+
}
35+
}
36+
37+
pub(crate) fn project_future(self: Pin<&mut Self>) -> Option<Pin<&mut Fut>> {
2438
match self.project() {
2539
UnfoldStateProj::Future { future } => Some(future),
2640
_ => None,

futures/tests/stream.rs

+20-3
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,32 @@ fn flat_map() {
3838
fn scan() {
3939
block_on(async {
4040
let values = stream::iter(vec![1u8, 2, 3, 4, 6, 8, 2])
41-
.scan(1, |state, e| {
42-
*state += 1;
43-
futures::future::ready(if e < *state { Some(e) } else { None })
41+
.scan(1, |mut state, e| async move {
42+
state += 1;
43+
if e < state {
44+
Some((state, e))
45+
} else {
46+
None
47+
}
4448
})
4549
.collect::<Vec<_>>()
4650
.await;
4751

4852
assert_eq!(values, vec![1u8, 2, 3, 4]);
4953
});
54+
55+
block_on(async {
56+
let mut state = vec![];
57+
let values = stream::iter(vec![1u8, 2, 3, 4, 6, 8, 2])
58+
.scan(&mut state, |state, e| async move {
59+
state.push(e);
60+
Some((state, e))
61+
})
62+
.collect::<Vec<_>>()
63+
.await;
64+
65+
assert_eq!(values, state);
66+
});
5067
}
5168

5269
#[test]

0 commit comments

Comments
 (0)