Skip to content

Add try_clone for TcpStream #642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions examples/a-chat/server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use std::{
collections::hash_map::{Entry, HashMap},
sync::Arc,
};
use std::collections::hash_map::{Entry, HashMap};

use futures::{channel::mpsc, select, FutureExt, SinkExt};

Expand Down Expand Up @@ -40,8 +37,7 @@ async fn accept_loop(addr: impl ToSocketAddrs) -> Result<()> {
}

async fn connection_loop(mut broker: Sender<Event>, stream: TcpStream) -> Result<()> {
let stream = Arc::new(stream);
let reader = BufReader::new(&*stream);
let reader = BufReader::new(&stream);
let mut lines = reader.lines();

let name = match lines.next().await {
Expand All @@ -52,7 +48,7 @@ async fn connection_loop(mut broker: Sender<Event>, stream: TcpStream) -> Result
broker
.send(Event::NewPeer {
name: name.clone(),
stream: Arc::clone(&stream),
stream: stream.try_clone()?,
shutdown: shutdown_receiver,
})
.await
Expand Down Expand Up @@ -85,10 +81,9 @@ async fn connection_loop(mut broker: Sender<Event>, stream: TcpStream) -> Result

async fn connection_writer_loop(
messages: &mut Receiver<String>,
stream: Arc<TcpStream>,
stream: &mut TcpStream,
mut shutdown: Receiver<Void>,
) -> Result<()> {
let mut stream = &*stream;
loop {
select! {
msg = messages.next().fuse() => match msg {
Expand All @@ -108,7 +103,7 @@ async fn connection_writer_loop(
enum Event {
NewPeer {
name: String,
stream: Arc<TcpStream>,
stream: TcpStream,
shutdown: Receiver<Void>,
},
Message {
Expand Down Expand Up @@ -146,7 +141,7 @@ async fn broker_loop(mut events: Receiver<Event>) {
}
Event::NewPeer {
name,
stream,
mut stream,
shutdown,
} => match peers.entry(name.clone()) {
Entry::Occupied(..) => (),
Expand All @@ -156,7 +151,8 @@ async fn broker_loop(mut events: Receiver<Event>) {
let mut disconnect_sender = disconnect_sender.clone();
spawn_and_log_error(async move {
let res =
connection_writer_loop(&mut client_receiver, stream, shutdown).await;
connection_writer_loop(&mut client_receiver, &mut stream, shutdown)
.await;
disconnect_sender
.send((name, client_receiver))
.await
Expand Down
23 changes: 23 additions & 0 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,29 @@ impl TcpStream {
self.watcher.get_ref().set_nodelay(nodelay)
}

/// Creates a new independently owned handle to the underlying socket.
///
/// The returned TcpStream is a reference to the same stream that this object references.
/// Both handles will read and write the same stream of data, and options set on one stream
/// will be propagated to the other stream.
///
/// # Examples
///
/// ```no_run
/// # fn main() -> std::io::Result<()> { async_std::task::block_on(async {
/// #
/// use async_std::net::TcpStream;
///
/// let stream = TcpStream::connect("127.0.0.1:8080").await?;
/// let cloned_stream = stream.try_clone()?;
/// #
/// # Ok(()) }) }
pub fn try_clone(&self) -> io::Result<TcpStream> {
Ok(TcpStream {
watcher: Watcher::new(self.watcher.get_ref().try_clone()?)
})
}

/// Shuts down the read, write, or both halves of this connection.
///
/// This method will cause all pending and future I/O on the specified portions to return
Expand Down
22 changes: 22 additions & 0 deletions tests/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,25 @@ fn smoke_async_stream_to_std_listener() -> io::Result<()> {

Ok(())
}

#[test]
fn cloned_streams() -> io::Result<()> {
task::block_on(async {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;

let mut stream = TcpStream::connect(&addr).await?;
let mut cloned_stream = stream.try_clone()?;
let mut incoming = listener.incoming();
let mut write_stream = incoming.next().await.unwrap()?;
write_stream.write_all(b"Each your doing").await?;

let mut buf = [0; 15];
stream.read_exact(&mut buf[..8]).await?;
cloned_stream.read_exact(&mut buf[8..]).await?;

assert_eq!(&buf[..15], b"Each your doing");

Ok(())
})
}