Skip to content

Commit b71ec09

Browse files
dongcarlNatnatee Dokmai
and
Natnatee Dokmai
committed
Add PSK support
Co-authored-by: Natnatee Dokmai <[email protected]>
1 parent 56b1dcf commit b71ec09

File tree

5 files changed

+123
-3
lines changed

5 files changed

+123
-3
lines changed

mbedtls/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ required-features = ["std"]
8585
name = "ssl_conf_ca_cb"
8686
required-features = ["std"]
8787

88+
[[test]]
89+
name = "ssl_conf_psk_cb"
90+
required-features = ["std"]
91+
8892
[[test]]
8993
name = "ssl_conf_verify"
9094
required-features = ["std"]

mbedtls/src/ssl/config.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ define!(
9999
callback!(DbgCallback: Fn(i32, Cow<'_, str>, i32, Cow<'_, str>) -> ());
100100
callback!(SniCallback: Fn(&mut HandshakeContext, &[u8]) -> Result<()>);
101101
callback!(CaCallback: Fn(&MbedtlsList<Certificate>) -> Result<MbedtlsList<Certificate>>);
102+
callback!(PskCallback: Fn(&mut HandshakeContext, &str) -> Result<()>);
102103

103104

104105
#[repr(transparent)]
@@ -164,6 +165,7 @@ define!(
164165
sni_callback: Option<Arc<dyn SniCallback + 'static>>,
165166
ticket_callback: Option<Arc<dyn TicketCallback + 'static>>,
166167
ca_callback: Option<Arc<dyn CaCallback + 'static>>,
168+
psk_callback: Option<Arc<dyn PskCallback + 'static>>,
167169
};
168170
const drop: fn(&mut Self) = ssl_config_free;
169171
impl<'a> Into<ptr> {}
@@ -199,6 +201,7 @@ impl Config {
199201
sni_callback: None,
200202
ticket_callback: None,
201203
ca_callback: None,
204+
psk_callback: None,
202205
}
203206
}
204207

@@ -457,6 +460,42 @@ impl Config {
457460
self.dbg_callback = Some(Arc::new(cb));
458461
unsafe { ssl_conf_dbg(self.into(), Some(dbg_callback::<F>), &**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void) }
459462
}
463+
464+
pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> {
465+
unsafe { ssl_conf_psk(&mut self.inner,
466+
psk.as_ptr(), psk.len(),
467+
psk_identity.as_ptr(), psk_identity.len())
468+
.into_result().map(|_| ())
469+
}
470+
}
471+
472+
#[cfg(feature = "std")]
473+
pub fn set_psk_callback<F>(&mut self, cb: F)
474+
where
475+
F: PskCallback + 'static,
476+
{
477+
unsafe extern "C" fn psk_callback<F>(
478+
closure: *mut c_void,
479+
ctx: *mut ssl_context,
480+
psk_identity: *const c_uchar,
481+
identity_len: size_t,
482+
) -> c_int
483+
where
484+
F: PskCallback + 'static,
485+
{
486+
let cb = &mut *(closure as *mut F);
487+
let ctx = UnsafeFrom::from(ctx).unwrap();
488+
489+
let psk_identity = std::str::from_utf8_unchecked(from_raw_parts(psk_identity, identity_len));
490+
match cb(ctx, psk_identity) {
491+
Ok(()) => 0,
492+
Err(e) => e.to_int(),
493+
}
494+
}
495+
496+
self.psk_callback = Some(Arc::new(cb));
497+
unsafe { ssl_conf_psk_cb(self.into(), Some(psk_callback::<F>), &**self.psk_callback.as_mut().unwrap() as *const _ as *mut c_void) }
498+
}
460499
}
461500

