Skip to content

Commit daf1991

Browse files
committed
feat(futures-util/stream): implement stream.unzip adapter
closes: rust-lang#2234
1 parent a570781 commit daf1991

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed

futures-util/src/stream/stream/mod.rs

+38
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ mod zip;
146146
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
147147
pub use self::zip::Zip;
148148

149+
mod unzip;
150+
#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
151+
pub use self::unzip::{unzip, UnzipLeft, UnzipRight};
152+
149153
#[cfg(feature = "alloc")]
150154
mod chunks;
151155
#[cfg(feature = "alloc")]
@@ -1182,6 +1186,40 @@ pub trait StreamExt: Stream {
11821186
assert_stream::<(Self::Item, St::Item), _>(Zip::new(self, other))
11831187
}
11841188

1189+
/// An adapter for unzipping a stream of tuples (T1, T2).
1190+
///
1191+
/// Returns two streams, left stream<Item = T1> and right stream<Item = T2>.
1192+
/// You can drop one of them and the other will still work. Underlying stream
1193+
/// Will be dropped only when both of the child streams are dropped.
1194+
///
1195+
/// # Examples
1196+
///
1197+
/// ```
1198+
/// # futures::executor::block_on(async {
1199+
/// use futures::stream::{self, StreamExt};
1200+
///
1201+
/// let stream = stream::iter(vec![(1, 2), (3, 4), (5, 6), (7, 8)]);
1202+
///
1203+
/// let (left, right) = stream.unzip();
1204+
/// let left = left.collect::<Vec<_>>().await;
1205+
/// let right = right.collect::<Vec<_>>().await;
1206+
/// assert_eq!(vec![1, 3, 5, 7], left);
1207+
/// assert_eq!(vec![2, 4, 6, 8], right);
1208+
/// # });
1209+
/// ```
1210+
///
1211+
fn unzip<T1, T2>(self) -> (UnzipLeft<Self, T1, T2>, UnzipRight<Self, T1, T2>)
1212+
where
1213+
Self: Stream<Item = (T1, T2)>,
1214+
Self: Sized,
1215+
{
1216+
let (left, right) = unzip(self);
1217+
(
1218+
assert_stream::<T1, _>(left),
1219+
assert_stream::<T2, _>(right),
1220+
)
1221+
}
1222+
11851223
/// Adapter for chaining two streams.
11861224
///
11871225
/// The resulting stream emits elements from the first stream, and when
+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
use crate::task::AtomicWaker;
2+
use alloc::sync::{Arc, Weak};
3+
use core::pin::Pin;
4+
use futures_core::stream::{FusedStream, Stream};
5+
use futures_core::task::{Context, Poll};
6+
use pin_project::{pin_project, pinned_drop};
7+
use std::sync::mpsc;
8+
9+
/// SAFETY: safe because only one of two unzipped streams is guaranteed
10+
/// to be accessing underlying stream. This is guaranteed by mpsc. Right
11+
/// stream will access underlying stream only if Sender (or left stream)
12+
/// is dropped in which case try_recv returns disconnected error.
13+
unsafe fn poll_unzipped<S, T1, T2>(
14+
stream: Pin<&mut Arc<S>>,
15+
cx: &mut Context<'_>,
16+
) -> Poll<Option<S::Item>>
17+
where
18+
S: Stream<Item = (T1, T2)>,
19+
{
20+
stream
21+
.map_unchecked_mut(|x| &mut *(Arc::as_ptr(x) as *mut S))
22+
.poll_next(cx)
23+
}
24+
25+
#[pin_project(PinnedDrop)]
26+
#[derive(Debug)]
27+
#[must_use = "streams do nothing unless polled"]
28+
pub struct UnzipLeft<S, T1, T2>
29+
where
30+
S: Stream<Item = (T1, T2)>,
31+
{
32+
#[pin]
33+
stream: Arc<S>,
34+
right_waker: Weak<AtomicWaker>,
35+
right_queue: mpsc::Sender<Option<T2>>,
36+
}
37+
38+
impl<S, T1, T2> UnzipLeft<S, T1, T2>
39+
where
40+
S: Stream<Item = (T1, T2)>,
41+
{
42+
fn send_to_right(&self, value: Option<T2>) {
43+
if let Some(right_waker) = self.right_waker.upgrade() {
44+
// if right_waker.upgrade() succeeds, then right is not
45+
// dropped so send won't fail.
46+
let _ = self.right_queue.send(value);
47+
right_waker.wake();
48+
}
49+
}
50+
}
51+
52+
impl<S, T1, T2> FusedStream for UnzipLeft<S, T1, T2>
53+
where
54+
S: Stream<Item = (T1, T2)> + FusedStream,
55+
{
56+
fn is_terminated(&self) -> bool {
57+
self.stream.as_ref().is_terminated()
58+
}
59+
}
60+
61+
impl<S, T1, T2> Stream for UnzipLeft<S, T1, T2>
62+
where
63+
S: Stream<Item = (T1, T2)>,
64+
{
65+
type Item = T1;
66+
67+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
68+
let this = self.as_mut().project();
69+
70+
// SAFETY: for safety details see comment for function: poll_unzipped
71+
if let Some(value) = ready!(unsafe { poll_unzipped(this.stream, cx) }) {
72+
self.send_to_right(Some(value.1));
73+
return Poll::Ready(Some(value.0));
74+
}
75+
self.send_to_right(None);
76+
Poll::Ready(None)
77+
}
78+
}
79+
80+
#[pinned_drop]
81+
impl<S, T1, T2> PinnedDrop for UnzipLeft<S, T1, T2>
82+
where
83+
S: Stream<Item = (T1, T2)>,
84+
{
85+
fn drop(self: Pin<&mut Self>) {
86+
let this = self.project();
87+
// wake right stream if it isn't dropped
88+
if let Some(right_waker) = this.right_waker.upgrade() {
89+
drop(this.stream);
90+
right_waker.wake();
91+
}
92+
}
93+
}
94+
95+
#[pin_project]
96+
#[derive(Debug)]
97+
#[must_use = "streams do nothing unless polled"]
98+
pub struct UnzipRight<S, T1, T2>
99+
where
100+
S: Stream<Item = (T1, T2)>,
101+
{
102+
#[pin]
103+
stream: Arc<S>,
104+
waker: Arc<AtomicWaker>,
105+
queue: mpsc::Receiver<Option<T2>>,
106+
}
107+
108+
impl<S, T1, T2> FusedStream for UnzipRight<S, T1, T2>
109+
where
110+
S: FusedStream<Item = (T1, T2)>,
111+
{
112+
fn is_terminated(&self) -> bool {
113+
self.stream.as_ref().is_terminated()
114+
}
115+
}
116+
117+
impl<S, T1, T2> Stream for UnzipRight<S, T1, T2>
118+
where
119+
S: Stream<Item = (T1, T2)>,
120+
{
121+
type Item = T2;
122+
123+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
124+
let this = self.project();
125+
this.waker.register(&cx.waker().clone());
126+
127+
match this.queue.try_recv() {
128+
Ok(value) => {
129+
// can't know if more items are in the queue so wake the task
130+
// again while there are items. Will cause extra wake though.
131+
cx.waker().clone().wake();
132+
Poll::Ready(value)
133+
}
134+
Err(mpsc::TryRecvError::Disconnected) => {
135+
// if left is dropped, it is no longer polling the base stream
136+
// so right should poll it instead.
137+
// SAFETY: for safety details see comment for function: poll_unzipped
138+
if let Some(value) = ready!(unsafe { poll_unzipped(this.stream, cx) }) {
139+
return Poll::Ready(Some(value.1));
140+
}
141+
Poll::Ready(None)
142+
}
143+
_ => Poll::Pending,
144+
}
145+
}
146+
}
147+
148+
pub fn unzip<S, T1, T2>(stream: S) -> (UnzipLeft<S, T1, T2>, UnzipRight<S, T1, T2>)
149+
where
150+
S: Stream<Item = (T1, T2)>,
151+
{
152+
let base_stream = Arc::new(stream);
153+
let waker = Arc::new(AtomicWaker::new());
154+
let (tx, rx) = mpsc::channel::<Option<T2>>();
155+
156+
(
157+
UnzipLeft {
158+
stream: base_stream.clone(),
159+
right_waker: Arc::downgrade(&waker),
160+
right_queue: tx,
161+
},
162+
UnzipRight {
163+
stream: base_stream.clone(),
164+
waker: waker,
165+
queue: rx,
166+
},
167+
)
168+
}

0 commit comments

Comments
 (0)