Skip to content

Add PSK support #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mbedtls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ required-features = ["std"]
name = "ssl_conf_ca_cb"
required-features = ["std"]

[[test]]
name = "ssl_conf_psk_cb"
required-features = ["std"]

[[test]]
name = "ssl_conf_verify"
required-features = ["std"]
Expand Down
41 changes: 39 additions & 2 deletions mbedtls/src/ssl/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ define!(
callback!(DbgCallback: Fn(i32, Cow<'_, str>, i32, Cow<'_, str>) -> ());
callback!(SniCallback: Fn(&mut HandshakeContext, &[u8]) -> Result<()>);
callback!(CaCallback: Fn(&MbedtlsList<Certificate>) -> Result<MbedtlsList<Certificate>>);
callback!(PskCallback: Fn(&mut HandshakeContext, &str) -> Result<()>);


#[repr(transparent)]
Expand Down Expand Up @@ -164,6 +165,7 @@ define!(
sni_callback: Option<Arc<dyn SniCallback + 'static>>,
ticket_callback: Option<Arc<dyn TicketCallback + 'static>>,
ca_callback: Option<Arc<dyn CaCallback + 'static>>,
psk_callback: Option<Arc<dyn PskCallback + 'static>>,
};
const drop: fn(&mut Self) = ssl_config_free;
impl<'a> Into<ptr> {}
Expand Down Expand Up @@ -199,6 +201,7 @@ impl Config {
sni_callback: None,
ticket_callback: None,
ca_callback: None,
psk_callback: None,
}
}

Expand Down Expand Up @@ -457,6 +460,42 @@ impl Config {
self.dbg_callback = Some(Arc::new(cb));
unsafe { ssl_conf_dbg(self.into(), Some(dbg_callback::<F>), &**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void) }
}

pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> {
unsafe { ssl_conf_psk(&mut self.inner,
psk.as_ptr(), psk.len(),
psk_identity.as_ptr(), psk_identity.len())
.into_result().map(|_| ())
}
}

#[cfg(feature = "std")]
pub fn set_psk_callback<F>(&mut self, cb: F)
where
F: PskCallback + 'static,
{
unsafe extern "C" fn psk_callback<F>(
closure: *mut c_void,
ctx: *mut ssl_context,
psk_identity: *const c_uchar,
identity_len: size_t,
) -> c_int
where
F: PskCallback + 'static,
{
let cb = &mut *(closure as *mut F);
let ctx = UnsafeFrom::from(ctx).unwrap();

let psk_identity = std::str::from_utf8_unchecked(from_raw_parts(psk_identity, identity_len));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There doesn't seem to be a guarantee that psk_identity is always a valid utf8 string.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the best thing to do here? unwrap a checked from_utf8 so the error is more apparent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you could change the type of PskCallback to

callback!(PskCallback: Fn(&mut HandshakeContext, &[u8]) -> Result<()>);

so there isn't a need to convert it.

match cb(ctx, psk_identity) {
Ok(()) => 0,
Err(e) => e.to_int(),
}
}

self.psk_callback = Some(Arc::new(cb));
unsafe { ssl_conf_psk_cb(self.into(), Some(psk_callback::<F>), &**self.psk_callback.as_mut().unwrap() as *const _ as *mut c_void) }
}
}

// TODO
Expand All @@ -466,8 +505,6 @@ impl Config {
// ssl_conf_dtls_badmac_limit
// ssl_conf_handshake_timeout
// ssl_conf_session_cache
// ssl_conf_psk
// ssl_conf_psk_cb
// ssl_conf_sig_hashes
// ssl_conf_alpn_protocols
// ssl_conf_fallback
Expand Down
9 changes: 8 additions & 1 deletion mbedtls/src/ssl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,14 @@ impl HandshakeContext {

Ok(())
}

pub fn set_psk(&mut self, psk: &[u8]) -> Result<()> {
unsafe {
ssl_set_hs_psk(self.into(), psk.as_ptr(), psk.len()).into_result()?;
}

Ok(())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -485,7 +493,6 @@ mod tests {
// ssl_renegotiate
// ssl_send_alert_message
// ssl_set_client_transport_id
// ssl_set_hs_psk
// ssl_set_timer_cb
//
// ssl_handshake_step
Expand Down
68 changes: 68 additions & 0 deletions mbedtls/tests/ssl_conf_psk_cb.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
* option. This file may not be copied, modified, or distributed except
* according to those terms. */

#![cfg(not(target_env = "sgx"))]

// needed to have common code for `mod support` in unit and integrations tests
extern crate mbedtls;

use std::net::TcpStream;
use std::sync::Arc;

use mbedtls::rng::CtrDrbg;
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
use mbedtls::ssl::{Config, Context};
use mbedtls::ssl::context::HandshakeContext;
use mbedtls::Result as TlsResult;
use mbedtls::ssl::config::PskCallback;

mod support;
use support::entropy::entropy_new;

fn client(conn: TcpStream, psk: &[u8]) -> TlsResult<()>
{
let entropy = Arc::new(entropy_new());
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
config.set_rng(rng);
config.set_psk(psk, "Client_identity")?;
let mut ctx = Context::new(Arc::new(config));
ctx.establish(conn, None).map(|_| ())
}

fn server<F>(conn: TcpStream, psk_callback: F) -> TlsResult<()>
where
F: PskCallback + Send + 'static,
{
let entropy = Arc::new(entropy_new());
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default);
config.set_rng(rng);
config.set_psk_callback(psk_callback);
let mut ctx = Context::new(Arc::new(config));
ctx.establish(conn, None).map(|_| ())
}

mod test {
use super::*;
use std::thread;
use crate::support::net::create_tcp_pair;
use crate::support::keys;

#[test]
fn callback_standard_psk() {
let (c, s) = create_tcp_pair().unwrap();

let psk_callback =
|ctx: &mut HandshakeContext, _: &str| {
ctx.set_psk(keys::PRESHARED_KEY)
};
let c = thread::spawn(move || super::client(c, keys::PRESHARED_KEY).unwrap());
let s = thread::spawn(move || super::server(s, psk_callback).unwrap());
c.join().unwrap();
s.join().unwrap();
}
}
4 changes: 4 additions & 0 deletions mbedtls/tests/support/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,7 @@ pub const ROOT_CA_KEY: &'static str = concat!(include_str!("./keys/ca.key"),"\0"
pub const EXPIRED_CERT_SUBJECT: &'static str = "CN=ExpiredNode";
pub const EXPIRED_CERT: &'static str = concat!(include_str!("./keys/expired.crt"),"\0");
pub const EXPIRED_KEY: &'static str = concat!(include_str!("./keys/expired.key"),"\0");

pub const PRESHARED_KEY: &'static [u8] = &[
234, 206, 151, 23, 219, 21, 71, 144,
107, 42, 23, 67, 249, 173, 182, 224 ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you pick a preshared key that is not valid utf8 and also contains a zero byte? It may trigger some weird code paths.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could we add a case with a leading zero byte? I'm certain it is correctly handled, but I have seen way too many leading zero bugs in the wild/

(Orthogonal but also see https://mbed-tls.readthedocs.io/en/latest/security-advisories/advisories/mbedtls-security-advisory-2020-09-3/).

Copy link
Author

@dongcarl dongcarl Aug 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh! If it's not valid UTF-8, we should expect an error from from_utf8?

What kind of behaviour do we expect for leading zero bytes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a requirement that pre-shared keys are valid UTF-8? At least in RFC4279, only PSK identities are assumed to be UTF-8 (Sec. 5.1). Also Sec. 4 states that PSK are the result of Diffie Hellman computations, so no UTF-8 expected.