Skip to content

Commit 4f90061

Browse files
feat(SHA): Borrow the SHA peripheral instead of unsafely stealing
1 parent 49040f2 commit 4f90061

12 files changed

+33
-9
lines changed

esp-mbedtls/src/compat/edge_nal_compat.rs

+7
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ use core::{
44
cell::{Cell, RefCell, UnsafeCell},
55
mem::MaybeUninit,
66
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
7+
ops::DerefMut,
78
ptr::NonNull,
89
};
910

11+
use crate::{hal::peripheral::PeripheralRef, SHA};
1012
use edge_nal::{Close, TcpBind};
13+
1114
use edge_nal_embassy::{Tcp, TcpAccept, TcpSocket};
1215

1316
pub struct TlsAcceptor<
@@ -23,6 +26,7 @@ pub struct TlsAcceptor<
2326
owns_rsa: bool,
2427
tls_buffers: &'d TlsBuffers<RX_SZ, TX_SZ>,
2528
tls_buffers_ptr: RefCell<NonNull<([u8; RX_SZ], [u8; TX_SZ])>>,
29+
sha: RefCell<PeripheralRef<'d, SHA>>,
2630
}
2731

2832
impl<'d, D, const N: usize, const RX_SZ: usize, const TX_SZ: usize> Drop
@@ -55,6 +59,7 @@ where
5559
port: u16,
5660
version: TlsVersion,
5761
certificates: Certificates<'d>,
62+
sha: impl Peripheral<P = SHA> + 'd,
5863
) -> Self {
5964
let acceptor = tcp
6065
.bind(SocketAddr::V4(SocketAddrV4::new(
@@ -73,6 +78,7 @@ where
7378
owns_rsa: false,
7479
tls_buffers,
7580
tls_buffers_ptr: RefCell::new(socket_buffers),
81+
sha: sha.into_ref().into(),
7682
}
7783
}
7884

@@ -146,6 +152,7 @@ where
146152
self.certificates,
147153
rx,
148154
tx,
155+
self.sha.borrow_mut().reborrow().deref_mut(),
149156
)?;
150157

151158
log::debug!("Establishing SSL connection");

esp-mbedtls/src/lib.rs

+14-6
Original file line numberDiff line numberDiff line change
@@ -495,10 +495,13 @@ impl<T> Session<T> {
495495
mode: Mode,
496496
min_version: TlsVersion,
497497
certificates: Certificates,
498+
sha: impl Peripheral<P = SHA>,
498499
) -> Result<Self, TlsError> {
499-
// TODO: Take peripheral from user
500-
let sha = Sha::new(unsafe { SHA::steal() });
501-
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).replace(sha));
500+
critical_section::with(|cs| {
501+
SHARED_SHA
502+
.borrow_ref_mut(cs)
503+
.replace(unsafe { core::mem::transmute(Sha::new(sha)) })
504+
});
502505

503506
let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) =
504507
certificates.init_ssl(servername, mode, min_version)?;
@@ -641,6 +644,7 @@ impl<T> Drop for Session<T> {
641644
if self.owns_rsa {
642645
RSA_REF = core::mem::transmute(None::<RSA>);
643646
}
647+
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).take());
644648
mbedtls_ssl_close_notify(self.ssl_context);
645649
mbedtls_ctr_drbg_free(self.drbg_context);
646650
mbedtls_ssl_config_free(self.ssl_config);
@@ -751,10 +755,13 @@ pub mod asynch {
751755

752756
rx_buffer: &'a mut [u8; RX_SIZE],
753757
tx_buffer: &'a mut [u8; TX_SIZE],
758+
sha: impl Peripheral<P = SHA>,
754759
) -> Result<Self, TlsError> {
755-
// TODO: Take peripheral from user
756-
let sha = Sha::new(unsafe { SHA::steal() });
757-
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).replace(sha));
760+
critical_section::with(|cs| {
761+
SHARED_SHA
762+
.borrow_ref_mut(cs)
763+
.replace(unsafe { core::mem::transmute(Sha::new(sha)) })
764+
});
758765

759766
let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) =
760767
certificates.init_ssl(servername, mode, min_version)?;
@@ -797,6 +804,7 @@ pub mod asynch {
797804
if self.owns_rsa {
798805
RSA_REF = core::mem::transmute(None::<RSA>);
799806
}
807+
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).take());
800808
mbedtls_ssl_close_notify(self.ssl_context);
801809
mbedtls_ctr_drbg_free(self.drbg_context);
802810
mbedtls_ssl_config_free(self.ssl_config);

