Skip to content

Commit 9711c45

Browse files
taiki-ecramertj
authored andcommitted
Add AsyncReadExt::chain
1 parent 52499d4 commit 9711c45

File tree

3 files changed

+193
-7
lines changed

3 files changed

+193
-7
lines changed

futures-util/src/io/chain.rs

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
use futures_core::task::{Context, Poll};
2+
use futures_io::{AsyncBufRead, AsyncRead, Initializer, IoSliceMut};
3+
use pin_utils::{unsafe_pinned, unsafe_unpinned};
4+
use std::fmt;
5+
use std::io;
6+
use std::pin::Pin;
7+
8+
/// Stream for the [`chain`](super::AsyncReadExt::chain) method.
9+
#[must_use = "streams do nothing unless polled"]
10+
pub struct Chain<T, U> {
11+
first: T,
12+
second: U,
13+
done_first: bool,
14+
}
15+
16+
impl<T, U> Unpin for Chain<T, U>
17+
where
18+
T: Unpin,
19+
U: Unpin,
20+
{
21+
}
22+
23+
impl<T, U> Chain<T, U>
24+
where
25+
T: AsyncRead,
26+
U: AsyncRead,
27+
{
28+
unsafe_pinned!(first: T);
29+
unsafe_pinned!(second: U);
30+
unsafe_unpinned!(done_first: bool);
31+
32+
pub(super) fn new(first: T, second: U) -> Self {
33+
Self {
34+
first,
35+
second,
36+
done_first: false,
37+
}
38+
}
39+
40+
/// Consumes the `Chain`, returning the wrapped readers.
41+
pub fn into_inner(self) -> (T, U) {
42+
(self.first, self.second)
43+
}
44+
45+
/// Gets references to the underlying readers in this `Chain`.
46+
pub fn get_ref(&self) -> (&T, &U) {
47+
(&self.first, &self.second)
48+
}
49+
50+
/// Gets mutable references to the underlying readers in this `Chain`.
51+
///
52+
/// Care should be taken to avoid modifying the internal I/O state of the
53+
/// underlying readers as doing so may corrupt the internal state of this
54+
/// `Chain`.
55+
pub fn get_mut(&mut self) -> (&mut T, &mut U) {
56+
(&mut self.first, &mut self.second)
57+
}
58+
}
59+
60+
impl<T, U> fmt::Debug for Chain<T, U>
61+
where
62+
T: fmt::Debug,
63+
U: fmt::Debug,
64+
{
65+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66+
f.debug_struct("Chain")
67+
.field("t", &self.first)
68+
.field("u", &self.second)
69+
.finish()
70+
}
71+
}
72+
73+
impl<T, U> AsyncRead for Chain<T, U>
74+
where
75+
T: AsyncRead,
76+
U: AsyncRead,
77+
{
78+
fn poll_read(
79+
mut self: Pin<&mut Self>,
80+
cx: &mut Context<'_>,
81+
buf: &mut [u8],
82+
) -> Poll<io::Result<usize>> {
83+
if !self.done_first {
84+
match ready!(self.as_mut().first().poll_read(cx, buf)?) {
85+
0 if !buf.is_empty() => *self.as_mut().done_first() = true,
86+
n => return Poll::Ready(Ok(n)),
87+
}
88+
}
89+
self.second().poll_read(cx, buf)
90+
}
91+
92+
fn poll_read_vectored(
93+
mut self: Pin<&mut Self>,
94+
cx: &mut Context<'_>,
95+
bufs: &mut [IoSliceMut<'_>],
96+
) -> Poll<io::Result<usize>> {
97+
if !self.done_first {
98+
let n = ready!(self.as_mut().first().poll_read_vectored(cx, bufs)?);
99+
if n == 0 && bufs.iter().any(|b| !b.is_empty()) {
100+
*self.as_mut().done_first() = true
101+
} else {
102+
return Poll::Ready(Ok(n));
103+
}
104+
}
105+
self.second().poll_read_vectored(cx, bufs)
106+
}
107+
108+
unsafe fn initializer(&self) -> Initializer {
109+
let initializer = self.first.initializer();
110+
if initializer.should_initialize() {
111+
initializer
112+
} else {
113+
self.second.initializer()
114+
}
115+
}
116+
}
117+
118+
impl<T, U> AsyncBufRead for Chain<T, U>
119+
where
120+
T: AsyncBufRead,
121+
U: AsyncBufRead,
122+
{
123+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
124+
let Self {
125+
first,
126+
second,
127+
done_first,
128+
} = unsafe { self.get_unchecked_mut() };
129+
let first = unsafe { Pin::new_unchecked(first) };
130+
let second = unsafe { Pin::new_unchecked(second) };
131+
132+
if !*done_first {
133+
match ready!(first.poll_fill_buf(cx)?) {
134+
buf if buf.is_empty() => {
135+
*done_first = true;
136+
}
137+
buf => return Poll::Ready(Ok(buf)),
138+
}
139+
}
140+
second.poll_fill_buf(cx)
141+
}
142+
143+
fn consume(self: Pin<&mut Self>, amt: usize) {
144+
if !self.done_first {
145+
self.first().consume(amt)
146+
} else {
147+
self.second().consume(amt)
148+
}
149+
}
150+
}

