Skip to content

Commit ca202fe

Browse files
authored
Rollup merge of rust-lang#38983 - APTy:udp-peek, r=aturon
Add peek APIs to std::net Adds "peek" APIs to `std::net` sockets, including: - `UdpSocket.peek()` - `UdpSocket.peek_from()` - `TcpStream.peek()` These methods enable socket reads without side-effects. That is, repeated calls to `peek()` return identical data. This is accomplished by providing the POSIX flag `MSG_PEEK` to the underlying socket read operations. This also moves the current implementation of `recv_from` out of the platform-independent `sys_common` and into respective `sys/windows` and `sys/unix` implementations. This allows for more platform-dependent implementations where necessary. Fixes rust-lang#38980
2 parents c4c6c49 + a40be08 commit ca202fe

File tree

8 files changed

+251
-17
lines changed

8 files changed

+251
-17
lines changed

src/libstd/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@
277277
#![feature(oom)]
278278
#![feature(optin_builtin_traits)]
279279
#![feature(panic_unwind)]
280+
#![feature(peek)]
280281
#![feature(placement_in_syntax)]
281282
#![feature(prelude_import)]
282283
#![feature(pub_restricted)]

src/libstd/net/tcp.rs

+54
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,29 @@ impl TcpStream {
296296
self.0.write_timeout()
297297
}
298298

299+
/// Receives data on the socket from the remote adress to which it is
300+
/// connected, without removing that data from the queue. On success,
301+
/// returns the number of bytes peeked.
302+
///
303+
/// Successive calls return the same data. This is accomplished by passing
304+
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
305+
///
306+
/// # Examples
307+
///
308+
/// ```no_run
309+
/// #![feature(peek)]
310+
/// use std::net::TcpStream;
311+
///
312+
/// let stream = TcpStream::connect("127.0.0.1:8000")
313+
/// .expect("couldn't bind to address");
314+
/// let mut buf = [0; 10];
315+
/// let len = stream.peek(&mut buf).expect("peek failed");
316+
/// ```
317+
#[unstable(feature = "peek", issue = "38980")]
318+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
319+
self.0.peek(buf)
320+
}
321+
299322
/// Sets the value of the `TCP_NODELAY` option on this socket.
300323
///
301324
/// If set, this option disables the Nagle algorithm. This means that
@@ -1406,4 +1429,35 @@ mod tests {
14061429
Err(e) => panic!("unexpected error {}", e),
14071430
}
14081431
}
1432+
1433+
#[test]
1434+
fn peek() {
1435+
each_ip(&mut |addr| {
1436+
let (txdone, rxdone) = channel();
1437+
1438+
let srv = t!(TcpListener::bind(&addr));
1439+
let _t = thread::spawn(move|| {
1440+
let mut cl = t!(srv.accept()).0;
1441+
cl.write(&[1,3,3,7]).unwrap();
1442+
t!(rxdone.recv());
1443+
});
1444+
1445+
let mut c = t!(TcpStream::connect(&addr));
1446+
let mut b = [0; 10];
1447+
for _ in 1..3 {
1448+
let len = c.peek(&mut b).unwrap();
1449+
assert_eq!(len, 4);
1450+
}
1451+
let len = c.read(&mut b).unwrap();
1452+
assert_eq!(len, 4);
1453+
1454+
t!(c.set_nonblocking(true));
1455+
match c.peek(&mut b) {
1456+
Ok(_) => panic!("expected error"),
1457+
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {}
1458+
Err(e) => panic!("unexpected error {}", e),
1459+
}
1460+
t!(txdone.send(()));
1461+
})
1462+
}
14091463
}

src/libstd/net/udp.rs

+97
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,30 @@ impl UdpSocket {
8383
self.0.recv_from(buf)
8484
}
8585

86+
/// Receives data from the socket, without removing it from the queue.
87+
///
88+
/// Successive calls return the same data. This is accomplished by passing
89+
/// `MSG_PEEK` as a flag to the underlying `recvfrom` system call.
90+
///
91+
/// On success, returns the number of bytes peeked and the address from
92+
/// whence the data came.
93+
///
94+
/// # Examples
95+
///
96+
/// ```no_run
97+
/// #![feature(peek)]
98+
/// use std::net::UdpSocket;
99+
///
100+
/// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
101+
/// let mut buf = [0; 10];
102+
/// let (number_of_bytes, src_addr) = socket.peek_from(&mut buf)
103+
/// .expect("Didn't receive data");
104+
/// ```
105+
#[unstable(feature = "peek", issue = "38980")]
106+
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
107+
self.0.peek_from(buf)
108+
}
109+
86110
/// Sends data on the socket to the given address. On success, returns the
87111
/// number of bytes written.
88112
///
@@ -579,6 +603,37 @@ impl UdpSocket {
579603
self.0.recv(buf)
580604
}
581605

