From e5ac363f7f7ad37c8b14d4ed06ac791a3eb7d160 Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Wed, 8 Jun 2022 14:33:39 +0200 Subject: [PATCH 01/10] Add serde/alloc as no_std dependency --- mbedtls/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index 5a39a7abc..3f5749fa8 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -57,7 +57,7 @@ cc = "1.0" default = ["std", "aesni", "time", "padlock"] std = ["mbedtls-sys-auto/std", "serde/std", "yasna"] debug = ["mbedtls-sys-auto/debug"] -no_std_deps = ["core_io", "spin"] +no_std_deps = ["core_io", "spin", "serde/alloc"] force_aesni_support = ["mbedtls-sys-auto/custom_has_support", "mbedtls-sys-auto/aes_alt", "aesni"] mpi_force_c_code = ["mbedtls-sys-auto/mpi_force_c_code"] rdrand = [] From 11b7799c5734ba372fab7ae6a05014141c817b36 Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Wed, 8 Jun 2022 14:35:02 +0200 Subject: [PATCH 02/10] Implement ssl_conf_psk to set PSK and PSK identity --- mbedtls/src/ssl/config.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index fadbdbdc5..38bc4ca39 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -457,6 +457,21 @@ impl Config { self.dbg_callback = Some(Arc::new(cb)); unsafe { ssl_conf_dbg(self.into(), Some(dbg_callback::), &**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void) } } + + /// Sets the PSK and the PSK-Identity + /// + /// Only a single entry is supported at the moment. If another one was set before, it will be + /// overridden. + pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> { + unsafe { + // This allocates and copies the buffers and does not store any pointer to them + let psk_identity = psk_identity.as_bytes(); + ssl_conf_psk(self.into(), psk.as_ptr(), psk.len(), psk_identity.as_ptr(), psk_identity.len()) + .into_result() + .map(|_| ())?; + } + Ok(()) + } } // TODO @@ -466,7 +481,6 @@ impl Config { // ssl_conf_dtls_badmac_limit // ssl_conf_handshake_timeout // ssl_conf_session_cache -// ssl_conf_psk // ssl_conf_psk_cb // ssl_conf_sig_hashes // ssl_conf_alpn_protocols From b4968c4e9945ca380916e0613164869c572de076 Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Wed, 8 Jun 2022 14:36:10 +0200 Subject: [PATCH 03/10] Add a PSK example client --- mbedtls/examples/client_psk.rs | 52 ++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 mbedtls/examples/client_psk.rs diff --git a/mbedtls/examples/client_psk.rs b/mbedtls/examples/client_psk.rs new file mode 100644 index 000000000..609dd099c --- /dev/null +++ b/mbedtls/examples/client_psk.rs @@ -0,0 +1,52 @@ +/* Copyright (c) Fortanix, Inc. + * + * Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ + +// needed to have common code for `mod support` in unit and integrations tests +extern crate mbedtls; + +use std::io::{self, stdin, stdout, Write}; +use std::net::TcpStream; +use std::sync::Arc; + +use mbedtls::rng::CtrDrbg; +use mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use mbedtls::ssl::{Config, Context}; +use mbedtls::Result as TlsResult; + +#[path = "../tests/support/mod.rs"] +mod support; +use support::entropy::entropy_new; + +fn result_main(addr: &str) -> TlsResult<()> { + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_psk(&[0x12, 0x34, 0x56, 0x78], "client").unwrap(); + let mut ctx = Context::new(Arc::new(config)); + + let conn = TcpStream::connect(addr).unwrap(); + ctx.establish(conn, None)?; + + let mut line = String::new(); + stdin().read_line(&mut line).unwrap(); + ctx.write_all(line.as_bytes()).unwrap(); + io::copy(&mut ctx, &mut stdout()).unwrap(); + Ok(()) +} + +fn main() { + let mut args = std::env::args(); + args.next(); + result_main( + &args + .next() + .expect("supply destination in command-line argument"), + ) + .unwrap(); +} From c2d237e6c5fb9a908817b7158631081f84bf5582 Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Thu, 9 Jun 2022 15:15:08 +0200 Subject: [PATCH 04/10] Add timer which is required for DTLS and an IoCallback impl for UDP --- mbedtls/examples/client_dtls.rs | 57 ++++++++++++++ mbedtls/src/ssl/context.rs | 129 +++++++++++++++++++++++++++++++- 2 files changed, 184 insertions(+), 2 deletions(-) create mode 100644 mbedtls/examples/client_dtls.rs diff --git a/mbedtls/examples/client_dtls.rs b/mbedtls/examples/client_dtls.rs new file mode 100644 index 000000000..65b7d5069 --- /dev/null +++ b/mbedtls/examples/client_dtls.rs @@ -0,0 +1,57 @@ +/* Copyright (c) Fortanix, Inc. + * + * Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ + +// needed to have common code for `mod support` in unit and integrations tests +extern crate mbedtls; + +use std::io::{self, stdin, stdout, Write}; +use std::net::UdpSocket; +use std::sync::Arc; + +use mbedtls::rng::CtrDrbg; +use mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use mbedtls::ssl::{Config, Context}; +use mbedtls::x509::Certificate; +use mbedtls::Result as TlsResult; + +#[path = "../tests/support/mod.rs"] +mod support; +use support::entropy::entropy_new; +use support::keys; + +fn result_main(addr: &str) -> TlsResult<()> { + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + let cert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?); + let mut config = Config::new(Endpoint::Client, Transport::Datagram, Preset::Default); + config.set_rng(rng); + config.set_ca_list(cert, None); + let mut ctx = Context::new(Arc::new(config)); + ctx.set_timer_callback(Box::new(mbedtls::ssl::context::Timer::new())); + + let sock = UdpSocket::bind("localhost:12345").unwrap(); + let sock = mbedtls::ssl::context::ConnectedUdpSocket::connect(sock, addr).unwrap(); + let mut res = ctx.establish(sock, None).unwrap(); + + let mut line = String::new(); + stdin().read_line(&mut line).unwrap(); + ctx.write_all(line.as_bytes()).unwrap(); + io::copy(&mut ctx, &mut stdout()).unwrap(); + Ok(()) +} + +fn main() { + let mut args = std::env::args(); + args.next(); + result_main( + &args + .next() + .expect("supply destination in command-line argument"), + ) + .unwrap(); +} diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index 1fd6746c4..bffe11fb0 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -67,6 +67,121 @@ impl IoCallback for IO { } } +#[cfg(feature = "std")] +pub struct ConnectedUdpSocket { + socket: std::net::UdpSocket, +} + +#[cfg(feature = "std")] +impl ConnectedUdpSocket { + pub fn connect(socket: std::net::UdpSocket, addr: A) -> std::result::Result { + match socket.connect(addr) { + Ok(_) => Ok(ConnectedUdpSocket { + socket, + }), + Err(e) => Err((e, socket)), + } + } +} + +#[cfg(feature = "std")] +impl IoCallback for ConnectedUdpSocket { + unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.recv(::core::slice::from_raw_parts_mut(data, len)) { + Ok(i) => i as c_int, + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => 0, + Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED, + } + } + + unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.send(::core::slice::from_raw_parts(data, len)) { + Ok(i) => i as c_int, + Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED, + } + } + + fn data_ptr(&mut self) -> *mut c_void { + self as *mut ConnectedUdpSocket as *mut c_void + } +} + +pub trait TimerCallback: Sync { + unsafe extern "C" fn set_timer( + p_timer: *mut c_void, + int_ms: u32, + fin_ms: u32, + ) where Self: Sized; + + unsafe extern "C" fn get_timer( + p_timer: *mut c_void, + ) -> c_int where Self: Sized; + + fn data_ptr(&mut self) -> *mut c_void; +} + +#[cfg(feature = "std")] +pub struct Timer { + timer_start: std::time::Instant, + timer_int_ms: u32, + timer_fin_ms: u32, +} + +#[cfg(feature = "std")] +impl Timer { + pub fn new() -> Self { + Timer { + timer_start: std::time::Instant::now(), + timer_int_ms: 0, + timer_fin_ms: 0, + } + } +} + +#[cfg(feature = "std")] +impl TimerCallback for Timer { + unsafe extern "C" fn set_timer( + p_timer: *mut c_void, + int_ms: u32, + fin_ms: u32, + ) where Self: Sized { + let slf = (p_timer as *mut Timer).as_mut().unwrap(); + slf.timer_start = std::time::Instant::now(); + slf.timer_int_ms = int_ms; + slf.timer_fin_ms = fin_ms; + } + + unsafe extern "C" fn get_timer( + p_timer: *mut c_void, + ) -> c_int where Self: Sized { + let slf = (p_timer as *mut Timer).as_mut().unwrap(); + if slf.timer_int_ms == 0 || slf.timer_fin_ms == 0 { + return 0; + } + let passed = std::time::Instant::now() - slf.timer_start; + if passed.as_millis() >= slf.timer_fin_ms.into() { + 2 + } else if passed.as_millis() >= slf.timer_int_ms.into() { + 1 + } else { + 0 + } + } + + fn data_ptr(&mut self) -> *mut mbedtls_sys::types::raw_types::c_void { + self as *mut _ as *mut _ + } +} define!( #[c_ty(ssl_context)] @@ -89,11 +204,13 @@ pub struct Context { // Base structure used in SNI callback where we cannot determine the io type. inner: HandshakeContext, - // config is used read-only for mutliple contexts and is immutable once configured. + // config is used read-only for multiple contexts and is immutable once configured. config: Arc, // Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated. io: Option>, + + timer_callback: Option>, } impl<'a, T> Into<*const ssl_context> for &'a Context { @@ -128,6 +245,7 @@ impl Context { }, config: config.clone(), io: None, + timer_callback: None, } } @@ -157,7 +275,7 @@ impl Context { ); self.io = Some(io); - self.inner.reset_handshake(); + self.inner.reset_handshake(); } match self.handshake() { @@ -298,6 +416,13 @@ impl Context { } } } + + pub fn set_timer_callback(&mut self, mut cb: Box) { + unsafe { + ssl_set_timer_cb(self.into(), cb.data_ptr(), Some(F::set_timer), Some(F::get_timer)); + } + self.timer_callback = Some(cb); + } } impl Drop for Context { From 06b38b8b8c27feb66bf27cd11cb059360f9c543e Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Fri, 10 Jun 2022 09:24:47 +0200 Subject: [PATCH 05/10] Fix build warnings and test errors --- mbedtls/examples/client_dtls.rs | 2 +- mbedtls/src/ssl/context.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mbedtls/examples/client_dtls.rs b/mbedtls/examples/client_dtls.rs index 65b7d5069..09f616bcb 100644 --- a/mbedtls/examples/client_dtls.rs +++ b/mbedtls/examples/client_dtls.rs @@ -36,7 +36,7 @@ fn result_main(addr: &str) -> TlsResult<()> { let sock = UdpSocket::bind("localhost:12345").unwrap(); let sock = mbedtls::ssl::context::ConnectedUdpSocket::connect(sock, addr).unwrap(); - let mut res = ctx.establish(sock, None).unwrap(); + ctx.establish(sock, None).unwrap(); let mut line = String::new(); stdin().read_line(&mut line).unwrap(); diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index bffe11fb0..170a8c80b 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -116,7 +116,7 @@ impl IoCallback for ConnectedUdpSocket { } } -pub trait TimerCallback: Sync { +pub trait TimerCallback: Send + Sync { unsafe extern "C" fn set_timer( p_timer: *mut c_void, int_ms: u32, From 828546a73822fea0f0ba58e79d95fefe4512f79b Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Mon, 13 Jun 2022 13:16:40 +0200 Subject: [PATCH 06/10] Require 'std' feature for new client examples --- mbedtls/Cargo.toml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index 3f5749fa8..079887274 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -73,6 +73,14 @@ pkcs12_rc2 = ["pkcs12", "rc2", "block-modes"] name = "client" required-features = ["std"] +[[example]] +name = "client_dtls" +required-features = ["std"] + +[[example]] +name = "client_psk" +required-features = ["std"] + [[example]] name = "server" required-features = ["std"] From 2c7ee98a9cf76f97c23b9dc6089c0893b24785d9 Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Fri, 16 Sep 2022 15:50:03 +0200 Subject: [PATCH 07/10] Use already imported StdResult and import std::io::Error as IoError --- mbedtls/src/ssl/context.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index 170a8c80b..c81471e79 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -10,7 +10,7 @@ use core::result::Result as StdResult; #[cfg(feature = "std")] use { - std::io::{Read, Write, Result as IoResult}, + std::io::{Read, Write, Result as IoResult, Error as IoError}, std::sync::Arc, }; @@ -74,7 +74,7 @@ pub struct ConnectedUdpSocket { #[cfg(feature = "std")] impl ConnectedUdpSocket { - pub fn connect(socket: std::net::UdpSocket, addr: A) -> std::result::Result { + pub fn connect(socket: std::net::UdpSocket, addr: A) -> StdResult { match socket.connect(addr) { Ok(_) => Ok(ConnectedUdpSocket { socket, From 790c2edc78dcf7f120d7a7258e95e2b2c5599f8c Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Wed, 28 Sep 2022 16:51:18 +0200 Subject: [PATCH 08/10] Implement DTLS server side with all required preconditions and add appropriate tests This requires to make the handshake method public because it needs to be called again after the initial handshake attempt has failed due to the server responding with a HelloVerifyRequest. --- mbedtls/Cargo.toml | 1 + mbedtls/src/ssl/config.rs | 13 +++- mbedtls/src/ssl/context.rs | 77 ++++++++++++++++++++-- mbedtls/src/ssl/cookie.rs | 116 +++++++++++++++++++++++++++++++++ mbedtls/src/ssl/mod.rs | 3 + mbedtls/tests/client_server.rs | 102 ++++++++++++++++++++++++++--- 6 files changed, 294 insertions(+), 18 deletions(-) create mode 100644 mbedtls/src/ssl/cookie.rs diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index 079887274..3f42f1320 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -68,6 +68,7 @@ padlock = ["mbedtls-sys-auto/padlock"] dsa = ["std", "yasna", "num-bigint", "bit-vec"] pkcs12 = ["std", "yasna"] pkcs12_rc2 = ["pkcs12", "rc2", "block-modes"] +legacy_protocols = ["mbedtls-sys-auto/legacy_protocols"] [[example]] name = "client" diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index 38bc4ca39..8880cb785 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -26,6 +26,7 @@ use crate::pk::Pk; use crate::pk::dhparam::Dhm; use crate::private::UnsafeFrom; use crate::rng::RngCallback; +use crate::ssl::cookie::CookieCallback; use crate::ssl::context::HandshakeContext; use crate::ssl::ticket::TicketCallback; use crate::x509::{self, Certificate, Crl, Profile, VerifyCallback}; @@ -164,6 +165,7 @@ define!( sni_callback: Option>, ticket_callback: Option>, ca_callback: Option>, + dtls_cookies: Option>, }; const drop: fn(&mut Self) = ssl_config_free; impl<'a> Into {} @@ -199,6 +201,7 @@ impl Config { sni_callback: None, ticket_callback: None, ca_callback: None, + dtls_cookies: None, } } @@ -465,12 +468,16 @@ impl Config { pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> { unsafe { // This allocates and copies the buffers and does not store any pointer to them - let psk_identity = psk_identity.as_bytes(); ssl_conf_psk(self.into(), psk.as_ptr(), psk.len(), psk_identity.as_ptr(), psk_identity.len()) .into_result() - .map(|_| ())?; + .map(|_| ()) } - Ok(()) + } + + /// Sets the cookie context and callbacks which are required for DTLS servers + pub fn set_dtls_cookies(&mut self, dtls_cookies: Arc) { + unsafe { ssl_conf_dtls_cookies(self.into(), Some(T::cookie_write), Some(T::cookie_check), dtls_cookies.data_ptr()) }; + self.dtls_cookies = Some(dtls_cookies); } } diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index c81471e79..c284c8884 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -211,6 +211,11 @@ pub struct Context { io: Option>, timer_callback: Option>, + + /// Stores the client identification on the DTLS server-side for the current connection. Must + /// be stored in [`Context`] first so that it can be set after the `ssl_session_reset` in the + /// [`establish`] call. + client_transport_id: Option>, } impl<'a, T> Into<*const ssl_context> for &'a Context { @@ -246,6 +251,7 @@ impl Context { config: config.clone(), io: None, timer_callback: None, + client_transport_id: None, } } @@ -264,6 +270,9 @@ impl Context { let mut io = Box::new(io); ssl_session_reset(self.into()).into_result()?; self.set_hostname(hostname)?; + if let Some(client_id) = self.client_transport_id.take() { + self.set_client_transport_id(&client_id)?; + } let ptr = &mut *io as *mut _ as *mut c_void; ssl_set_bio( @@ -278,21 +287,52 @@ impl Context { self.inner.reset_handshake(); } - match self.handshake() { + self.handshake() + } +} + +impl Context { + /// Try to complete the handshake procedure to set up a (D)TLS connection + /// + /// In general, this should not be called directly. Instead, [`establish`] should be used which + /// properly sets up the [`IoCallback`] and resets any previous sessions. + /// + /// This should only be used directly if the handshake could not be completed successfully in + /// [`establish`], i.e.: + /// - If using nonblocking operation and [`establish`] failed with [`Error::SslWantRead`] or + /// [`Error::SslWantWrite`] + /// - If running a DTLS server and it answers the first `ClientHello` (without cookie) with a + /// `HelloVerifyRequest`, i.e. [`establish`] failed with [`Error::SslHelloVerifyRequired`] + pub fn handshake(&mut self) -> Result<()> { + match self.inner_handshake() { Ok(()) => Ok(()), Err(Error::SslWantRead) => Err(Error::SslWantRead), Err(Error::SslWantWrite) => Err(Error::SslWantWrite), + Err(Error::SslHelloVerifyRequired) => { + unsafe { + // `ssl_session_reset` resets the client ID but the user will call handshake + // again in this case and the client ID is required for a DTLS connection setup + // on the server side. So we extract it before and set it after + // `ssl_session_reset`. + let mut client_transport_id = None; + if !self.inner.handle().cli_id.is_null() { + client_transport_id = Some(Vec::from(core::slice::from_raw_parts(self.inner.handle().cli_id, self.inner.handle().cli_id_len))); + } + ssl_session_reset(self.into()).into_result()?; + if let Some(client_id) = client_transport_id.take() { + self.set_client_transport_id(&client_id)?; + } + } + Err(Error::SslHelloVerifyRequired) + } Err(e) => { self.close(); - Err(e) + Err(e) }, } - } -} -impl Context { - fn handshake(&mut self) -> Result<()> { + fn inner_handshake(&mut self) -> Result<()> { unsafe { ssl_flush_output(self.into()).into_result()?; ssl_handshake(self.into()).into_result_discard() @@ -423,6 +463,29 @@ impl Context { } self.timer_callback = Some(cb); } + + /// Set client's transport-level identification info (dtls server only) + /// + /// See `mbedtls_ssl_set_client_transport_id` + fn set_client_transport_id(&mut self, info: &[u8]) -> Result<()> { + unsafe { + ssl_set_client_transport_id(self.into(), info.as_ptr(), info.len()) + .into_result() + .map(|_| ()) + } + } + + /// Set client's transport-level identification info (dtls server only) + /// + /// See `mbedtls_ssl_set_client_transport_id` + /// + /// The `info` is used only for the next connection, i.e. it will be used for the next + /// [`establish`] call. Afterwards, it will be unset again. This is to ensure that no client + /// identification is accidentally reused if this [`Context`] is reused for further + /// connections. + pub fn set_client_transport_id_once(&mut self, info: &[u8]) { + self.client_transport_id = Some(info.into()); + } } impl Drop for Context { @@ -459,7 +522,7 @@ impl Write for Context { } // // Class exists only during SNI callback that is configured from Config. -// SNI Callback must provide input whos lifetime exceed the SNI closure to avoid memory corruptions. +// SNI Callback must provide input whose lifetime exceeds the SNI closure to avoid memory corruptions. // That can be achieved easily by storing certificate chains/crls inside the closure for the lifetime of the closure. // // That is due to SNI being held by an Arc inside Config. diff --git a/mbedtls/src/ssl/cookie.rs b/mbedtls/src/ssl/cookie.rs new file mode 100644 index 000000000..e9ac60539 --- /dev/null +++ b/mbedtls/src/ssl/cookie.rs @@ -0,0 +1,116 @@ +#[cfg(not(feature = "std"))] +use crate::alloc_prelude::*; +#[cfg(feature = "std")] +use std::sync::Arc; + +use mbedtls_sys::types::raw_types::*; +use mbedtls_sys::types::size_t; +use mbedtls_sys::*; + +use crate::error::{IntoResult, Result}; +use crate::rng::RngCallback; + +pub trait CookieCallback { + /* + typedef int mbedtls_ssl_cookie_write_t( void *ctx, + unsigned char **p, unsigned char *end, + const unsigned char *info, size_t ilen ); + */ + unsafe extern "C" fn cookie_write( + ctx: *mut c_void, + p: *mut *mut c_uchar, + end: *mut c_uchar, + info: *const c_uchar, + ilen: size_t, + ) -> c_int + where + Self: Sized; + /* + typedef int mbedtls_ssl_cookie_check_t( void *ctx, + const unsigned char *cookie, size_t clen, + const unsigned char *info, size_t ilen ); + */ + unsafe extern "C" fn cookie_check( + ctx: *mut c_void, + cookie: *const c_uchar, + clen: size_t, + info: *const c_uchar, + ilen: size_t, + ) -> c_int + where + Self: Sized; + + /// Returns a mutable pointer to this shared reference which will be used as first argument to + /// the other two methods + /// + /// A mutable pointer is required because the underlying cookie implementation should be + /// allowed to store some information, e.g. mbedtls' implementation uses an internal counter. + /// We only have a shared reference because in general, the `CookieCallback` will be behind an + /// `Arc` (in [`Config`]). So we need to remove const-ness here which is + /// unsafe in general. Each respective implementation has to guarantee that shared accesses are + /// safe. mbedtls' implementation uses internal mutexes in multithreaded contexts (when the + /// `threading` feature is activated) to do so. + fn data_ptr(&self) -> *mut c_void; +} + +define!( + #[c_ty(ssl_cookie_ctx)] + #[repr(C)] + struct CookieContext { + // We set rng from constructor, we never read it directly. It is only used to ensure rng lives as long as we need. + #[allow(dead_code)] + rng: Arc, + }; + const drop: fn(&mut Self) = ssl_cookie_free; + impl<'a> Into {} +); + +unsafe impl Sync for CookieContext {} + +impl CookieContext { + pub fn new(rng: Arc) -> Result { + let mut ret = CookieContext { + inner: ssl_cookie_ctx::default(), + rng, + }; + + unsafe { + ssl_cookie_init(&mut ret.inner); + ssl_cookie_setup(&mut ret.inner, Some(T::call), ret.rng.data_ptr()).into_result()?; + } + + Ok(ret) + } +} + +impl CookieCallback for CookieContext { + unsafe extern "C" fn cookie_write( + ctx: *mut c_void, + p: *mut *mut c_uchar, + end: *mut c_uchar, + info: *const c_uchar, + ilen: size_t, + ) -> c_int + where + Self: Sized, + { + ssl_cookie_write(ctx, p, end, info, ilen) + } + + unsafe extern "C" fn cookie_check( + ctx: *mut c_void, + cookie: *const c_uchar, + clen: size_t, + info: *const c_uchar, + ilen: size_t, + ) -> c_int + where + Self: Sized, + { + ssl_cookie_check(ctx, cookie, clen, info, ilen) + } + + fn data_ptr(&self) -> *mut c_void { + self.handle() as *const _ as *mut _ + } +} diff --git a/mbedtls/src/ssl/mod.rs b/mbedtls/src/ssl/mod.rs index 430d439ea..1bfc078cf 100644 --- a/mbedtls/src/ssl/mod.rs +++ b/mbedtls/src/ssl/mod.rs @@ -9,6 +9,7 @@ pub mod ciphersuites; pub mod config; pub mod context; +pub mod cookie; pub mod ticket; #[doc(inline)] @@ -18,4 +19,6 @@ pub use self::config::{Config, Version, UseSessionTickets}; #[doc(inline)] pub use self::context::Context; #[doc(inline)] +pub use self::cookie::CookieContext; +#[doc(inline)] pub use self::ticket::TicketContext; diff --git a/mbedtls/tests/client_server.rs b/mbedtls/tests/client_server.rs index 0f499569e..1231c32a8 100644 --- a/mbedtls/tests/client_server.rs +++ b/mbedtls/tests/client_server.rs @@ -17,18 +17,50 @@ use std::net::TcpStream; use mbedtls::pk::Pk; use mbedtls::rng::CtrDrbg; use mbedtls::ssl::config::{Endpoint, Preset, Transport}; -use mbedtls::ssl::{Config, Context, Version}; +use mbedtls::ssl::context::{ConnectedUdpSocket, IoCallback, Timer}; +use mbedtls::ssl::{Config, Context, CookieContext, Version}; use mbedtls::x509::{Certificate, VerifyError}; use mbedtls::Error; use mbedtls::Result as TlsResult; use std::sync::Arc; +use mbedtls_sys::types::raw_types::*; +use mbedtls_sys::types::size_t; + mod support; use support::entropy::entropy_new; use support::keys; +/// Simple type to unify TCP and UDP connections, to support both TLS and DTLS +enum Connection { + Tcp(TcpStream), + Udp(ConnectedUdpSocket), +} + +impl IoCallback for Connection { + unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { + let conn = &mut *(user_data as *mut Connection); + match conn { + Connection::Tcp(c) => TcpStream::call_recv(c.data_ptr(), data, len), + Connection::Udp(c) => ConnectedUdpSocket::call_recv(c.data_ptr(), data, len), + } + } + + unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { + let conn = &mut *(user_data as *mut Connection); + match conn { + Connection::Tcp(c) => TcpStream::call_send(c.data_ptr(), data, len), + Connection::Udp(c) => ConnectedUdpSocket::call_send(c.data_ptr(), data, len), + } + } + + fn data_ptr(&mut self) -> *mut c_void { + self as *mut Connection as *mut c_void + } +} + fn client( - conn: TcpStream, + conn: Connection, min_version: Version, max_version: Version, exp_version: Option) -> TlsResult<()> { @@ -51,7 +83,10 @@ fn client( //so removing this flag here prevents the connections from failing with VerifyError Ok(()) }; - let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + let mut config = match conn { + Connection::Tcp(_) => Config::new(Endpoint::Client, Transport::Stream, Preset::Default), + Connection::Udp(_) => Config::new(Endpoint::Client, Transport::Datagram, Preset::Default), + }; config.set_rng(rng); config.set_verify_callback(verify_callback); config.set_ca_list(cacert, None); @@ -59,6 +94,11 @@ fn client( config.set_max_version(max_version)?; let mut ctx = Context::new(Arc::new(config)); + // For DTLS, timers are required to support retransmissions + if let Connection::Udp(_) = conn { + ctx.set_timer_callback(Box::new(Timer::new())); + } + match ctx.establish(conn, None) { Ok(()) => { assert_eq!(ctx.version(), exp_version.unwrap()); @@ -83,7 +123,7 @@ fn client( } fn server( - conn: TcpStream, + conn: Connection, min_version: Version, max_version: Version, exp_version: Option, @@ -92,14 +132,40 @@ fn server( let rng = Arc::new(CtrDrbg::new(Arc::new(entropy), None)?); let cert = Arc::new(Certificate::from_pem_multiple(keys::EXPIRED_CERT.as_bytes())?); let key = Arc::new(Pk::from_private_key(keys::EXPIRED_KEY.as_bytes(), None)?); - let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + let mut config = match conn { + Connection::Tcp(_) => Config::new(Endpoint::Server, Transport::Stream, Preset::Default), + Connection::Udp(_) => { + let mut config = Config::new(Endpoint::Server, Transport::Datagram, Preset::Default); + // For DTLS, we need a cookie context to work against DoS attacks + let cookies = CookieContext::new(rng.clone())?; + config.set_dtls_cookies(Arc::new(cookies)); + config + } + }; config.set_rng(rng); config.set_min_version(min_version)?; config.set_max_version(max_version)?; config.push_cert(cert, key)?; let mut ctx = Context::new(Arc::new(config)); - match ctx.establish(conn, None) { + let res = if let Connection::Udp(_) = conn { + // For DTLS, timers are required to support retransmissions and the DTLS server needs a client + // ID to create individual cookies per client + ctx.set_timer_callback(Box::new(Timer::new())); + ctx.set_client_transport_id_once(b"127.0.0.1:12341"); + // The first connection setup attempt will fail because the ClientHello is received without + // a cookie + match ctx.establish(conn, None) { + Err(Error::SslHelloVerifyRequired) => {} + Ok(()) => panic!("SslHelloVerifyRequired expected, got Ok instead"), + Err(e) => panic!("SslHelloVerifyRequired expected, got {} instead", e), + } + ctx.handshake() + } else { + ctx.establish(conn, None) // TLS + }; + + match res { Ok(()) => { assert_eq!(ctx.version(), exp_version.unwrap()); } @@ -130,6 +196,8 @@ mod test { #[test] fn client_server_test() { use mbedtls::ssl::Version; + use std::net::UdpSocket; + use mbedtls::ssl::context::ConnectedUdpSocket; #[derive(Copy,Clone)] struct TestConfig { @@ -168,12 +236,30 @@ mod test { continue; } + // TLS tests + let (c, s) = crate::support::net::create_tcp_pair().unwrap(); - let c = thread::spawn(move || super::client(c, min_c, max_c, exp_ver.clone()).unwrap()); - let s = thread::spawn(move || super::server(s, min_s, max_s, exp_ver).unwrap()); + let c = thread::spawn(move || super::client(super::Connection::Tcp(c), min_c, max_c, exp_ver.clone()).unwrap()); + let s = thread::spawn(move || super::server(super::Connection::Tcp(s), min_s, max_s, exp_ver).unwrap()); c.join().unwrap(); s.join().unwrap(); + + // DTLS tests (DTLS 1.0 corresponds to TSL 1.1) + + if min_c < Version::Tls1_1 || min_s < Version::Tls1_1 || exp_ver.is_none() { + continue; + } + + let s = UdpSocket::bind("127.0.0.1:12340").expect("could not bind UdpSocket"); + let s = ConnectedUdpSocket::connect(s, "127.0.0.1:12341").expect("could not connect UdpSocket"); + let s = thread::spawn(move || super::server(super::Connection::Udp(s), min_s, max_s, exp_ver).unwrap()); + let c = UdpSocket::bind("127.0.0.1:12341").expect("could not bind UdpSocket"); + let c = ConnectedUdpSocket::connect(c, "127.0.0.1:12340").expect("could not connect UdpSocket"); + let c = thread::spawn(move || super::client(super::Connection::Udp(c), min_c, max_c, exp_ver.clone()).unwrap()); + + s.join().unwrap(); + c.join().unwrap(); } } } From 16222fa12cb260d3d7d9aa71e6b7a833b4d8c926 Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Wed, 28 Sep 2022 17:04:00 +0200 Subject: [PATCH 09/10] Fix documentation links --- mbedtls/src/ssl/context.rs | 18 +++++++++--------- mbedtls/src/ssl/cookie.rs | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index c284c8884..28e48e854 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -214,7 +214,7 @@ pub struct Context { /// Stores the client identification on the DTLS server-side for the current connection. Must /// be stored in [`Context`] first so that it can be set after the `ssl_session_reset` in the - /// [`establish`] call. + /// [`establish`](Context::establish) call. client_transport_id: Option>, } @@ -294,15 +294,15 @@ impl Context { impl Context { /// Try to complete the handshake procedure to set up a (D)TLS connection /// - /// In general, this should not be called directly. Instead, [`establish`] should be used which - /// properly sets up the [`IoCallback`] and resets any previous sessions. + /// In general, this should not be called directly. Instead, [`establish`](Context::establish) + /// should be used which properly sets up the [`IoCallback`] and resets any previous sessions. /// /// This should only be used directly if the handshake could not be completed successfully in - /// [`establish`], i.e.: - /// - If using nonblocking operation and [`establish`] failed with [`Error::SslWantRead`] or + /// `establish`, i.e.: + /// - If using nonblocking operation and `establish` failed with [`Error::SslWantRead`] or /// [`Error::SslWantWrite`] /// - If running a DTLS server and it answers the first `ClientHello` (without cookie) with a - /// `HelloVerifyRequest`, i.e. [`establish`] failed with [`Error::SslHelloVerifyRequired`] + /// `HelloVerifyRequest`, i.e. `establish` failed with [`Error::SslHelloVerifyRequired`] pub fn handshake(&mut self) -> Result<()> { match self.inner_handshake() { Ok(()) => Ok(()), @@ -480,9 +480,9 @@ impl Context { /// See `mbedtls_ssl_set_client_transport_id` /// /// The `info` is used only for the next connection, i.e. it will be used for the next - /// [`establish`] call. Afterwards, it will be unset again. This is to ensure that no client - /// identification is accidentally reused if this [`Context`] is reused for further - /// connections. + /// [`establish`](Context::establish) call. Afterwards, it will be unset again. This is to + /// ensure that no client identification is accidentally reused if this [`Context`] is reused + /// for further connections. pub fn set_client_transport_id_once(&mut self, info: &[u8]) { self.client_transport_id = Some(info.into()); } diff --git a/mbedtls/src/ssl/cookie.rs b/mbedtls/src/ssl/cookie.rs index e9ac60539..17a50ca9b 100644 --- a/mbedtls/src/ssl/cookie.rs +++ b/mbedtls/src/ssl/cookie.rs @@ -46,10 +46,10 @@ pub trait CookieCallback { /// A mutable pointer is required because the underlying cookie implementation should be /// allowed to store some information, e.g. mbedtls' implementation uses an internal counter. /// We only have a shared reference because in general, the `CookieCallback` will be behind an - /// `Arc` (in [`Config`]). So we need to remove const-ness here which is - /// unsafe in general. Each respective implementation has to guarantee that shared accesses are - /// safe. mbedtls' implementation uses internal mutexes in multithreaded contexts (when the - /// `threading` feature is activated) to do so. + /// `Arc` (in [`Config`](crate::ssl::Config)). So we need to remove + /// const-ness here which is unsafe in general. Each respective implementation has to + /// guarantee that shared accesses are safe. mbedtls' implementation uses internal mutexes in + /// multithreaded contexts (when the `threading` feature is activated) to do so. fn data_ptr(&self) -> *mut c_void; } From 6909bf035b188cfdaddf33810dc80a1ca9b1a250 Mon Sep 17 00:00:00 2001 From: Tobias Naumann Date: Thu, 29 Sep 2022 15:55:52 +0200 Subject: [PATCH 10/10] Add PSK-based operation to the client_server integration test --- mbedtls/tests/client_server.rs | 129 +++++++++++++++++++++------------ 1 file changed, 82 insertions(+), 47 deletions(-) diff --git a/mbedtls/tests/client_server.rs b/mbedtls/tests/client_server.rs index 1231c32a8..da3f5a4ac 100644 --- a/mbedtls/tests/client_server.rs +++ b/mbedtls/tests/client_server.rs @@ -63,14 +63,22 @@ fn client( conn: Connection, min_version: Version, max_version: Version, - exp_version: Option) -> TlsResult<()> { + exp_version: Option, + use_psk: bool) -> TlsResult<()> { let entropy = Arc::new(entropy_new()); let rng = Arc::new(CtrDrbg::new(entropy, None)?); - let cacert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?); - let expected_flags = VerifyError::empty(); - #[cfg(feature = "time")] - let expected_flags = expected_flags | VerifyError::CERT_EXPIRED; - { + let mut config = match conn { + Connection::Tcp(_) => Config::new(Endpoint::Client, Transport::Stream, Preset::Default), + Connection::Udp(_) => Config::new(Endpoint::Client, Transport::Datagram, Preset::Default), + }; + config.set_rng(rng); + config.set_min_version(min_version)?; + config.set_max_version(max_version)?; + if !use_psk { // for certificate-based operation, set up ca and verification callback + let cacert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?); + let expected_flags = VerifyError::empty(); + #[cfg(feature = "time")] + let expected_flags = expected_flags | VerifyError::CERT_EXPIRED; let verify_callback = move |crt: &Certificate, depth: i32, verify_flags: &mut VerifyError| { match (crt.subject().unwrap().as_str(), depth, &verify_flags) { @@ -83,42 +91,37 @@ fn client( //so removing this flag here prevents the connections from failing with VerifyError Ok(()) }; - let mut config = match conn { - Connection::Tcp(_) => Config::new(Endpoint::Client, Transport::Stream, Preset::Default), - Connection::Udp(_) => Config::new(Endpoint::Client, Transport::Datagram, Preset::Default), - }; - config.set_rng(rng); config.set_verify_callback(verify_callback); config.set_ca_list(cacert, None); - config.set_min_version(min_version)?; - config.set_max_version(max_version)?; - let mut ctx = Context::new(Arc::new(config)); + } else { // for psk-based operation, only PSK required + config.set_psk(&[0x12, 0x34, 0x56, 0x78], "client")?; + } + let mut ctx = Context::new(Arc::new(config)); - // For DTLS, timers are required to support retransmissions - if let Connection::Udp(_) = conn { - ctx.set_timer_callback(Box::new(Timer::new())); - } + // For DTLS, timers are required to support retransmissions + if let Connection::Udp(_) = conn { + ctx.set_timer_callback(Box::new(Timer::new())); + } - match ctx.establish(conn, None) { - Ok(()) => { - assert_eq!(ctx.version(), exp_version.unwrap()); - } - Err(e) => { - match e { - Error::SslBadHsProtocolVersion => {assert!(exp_version.is_none())}, - Error::SslFatalAlertMessage => {}, - e => panic!("Unexpected error {}", e), - }; - return Ok(()); - } - }; + match ctx.establish(conn, None) { + Ok(()) => { + assert_eq!(ctx.version(), exp_version.unwrap()); + } + Err(e) => { + match e { + Error::SslBadHsProtocolVersion => {assert!(exp_version.is_none())}, + Error::SslFatalAlertMessage => {}, + e => panic!("Unexpected error {}", e), + }; + return Ok(()); + } + }; - let ciphersuite = ctx.ciphersuite().unwrap(); - ctx.write_all(format!("Client2Server {:4x}", ciphersuite).as_bytes()).unwrap(); - let mut buf = [0u8; 13 + 4 + 1]; - ctx.read_exact(&mut buf).unwrap(); - assert_eq!(&buf, format!("Server2Client {:4x}", ciphersuite).as_bytes()); - } + let ciphersuite = ctx.ciphersuite().unwrap(); + ctx.write_all(format!("Client2Server {:4x}", ciphersuite).as_bytes()).unwrap(); + let mut buf = [0u8; 13 + 4 + 1]; + ctx.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, format!("Server2Client {:4x}", ciphersuite).as_bytes()); Ok(()) } @@ -127,11 +130,10 @@ fn server( min_version: Version, max_version: Version, exp_version: Option, + use_psk: bool, ) -> TlsResult<()> { let entropy = entropy_new(); let rng = Arc::new(CtrDrbg::new(Arc::new(entropy), None)?); - let cert = Arc::new(Certificate::from_pem_multiple(keys::EXPIRED_CERT.as_bytes())?); - let key = Arc::new(Pk::from_private_key(keys::EXPIRED_KEY.as_bytes(), None)?); let mut config = match conn { Connection::Tcp(_) => Config::new(Endpoint::Server, Transport::Stream, Preset::Default), Connection::Udp(_) => { @@ -145,7 +147,13 @@ fn server( config.set_rng(rng); config.set_min_version(min_version)?; config.set_max_version(max_version)?; - config.push_cert(cert, key)?; + if !use_psk { // for certificate-based operation, set up certificates + let cert = Arc::new(Certificate::from_pem_multiple(keys::EXPIRED_CERT.as_bytes())?); + let key = Arc::new(Pk::from_private_key(keys::EXPIRED_KEY.as_bytes(), None)?); + config.push_cert(cert, key)?; + } else { // for psk-based operation, only PSK required + config.set_psk(&[0x12, 0x34, 0x56, 0x78], "client")?; + } let mut ctx = Context::new(Arc::new(config)); let res = if let Connection::Udp(_) = conn { @@ -162,7 +170,7 @@ fn server( } ctx.handshake() } else { - ctx.establish(conn, None) // TLS + ctx.establish(conn, None) // For TLS, establish the connection which should just work }; match res { @@ -236,27 +244,54 @@ mod test { continue; } - // TLS tests + // TLS tests using certificates let (c, s) = crate::support::net::create_tcp_pair().unwrap(); - let c = thread::spawn(move || super::client(super::Connection::Tcp(c), min_c, max_c, exp_ver.clone()).unwrap()); - let s = thread::spawn(move || super::server(super::Connection::Tcp(s), min_s, max_s, exp_ver).unwrap()); + let c = thread::spawn(move || super::client(super::Connection::Tcp(c), min_c, max_c, exp_ver, false).unwrap()); + let s = thread::spawn(move || super::server(super::Connection::Tcp(s), min_s, max_s, exp_ver, false).unwrap()); c.join().unwrap(); s.join().unwrap(); - // DTLS tests (DTLS 1.0 corresponds to TSL 1.1) + // TLS tests using PSK + + let (c, s) = crate::support::net::create_tcp_pair().unwrap(); + let c = thread::spawn(move || super::client(super::Connection::Tcp(c), min_c, max_c, exp_ver, true).unwrap()); + let s = thread::spawn(move || super::server(super::Connection::Tcp(s), min_s, max_s, exp_ver, true).unwrap()); + + c.join().unwrap(); + s.join().unwrap(); + // DTLS tests using certificates + + // DTLS 1.0 is based on TSL 1.1 if min_c < Version::Tls1_1 || min_s < Version::Tls1_1 || exp_ver.is_none() { continue; } let s = UdpSocket::bind("127.0.0.1:12340").expect("could not bind UdpSocket"); let s = ConnectedUdpSocket::connect(s, "127.0.0.1:12341").expect("could not connect UdpSocket"); - let s = thread::spawn(move || super::server(super::Connection::Udp(s), min_s, max_s, exp_ver).unwrap()); + let s = thread::spawn(move || super::server(super::Connection::Udp(s), min_s, max_s, exp_ver, false).unwrap()); + let c = UdpSocket::bind("127.0.0.1:12341").expect("could not bind UdpSocket"); + let c = ConnectedUdpSocket::connect(c, "127.0.0.1:12340").expect("could not connect UdpSocket"); + let c = thread::spawn(move || super::client(super::Connection::Udp(c), min_c, max_c, exp_ver, false).unwrap()); + + s.join().unwrap(); + c.join().unwrap(); + + // TODO There seems to be a race condition which does not allow us to directly reuse + // the UDP address? Without a short delay here, the DTLS tests using PSK fail with + // NetRecvFailed in some cases. + std::thread::sleep(std::time::Duration::from_millis(10)); + + // DTLS tests using PSK + + let s = UdpSocket::bind("127.0.0.1:12340").expect("could not bind UdpSocket"); + let s = ConnectedUdpSocket::connect(s, "127.0.0.1:12341").expect("could not connect UdpSocket"); + let s = thread::spawn(move || super::server(super::Connection::Udp(s), min_s, max_s, exp_ver, true).unwrap()); let c = UdpSocket::bind("127.0.0.1:12341").expect("could not bind UdpSocket"); let c = ConnectedUdpSocket::connect(c, "127.0.0.1:12340").expect("could not connect UdpSocket"); - let c = thread::spawn(move || super::client(super::Connection::Udp(c), min_c, max_c, exp_ver.clone()).unwrap()); + let c = thread::spawn(move || super::client(super::Connection::Udp(c), min_c, max_c, exp_ver, true).unwrap()); s.join().unwrap(); c.join().unwrap();