Skip to content

Commit 9df4d7b

Browse files
feat(SHA): Borrow the SHA peripheral instead of unsafely stealing
1 parent e5965e4 commit 9df4d7b

12 files changed

+33
-9
lines changed

esp-mbedtls/src/compat/edge_nal_compat.rs

+7
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ use core::{
44
cell::{Cell, RefCell, UnsafeCell},
55
mem::MaybeUninit,
66
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
7+
ops::DerefMut,
78
ptr::NonNull,
89
};
910

11+
use crate::{hal::peripheral::PeripheralRef, SHA};
12+
1013
use edge_nal::TcpBind;
1114
use edge_nal_embassy::{Tcp, TcpAccept, TcpSocket};
1215

@@ -23,6 +26,7 @@ pub struct TlsAcceptor<
2326
owns_rsa: bool,
2427
tls_buffers: &'d TlsBuffers<RX_SZ, TX_SZ>,
2528
tls_buffers_ptr: RefCell<NonNull<([u8; RX_SZ], [u8; TX_SZ])>>,
29+
sha: RefCell<PeripheralRef<'d, SHA>>,
2630
}
2731

2832
impl<'d, D, const N: usize, const RX_SZ: usize, const TX_SZ: usize> Drop
@@ -55,6 +59,7 @@ where
5559
port: u16,
5660
version: TlsVersion,
5761
certificates: Certificates<'d>,
62+
sha: impl Peripheral<P = SHA> + 'd,
5863
) -> Self {
5964
let acceptor = tcp
6065
.bind(SocketAddr::V4(SocketAddrV4::new(
@@ -73,6 +78,7 @@ where
7378
owns_rsa: false,
7479
tls_buffers,
7580
tls_buffers_ptr: RefCell::new(socket_buffers),
81+
sha: sha.into_ref().into(),
7682
}
7783
}
7884

@@ -127,6 +133,7 @@ where
127133
self.certificates,
128134
rx,
129135
tx,
136+
self.sha.borrow_mut().reborrow().deref_mut(),
130137
)?;
131138

132139
log::debug!("Establishing SSL connection");

esp-mbedtls/src/lib.rs

+14-6
Original file line numberDiff line numberDiff line change
@@ -488,10 +488,13 @@ impl<T> Session<T> {
488488
mode: Mode,
489489
min_version: TlsVersion,
490490
certificates: Certificates,
491+
sha: impl Peripheral<P = SHA>,
491492
) -> Result<Self, TlsError> {
492-
// TODO: Take peripheral from user
493-
let sha = Sha::new(unsafe { SHA::steal() });
494-
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).replace(sha));
493+
critical_section::with(|cs| {
494+
SHARED_SHA
495+
.borrow_ref_mut(cs)
496+
.replace(unsafe { core::mem::transmute(Sha::new(sha)) })
497+
});
495498

496499
let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) =
497500
certificates.init_ssl(servername, mode, min_version)?;
@@ -634,6 +637,7 @@ impl<T> Drop for Session<T> {
634637
if self.owns_rsa {
635638
RSA_REF = core::mem::transmute(None::<RSA>);
636639
}
640+
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).take());
637641
mbedtls_ssl_close_notify(self.ssl_context);
638642
mbedtls_ctr_drbg_free(self.drbg_context);
639643
mbedtls_ssl_config_free(self.ssl_config);
@@ -744,10 +748,13 @@ pub mod asynch {
744748

745749
rx_buffer: &'a mut [u8; RX_SIZE],
746750
tx_buffer: &'a mut [u8; TX_SIZE],
751+
sha: impl Peripheral<P = SHA>,
747752
) -> Result<Self, TlsError> {
748-
// TODO: Take peripheral from user
749-
let sha = Sha::new(unsafe { SHA::steal() });
750-
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).replace(sha));
753+
critical_section::with(|cs| {
754+
SHARED_SHA
755+
.borrow_ref_mut(cs)
756+
.replace(unsafe { core::mem::transmute(Sha::new(sha)) })
757+
});
751758

752759
let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) =
753760
certificates.init_ssl(servername, mode, min_version)?;
@@ -790,6 +797,7 @@ pub mod asynch {
790797
if self.owns_rsa {
791798
RSA_REF = core::mem::transmute(None::<RSA>);
792799
}
800+
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).take());
793801
mbedtls_ssl_close_notify(self.ssl_context);
794802
mbedtls_ctr_drbg_free(self.drbg_context);
795803
mbedtls_ssl_config_free(self.ssl_config);