futures-util/src/io/mod.rs

+39-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ pub use self::buf_reader::BufReader;
2929
mod buf_writer;
3030
pub use self::buf_writer::BufWriter;
3131

32+
mod chain;
33+
pub use self::chain::Chain;
34+
35+
mod close;
36+
pub use self::close::Close;
37+
3238
mod copy_into;
3339
pub use self::copy_into::CopyInto;
3440

@@ -67,9 +73,6 @@ pub use self::read_to_string::ReadToString;
6773
mod read_until;
6874
pub use self::read_until::ReadUntil;
6975

70-
mod close;
71-
pub use self::close::Close;
72-
7376
mod seek;
7477
pub use self::seek::Seek;
7578

@@ -93,6 +96,39 @@ pub use self::write_all::WriteAll;
9396

9497
/// An extension trait which adds utility methods to `AsyncRead` types.
9598
pub trait AsyncReadExt: AsyncRead {
99+
/// Creates an adaptor which will chain this stream with another.
100+
///
101+
/// The returned `AsyncRead` instance will first read all bytes from this object
102+
/// until EOF is encountered. Afterwards the output is equivalent to the
103+
/// output of `next`.
104+
///
105+
/// # Examples
106+
///
107+
/// ```
108+
/// #![feature(async_await)]
109+
/// # futures::executor::block_on(async {
110+
/// use futures::io::AsyncReadExt;
111+
/// use std::io::Cursor;
112+
///
113+
/// let reader1 = Cursor::new([1, 2, 3, 4]);
114+
/// let reader2 = Cursor::new([5, 6, 7, 8]);
115+
///
116+
/// let mut reader = reader1.chain(reader2);
117+
/// let mut buffer = Vec::new();
118+
///
119+
/// // read the value into a Vec.
120+
/// reader.read_to_end(&mut buffer).await?;
121+
/// assert_eq!(buffer, [1, 2, 3, 4, 5, 6, 7, 8]);
122+
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
123+
/// ```
124+
fn chain<R>(self, next: R) -> Chain<Self, R>
125+
where
126+
Self: Sized,
127+
R: AsyncRead,
128+
{
129+
Chain::new(self, next)
130+
}
131+
96132
/// Creates a future which copies all the bytes from one object to another.
97133
///
98134
/// The returned future will copy all the bytes read from this `AsyncRead` into the

futures/src/lib.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,10 @@ pub mod io {
303303

304304
pub use futures_util::io::{
305305
AsyncReadExt, AsyncWriteExt, AsyncSeekExt, AsyncBufReadExt, AllowStdIo,
306-
BufReader, BufWriter, Close, CopyInto, CopyBufInto, Flush, IntoSink,
307-
Lines, Read, ReadExact, ReadHalf, ReadLine, ReadToEnd, ReadToString,
308-
ReadUntil, ReadVectored, Seek, Window, Write, WriteAll, WriteHalf,
309-
WriteVectored,
306+
BufReader, BufWriter, Chain, Close, CopyInto, CopyBufInto, Flush,
307+
IntoSink, Lines, Read, ReadExact, ReadHalf, ReadLine, ReadToEnd,
308+
ReadToString, ReadUntil, ReadVectored, Seek, Take, Window, Write,
309+
WriteAll, WriteHalf, WriteVectored,
310310
};
311311
}
312312

0 commit comments

Comments
 (0)