606+
/// Receives data on the socket from the remote adress to which it is
607+
/// connected, without removing that data from the queue. On success,
608+
/// returns the number of bytes peeked.
609+
///
610+
/// Successive calls return the same data. This is accomplished by passing
611+
/// `MSG_PEEK` as a flag to the underlying `recv` system call.
612+
///
613+
/// # Errors
614+
///
615+
/// This method will fail if the socket is not connected. The `connect` method
616+
/// will connect this socket to a remote address.
617+
///
618+
/// # Examples
619+
///
620+
/// ```no_run
621+
/// #![feature(peek)]
622+
/// use std::net::UdpSocket;
623+
///
624+
/// let socket = UdpSocket::bind("127.0.0.1:34254").expect("couldn't bind to address");
625+
/// socket.connect("127.0.0.1:8080").expect("connect function failed");
626+
/// let mut buf = [0; 10];
627+
/// match socket.peek(&mut buf) {
628+
/// Ok(received) => println!("received {} bytes", received),
629+
/// Err(e) => println!("peek function failed: {:?}", e),
630+
/// }
631+
/// ```
632+
#[unstable(feature = "peek", issue = "38980")]
633+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
634+
self.0.peek(buf)
635+
}
636+
582637
/// Moves this UDP socket into or out of nonblocking mode.
583638
///
584639
/// On Unix this corresponds to calling fcntl, and on Windows this
@@ -869,6 +924,48 @@ mod tests {
869924
assert_eq!(b"hello world", &buf[..]);
870925
}
871926

927+
#[test]
928+
fn connect_send_peek_recv() {
929+
each_ip(&mut |addr, _| {
930+
let socket = t!(UdpSocket::bind(&addr));
931+
t!(socket.connect(addr));
932+
933+
t!(socket.send(b"hello world"));
934+
935+
for _ in 1..3 {
936+
let mut buf = [0; 11];
937+
let size = t!(socket.peek(&mut buf));
938+
assert_eq!(b"hello world", &buf[..]);
939+
assert_eq!(size, 11);
940+
}
941+
942+
let mut buf = [0; 11];
943+
let size = t!(socket.recv(&mut buf));
944+
assert_eq!(b"hello world", &buf[..]);
945+
assert_eq!(size, 11);
946+
})
947+
}
948+
949+
#[test]
950+
fn peek_from() {
951+
each_ip(&mut |addr, _| {
952+
let socket = t!(UdpSocket::bind(&addr));
953+
t!(socket.send_to(b"hello world", &addr));
954+
955+
for _ in 1..3 {
956+
let mut buf = [0; 11];
957+
let (size, _) = t!(socket.peek_from(&mut buf));
958+
assert_eq!(b"hello world", &buf[..]);
959+
assert_eq!(size, 11);
960+
}
961+
962+
let mut buf = [0; 11];
963+
let (size, _) = t!(socket.recv_from(&mut buf));
964+
assert_eq!(b"hello world", &buf[..]);
965+
assert_eq!(size, 11);
966+
})
967+
}
968+
872969
#[test]
873970
fn ttl() {
874971
let ttl = 100;

src/libstd/sys/unix/net.rs

+42-3
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010

1111
use ffi::CStr;
1212
use io;
13-
use libc::{self, c_int, size_t, sockaddr, socklen_t, EAI_SYSTEM};
13+
use libc::{self, c_int, c_void, size_t, sockaddr, socklen_t, EAI_SYSTEM, MSG_PEEK};
14+
use mem;
1415
use net::{SocketAddr, Shutdown};
1516
use str;
1617
use sys::fd::FileDesc;
1718
use sys_common::{AsInner, FromInner, IntoInner};
18-
use sys_common::net::{getsockopt, setsockopt};
19+
use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr};
1920
use time::Duration;
2021

2122
pub use sys::{cvt, cvt_r};
@@ -155,8 +156,46 @@ impl Socket {
155156
self.0.duplicate().map(Socket)
156157
}
157158

159+
fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
160+
let ret = cvt(unsafe {
161+
libc::recv(self.0.raw(),
162+
buf.as_mut_ptr() as *mut c_void,
163+
buf.len(),
164+
flags)
165+
})?;
166+
Ok(ret as usize)
167+
}
168+
158169
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
159-
self.0.read(buf)
170+
self.recv_with_flags(buf, 0)
171+
}
172+
173+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
174+
self.recv_with_flags(buf, MSG_PEEK)
175+
}
176+
177+
fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
178+
-> io::Result<(usize, SocketAddr)> {
179+
let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
180+
let mut addrlen = mem::size_of_val(&storage) as libc::socklen_t;
181+
182+
let n = cvt(unsafe {
183+
libc::recvfrom(self.0.raw(),
184+
buf.as_mut_ptr() as *mut c_void,
185+
buf.len(),
186+
flags,
187+
&mut storage as *mut _ as *mut _,
188+
&mut addrlen)
189+
})?;
190+
Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
191+
}
192+
193+
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
194+
self.recv_from_with_flags(buf, 0)
195+
}
196+
197+
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
198+
self.recv_from_with_flags(buf, MSG_PEEK)
160199
}
161200