esp-mbedtls/src/sha/sha1.rs

-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ pub struct mbedtls_sha1_context {
1010

1111
#[no_mangle]
1212
pub unsafe extern "C" fn mbedtls_sha1_init(ctx: *mut mbedtls_sha1_context) {
13-
let sha = crate::Sha::new(unsafe { crate::SHA::steal() });
14-
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).replace(sha));
1513
let hasher_mem =
1614
crate::calloc(1, core::mem::size_of::<Context<Sha1>>() as u32) as *mut Context<Sha1>;
1715
core::ptr::write(hasher_mem, Context::<Sha1>::new());

examples/async_client.rs

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async fn main(spawner: Spawner) -> ! {
140140
},
141141
mk_static!([u8; 4096], [0; 4096]),
142142
mk_static!([u8; 4096], [0; 4096]),
143+
peripherals.SHA,
143144
)
144145
.unwrap()
145146
.with_hardware_rsa(peripherals.RSA);

examples/async_client_mTLS.rs

+1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ async fn main(spawner: Spawner) -> ! {
146146
certificates,
147147
mk_static!([u8; 4096], [0; 4096]),
148148
mk_static!([u8; 4096], [0; 4096]),
149+
peripherals.SHA,
149150
)
150151
.unwrap()
151152
.with_hardware_rsa(peripherals.RSA);

examples/async_server.rs

+1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ async fn main(spawner: Spawner) -> ! {
164164
},
165165
tls_rx_buffer,
166166
tls_tx_buffer,
167+
&mut peripherals.SHA,
167168
)
168169
.unwrap()
169170
.with_hardware_rsa(&mut peripherals.RSA);

examples/async_server_mTLS.rs

+1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ async fn main(spawner: Spawner) -> ! {
182182
},
183183
tls_rx_buffer,
184184
tls_tx_buffer,
185+
&mut peripherals.SHA,
185186
)
186187
.unwrap()
187188
.with_hardware_rsa(&mut peripherals.RSA);

examples/edge_server.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use esp_wifi::wifi::{
3434
WifiState,
3535
};
3636
use esp_wifi::{init, EspWifiInitFor};
37-
use hal::{prelude::*, rng::Rng, timer::timg::TimerGroup};
37+
use hal::{peripherals::SHA, prelude::*, rng::Rng, timer::timg::TimerGroup};
3838

3939
// Patch until https://github.com/embassy-rs/static-cell/issues/16 is fixed
4040
macro_rules! mk_static {
@@ -164,13 +164,16 @@ async fn main(spawner: Spawner) -> ! {
164164
..Default::default()
165165
};
166166

167+
let sha = mk_static!(SHA, peripherals.SHA);
168+
167169
loop {
168170
let tls_acceptor = esp_mbedtls::asynch::TlsAcceptor::new(
169171
tcp,
170172
tls_buffers,
171173
443,
172174
TlsVersion::Tls1_2,
173175
certificates,
176+
&mut *sha,
174177
)
175178
.await
176179
.with_hardware_rsa(&mut peripherals.RSA);

examples/sync_client.rs

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ fn main() -> ! {
116116
.ok(),
117117
..Default::default()
118118
},
119+
peripherals.SHA,
119120
)
120121
.unwrap()
121122
.with_hardware_rsa(peripherals.RSA);

examples/sync_client_mTLS.rs

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ fn main() -> ! {
123123
Mode::Client,
124124
TlsVersion::Tls1_3,
125125
certificates,
126+
peripherals.SHA,
126127
)
127128
.unwrap()
128129
.with_hardware_rsa(peripherals.RSA);

examples/sync_server.rs

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ fn main() -> ! {
137137
.ok(),
138138
..Default::default()
139139
},
140+
&mut peripherals.SHA,
140141
)
141142
.unwrap()
142143
.with_hardware_rsa(&mut peripherals.RSA);

examples/sync_server_mTLS.rs

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ fn main() -> ! {
158158
.ok(),
159159
..Default::default()
160160
},
161+
&mut peripherals.SHA,
161162
)
162163
.unwrap()
163164
.with_hardware_rsa(&mut peripherals.RSA);

0 commit comments

Comments
 (0)