Skip to content

Commit 8ca4091

Browse files
abonanderThomasdezeeuw
authored andcommitted
Add Socket::peek_sender()
Signed-off-by: Austin Bonander <[email protected]>
1 parent 2c4468f commit 8ca4091

File tree

4 files changed

+81
-0
lines changed

4 files changed

+81
-0
lines changed

src/socket.rs

+25
Original file line numberDiff line numberDiff line change
@@ -596,11 +596,36 @@ impl Socket {
596596
/// `peek_from` makes the same safety guarantees regarding the `buf`fer as
597597
/// [`recv`].
598598
///
599+
/// # Note: Datagram Sockets
600+
/// For datagram sockets, the behavior of this method when `buf` is smaller than
601+
/// the datagram at the head of the receive queue differs between Windows and
602+
/// Unix-like platforms (Linux, macOS, BSDs, etc: colloquially termed "*nix").
603+
///
604+
/// On *nix platforms, the datagram is truncated to the length of `buf`.
605+
///
606+
/// On Windows, an error corresponding to `WSAEMSGSIZE` will be returned.
607+
///
608+
/// For consistency between platforms, be sure to provide a sufficiently large buffer to avoid
609+
/// truncation; the exact size required depends on the underlying protocol.
610+
///
611+
/// If you just want to know the sender of the data, try [`peek_sender`].
612+
///
599613
/// [`recv`]: Socket::recv
614+
/// [`peek_sender`]: Socket::peek_sender
600615
pub fn peek_from(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<(usize, SockAddr)> {
601616
self.recv_from_with_flags(buf, sys::MSG_PEEK)
602617
}
603618

619+
/// Retrieve the sender for the data at the head of the receive queue.
620+
///
621+
/// This is equivalent to calling [`peek_from`] with a zero-sized buffer,
622+
/// but suppresses the `WSAEMSGSIZE` error on Windows.
623+
///
624+
/// [`peek_from`]: Socket::peek_from
625+
pub fn peek_sender(&self) -> io::Result<SockAddr> {
626+
sys::peek_sender(self.as_raw())
627+
}
628+
604629
/// Sends data on the socket to a connected peer.
605630
///
606631
/// This is typically used on TCP sockets or datagram sockets which have

src/sys/unix.rs

+9
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,15 @@ pub(crate) fn recv_from(
749749
}
750750
}
751751

752+
pub(crate) fn peek_sender(fd: Socket) -> io::Result<SockAddr> {
753+
// Unix-like platforms simply truncate the returned data, so this implementation is trivial.
754+
// However, for Windows this requires suppressing the `WSAEMSGSIZE` error,
755+
// so that requires a different approach.
756+
// NOTE: macOS does not populate `sockaddr` if you pass a zero-sized buffer.
757+
let (_, sender) = recv_from(fd, &mut [MaybeUninit::uninit(); 8], MSG_PEEK)?;
758+
Ok(sender)
759+
}
760+
752761
#[cfg(not(target_os = "redox"))]
753762
pub(crate) fn recv_vectored(
754763
fd: Socket,

src/sys/windows.rs

+32
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,38 @@ pub(crate) fn recv_from(
469469
}
470470
}
471471

472+
pub(crate) fn peek_sender(socket: Socket) -> io::Result<SockAddr> {
473+
// Safety: `recvfrom` initialises the `SockAddr` for us.
474+
let ((), sender) = unsafe {
475+
SockAddr::try_init(|storage, addrlen| {
476+
let res = syscall!(
477+
recvfrom(
478+
socket,
479+
// Windows *appears* not to care if you pass a null pointer.
480+
ptr::null_mut(),
481+
0,
482+
MSG_PEEK,
483+
storage.cast(),
484+
addrlen,
485+
),
486+
PartialEq::eq,
487+
SOCKET_ERROR
488+
);
489+
match res {
490+
Ok(_n) => Ok(()),
491+
Err(e) => match e.raw_os_error() {
492+
Some(code) if code == (WSAESHUTDOWN as i32) || code == (WSAEMSGSIZE as i32) => {
493+
Ok(())
494+
}
495+
_ => Err(e),
496+
},
497+
}
498+
})
499+
}?;
500+
501+
Ok(sender)
502+
}
503+
472504
pub(crate) fn recv_from_vectored(
473505
socket: Socket,
474506
bufs: &mut [crate::MaybeUninitSlice<'_>],

tests/socket.rs

+15
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,21 @@ fn out_of_band() {
531531
assert_eq!(unsafe { assume_init(&buf[..n]) }, DATA);
532532
}
533533

534+
#[test]
535+
#[cfg(not(target_os = "redox"))] // cfg of `udp_pair_unconnected()`
536+
fn udp_peek_sender() {
537+
let (socket_a, socket_b) = udp_pair_unconnected();
538+
539+
let socket_a_addr = socket_a.local_addr().unwrap();
540+
let socket_b_addr = socket_b.local_addr().unwrap();
541+
542+
socket_b.send_to(b"Hello, world!", &socket_a_addr).unwrap();
543+
544+
let sender_addr = socket_a.peek_sender().unwrap();
545+
546+
assert_eq!(sender_addr.as_socket(), socket_b_addr.as_socket());
547+
}
548+
534549
#[test]
535550
#[cfg(not(target_os = "redox"))]
536551
fn send_recv_vectored() {

0 commit comments

Comments
 (0)