162201
pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {

src/libstd/sys/windows/c.rs

+1
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ pub const IP_ADD_MEMBERSHIP: c_int = 12;
246246
pub const IP_DROP_MEMBERSHIP: c_int = 13;
247247
pub const IPV6_ADD_MEMBERSHIP: c_int = 12;
248248
pub const IPV6_DROP_MEMBERSHIP: c_int = 13;
249+
pub const MSG_PEEK: c_int = 0x2;
249250

250251
#[repr(C)]
251252
pub struct ip_mreq {

src/libstd/sys/windows/net.rs

+42-2
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,59 @@ impl Socket {
147147
Ok(socket)
148148
}
149149

150-
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
150+
fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result<usize> {
151151
// On unix when a socket is shut down all further reads return 0, so we
152152
// do the same on windows to map a shut down socket to returning EOF.
153153
let len = cmp::min(buf.len(), i32::max_value() as usize) as i32;
154154
unsafe {
155-
match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, 0) {
155+
match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, flags) {
156156
-1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0),
157157
-1 => Err(last_error()),
158158
n => Ok(n as usize)
159159
}
160160
}
161161
}
162162

163+
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
164+
self.recv_with_flags(buf, 0)
165+
}
166+
167+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
168+
self.recv_with_flags(buf, c::MSG_PEEK)
169+
}
170+
171+
fn recv_from_with_flags(&self, buf: &mut [u8], flags: c_int)
172+
-> io::Result<(usize, SocketAddr)> {
173+
let mut storage: c::SOCKADDR_STORAGE_LH = unsafe { mem::zeroed() };
174+
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
175+
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
176+
177+
// On unix when a socket is shut down all further reads return 0, so we
178+
// do the same on windows to map a shut down socket to returning EOF.
179+
unsafe {
180+
match c::recvfrom(self.0,
181+
buf.as_mut_ptr() as *mut c_void,
182+
len,
183+
flags,
184+
&mut storage as *mut _ as *mut _,
185+
&mut addrlen) {
186+
-1 if c::WSAGetLastError() == c::WSAESHUTDOWN => {
187+
Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?))
188+
},
189+
-1 => Err(last_error()),
190+
n => Ok((n as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)),
191+
}
192+
}
193+
}
194+
195+
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
196+
self.recv_from_with_flags(buf, 0)
197+
}
198+
199+
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
200+
self.recv_from_with_flags(buf, c::MSG_PEEK)
201+
}
202+
163203
pub fn read_to_end(&self, buf: &mut Vec<u8>) -> io::Result<usize> {
164204
let mut me = self;
165205
(&mut me).read_to_end(buf)

src/libstd/sys_common/net.rs

+13-11
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ fn sockname<F>(f: F) -> io::Result<SocketAddr>
9191
}
9292
}
9393

94-
fn sockaddr_to_addr(storage: &c::sockaddr_storage,
94+
pub fn sockaddr_to_addr(storage: &c::sockaddr_storage,
9595
len: usize) -> io::Result<SocketAddr> {
9696
match storage.ss_family as c_int {
9797
c::AF_INET => {
@@ -222,6 +222,10 @@ impl TcpStream {
222222
self.inner.timeout(c::SO_SNDTIMEO)
223223
}
224224

225+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
226+
self.inner.peek(buf)
227+
}
228+
225229
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
226230
self.inner.read(buf)
227231
}
@@ -441,17 +445,11 @@ impl UdpSocket {
441445
}
442446

443447
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
444-
let mut storage: c::sockaddr_storage = unsafe { mem::zeroed() };
445-
let mut addrlen = mem::size_of_val(&storage) as c::socklen_t;
446-
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
448+
self.inner.recv_from(buf)
449+
}
447450

448-
let n = cvt(unsafe {
449-
c::recvfrom(*self.inner.as_inner(),
450-
buf.as_mut_ptr() as *mut c_void,
451-
len, 0,
452-
&mut storage as *mut _ as *mut _, &mut addrlen)
453-
})?;
454-
Ok((n as usize, sockaddr_to_addr(&storage, addrlen as usize)?))
451+
pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
452+
self.inner.peek_from(buf)
455453
}
456454

457455
pub fn send_to(&self, buf: &[u8], dst: &SocketAddr) -> io::Result<usize> {
@@ -578,6 +576,10 @@ impl UdpSocket {
578576
self.inner.read(buf)
579577
}
580578

579+
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
580+
self.inner.peek(buf)
581+
}
582+
581583
pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
582584
let len = cmp::min(buf.len(), <wrlen_t>::max_value() as usize) as wrlen_t;
583585
let ret = cvt(unsafe {

0 commit comments

Comments
 (0)