Skip to content

Commit 03ca6c6

Browse files
committed
make UnixStream Clone
1 parent e20b0f0 commit 03ca6c6

File tree

3 files changed

+43
-11
lines changed

3 files changed

+43
-11
lines changed

Diff for: src/os/unix/net/listener.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::io;
1313
use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
1414
use crate::path::Path;
1515
use crate::stream::Stream;
16+
use crate::sync::Arc;
1617
use crate::task::{Context, Poll};
1718

1819
/// A Unix domain socket server, listening for connections.
@@ -92,7 +93,7 @@ impl UnixListener {
9293
pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> {
9394
let (stream, addr) = self.watcher.accept().await?;
9495

95-
Ok((UnixStream { watcher: stream }, addr))
96+
Ok((UnixStream { watcher: Arc::new(stream) }, addr))
9697
}
9798

9899
/// Returns a stream of incoming connections.

Diff for: src/os/unix/net/stream.rs

+16-10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use super::SocketAddr;
1111
use crate::io::{self, Read, Write};
1212
use crate::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
1313
use crate::path::Path;
14+
use crate::sync::Arc;
1415
use crate::task::{Context, Poll};
1516

1617
/// A Unix stream socket.
@@ -36,8 +37,9 @@ use crate::task::{Context, Poll};
3637
/// #
3738
/// # Ok(()) }) }
3839
/// ```
40+
#[derive(Clone)]
3941
pub struct UnixStream {
40-
pub(super) watcher: Async<StdUnixStream>,
42+
pub(super) watcher: Arc<Async<StdUnixStream>>,
4143
}
4244

4345
impl UnixStream {
@@ -56,7 +58,7 @@ impl UnixStream {
5658
/// ```
5759
pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> {
5860
let path = path.as_ref().to_owned();
59-
let stream = Async::<StdUnixStream>::connect(path).await?;
61+
let stream = Arc::new(Async::<StdUnixStream>::connect(path).await?);
6062

6163
Ok(UnixStream { watcher: stream })
6264
}
@@ -78,8 +80,12 @@ impl UnixStream {
7880
/// ```
7981
pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
8082
let (a, b) = Async::<StdUnixStream>::pair()?;
81-
let a = UnixStream { watcher: a };
82-
let b = UnixStream { watcher: b };
83+
let a = UnixStream {
84+
watcher: Arc::new(a),
85+
};
86+
let b = UnixStream {
87+
watcher: Arc::new(b),
88+
};
8389
Ok((a, b))
8490
}
8591

@@ -158,7 +164,7 @@ impl Read for &UnixStream {
158164
cx: &mut Context<'_>,
159165
buf: &mut [u8],
160166
) -> Poll<io::Result<usize>> {
161-
Pin::new(&mut &self.watcher).poll_read(cx, buf)
167+
Pin::new(&mut &*self.watcher).poll_read(cx, buf)
162168
}
163169
}
164170

@@ -186,15 +192,15 @@ impl Write for &UnixStream {
186192
cx: &mut Context<'_>,
187193
buf: &[u8],
188194
) -> Poll<io::Result<usize>> {
189-
Pin::new(&mut &self.watcher).poll_write(cx, buf)
195+
Pin::new(&mut &*self.watcher).poll_write(cx, buf)
190196
}
191197

192198
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
193-
Pin::new(&mut &self.watcher).poll_flush(cx)
199+
Pin::new(&mut &*self.watcher).poll_flush(cx)
194200
}
195201

196202
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
197-
Pin::new(&mut &self.watcher).poll_close(cx)
203+
Pin::new(&mut &*self.watcher).poll_close(cx)
198204
}
199205
}
200206

@@ -219,7 +225,7 @@ impl From<StdUnixStream> for UnixStream {
219225
/// Converts a `std::os::unix::net::UnixStream` into its asynchronous equivalent.
220226
fn from(stream: StdUnixStream) -> UnixStream {
221227
let stream = Async::new(stream).expect("UnixStream is known to be good");
222-
UnixStream { watcher: stream }
228+
UnixStream { watcher: Arc::new(stream) }
223229
}
224230
}
225231

@@ -238,6 +244,6 @@ impl FromRawFd for UnixStream {
238244

239245
impl IntoRawFd for UnixStream {
240246
fn into_raw_fd(self) -> RawFd {
241-
self.watcher.into_raw_fd()
247+
self.as_raw_fd()
242248
}
243249
}

Diff for: tests/uds.rs

+25
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,28 @@ async fn ping_pong_client(socket: &std::path::PathBuf, iterations: u32) -> std::
9494
}
9595
Ok(())
9696
}
97+
98+
99+
#[test]
100+
fn uds_clone() -> io::Result<()> {
101+
task::block_on(async {
102+
let tmp_dir = TempDir::new("socket_ping_pong").expect("Temp dir not created");
103+
let sock_path = tmp_dir.as_ref().join("sock");
104+
let input = UnixListener::bind(&sock_path).await?;
105+
106+
let mut writer = UnixStream::connect(&sock_path).await?;
107+
let mut reader = input.incoming().next().await.unwrap()?;
108+
109+
writer.write(b"original").await.unwrap();
110+
let mut original_buf = [0;8];
111+
reader.read(&mut original_buf).await?;
112+
assert_eq!(&original_buf, b"original");
113+
114+
writer.clone().write(b"clone").await.unwrap();
115+
let mut clone_buf = [0;5];
116+
reader.clone().read(&mut clone_buf).await?;
117+
assert_eq!(&clone_buf, b"clone");
118+
119+
Ok(())
120+
})
121+
}

0 commit comments

Comments
 (0)