462501
// TODO
@@ -466,8 +505,6 @@ impl Config {
466505
// ssl_conf_dtls_badmac_limit
467506
// ssl_conf_handshake_timeout
468507
// ssl_conf_session_cache
469-
// ssl_conf_psk
470-
// ssl_conf_psk_cb
471508
// ssl_conf_sig_hashes
472509
// ssl_conf_alpn_protocols
473510
// ssl_conf_fallback

mbedtls/src/ssl/context.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,14 @@ impl HandshakeContext {
408408

409409
Ok(())
410410
}
411+
412+
pub fn set_psk(&mut self, psk: &[u8]) -> Result<()> {
413+
unsafe {
414+
ssl_set_hs_psk(self.into(), psk.as_ptr(), psk.len()).into_result()?;
415+
}
416+
417+
Ok(())
418+
}
411419
}
412420

413421
#[cfg(test)]
@@ -485,7 +493,6 @@ mod tests {
485493
// ssl_renegotiate
486494
// ssl_send_alert_message
487495
// ssl_set_client_transport_id
488-
// ssl_set_hs_psk
489496
// ssl_set_timer_cb
490497
//
491498
// ssl_handshake_step

mbedtls/tests/ssl_conf_psk_cb.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
2+
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
3+
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
4+
* option. This file may not be copied, modified, or distributed except
5+
* according to those terms. */
6+
7+
#![cfg(not(target_env = "sgx"))]
8+
9+
// needed to have common code for `mod support` in unit and integrations tests
10+
extern crate mbedtls;
11+
12+
use std::net::TcpStream;
13+
use std::sync::Arc;
14+
15+
use mbedtls::rng::CtrDrbg;
16+
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
17+
use mbedtls::ssl::{Config, Context};
18+
use mbedtls::ssl::context::HandshakeContext;
19+
use mbedtls::Result as TlsResult;
20+
use mbedtls::ssl::config::PskCallback;
21+
22+
mod support;
23+
use support::entropy::entropy_new;
24+
25+
fn client(conn: TcpStream, psk: &[u8]) -> TlsResult<()>
26+
{
27+
let entropy = Arc::new(entropy_new());
28+
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
29+
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
30+
config.set_rng(rng);
31+
config.set_psk(psk, "Client_identity")?;
32+
let mut ctx = Context::new(Arc::new(config));
33+
ctx.establish(conn, None).map(|_| ())
34+
}
35+
36+
fn server<F>(conn: TcpStream, psk_callback: F) -> TlsResult<()>
37+
where
38+
F: PskCallback + Send + 'static,
39+
{
40+
let entropy = Arc::new(entropy_new());
41+
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
42+
let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default);
43+
config.set_rng(rng);
44+
config.set_psk_callback(psk_callback);
45+
let mut ctx = Context::new(Arc::new(config));
46+
ctx.establish(conn, None).map(|_| ())
47+
}
48+
49+
mod test {
50+
use super::*;
51+
use std::thread;
52+
use crate::support::net::create_tcp_pair;
53+
use crate::support::keys;
54+
55+
#[test]
56+
fn callback_standard_psk() {
57+
let (c, s) = create_tcp_pair().unwrap();
58+
59+
let psk_callback =
60+
|ctx: &mut HandshakeContext, _: &str| {
61+
ctx.set_psk(keys::PRESHARED_KEY)
62+
};
63+
let c = thread::spawn(move || super::client(c, keys::PRESHARED_KEY).unwrap());
64+
let s = thread::spawn(move || super::server(s, psk_callback).unwrap());
65+
c.join().unwrap();
66+
s.join().unwrap();
67+
}
68+
}

mbedtls/tests/support/keys.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ pub const ROOT_CA_KEY: &'static str = concat!(include_str!("./keys/ca.key"),"\0"
9292
pub const EXPIRED_CERT_SUBJECT: &'static str = "CN=ExpiredNode";
9393
pub const EXPIRED_CERT: &'static str = concat!(include_str!("./keys/expired.crt"),"\0");
9494
pub const EXPIRED_KEY: &'static str = concat!(include_str!("./keys/expired.key"),"\0");
95+
96+
pub const PRESHARED_KEY: &'static [u8] = &[
97+
234, 206, 151, 23, 219, 21, 71, 144,
98+
107, 42, 23, 67, 249, 173, 182, 224 ];

0 commit comments

Comments
 (0)