Skip to content

Commit 29a5eda

Browse files
dongcarlNatnatee Dokmai
and
Natnatee Dokmai
committed
Add PSK support
Co-authored-by: Natnatee Dokmai <[email protected]>
1 parent 178cc37 commit 29a5eda

File tree

5 files changed

+117
-3
lines changed

5 files changed

+117
-3
lines changed

mbedtls/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ required-features = ["std"]
8585
name = "ssl_conf_ca_cb"
8686
required-features = ["std"]
8787

88+
[[test]]
89+
name = "ssl_conf_psk_cb"
90+
required-features = ["std"]
91+
8892
[[test]]
8993
name = "ssl_conf_verify"
9094
required-features = ["std"]

mbedtls/src/ssl/config.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ define!(
9999
callback!(DbgCallback: Fn(i32, Cow<'_, str>, i32, Cow<'_, str>) -> ());
100100
callback!(SniCallback: Fn(&mut HandshakeContext, &[u8]) -> Result<()>);
101101
callback!(CaCallback: Fn(&MbedtlsList<Certificate>) -> Result<MbedtlsList<Certificate>>);
102+
callback!(PskCallback: Fn(&mut HandshakeContext, &str) -> Result<()>);
102103

103104

104105
#[repr(transparent)]
@@ -164,6 +165,7 @@ define!(
164165
sni_callback: Option<Arc<dyn SniCallback + 'static>>,
165166
ticket_callback: Option<Arc<dyn TicketCallback + 'static>>,
166167
ca_callback: Option<Arc<dyn CaCallback + 'static>>,
168+
psk_callback: Option<Arc<dyn PskCallback + 'static>>,
167169
};
168170
const drop: fn(&mut Self) = ssl_config_free;
169171
impl<'a> Into<ptr> {}
@@ -199,6 +201,7 @@ impl Config {
199201
sni_callback: None,
200202
ticket_callback: None,
201203
ca_callback: None,
204+
psk_callback: None,
202205
}
203206
}
204207

@@ -457,6 +460,43 @@ impl Config {
457460
self.dbg_callback = Some(Arc::new(cb));
458461
unsafe { ssl_conf_dbg(self.into(), Some(dbg_callback::<F>), &**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void) }
459462
}
463+
464+
pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> {
465+
unsafe { ssl_conf_psk(&mut self.inner,
466+
psk.as_ptr(), psk.len(),
467+
psk_identity.as_ptr(), psk_identity.len())
468+
.into_result().map(|_| ())
469+
}
470+
}
471+
472+
pub fn set_psk_callback<F>(&mut self, cb: F)
473+
where
474+
F: PskCallback + 'static,
475+
{
476+
unsafe extern "C" fn psk_callback<F>(
477+
closure: *mut c_void,
478+
ctx: *mut ssl_context,
479+
psk_identity: *const c_uchar,
480+
identity_len: size_t,
481+
) -> c_int
482+
where
483+
F: PskCallback + 'static,
484+
{
485+
let cb = &mut *(closure as *mut F);
486+
let context = UnsafeFrom::from(ctx).unwrap();
487+
488+
let mut ctx = HandshakeContext::init(context);
489+
490+
let psk_identity = std::str::from_utf8_unchecked(from_raw_parts(psk_identity, identity_len));
491+
match cb(&mut ctx, psk_identity) {
492+
Ok(()) => 0,
493+
Err(e) => e.to_int(),
494+
}
495+
}
496+
497+
self.psk_callback = Some(Arc::new(cb));
498+
unsafe { ssl_conf_psk_cb(self.into(), Some(psk_callback::<F>), &**self.psk_callback.as_mut().unwrap() as *const _ as *mut c_void) }
499+
}
460500
}
461501

