diff --git a/examples/a-chat/server.rs b/examples/a-chat/server.rs index e049a490e..85e5027d2 100644 --- a/examples/a-chat/server.rs +++ b/examples/a-chat/server.rs @@ -1,7 +1,4 @@ -use std::{ - collections::hash_map::{Entry, HashMap}, - sync::Arc, -}; +use std::collections::hash_map::{Entry, HashMap}; use futures::{channel::mpsc, select, FutureExt, SinkExt}; @@ -40,8 +37,7 @@ async fn accept_loop(addr: impl ToSocketAddrs) -> Result<()> { } async fn connection_loop(mut broker: Sender, stream: TcpStream) -> Result<()> { - let stream = Arc::new(stream); - let reader = BufReader::new(&*stream); + let reader = BufReader::new(&stream); let mut lines = reader.lines(); let name = match lines.next().await { @@ -52,7 +48,7 @@ async fn connection_loop(mut broker: Sender, stream: TcpStream) -> Result broker .send(Event::NewPeer { name: name.clone(), - stream: Arc::clone(&stream), + stream: stream.try_clone()?, shutdown: shutdown_receiver, }) .await @@ -85,10 +81,9 @@ async fn connection_loop(mut broker: Sender, stream: TcpStream) -> Result async fn connection_writer_loop( messages: &mut Receiver, - stream: Arc, + stream: &mut TcpStream, mut shutdown: Receiver, ) -> Result<()> { - let mut stream = &*stream; loop { select! { msg = messages.next().fuse() => match msg { @@ -108,7 +103,7 @@ async fn connection_writer_loop( enum Event { NewPeer { name: String, - stream: Arc, + stream: TcpStream, shutdown: Receiver, }, Message { @@ -146,7 +141,7 @@ async fn broker_loop(mut events: Receiver) { } Event::NewPeer { name, - stream, + mut stream, shutdown, } => match peers.entry(name.clone()) { Entry::Occupied(..) => (), @@ -156,7 +151,8 @@ async fn broker_loop(mut events: Receiver) { let mut disconnect_sender = disconnect_sender.clone(); spawn_and_log_error(async move { let res = - connection_writer_loop(&mut client_receiver, stream, shutdown).await; + connection_writer_loop(&mut client_receiver, &mut stream, shutdown) + .await; disconnect_sender .send((name, client_receiver)) .await diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index ae8ca7dc8..46895c5ac 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -261,6 +261,29 @@ impl TcpStream { self.watcher.get_ref().set_nodelay(nodelay) } + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned TcpStream is a reference to the same stream that this object references. + /// Both handles will read and write the same stream of data, and options set on one stream + /// will be propagated to the other stream. + /// + /// # Examples + /// + /// ```no_run + /// # fn main() -> std::io::Result<()> { async_std::task::block_on(async { + /// # + /// use async_std::net::TcpStream; + /// + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// let cloned_stream = stream.try_clone()?; + /// # + /// # Ok(()) }) } + pub fn try_clone(&self) -> io::Result { + Ok(TcpStream { + watcher: Watcher::new(self.watcher.get_ref().try_clone()?) + }) + } + /// Shuts down the read, write, or both halves of this connection. /// /// This method will cause all pending and future I/O on the specified portions to return diff --git a/tests/tcp.rs b/tests/tcp.rs index 00fa3a045..e53fb8a05 100644 --- a/tests/tcp.rs +++ b/tests/tcp.rs @@ -94,3 +94,25 @@ fn smoke_async_stream_to_std_listener() -> io::Result<()> { Ok(()) } + +#[test] +fn cloned_streams() -> io::Result<()> { + task::block_on(async { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let mut stream = TcpStream::connect(&addr).await?; + let mut cloned_stream = stream.try_clone()?; + let mut incoming = listener.incoming(); + let mut write_stream = incoming.next().await.unwrap()?; + write_stream.write_all(b"Each your doing").await?; + + let mut buf = [0; 15]; + stream.read_exact(&mut buf[..8]).await?; + cloned_stream.read_exact(&mut buf[8..]).await?; + + assert_eq!(&buf[..15], b"Each your doing"); + + Ok(()) + }) +}