From b71ec09e04c1e7d769f8e81285bb8d2821818486 Mon Sep 17 00:00:00 2001 From: Carl Dong Date: Sun, 24 Jul 2022 22:32:14 -0400 Subject: [PATCH] Add PSK support Co-authored-by: Natnatee Dokmai --- mbedtls/Cargo.toml | 4 ++ mbedtls/src/ssl/config.rs | 41 ++++++++++++++++++- mbedtls/src/ssl/context.rs | 9 ++++- mbedtls/tests/ssl_conf_psk_cb.rs | 68 ++++++++++++++++++++++++++++++++ mbedtls/tests/support/keys.rs | 4 ++ 5 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 mbedtls/tests/ssl_conf_psk_cb.rs diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index 5a39a7abc..430005156 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -85,6 +85,10 @@ required-features = ["std"] name = "ssl_conf_ca_cb" required-features = ["std"] +[[test]] +name = "ssl_conf_psk_cb" +required-features = ["std"] + [[test]] name = "ssl_conf_verify" required-features = ["std"] diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index fadbdbdc5..962a749aa 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -99,6 +99,7 @@ define!( callback!(DbgCallback: Fn(i32, Cow<'_, str>, i32, Cow<'_, str>) -> ()); callback!(SniCallback: Fn(&mut HandshakeContext, &[u8]) -> Result<()>); callback!(CaCallback: Fn(&MbedtlsList) -> Result>); +callback!(PskCallback: Fn(&mut HandshakeContext, &str) -> Result<()>); #[repr(transparent)] @@ -164,6 +165,7 @@ define!( sni_callback: Option>, ticket_callback: Option>, ca_callback: Option>, + psk_callback: Option>, }; const drop: fn(&mut Self) = ssl_config_free; impl<'a> Into {} @@ -199,6 +201,7 @@ impl Config { sni_callback: None, ticket_callback: None, ca_callback: None, + psk_callback: None, } } @@ -457,6 +460,42 @@ impl Config { self.dbg_callback = Some(Arc::new(cb)); unsafe { ssl_conf_dbg(self.into(), Some(dbg_callback::), &**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void) } } + + pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> { + unsafe { ssl_conf_psk(&mut self.inner, + psk.as_ptr(), psk.len(), + psk_identity.as_ptr(), psk_identity.len()) + .into_result().map(|_| ()) + } + } + + #[cfg(feature = "std")] + pub fn set_psk_callback(&mut self, cb: F) + where + F: PskCallback + 'static, + { + unsafe extern "C" fn psk_callback( + closure: *mut c_void, + ctx: *mut ssl_context, + psk_identity: *const c_uchar, + identity_len: size_t, + ) -> c_int + where + F: PskCallback + 'static, + { + let cb = &mut *(closure as *mut F); + let ctx = UnsafeFrom::from(ctx).unwrap(); + + let psk_identity = std::str::from_utf8_unchecked(from_raw_parts(psk_identity, identity_len)); + match cb(ctx, psk_identity) { + Ok(()) => 0, + Err(e) => e.to_int(), + } + } + + self.psk_callback = Some(Arc::new(cb)); + unsafe { ssl_conf_psk_cb(self.into(), Some(psk_callback::), &**self.psk_callback.as_mut().unwrap() as *const _ as *mut c_void) } + } } // TODO @@ -466,8 +505,6 @@ impl Config { // ssl_conf_dtls_badmac_limit // ssl_conf_handshake_timeout // ssl_conf_session_cache -// ssl_conf_psk -// ssl_conf_psk_cb // ssl_conf_sig_hashes // ssl_conf_alpn_protocols // ssl_conf_fallback diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index 1fd6746c4..9bde60fc3 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -408,6 +408,14 @@ impl HandshakeContext { Ok(()) } + + pub fn set_psk(&mut self, psk: &[u8]) -> Result<()> { + unsafe { + ssl_set_hs_psk(self.into(), psk.as_ptr(), psk.len()).into_result()?; + } + + Ok(()) + } } #[cfg(test)] @@ -485,7 +493,6 @@ mod tests { // ssl_renegotiate // ssl_send_alert_message // ssl_set_client_transport_id -// ssl_set_hs_psk // ssl_set_timer_cb // // ssl_handshake_step diff --git a/mbedtls/tests/ssl_conf_psk_cb.rs b/mbedtls/tests/ssl_conf_psk_cb.rs new file mode 100644 index 000000000..2f32a6354 --- /dev/null +++ b/mbedtls/tests/ssl_conf_psk_cb.rs @@ -0,0 +1,68 @@ +/* Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ + +#![cfg(not(target_env = "sgx"))] + +// needed to have common code for `mod support` in unit and integrations tests +extern crate mbedtls; + +use std::net::TcpStream; +use std::sync::Arc; + +use mbedtls::rng::CtrDrbg; +use mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use mbedtls::ssl::{Config, Context}; +use mbedtls::ssl::context::HandshakeContext; +use mbedtls::Result as TlsResult; +use mbedtls::ssl::config::PskCallback; + +mod support; +use support::entropy::entropy_new; + +fn client(conn: TcpStream, psk: &[u8]) -> TlsResult<()> +{ + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_psk(psk, "Client_identity")?; + let mut ctx = Context::new(Arc::new(config)); + ctx.establish(conn, None).map(|_| ()) +} + +fn server(conn: TcpStream, psk_callback: F) -> TlsResult<()> + where + F: PskCallback + Send + 'static, +{ + let entropy = Arc::new(entropy_new()); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + config.set_rng(rng); + config.set_psk_callback(psk_callback); + let mut ctx = Context::new(Arc::new(config)); + ctx.establish(conn, None).map(|_| ()) +} + +mod test { + use super::*; + use std::thread; + use crate::support::net::create_tcp_pair; + use crate::support::keys; + + #[test] + fn callback_standard_psk() { + let (c, s) = create_tcp_pair().unwrap(); + + let psk_callback = + |ctx: &mut HandshakeContext, _: &str| { + ctx.set_psk(keys::PRESHARED_KEY) + }; + let c = thread::spawn(move || super::client(c, keys::PRESHARED_KEY).unwrap()); + let s = thread::spawn(move || super::server(s, psk_callback).unwrap()); + c.join().unwrap(); + s.join().unwrap(); + } +} diff --git a/mbedtls/tests/support/keys.rs b/mbedtls/tests/support/keys.rs index c899f35c6..1c588f3e0 100644 --- a/mbedtls/tests/support/keys.rs +++ b/mbedtls/tests/support/keys.rs @@ -92,3 +92,7 @@ pub const ROOT_CA_KEY: &'static str = concat!(include_str!("./keys/ca.key"),"\0" pub const EXPIRED_CERT_SUBJECT: &'static str = "CN=ExpiredNode"; pub const EXPIRED_CERT: &'static str = concat!(include_str!("./keys/expired.crt"),"\0"); pub const EXPIRED_KEY: &'static str = concat!(include_str!("./keys/expired.key"),"\0"); + +pub const PRESHARED_KEY: &'static [u8] = &[ + 234, 206, 151, 23, 219, 21, 71, 144, + 107, 42, 23, 67, 249, 173, 182, 224 ];