diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index 5a39a7abc..3f42f1320 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -57,7 +57,7 @@ cc = "1.0" default = ["std", "aesni", "time", "padlock"] std = ["mbedtls-sys-auto/std", "serde/std", "yasna"] debug = ["mbedtls-sys-auto/debug"] -no_std_deps = ["core_io", "spin"] +no_std_deps = ["core_io", "spin", "serde/alloc"] force_aesni_support = ["mbedtls-sys-auto/custom_has_support", "mbedtls-sys-auto/aes_alt", "aesni"] mpi_force_c_code = ["mbedtls-sys-auto/mpi_force_c_code"] rdrand = [] @@ -68,11 +68,20 @@ padlock = ["mbedtls-sys-auto/padlock"] dsa = ["std", "yasna", "num-bigint", "bit-vec"] pkcs12 = ["std", "yasna"] pkcs12_rc2 = ["pkcs12", "rc2", "block-modes"] +legacy_protocols = ["mbedtls-sys-auto/legacy_protocols"] [[example]] name = "client" required-features = ["std"] +[[example]] +name = "client_dtls" +required-features = ["std"] + +[[example]] +name = "client_psk" +required-features = ["std"] + [[example]] name = "server" required-features = ["std"] diff --git a/mbedtls/examples/client_dtls.rs b/mbedtls/examples/client_dtls.rs new file mode 100644 index 000000000..09f616bcb --- /dev/null +++ b/mbedtls/examples/client_dtls.rs @@ -0,0 +1,57 @@ +/* Copyright (c) Fortanix, Inc. + * + * Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ + +// needed to have common code for `mod support` in unit and integrations tests +extern crate mbedtls; + +use std::io::{self, stdin, stdout, Write}; +use std::net::UdpSocket; +use std::sync::Arc; + +use mbedtls::rng::CtrDrbg; +use mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use mbedtls::ssl::{Config, Context}; +use mbedtls::x509::Certificate; +use mbedtls::Result as TlsResult; + +#[path = "../tests/support/mod.rs"] +mod support; +use support::entropy::entropy_new; +use support::keys; + +fn result_main(addr: &str) -> TlsResult<()> { + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + let cert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?); + let mut config = Config::new(Endpoint::Client, Transport::Datagram, Preset::Default); + config.set_rng(rng); + config.set_ca_list(cert, None); + let mut ctx = Context::new(Arc::new(config)); + ctx.set_timer_callback(Box::new(mbedtls::ssl::context::Timer::new())); + + let sock = UdpSocket::bind("localhost:12345").unwrap(); + let sock = mbedtls::ssl::context::ConnectedUdpSocket::connect(sock, addr).unwrap(); + ctx.establish(sock, None).unwrap(); + + let mut line = String::new(); + stdin().read_line(&mut line).unwrap(); + ctx.write_all(line.as_bytes()).unwrap(); + io::copy(&mut ctx, &mut stdout()).unwrap(); + Ok(()) +} + +fn main() { + let mut args = std::env::args(); + args.next(); + result_main( + &args + .next() + .expect("supply destination in command-line argument"), + ) + .unwrap(); +} diff --git a/mbedtls/examples/client_psk.rs b/mbedtls/examples/client_psk.rs new file mode 100644 index 000000000..609dd099c --- /dev/null +++ b/mbedtls/examples/client_psk.rs @@ -0,0 +1,52 @@ +/* Copyright (c) Fortanix, Inc. + * + * Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ + +// needed to have common code for `mod support` in unit and integrations tests +extern crate mbedtls; + +use std::io::{self, stdin, stdout, Write}; +use std::net::TcpStream; +use std::sync::Arc; + +use mbedtls::rng::CtrDrbg; +use mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use mbedtls::ssl::{Config, Context}; +use mbedtls::Result as TlsResult; + +#[path = "../tests/support/mod.rs"] +mod support; +use support::entropy::entropy_new; + +fn result_main(addr: &str) -> TlsResult<()> { + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_psk(&[0x12, 0x34, 0x56, 0x78], "client").unwrap(); + let mut ctx = Context::new(Arc::new(config)); + + let conn = TcpStream::connect(addr).unwrap(); + ctx.establish(conn, None)?; + + let mut line = String::new(); + stdin().read_line(&mut line).unwrap(); + ctx.write_all(line.as_bytes()).unwrap(); + io::copy(&mut ctx, &mut stdout()).unwrap(); + Ok(()) +} + +fn main() { + let mut args = std::env::args(); + args.next(); + result_main( + &args + .next() + .expect("supply destination in command-line argument"), + ) + .unwrap(); +} diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index fadbdbdc5..8880cb785 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -26,6 +26,7 @@ use crate::pk::Pk; use crate::pk::dhparam::Dhm; use crate::private::UnsafeFrom; use crate::rng::RngCallback; +use crate::ssl::cookie::CookieCallback; use crate::ssl::context::HandshakeContext; use crate::ssl::ticket::TicketCallback; use crate::x509::{self, Certificate, Crl, Profile, VerifyCallback}; @@ -164,6 +165,7 @@ define!( sni_callback: Option>, ticket_callback: Option>, ca_callback: Option>, + dtls_cookies: Option>, }; const drop: fn(&mut Self) = ssl_config_free; impl<'a> Into {} @@ -199,6 +201,7 @@ impl Config { sni_callback: None, ticket_callback: None, ca_callback: None, + dtls_cookies: None, } } @@ -457,6 +460,25 @@ impl Config { self.dbg_callback = Some(Arc::new(cb)); unsafe { ssl_conf_dbg(self.into(), Some(dbg_callback::), &**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void) } } + + /// Sets the PSK and the PSK-Identity + /// + /// Only a single entry is supported at the moment. If another one was set before, it will be + /// overridden. + pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> { + unsafe { + // This allocates and copies the buffers and does not store any pointer to them + ssl_conf_psk(self.into(), psk.as_ptr(), psk.len(), psk_identity.as_ptr(), psk_identity.len()) + .into_result() + .map(|_| ()) + } + } + + /// Sets the cookie context and callbacks which are required for DTLS servers + pub fn set_dtls_cookies(&mut self, dtls_cookies: Arc) { + unsafe { ssl_conf_dtls_cookies(self.into(), Some(T::cookie_write), Some(T::cookie_check), dtls_cookies.data_ptr()) }; + self.dtls_cookies = Some(dtls_cookies); + } } // TODO @@ -466,7 +488,6 @@ impl Config { // ssl_conf_dtls_badmac_limit // ssl_conf_handshake_timeout // ssl_conf_session_cache -// ssl_conf_psk // ssl_conf_psk_cb // ssl_conf_sig_hashes // ssl_conf_alpn_protocols diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index 1fd6746c4..28e48e854 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -10,7 +10,7 @@ use core::result::Result as StdResult; #[cfg(feature = "std")] use { - std::io::{Read, Write, Result as IoResult}, + std::io::{Read, Write, Result as IoResult, Error as IoError}, std::sync::Arc, }; @@ -67,6 +67,121 @@ impl IoCallback for IO { } } +#[cfg(feature = "std")] +pub struct ConnectedUdpSocket { + socket: std::net::UdpSocket, +} + +#[cfg(feature = "std")] +impl ConnectedUdpSocket { + pub fn connect(socket: std::net::UdpSocket, addr: A) -> StdResult { + match socket.connect(addr) { + Ok(_) => Ok(ConnectedUdpSocket { + socket, + }), + Err(e) => Err((e, socket)), + } + } +} + +#[cfg(feature = "std")] +impl IoCallback for ConnectedUdpSocket { + unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.recv(::core::slice::from_raw_parts_mut(data, len)) { + Ok(i) => i as c_int, + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => 0, + Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED, + } + } + + unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.send(::core::slice::from_raw_parts(data, len)) { + Ok(i) => i as c_int, + Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED, + } + } + + fn data_ptr(&mut self) -> *mut c_void { + self as *mut ConnectedUdpSocket as *mut c_void + } +} + +pub trait TimerCallback: Send + Sync { + unsafe extern "C" fn set_timer( + p_timer: *mut c_void, + int_ms: u32, + fin_ms: u32, + ) where Self: Sized; + + unsafe extern "C" fn get_timer( + p_timer: *mut c_void, + ) -> c_int where Self: Sized; + + fn data_ptr(&mut self) -> *mut c_void; +} + +#[cfg(feature = "std")] +pub struct Timer { + timer_start: std::time::Instant, + timer_int_ms: u32, + timer_fin_ms: u32, +} + +#[cfg(feature = "std")] +impl Timer { + pub fn new() -> Self { + Timer { + timer_start: std::time::Instant::now(), + timer_int_ms: 0, + timer_fin_ms: 0, + } + } +} + +#[cfg(feature = "std")] +impl TimerCallback for Timer { + unsafe extern "C" fn set_timer( + p_timer: *mut c_void, + int_ms: u32, + fin_ms: u32, + ) where Self: Sized { + let slf = (p_timer as *mut Timer).as_mut().unwrap(); + slf.timer_start = std::time::Instant::now(); + slf.timer_int_ms = int_ms; + slf.timer_fin_ms = fin_ms; + } + + unsafe extern "C" fn get_timer( + p_timer: *mut c_void, + ) -> c_int where Self: Sized { + let slf = (p_timer as *mut Timer).as_mut().unwrap(); + if slf.timer_int_ms == 0 || slf.timer_fin_ms == 0 { + return 0; + } + let passed = std::time::Instant::now() - slf.timer_start; + if passed.as_millis() >= slf.timer_fin_ms.into() { + 2 + } else if passed.as_millis() >= slf.timer_int_ms.into() { + 1 + } else { + 0 + } + } + + fn data_ptr(&mut self) -> *mut mbedtls_sys::types::raw_types::c_void { + self as *mut _ as *mut _ + } +} define!( #[c_ty(ssl_context)] @@ -89,11 +204,18 @@ pub struct Context { // Base structure used in SNI callback where we cannot determine the io type. inner: HandshakeContext, - // config is used read-only for mutliple contexts and is immutable once configured. + // config is used read-only for multiple contexts and is immutable once configured. config: Arc, // Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated. io: Option>, + + timer_callback: Option>, + + /// Stores the client identification on the DTLS server-side for the current connection. Must + /// be stored in [`Context`] first so that it can be set after the `ssl_session_reset` in the + /// [`establish`](Context::establish) call. + client_transport_id: Option>, } impl<'a, T> Into<*const ssl_context> for &'a Context { @@ -128,6 +250,8 @@ impl Context { }, config: config.clone(), io: None, + timer_callback: None, + client_transport_id: None, } } @@ -146,6 +270,9 @@ impl Context { let mut io = Box::new(io); ssl_session_reset(self.into()).into_result()?; self.set_hostname(hostname)?; + if let Some(client_id) = self.client_transport_id.take() { + self.set_client_transport_id(&client_id)?; + } let ptr = &mut *io as *mut _ as *mut c_void; ssl_set_bio( @@ -157,24 +284,55 @@ impl Context { ); self.io = Some(io); - self.inner.reset_handshake(); + self.inner.reset_handshake(); } - match self.handshake() { + self.handshake() + } +} + +impl Context { + /// Try to complete the handshake procedure to set up a (D)TLS connection + /// + /// In general, this should not be called directly. Instead, [`establish`](Context::establish) + /// should be used which properly sets up the [`IoCallback`] and resets any previous sessions. + /// + /// This should only be used directly if the handshake could not be completed successfully in + /// `establish`, i.e.: + /// - If using nonblocking operation and `establish` failed with [`Error::SslWantRead`] or + /// [`Error::SslWantWrite`] + /// - If running a DTLS server and it answers the first `ClientHello` (without cookie) with a + /// `HelloVerifyRequest`, i.e. `establish` failed with [`Error::SslHelloVerifyRequired`] + pub fn handshake(&mut self) -> Result<()> { + match self.inner_handshake() { Ok(()) => Ok(()), Err(Error::SslWantRead) => Err(Error::SslWantRead), Err(Error::SslWantWrite) => Err(Error::SslWantWrite), + Err(Error::SslHelloVerifyRequired) => { + unsafe { + // `ssl_session_reset` resets the client ID but the user will call handshake + // again in this case and the client ID is required for a DTLS connection setup + // on the server side. So we extract it before and set it after + // `ssl_session_reset`. + let mut client_transport_id = None; + if !self.inner.handle().cli_id.is_null() { + client_transport_id = Some(Vec::from(core::slice::from_raw_parts(self.inner.handle().cli_id, self.inner.handle().cli_id_len))); + } + ssl_session_reset(self.into()).into_result()?; + if let Some(client_id) = client_transport_id.take() { + self.set_client_transport_id(&client_id)?; + } + } + Err(Error::SslHelloVerifyRequired) + } Err(e) => { self.close(); - Err(e) + Err(e) }, } - } -} -impl Context { - fn handshake(&mut self) -> Result<()> { + fn inner_handshake(&mut self) -> Result<()> { unsafe { ssl_flush_output(self.into()).into_result()?; ssl_handshake(self.into()).into_result_discard() @@ -298,6 +456,36 @@ impl Context { } } } + + pub fn set_timer_callback(&mut self, mut cb: Box) { + unsafe { + ssl_set_timer_cb(self.into(), cb.data_ptr(), Some(F::set_timer), Some(F::get_timer)); + } + self.timer_callback = Some(cb); + } + + /// Set client's transport-level identification info (dtls server only) + /// + /// See `mbedtls_ssl_set_client_transport_id` + fn set_client_transport_id(&mut self, info: &[u8]) -> Result<()> { + unsafe { + ssl_set_client_transport_id(self.into(), info.as_ptr(), info.len()) + .into_result() + .map(|_| ()) + } + } + + /// Set client's transport-level identification info (dtls server only) + /// + /// See `mbedtls_ssl_set_client_transport_id` + /// + /// The `info` is used only for the next connection, i.e. it will be used for the next + /// [`establish`](Context::establish) call. Afterwards, it will be unset again. This is to + /// ensure that no client identification is accidentally reused if this [`Context`] is reused + /// for further connections. + pub fn set_client_transport_id_once(&mut self, info: &[u8]) { + self.client_transport_id = Some(info.into()); + } } impl Drop for Context { @@ -334,7 +522,7 @@ impl Write for Context { } // // Class exists only during SNI callback that is configured from Config. -// SNI Callback must provide input whos lifetime exceed the SNI closure to avoid memory corruptions. +// SNI Callback must provide input whose lifetime exceeds the SNI closure to avoid memory corruptions. // That can be achieved easily by storing certificate chains/crls inside the closure for the lifetime of the closure. // // That is due to SNI being held by an Arc inside Config. diff --git a/mbedtls/src/ssl/cookie.rs b/mbedtls/src/ssl/cookie.rs new file mode 100644 index 000000000..17a50ca9b --- /dev/null +++ b/mbedtls/src/ssl/cookie.rs @@ -0,0 +1,116 @@ +#[cfg(not(feature = "std"))] +use crate::alloc_prelude::*; +#[cfg(feature = "std")] +use std::sync::Arc; + +use mbedtls_sys::types::raw_types::*; +use mbedtls_sys::types::size_t; +use mbedtls_sys::*; + +use crate::error::{IntoResult, Result}; +use crate::rng::RngCallback; + +pub trait CookieCallback { + /* + typedef int mbedtls_ssl_cookie_write_t( void *ctx, + unsigned char **p, unsigned char *end, + const unsigned char *info, size_t ilen ); + */ + unsafe extern "C" fn cookie_write( + ctx: *mut c_void, + p: *mut *mut c_uchar, + end: *mut c_uchar, + info: *const c_uchar, + ilen: size_t, + ) -> c_int + where + Self: Sized; + /* + typedef int mbedtls_ssl_cookie_check_t( void *ctx, + const unsigned char *cookie, size_t clen, + const unsigned char *info, size_t ilen ); + */ + unsafe extern "C" fn cookie_check( + ctx: *mut c_void, + cookie: *const c_uchar, + clen: size_t, + info: *const c_uchar, + ilen: size_t, + ) -> c_int + where + Self: Sized; + + /// Returns a mutable pointer to this shared reference which will be used as first argument to + /// the other two methods + /// + /// A mutable pointer is required because the underlying cookie implementation should be + /// allowed to store some information, e.g. mbedtls' implementation uses an internal counter. + /// We only have a shared reference because in general, the `CookieCallback` will be behind an + /// `Arc` (in [`Config`](crate::ssl::Config)). So we need to remove + /// const-ness here which is unsafe in general. Each respective implementation has to + /// guarantee that shared accesses are safe. mbedtls' implementation uses internal mutexes in + /// multithreaded contexts (when the `threading` feature is activated) to do so. + fn data_ptr(&self) -> *mut c_void; +} + +define!( + #[c_ty(ssl_cookie_ctx)] + #[repr(C)] + struct CookieContext { + // We set rng from constructor, we never read it directly. It is only used to ensure rng lives as long as we need. + #[allow(dead_code)] + rng: Arc, + }; + const drop: fn(&mut Self) = ssl_cookie_free; + impl<'a> Into {} +); + +unsafe impl Sync for CookieContext {} + +impl CookieContext { + pub fn new(rng: Arc) -> Result { + let mut ret = CookieContext { + inner: ssl_cookie_ctx::default(), + rng, + }; + + unsafe { + ssl_cookie_init(&mut ret.inner); + ssl_cookie_setup(&mut ret.inner, Some(T::call), ret.rng.data_ptr()).into_result()?; + } + + Ok(ret) + } +} + +impl CookieCallback for CookieContext { + unsafe extern "C" fn cookie_write( + ctx: *mut c_void, + p: *mut *mut c_uchar, + end: *mut c_uchar, + info: *const c_uchar, + ilen: size_t, + ) -> c_int + where + Self: Sized, + { + ssl_cookie_write(ctx, p, end, info, ilen) + } + + unsafe extern "C" fn cookie_check( + ctx: *mut c_void, + cookie: *const c_uchar, + clen: size_t, + info: *const c_uchar, + ilen: size_t, + ) -> c_int + where + Self: Sized, + { + ssl_cookie_check(ctx, cookie, clen, info, ilen) + } + + fn data_ptr(&self) -> *mut c_void { + self.handle() as *const _ as *mut _ + } +} diff --git a/mbedtls/src/ssl/mod.rs b/mbedtls/src/ssl/mod.rs index 430d439ea..1bfc078cf 100644 --- a/mbedtls/src/ssl/mod.rs +++ b/mbedtls/src/ssl/mod.rs @@ -9,6 +9,7 @@ pub mod ciphersuites; pub mod config; pub mod context; +pub mod cookie; pub mod ticket; #[doc(inline)] @@ -18,4 +19,6 @@ pub use self::config::{Config, Version, UseSessionTickets}; #[doc(inline)] pub use self::context::Context; #[doc(inline)] +pub use self::cookie::CookieContext; +#[doc(inline)] pub use self::ticket::TicketContext; diff --git a/mbedtls/tests/client_server.rs b/mbedtls/tests/client_server.rs index 0f499569e..da3f5a4ac 100644 --- a/mbedtls/tests/client_server.rs +++ b/mbedtls/tests/client_server.rs @@ -17,28 +17,68 @@ use std::net::TcpStream; use mbedtls::pk::Pk; use mbedtls::rng::CtrDrbg; use mbedtls::ssl::config::{Endpoint, Preset, Transport}; -use mbedtls::ssl::{Config, Context, Version}; +use mbedtls::ssl::context::{ConnectedUdpSocket, IoCallback, Timer}; +use mbedtls::ssl::{Config, Context, CookieContext, Version}; use mbedtls::x509::{Certificate, VerifyError}; use mbedtls::Error; use mbedtls::Result as TlsResult; use std::sync::Arc; +use mbedtls_sys::types::raw_types::*; +use mbedtls_sys::types::size_t; + mod support; use support::entropy::entropy_new; use support::keys; +/// Simple type to unify TCP and UDP connections, to support both TLS and DTLS +enum Connection { + Tcp(TcpStream), + Udp(ConnectedUdpSocket), +} + +impl IoCallback for Connection { + unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { + let conn = &mut *(user_data as *mut Connection); + match conn { + Connection::Tcp(c) => TcpStream::call_recv(c.data_ptr(), data, len), + Connection::Udp(c) => ConnectedUdpSocket::call_recv(c.data_ptr(), data, len), + } + } + + unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { + let conn = &mut *(user_data as *mut Connection); + match conn { + Connection::Tcp(c) => TcpStream::call_send(c.data_ptr(), data, len), + Connection::Udp(c) => ConnectedUdpSocket::call_send(c.data_ptr(), data, len), + } + } + + fn data_ptr(&mut self) -> *mut c_void { + self as *mut Connection as *mut c_void + } +} + fn client( - conn: TcpStream, + conn: Connection, min_version: Version, max_version: Version, - exp_version: Option) -> TlsResult<()> { + exp_version: Option, + use_psk: bool) -> TlsResult<()> { let entropy = Arc::new(entropy_new()); let rng = Arc::new(CtrDrbg::new(entropy, None)?); - let cacert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?); - let expected_flags = VerifyError::empty(); - #[cfg(feature = "time")] - let expected_flags = expected_flags | VerifyError::CERT_EXPIRED; - { + let mut config = match conn { + Connection::Tcp(_) => Config::new(Endpoint::Client, Transport::Stream, Preset::Default), + Connection::Udp(_) => Config::new(Endpoint::Client, Transport::Datagram, Preset::Default), + }; + config.set_rng(rng); + config.set_min_version(min_version)?; + config.set_max_version(max_version)?; + if !use_psk { // for certificate-based operation, set up ca and verification callback + let cacert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?); + let expected_flags = VerifyError::empty(); + #[cfg(feature = "time")] + let expected_flags = expected_flags | VerifyError::CERT_EXPIRED; let verify_callback = move |crt: &Certificate, depth: i32, verify_flags: &mut VerifyError| { match (crt.subject().unwrap().as_str(), depth, &verify_flags) { @@ -51,55 +91,89 @@ fn client( //so removing this flag here prevents the connections from failing with VerifyError Ok(()) }; - let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); - config.set_rng(rng); config.set_verify_callback(verify_callback); config.set_ca_list(cacert, None); - config.set_min_version(min_version)?; - config.set_max_version(max_version)?; - let mut ctx = Context::new(Arc::new(config)); - - match ctx.establish(conn, None) { - Ok(()) => { - assert_eq!(ctx.version(), exp_version.unwrap()); - } - Err(e) => { - match e { - Error::SslBadHsProtocolVersion => {assert!(exp_version.is_none())}, - Error::SslFatalAlertMessage => {}, - e => panic!("Unexpected error {}", e), - }; - return Ok(()); - } - }; + } else { // for psk-based operation, only PSK required + config.set_psk(&[0x12, 0x34, 0x56, 0x78], "client")?; + } + let mut ctx = Context::new(Arc::new(config)); - let ciphersuite = ctx.ciphersuite().unwrap(); - ctx.write_all(format!("Client2Server {:4x}", ciphersuite).as_bytes()).unwrap(); - let mut buf = [0u8; 13 + 4 + 1]; - ctx.read_exact(&mut buf).unwrap(); - assert_eq!(&buf, format!("Server2Client {:4x}", ciphersuite).as_bytes()); + // For DTLS, timers are required to support retransmissions + if let Connection::Udp(_) = conn { + ctx.set_timer_callback(Box::new(Timer::new())); } + + match ctx.establish(conn, None) { + Ok(()) => { + assert_eq!(ctx.version(), exp_version.unwrap()); + } + Err(e) => { + match e { + Error::SslBadHsProtocolVersion => {assert!(exp_version.is_none())}, + Error::SslFatalAlertMessage => {}, + e => panic!("Unexpected error {}", e), + }; + return Ok(()); + } + }; + + let ciphersuite = ctx.ciphersuite().unwrap(); + ctx.write_all(format!("Client2Server {:4x}", ciphersuite).as_bytes()).unwrap(); + let mut buf = [0u8; 13 + 4 + 1]; + ctx.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, format!("Server2Client {:4x}", ciphersuite).as_bytes()); Ok(()) } fn server( - conn: TcpStream, + conn: Connection, min_version: Version, max_version: Version, exp_version: Option, + use_psk: bool, ) -> TlsResult<()> { let entropy = entropy_new(); let rng = Arc::new(CtrDrbg::new(Arc::new(entropy), None)?); - let cert = Arc::new(Certificate::from_pem_multiple(keys::EXPIRED_CERT.as_bytes())?); - let key = Arc::new(Pk::from_private_key(keys::EXPIRED_KEY.as_bytes(), None)?); - let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + let mut config = match conn { + Connection::Tcp(_) => Config::new(Endpoint::Server, Transport::Stream, Preset::Default), + Connection::Udp(_) => { + let mut config = Config::new(Endpoint::Server, Transport::Datagram, Preset::Default); + // For DTLS, we need a cookie context to work against DoS attacks + let cookies = CookieContext::new(rng.clone())?; + config.set_dtls_cookies(Arc::new(cookies)); + config + } + }; config.set_rng(rng); config.set_min_version(min_version)?; config.set_max_version(max_version)?; - config.push_cert(cert, key)?; + if !use_psk { // for certificate-based operation, set up certificates + let cert = Arc::new(Certificate::from_pem_multiple(keys::EXPIRED_CERT.as_bytes())?); + let key = Arc::new(Pk::from_private_key(keys::EXPIRED_KEY.as_bytes(), None)?); + config.push_cert(cert, key)?; + } else { // for psk-based operation, only PSK required + config.set_psk(&[0x12, 0x34, 0x56, 0x78], "client")?; + } let mut ctx = Context::new(Arc::new(config)); - match ctx.establish(conn, None) { + let res = if let Connection::Udp(_) = conn { + // For DTLS, timers are required to support retransmissions and the DTLS server needs a client + // ID to create individual cookies per client + ctx.set_timer_callback(Box::new(Timer::new())); + ctx.set_client_transport_id_once(b"127.0.0.1:12341"); + // The first connection setup attempt will fail because the ClientHello is received without + // a cookie + match ctx.establish(conn, None) { + Err(Error::SslHelloVerifyRequired) => {} + Ok(()) => panic!("SslHelloVerifyRequired expected, got Ok instead"), + Err(e) => panic!("SslHelloVerifyRequired expected, got {} instead", e), + } + ctx.handshake() + } else { + ctx.establish(conn, None) // For TLS, establish the connection which should just work + }; + + match res { Ok(()) => { assert_eq!(ctx.version(), exp_version.unwrap()); } @@ -130,6 +204,8 @@ mod test { #[test] fn client_server_test() { use mbedtls::ssl::Version; + use std::net::UdpSocket; + use mbedtls::ssl::context::ConnectedUdpSocket; #[derive(Copy,Clone)] struct TestConfig { @@ -168,12 +244,57 @@ mod test { continue; } + // TLS tests using certificates + + let (c, s) = crate::support::net::create_tcp_pair().unwrap(); + let c = thread::spawn(move || super::client(super::Connection::Tcp(c), min_c, max_c, exp_ver, false).unwrap()); + let s = thread::spawn(move || super::server(super::Connection::Tcp(s), min_s, max_s, exp_ver, false).unwrap()); + + c.join().unwrap(); + s.join().unwrap(); + + // TLS tests using PSK + let (c, s) = crate::support::net::create_tcp_pair().unwrap(); - let c = thread::spawn(move || super::client(c, min_c, max_c, exp_ver.clone()).unwrap()); - let s = thread::spawn(move || super::server(s, min_s, max_s, exp_ver).unwrap()); + let c = thread::spawn(move || super::client(super::Connection::Tcp(c), min_c, max_c, exp_ver, true).unwrap()); + let s = thread::spawn(move || super::server(super::Connection::Tcp(s), min_s, max_s, exp_ver, true).unwrap()); c.join().unwrap(); s.join().unwrap(); + + // DTLS tests using certificates + + // DTLS 1.0 is based on TSL 1.1 + if min_c < Version::Tls1_1 || min_s < Version::Tls1_1 || exp_ver.is_none() { + continue; + } + + let s = UdpSocket::bind("127.0.0.1:12340").expect("could not bind UdpSocket"); + let s = ConnectedUdpSocket::connect(s, "127.0.0.1:12341").expect("could not connect UdpSocket"); + let s = thread::spawn(move || super::server(super::Connection::Udp(s), min_s, max_s, exp_ver, false).unwrap()); + let c = UdpSocket::bind("127.0.0.1:12341").expect("could not bind UdpSocket"); + let c = ConnectedUdpSocket::connect(c, "127.0.0.1:12340").expect("could not connect UdpSocket"); + let c = thread::spawn(move || super::client(super::Connection::Udp(c), min_c, max_c, exp_ver, false).unwrap()); + + s.join().unwrap(); + c.join().unwrap(); + + // TODO There seems to be a race condition which does not allow us to directly reuse + // the UDP address? Without a short delay here, the DTLS tests using PSK fail with + // NetRecvFailed in some cases. + std::thread::sleep(std::time::Duration::from_millis(10)); + + // DTLS tests using PSK + + let s = UdpSocket::bind("127.0.0.1:12340").expect("could not bind UdpSocket"); + let s = ConnectedUdpSocket::connect(s, "127.0.0.1:12341").expect("could not connect UdpSocket"); + let s = thread::spawn(move || super::server(super::Connection::Udp(s), min_s, max_s, exp_ver, true).unwrap()); + let c = UdpSocket::bind("127.0.0.1:12341").expect("could not bind UdpSocket"); + let c = ConnectedUdpSocket::connect(c, "127.0.0.1:12340").expect("could not connect UdpSocket"); + let c = thread::spawn(move || super::client(super::Connection::Udp(c), min_c, max_c, exp_ver, true).unwrap()); + + s.join().unwrap(); + c.join().unwrap(); } } }