Skip to content

Commit 904d386

Browse files
authored
Merge pull request raspberrypi#744 from wedsonaf/socket
rust: add basic tcp socket support
2 parents e7e9516 + 4cd6ec9 commit 904d386

File tree

1 file changed

+276
-1
lines changed

1 file changed

+276
-1
lines changed

rust/kernel/net.rs

Lines changed: 276 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//! [`include/linux/netdevice.h`](../../../../include/linux/netdevice.h),
77
//! [`include/linux/skbuff.h`](../../../../include/linux/skbuff.h).
88
9-
use crate::{bindings, str::CStr, ARef, AlwaysRefCounted};
9+
use crate::{bindings, str::CStr, to_result, ARef, AlwaysRefCounted, Error, Result};
1010
use core::{cell::UnsafeCell, ptr::NonNull};
1111

1212
#[cfg(CONFIG_NETFILTER)]
@@ -115,3 +115,278 @@ unsafe impl AlwaysRefCounted for SkBuff {
115115
};
116116
}
117117
}
118+
119+
/// An IPv4 address.
120+
///
121+
/// This is equivalent to C's `in_addr`.
122+
#[repr(transparent)]
123+
pub struct Ipv4Addr(bindings::in_addr);
124+
125+
impl Ipv4Addr {
126+
/// A wildcard IPv4 address.
127+
///
128+
/// Binding to this address means binding to all IPv4 addresses.
129+
pub const ANY: Self = Self::new(0, 0, 0, 0);
130+
131+
/// The IPv4 loopback address.
132+
pub const LOOPBACK: Self = Self::new(127, 0, 0, 1);
133+
134+
/// The IPv4 broadcast address.
135+
pub const BROADCAST: Self = Self::new(255, 255, 255, 255);
136+
137+
/// Creates a new IPv4 address with the given components.
138+
pub const fn new(a: u8, b: u8, c: u8, d: u8) -> Self {
139+
Self(bindings::in_addr {
140+
s_addr: u32::from_be_bytes([a, b, c, d]).to_be(),
141+
})
142+
}
143+
}
144+
145+
/// An IPv6 address.
146+
///
147+
/// This is equivalent to C's `in6_addr`.
148+
#[repr(transparent)]
149+
pub struct Ipv6Addr(bindings::in6_addr);
150+
151+
impl Ipv6Addr {
152+
/// A wildcard IPv6 address.
153+
///
154+
/// Binding to this address means binding to all IPv6 addresses.
155+
pub const ANY: Self = Self::new(0, 0, 0, 0, 0, 0, 0, 0);
156+
157+
/// The IPv6 loopback address.
158+
pub const LOOPBACK: Self = Self::new(0, 0, 0, 0, 0, 0, 0, 1);
159+
160+
/// Creates a new IPv6 address with the given components.
161+
#[allow(clippy::too_many_arguments)]
162+
pub const fn new(a: u16, b: u16, c: u16, d: u16, e: u16, f: u16, g: u16, h: u16) -> Self {
163+
Self(bindings::in6_addr {
164+
in6_u: bindings::in6_addr__bindgen_ty_1 {
165+
u6_addr16: [
166+
a.to_be(),
167+
b.to_be(),
168+
c.to_be(),
169+
d.to_be(),
170+
e.to_be(),
171+
f.to_be(),
172+
g.to_be(),
173+
h.to_be(),
174+
],
175+
},
176+
})
177+
}
178+
}
179+
180+
/// A socket address.
181+
///
182+
/// It's an enum with either an IPv4 or IPv6 socket address.
183+
pub enum SocketAddr {
184+
/// An IPv4 socket address.
185+
V4(SocketAddrV4),
186+
187+
/// An IPv6 socket address.
188+
V6(SocketAddrV6),
189+
}
190+
191+
/// An IPv4 socket address.
192+
///
193+
/// This is equivalent to C's `sockaddr_in`.
194+
#[repr(transparent)]
195+
pub struct SocketAddrV4(bindings::sockaddr_in);
196+
197+
impl SocketAddrV4 {
198+
/// Creates a new IPv4 socket address.
199+
pub const fn new(addr: Ipv4Addr, port: u16) -> Self {
200+
Self(bindings::sockaddr_in {
201+
sin_family: bindings::AF_INET as _,
202+
sin_port: port.to_be(),
203+
sin_addr: addr.0,
204+
__pad: [0; 8],
205+
})
206+
}
207+
}
208+
209+
/// An IPv6 socket address.
210+
///
211+
/// This is equivalent to C's `sockaddr_in6`.
212+
#[repr(transparent)]
213+
pub struct SocketAddrV6(bindings::sockaddr_in6);
214+
215+
impl SocketAddrV6 {
216+
/// Creates a new IPv6 socket address.
217+
pub const fn new(addr: Ipv6Addr, port: u16, flowinfo: u32, scopeid: u32) -> Self {
218+
Self(bindings::sockaddr_in6 {
219+
sin6_family: bindings::AF_INET6 as _,
220+
sin6_port: port.to_be(),
221+
sin6_addr: addr.0,
222+
sin6_flowinfo: flowinfo,
223+
sin6_scope_id: scopeid,
224+
})
225+
}
226+
}
227+
228+
/// A socket listening on a TCP port.
229+
///
230+
/// # Invariants
231+
///
232+
/// The socket pointer is always non-null and valid.
233+
pub struct TcpListener {
234+
sock: *mut bindings::socket,
235+
}
236+
237+
// SAFETY: `TcpListener` is just a wrapper for a kernel socket, which can be used from any thread.
238+
unsafe impl Send for TcpListener {}
239+
240+
// SAFETY: `TcpListener` is just a wrapper for a kernel socket, which can be used from any thread.
241+
unsafe impl Sync for TcpListener {}
242+
243+
impl TcpListener {
244+
/// Creates a new TCP listener.
245+
///
246+
/// It is configured to listen on the given socket address for the given namespace.
247+
pub fn try_new(ns: &Namespace, addr: &SocketAddr) -> Result<Self> {
248+
let mut socket = core::ptr::null_mut();
249+
let (pf, addr, addrlen) = match addr {
250+
SocketAddr::V4(addr) => (
251+
bindings::PF_INET,
252+
addr as *const _ as _,
253+
core::mem::size_of::<bindings::sockaddr_in>(),
254+
),
255+
SocketAddr::V6(addr) => (
256+
bindings::PF_INET6,
257+
addr as *const _ as _,
258+
core::mem::size_of::<bindings::sockaddr_in6>(),
259+
),
260+
};
261+
262+
// SAFETY: The namespace is valid and the output socket pointer is valid for write.
263+
to_result(|| unsafe {
264+
bindings::sock_create_kern(
265+
ns.0.get(),
266+
pf as _,
267+
bindings::sock_type_SOCK_STREAM as _,
268+
bindings::IPPROTO_TCP as _,
269+
&mut socket,
270+
)
271+
})?;
272+
273+
// INVARIANT: The socket was just created, so it is valid.
274+
let listener = Self { sock: socket };
275+
276+
// SAFETY: The type invariant guarantees that the socket is valid, and `addr` and `addrlen`
277+
// were initialised based on valid values provided in the address enum.
278+
to_result(|| unsafe { bindings::kernel_bind(socket, addr, addrlen as _) })?;
279+
280+
// SAFETY: The socket is valid per the type invariant.
281+
to_result(|| unsafe { bindings::kernel_listen(socket, bindings::SOMAXCONN as _) })?;
282+
283+
Ok(listener)
284+
}
285+
286+
/// Accepts a new connection.
287+
///
288+
/// On success, returns the newly-accepted socket stream.
289+
///
290+
/// If no connection is available to be accepted, one of two behaviours will occur:
291+
/// - If `block` is `false`, returns [`crate::error::code::EAGAIN`];
292+
/// - If `block` is `true`, blocks until an error occurs or some connection can be accepted.
293+
pub fn accept(&self, block: bool) -> Result<TcpStream> {
294+
let mut new = core::ptr::null_mut();
295+
let flags = if block { 0 } else { bindings::O_NONBLOCK };
296+
// SAFETY: The type invariant guarantees that the socket is valid, and the output argument
297+
// is also valid for write.
298+
to_result(|| unsafe { bindings::kernel_accept(self.sock, &mut new, flags as _) })?;
299+
Ok(TcpStream { sock: new })
300+
}
301+
}
302+
303+
impl Drop for TcpListener {
304+
fn drop(&mut self) {
305+
// SAFETY: The type invariant guarantees that the socket is valid.
306+
unsafe { bindings::sock_release(self.sock) };
307+
}
308+
}
309+
310+
/// A connected TCP socket.
311+
///
312+
/// # Invariants
313+
///
314+
/// The socket pointer is always non-null and valid.
315+
pub struct TcpStream {
316+
sock: *mut bindings::socket,
317+
}
318+
319+
// SAFETY: `TcpStream` is just a wrapper for a kernel socket, which can be used from any thread.
320+
unsafe impl Send for TcpStream {}
321+
322+
// SAFETY: `TcpStream` is just a wrapper for a kernel socket, which can be used from any thread.
323+
unsafe impl Sync for TcpStream {}
324+
325+
impl TcpStream {
326+
/// Reads data from a connected socket.
327+
///
328+
/// On success, returns the number of bytes read, which will be zero if the connection is
329+
/// closed.
330+
///
331+
/// If no data is immediately available for reading, one of two behaviours will occur:
332+
/// - If `block` is `false`, returns [`crate::error::code::EAGAIN`];
333+
/// - If `block` is `true`, blocks until an error occurs, the connection is closed, or some
334+
/// becomes readable.
335+
pub fn read(&mut self, buf: &mut [u8], block: bool) -> Result<usize> {
336+
let mut msg = bindings::msghdr::default();
337+
let mut vec = bindings::kvec {
338+
iov_base: buf.as_mut_ptr().cast(),
339+
iov_len: buf.len(),
340+
};
341+
// SAFETY: The type invariant guarantees that the socket is valid, and `vec` was
342+
// initialised with the output buffer.
343+
let r = unsafe {
344+
bindings::kernel_recvmsg(
345+
self.sock,
346+
&mut msg,
347+
&mut vec,
348+
1,
349+
vec.iov_len,
350+
if block { 0 } else { bindings::MSG_DONTWAIT } as _,
351+
)
352+
};
353+
if r < 0 {
354+
Err(Error::from_kernel_errno(r))
355+
} else {
356+
Ok(r as _)
357+
}
358+
}
359+
360+
/// Writes data to the connected socket.
361+
///
362+
/// On success, returns the number of bytes written.
363+
///
364+
/// If the send buffer of the socket is full, one of two behaviours will occur:
365+
/// - If `block` is `false`, returns [`crate::error::code::EAGAIN`];
366+
/// - If `block` is `true`, blocks until an error occurs or some data is written.
367+
pub fn write(&mut self, buf: &[u8], block: bool) -> Result<usize> {
368+
let mut msg = bindings::msghdr {
369+
msg_flags: if block { 0 } else { bindings::MSG_DONTWAIT },
370+
..bindings::msghdr::default()
371+
};
372+
let mut vec = bindings::kvec {
373+
iov_base: buf.as_ptr() as *mut u8 as _,
374+
iov_len: buf.len(),
375+
};
376+
// SAFETY: The type invariant guarantees that the socket is valid, and `vec` was
377+
// initialised with the input buffer.
378+
let r = unsafe { bindings::kernel_sendmsg(self.sock, &mut msg, &mut vec, 1, vec.iov_len) };
379+
if r < 0 {
380+
Err(Error::from_kernel_errno(r))
381+
} else {
382+
Ok(r as _)
383+
}
384+
}
385+
}
386+
387+
impl Drop for TcpStream {
388+
fn drop(&mut self) {
389+
// SAFETY: The type invariant guarantees that the socket is valid.
390+
unsafe { bindings::sock_release(self.sock) };
391+
}
392+
}

0 commit comments

Comments
 (0)