@@ -84,27 +84,20 @@ class SSLContext
84
84
85
85
~SSLContext ()
86
86
{
87
- if (_ssl ) {
88
- ssl_free (_ssl );
89
- _ssl = nullptr ;
87
+ if (io_ctx ) {
88
+ io_ctx-> unref ( );
89
+ io_ctx = NULL ;
90
90
}
91
-
91
+ _ssl = nullptr ;
92
92
--_ssl_ctx_refcnt;
93
93
if (_ssl_ctx_refcnt == 0 ) {
94
94
ssl_ctx_free (_ssl_ctx);
95
95
}
96
96
}
97
97
98
- void ref ( )
98
+ static void _delete_shared_SSL (SSL *_to_del )
99
99
{
100
- ++_refcnt;
101
- }
102
-
103
- void unref ()
104
- {
105
- if (--_refcnt == 0 ) {
106
- delete this ;
107
- }
100
+ ssl_free (_to_del);
108
101
}
109
102
110
103
void connect (ClientContext* ctx, const char * hostName, uint32_t timeout_ms)
@@ -116,17 +109,23 @@ class SSLContext
116
109
ssl_free will want to send a close notify alert, but the old TCP connection
117
110
is already gone at this point, so reset io_ctx. */
118
111
io_ctx = nullptr ;
119
- ssl_free ( _ssl) ;
112
+ _ssl = nullptr ;
120
113
_available = 0 ;
121
114
_read_ptr = nullptr ;
122
115
}
123
116
io_ctx = ctx;
124
- _ssl = ssl_client_new (_ssl_ctx, reinterpret_cast <int >(this ), nullptr , 0 , ext);
117
+ ctx->ref ();
118
+
119
+ // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
120
+ SSL *_new_ssl = ssl_client_new (_ssl_ctx, reinterpret_cast <int >(this ), nullptr , 0 , ext);
121
+ std::shared_ptr<SSL> _new_ssl_shared (_new_ssl, _delete_shared_SSL);
122
+ _ssl = _new_ssl_shared;
123
+
125
124
uint32_t t = millis ();
126
125
127
- while (millis () - t < timeout_ms && ssl_handshake_status (_ssl) != SSL_OK) {
126
+ while (millis () - t < timeout_ms && ssl_handshake_status (_ssl. get () ) != SSL_OK) {
128
127
uint8_t * data;
129
- int rc = ssl_read (_ssl, &data);
128
+ int rc = ssl_read (_ssl. get () , &data);
130
129
if (rc < SSL_OK) {
131
130
ssl_display_error (rc);
132
131
break ;
@@ -136,30 +135,40 @@ class SSLContext
136
135
137
136
void connectServer (ClientContext *ctx) {
138
137
io_ctx = ctx;
139
- _ssl = ssl_server_new (_ssl_ctx, reinterpret_cast <int >(this ));
138
+ ctx->ref ();
139
+
140
+ // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
141
+ SSL *_new_ssl = ssl_server_new (_ssl_ctx, reinterpret_cast <int >(this ));
142
+ std::shared_ptr<SSL> _new_ssl_shared (_new_ssl, _delete_shared_SSL);
143
+ _ssl = _new_ssl_shared;
144
+
140
145
_isServer = true ;
141
146
142
147
uint32_t timeout_ms = 5000 ;
143
148
uint32_t t = millis ();
144
149
145
- while (millis () - t < timeout_ms && ssl_handshake_status (_ssl) != SSL_OK) {
150
+ while (millis () - t < timeout_ms && ssl_handshake_status (_ssl. get () ) != SSL_OK) {
146
151
uint8_t * data;
147
- int rc = ssl_read (_ssl, &data);
152
+ int rc = ssl_read (_ssl. get () , &data);
148
153
if (rc < SSL_OK) {
154
+ ssl_display_error (rc);
149
155
break ;
150
156
}
151
157
}
152
158
}
153
159
154
160
void stop ()
155
161
{
162
+ if (io_ctx) {
163
+ io_ctx->unref ();
164
+ }
156
165
io_ctx = nullptr ;
157
166
}
158
167
159
168
bool connected ()
160
169
{
161
170
if (_isServer) return _ssl != nullptr ;
162
- else return _ssl != nullptr && ssl_handshake_status (_ssl) == SSL_OK;
171
+ else return _ssl != nullptr && ssl_handshake_status (_ssl. get () ) == SSL_OK;
163
172
}
164
173
165
174
int read (uint8_t * dst, size_t size)
@@ -302,7 +311,7 @@ class SSLContext
302
311
303
312
bool verifyCert ()
304
313
{
305
- int rc = ssl_verify_cert (_ssl);
314
+ int rc = ssl_verify_cert (_ssl. get () );
306
315
if (_allowSelfSignedCerts && rc == SSL_X509_ERROR (X509_VFY_ERROR_SELF_SIGNED)) {
307
316
DEBUGV (" Allowing self-signed certificate\n " );
308
317
return true ;
@@ -321,12 +330,14 @@ class SSLContext
321
330
322
331
operator SSL*()
323
332
{
324
- return _ssl;
333
+ return _ssl. get () ;
325
334
}
326
335
327
336
static ClientContext* getIOContext (int fd)
328
337
{
329
- return reinterpret_cast <SSLContext*>(fd)->io_ctx ;
338
+ if (!fd) return NULL ;
339
+ SSLContext *thisSSL = reinterpret_cast <SSLContext*>(fd);
340
+ return thisSSL->io_ctx ;
330
341
}
331
342
332
343
int loadServerX509Cert (const uint8_t *cert, int len) {
@@ -347,10 +358,9 @@ class SSLContext
347
358
optimistic_yield (100 );
348
359
349
360
uint8_t * data;
350
- int rc = ssl_read (_ssl, &data);
361
+ int rc = ssl_read (_ssl. get () , &data);
351
362
if (rc <= 0 ) {
352
363
if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) {
353
- ssl_free (_ssl);
354
364
_ssl = nullptr ;
355
365
}
356
366
return 0 ;
@@ -367,7 +377,7 @@ class SSLContext
367
377
return 0 ;
368
378
}
369
379
370
- int rc = ssl_write (_ssl, src, size);
380
+ int rc = ssl_write (_ssl. get () , src, size);
371
381
if (rc >= 0 ) {
372
382
return rc;
373
383
}
@@ -410,12 +420,11 @@ class SSLContext
410
420
{
411
421
return !_writeBuffers.empty ();
412
422
}
413
-
423
+ public:
414
424
bool _isServer = false ;
415
425
static SSL_CTX* _ssl_ctx;
416
426
static int _ssl_ctx_refcnt;
417
- SSL* _ssl = nullptr ;
418
- int _refcnt = 0 ;
427
+ std::shared_ptr<SSL> _ssl = nullptr ;
419
428
const uint8_t * _read_ptr = nullptr ;
420
429
size_t _available = 0 ;
421
430
BufferList _writeBuffers;
@@ -434,42 +443,28 @@ WiFiClientSecure::WiFiClientSecure()
434
443
435
444
WiFiClientSecure::~WiFiClientSecure ()
436
445
{
437
- if (_ssl) {
438
- _ssl->unref ();
439
- }
446
+ _ssl = nullptr ;
440
447
}
441
448
442
- WiFiClientSecure::WiFiClientSecure (const WiFiClientSecure& other)
443
- : WiFiClient(static_cast <const WiFiClient&>(other))
444
- {
445
- _ssl = other._ssl ;
446
- if (_ssl) {
447
- _ssl->ref ();
449
+ static void _delete_shared_SSLContext (SSLContext *_to_del)
450
+ {
451
+ delete _to_del;
448
452
}
449
- }
450
453
451
- WiFiClientSecure& WiFiClientSecure::operator =(const WiFiClientSecure& rhs)
452
- {
453
- (WiFiClient&) *this = rhs;
454
- _ssl = rhs._ssl ;
455
- if (_ssl) {
456
- _ssl->ref ();
457
- }
458
- return *this ;
459
- }
460
454
461
455
// Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning
462
456
WiFiClientSecure::WiFiClientSecure (ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen)
463
457
{
458
+ // We've been given the client context from the available() call
464
459
_client = client;
465
- if (_ssl) {
466
- _ssl->unref ();
467
- _ssl = nullptr ;
468
- }
460
+ _client->ref ();
469
461
470
- _ssl = new SSLContext;
471
- _ssl->ref ();
462
+ // Make the "_ssl" SSLContext, in the constructor there should be none yet
463
+ SSLContext *_new_ssl = new SSLContext;
464
+ std::shared_ptr<SSLContext> _new_ssl_shared (_new_ssl, _delete_shared_SSLContext);
465
+ _ssl = _new_ssl_shared;
472
466
467
+ _ssl = std::make_shared<SSLContext>();
473
468
if (usePMEM) {
474
469
// When using PMEM based certs, allocate stack and copy from flash to DRAM, call SSL functions to avoid
475
470
// heap fragmentation that would happen w/malloc()
@@ -490,7 +485,6 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui
490
485
_ssl->loadServerX509Cert (cert, certLen);
491
486
}
492
487
}
493
- _client->ref ();
494
488
_ssl->connectServer (client);
495
489
}
496
490
@@ -523,14 +517,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port)
523
517
int WiFiClientSecure::_connectSSL (const char * hostName)
524
518
{
525
519
if (!_ssl) {
526
- _ssl = new SSLContext;
527
- _ssl->ref ();
520
+ _ssl = std::make_shared<SSLContext>();
528
521
}
529
522
_ssl->connect (_client, hostName, _timeout);
530
523
531
524
auto status = ssl_handshake_status (*_ssl);
532
525
if (status != SSL_OK) {
533
- _ssl->unref ();
534
526
_ssl = nullptr ;
535
527
return 0 ;
536
528
}
@@ -550,7 +542,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
550
542
}
551
543
552
544
if (rc != SSL_CLOSE_NOTIFY) {
553
- _ssl->unref ();
554
545
_ssl = nullptr ;
555
546
}
556
547
@@ -653,8 +644,6 @@ void WiFiClientSecure::stop()
653
644
{
654
645
if (_ssl) {
655
646
_ssl->stop ();
656
- _ssl->unref ();
657
- _ssl = nullptr ;
658
647
}
659
648
WiFiClient::stop ();
660
649
}
@@ -772,8 +761,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name)
772
761
void WiFiClientSecure::_initSSLContext ()
773
762
{
774
763
if (!_ssl) {
775
- _ssl = new SSLContext;
776
- _ssl->ref ();
764
+ _ssl = std::make_shared<SSLContext>();
777
765
}
778
766
}
779
767
0 commit comments