esp-mbedtls/src/sha/sha1.rs

-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ pub struct mbedtls_sha1_context {
1010

1111
#[no_mangle]
1212
pub unsafe extern "C" fn mbedtls_sha1_init(ctx: *mut mbedtls_sha1_context) {
13-
let sha = crate::Sha::new(unsafe { crate::SHA::steal() });
14-
critical_section::with(|cs| SHARED_SHA.borrow_ref_mut(cs).replace(sha));
1513
let hasher_mem =
1614
crate::calloc(1, core::mem::size_of::<Context<Sha1>>() as u32) as *mut Context<Sha1>;
1715
core::ptr::write(hasher_mem, Context::<Sha1>::new());

examples/async_client.rs

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async fn main(spawner: Spawner) -> ! {
140140
},
141141
mk_static!([u8; 4096], [0; 4096]),
142142
mk_static!([u8; 4096], [0; 4096]),
143+
peripherals.SHA,
143144
)
144145
.unwrap()
145146
.with_hardware_rsa(peripherals.RSA);

examples/async_client_mTLS.rs

+1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ async fn main(spawner: Spawner) -> ! {
146146
certificates,
147147
mk_static!([u8; 4096], [0; 4096]),
148148
mk_static!([u8; 4096], [0; 4096]),
149+
peripherals.SHA,
149150
)
150151
.unwrap()
151152
.with_hardware_rsa(peripherals.RSA);

examples/async_server.rs

+1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ async fn main(spawner: Spawner) -> ! {
164164
},
165165
tls_rx_buffer,
166166
tls_tx_buffer,
167+
&mut peripherals.SHA,
167168
)
168169
.unwrap()
169170
.with_hardware_rsa(&mut peripherals.RSA);

examples/async_server_mTLS.rs

+1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ async fn main(spawner: Spawner) -> ! {
182182
},
183183
tls_rx_buffer,
184184
tls_tx_buffer,
185+
&mut peripherals.SHA,
185186
)
186187
.unwrap()
187188
.with_hardware_rsa(&mut peripherals.RSA);

examples/edge_server.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use esp_wifi::wifi::{
3434
WifiState,
3535
};
3636
use esp_wifi::{init, EspWifiInitFor};
37-
use hal::{prelude::*, rng::Rng, timer::timg::TimerGroup};
37+
use hal::{peripherals::SHA, prelude::*, rng::Rng, timer::timg::TimerGroup};
3838

3939
// Patch until https://github.com/embassy-rs/static-cell/issues/16 is fixed
4040
macro_rules! mk_static {
@@ -161,13 +161,16 @@ async fn main(spawner: Spawner) -> ! {
161161
..Default::default()
162162
};
163163

164+
let sha = mk_static!(SHA, peripherals.SHA);
165+
164166
loop {
165167
let tls_acceptor = esp_mbedtls::asynch::TlsAcceptor::new(
166168
tcp,
167169
tls_buffers,
168170
443,
169171
TlsVersion::Tls1_2,
170172
certificates,
173+
&mut *sha,
171174
)
172175
.await
173176
.with_hardware_rsa(&mut peripherals.RSA);

examples/sync_client.rs

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ fn main() -> ! {
116116
.ok(),
117117
..Default::default()
118118
},
119+
peripherals.SHA,
119120
)
120121
.unwrap()
121122
.with_hardware_rsa(peripherals.RSA);

examples/sync_client_mTLS.rs

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ fn main() -> ! {
123123
Mode::Client,
124124
TlsVersion::Tls1_3,
125125
certificates,
126+
peripherals.SHA,
126127
)
127128
.unwrap()
128129
.with_hardware_rsa(peripherals.RSA);

examples/sync_server.rs

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ fn main() -> ! {
137137
.ok(),
138138
..Default::default()
139139
},
140+
&mut peripherals.SHA,
140141
)
141142
.unwrap()
142143
.with_hardware_rsa(&mut peripherals.RSA);

examples/sync_server_mTLS.rs

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ fn main() -> ! {
158158
.ok(),
159159
..Default::default()
160160
},
161+
&mut peripherals.SHA,
161162
)
162163
.unwrap()
163164
.with_hardware_rsa(&mut peripherals.RSA);

0 commit comments

Comments
 (0)