Skip to content

Commit 3851430

Browse files
kolapapaThomasdezeeuw
kolapapa
andcommitted
Add Socket::(bind_)device
Co-authored-by: Thomas de Zeeuw <[email protected]>
1 parent 9140e2a commit 3851430

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

Diff for: src/sys/unix.rs

+71
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
// except according to those terms.
88

99
use std::cmp::min;
10+
#[cfg(all(feature = "all", target_os = "linux"))]
11+
use std::ffi::{CStr, CString};
1012
#[cfg(not(target_os = "redox"))]
1113
use std::io::{IoSlice, IoSliceMut};
1214
use std::mem::{self, size_of, MaybeUninit};
@@ -19,6 +21,8 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
1921
use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
2022
#[cfg(feature = "all")]
2123
use std::path::Path;
24+
#[cfg(all(feature = "all", target_os = "linux"))]
25+
use std::slice;
2226
use std::time::Duration;
2327
use std::{io, ptr};
2428

@@ -867,6 +871,73 @@ impl crate::Socket {
867871
unsafe { setsockopt::<c_int>(self.inner, libc::SOL_SOCKET, libc::SO_MARK, mark as c_int) }
868872
}
869873

874+
/// Gets the value for the `SO_BINDTODEVICE` option on this socket.
875+
///
876+
/// This value gets the socket binded device's interface name.
877+
///
878+
/// This function is only available on Linux.
879+
#[cfg(all(feature = "all", target_os = "linux"))]
880+
pub fn device(&self) -> io::Result<Option<CString>> {
881+
// TODO: replace with `MaybeUninit::uninit_array` once stable.
882+
let mut buf: [MaybeUninit<u8>; libc::IFNAMSIZ] =
883+
unsafe { MaybeUninit::<[MaybeUninit<u8>; libc::IFNAMSIZ]>::uninit().assume_init() };
884+
let mut len = buf.len() as libc::socklen_t;
885+
unsafe {
886+
syscall!(getsockopt(
887+
self.inner,
888+
libc::SOL_SOCKET,
889+
libc::SO_BINDTODEVICE,
890+
buf.as_mut_ptr().cast(),
891+
&mut len,
892+
))?;
893+
}
894+
if len == 0 {
895+
Ok(None)
896+
} else {
897+
// Allocate a buffer for `CString` with the length including the
898+
// null terminator.
899+
let len = len as usize;
900+
let mut name = Vec::with_capacity(len);
901+
902+
// TODO: use `MaybeUninit::slice_assume_init_ref` once stable.
903+
// Safety: `len` bytes are writen by the OS, this includes a null
904+
// terminator. However we don't copy the null terminator because
905+
// `CString::from_vec_unchecked` adds its own null terminator.
906+
let buf = unsafe { slice::from_raw_parts(buf.as_ptr().cast(), len - 1) };
907+
name.extend_from_slice(buf);
908+
909+
// Safety: the OS initialised the string for us, which shouldn't
910+
// include any null bytes.
911+
Ok(Some(unsafe { CString::from_vec_unchecked(name) }))
912+
}
913+
}
914+
915+
/// Sets the value for the `SO_BINDTODEVICE` option on this socket.
916+
///
917+
/// If a socket is bound to an interface, only packets received from that
918+
/// particular interface are processed by the socket. Note that this only
919+
/// works for some socket types, particularly `AF_INET` sockets.
920+
///
921+
/// If `interface` is `None` or an empty string it removes the binding.
922+
///
923+
/// This function is only available on Linux.
924+
#[cfg(all(feature = "all", target_os = "linux"))]
925+
pub fn bind_device(&self, interface: Option<&CStr>) -> io::Result<()> {
926+
let (value, len) = if let Some(interface) = interface {
927+
(interface.as_ptr(), interface.to_bytes_with_nul().len())
928+
} else {
929+
(ptr::null(), 0)
930+
};
931+
syscall!(setsockopt(
932+
self.inner,
933+
libc::SOL_SOCKET,
934+
libc::SO_BINDTODEVICE,
935+
value.cast(),
936+
len as libc::socklen_t,
937+
))
938+
.map(|_| ())
939+
}
940+
870941
/// Get the value of the `SO_REUSEPORT` option on this socket.
871942
///
872943
/// For more information about this option, see [`set_reuse_port`].

Diff for: tests/socket.rs

+18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#[cfg(all(feature = "all", target_os = "linux"))]
2+
use std::ffi::CStr;
13
#[cfg(any(windows, target_vendor = "apple"))]
24
use std::io;
35
#[cfg(unix)]
@@ -271,3 +273,19 @@ fn keepalive() {
271273
))]
272274
assert_eq!(socket.keepalive_retries().unwrap(), 10);
273275
}
276+
277+
#[cfg(all(feature = "all", target_os = "linux"))]
278+
#[test]
279+
fn device() {
280+
const INTERFACE: &str = "lo0\0";
281+
let interface = CStr::from_bytes_with_nul(INTERFACE.as_bytes()).unwrap();
282+
let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap();
283+
284+
assert_eq!(socket.device().unwrap(), None);
285+
286+
socket.bind_device(Some(interface)).unwrap();
287+
assert_eq!(socket.device().unwrap().as_deref(), Some(interface));
288+
289+
socket.bind_device(None).unwrap();
290+
assert_eq!(socket.device().unwrap(), None);
291+
}

0 commit comments

Comments
 (0)