462502
// TODO
@@ -466,8 +506,6 @@ impl Config {
466506
// ssl_conf_dtls_badmac_limit
467507
// ssl_conf_handshake_timeout
468508
// ssl_conf_session_cache
469-
// ssl_conf_psk
470-
// ssl_conf_psk_cb
471509
// ssl_conf_sig_hashes
472510
// ssl_conf_alpn_protocols
473511
// ssl_conf_fallback

mbedtls/src/ssl/context.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,14 @@ impl HandshakeContext {
408408

409409
Ok(())
410410
}
411+
412+
pub fn set_psk(&mut self, psk: &[u8]) -> Result<()> {
413+
unsafe {
414+
ssl_set_hs_psk(self.context.into(), psk.as_ptr(), psk.len()).into_result()?;
415+
}
416+
417+
Ok(())
418+
}
411419
}
412420

413421
#[cfg(test)]
@@ -485,7 +493,6 @@ mod tests {
485493
// ssl_renegotiate
486494
// ssl_send_alert_message
487495
// ssl_set_client_transport_id
488-
// ssl_set_hs_psk
489496
// ssl_set_timer_cb
490497
//
491498
// ssl_handshake_step

mbedtls/tests/ssl_conf_psk_cb.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#![allow(dead_code)]
2+
extern crate mbedtls;
3+
4+
use std::net::TcpStream;
5+
use std::sync::Arc;
6+
7+
use mbedtls::rng::CtrDrbg;
8+
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
9+
use mbedtls::ssl::{Config, Context};
10+
use mbedtls::ssl::context::HandshakeContext;
11+
use mbedtls::Result as TlsResult;
12+
use mbedtls::ssl::config::PskCallback;
13+
14+
mod support;
15+
use support::entropy::entropy_new;
16+
17+
fn client(conn: TcpStream, psk: &[u8]) -> TlsResult<()>
18+
{
19+
let entropy = Arc::new(entropy_new());
20+
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
21+
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
22+
config.set_rng(rng);
23+
config.set_psk(psk, "Client_identity")?;
24+
let mut ctx = Context::new(Arc::new(config));
25+
ctx.establish(conn, None).map(|_| ())
26+
}
27+
28+
fn server<F>(conn: TcpStream, psk_callback: F) -> TlsResult<()>
29+
where
30+
F: PskCallback + Send + 'static,
31+
{
32+
let entropy = Arc::new(entropy_new());
33+
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
34+
let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default);
35+
config.set_rng(rng);
36+
config.set_psk_callback(psk_callback);
37+
let mut ctx = Context::new(Arc::new(config));
38+
ctx.establish(conn, None).map(|_| ())
39+
}
40+
41+
#[cfg(unix)]
42+
mod test {
43+
use super::*;
44+
use std::thread;
45+
use crate::support::net::create_tcp_pair;
46+
use crate::support::keys;
47+
48+
#[test]
49+
fn callback_standard_psk() {
50+
let (c, s) = create_tcp_pair().unwrap();
51+
52+
let psk_callback =
53+
|ctx: &mut HandshakeContext, _: &str| {
54+
ctx.set_psk(keys::PRESHARED_KEY)
55+
};
56+
let c = thread::spawn(move || super::client(c, keys::PRESHARED_KEY).unwrap());
57+
let s = thread::spawn(move || super::server(s, psk_callback).unwrap());
58+
c.join().unwrap();
59+
s.join().unwrap();
60+
}
61+
}

mbedtls/tests/support/keys.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ pub const ROOT_CA_KEY: &'static str = concat!(include_str!("./keys/ca.key"),"\0"
9292
pub const EXPIRED_CERT_SUBJECT: &'static str = "CN=ExpiredNode";
9393
pub const EXPIRED_CERT: &'static str = concat!(include_str!("./keys/expired.crt"),"\0");
9494
pub const EXPIRED_KEY: &'static str = concat!(include_str!("./keys/expired.key"),"\0");
95+
96+
pub const PRESHARED_KEY: &'static [u8] = &[
97+
234, 206, 151, 23, 219, 21, 71, 144,
98+
107, 42, 23, 67, 249, 173, 182, 224 ];

0 commit comments

Comments
 (0)