From 8c6ea611d044a68fad3ab188e7f8576535a2808d Mon Sep 17 00:00:00 2001 From: "Earle F. Philhower, III" Date: Mon, 5 Feb 2018 19:44:47 -0800 Subject: [PATCH 1/3] Fix leak on multiple SSL server connections Fixes #4302 The refcnt setup for the WiFiClientSecure's SSLContext and ClientContext had issues in certain conditions, causing a massive memory leak on each SSL server connection. Depending on the state of the machine, after two or three connections it would OOM and crash. This patch replaces most of the refcnt operations with C++11 shared_ptr operations, cleaning up the code substantially and removing the leakage. Also fixes a race condition where ClientContext was free'd before the SSLContext was stopped/shutdown. When the SSLContext tried to do ssl_free, axtls would attempt to send out the real SSL disconnect bits over the wire, however by this time the ClientContext is invalid and it would fault. --- .../ESP8266WiFi/src/WiFiClientSecure.cpp | 116 ++++++++---------- libraries/ESP8266WiFi/src/WiFiClientSecure.h | 4 +- 2 files changed, 53 insertions(+), 67 deletions(-) diff --git a/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp b/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp index 4876a4710c..4e774df16e 100644 --- a/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp +++ b/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp @@ -84,27 +84,20 @@ class SSLContext ~SSLContext() { - if (_ssl) { - ssl_free(_ssl); - _ssl = nullptr; + if (io_ctx) { + io_ctx->unref(); + io_ctx = NULL; } - + _ssl = nullptr; --_ssl_ctx_refcnt; if (_ssl_ctx_refcnt == 0) { ssl_ctx_free(_ssl_ctx); } } - void ref() + static void _delete_shared_SSL(SSL *_to_del) { - ++_refcnt; - } - - void unref() - { - if (--_refcnt == 0) { - delete this; - } + ssl_free(_to_del); } void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms) @@ -116,17 +109,23 @@ class SSLContext ssl_free will want to send a close notify alert, but the old TCP connection is already gone at this point, so reset io_ctx. */ io_ctx = nullptr; - ssl_free(_ssl); + _ssl = nullptr; _available = 0; _read_ptr = nullptr; } io_ctx = ctx; - _ssl = ssl_client_new(_ssl_ctx, reinterpret_cast(this), nullptr, 0, ext); + ctx->ref(); + + // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free + SSL *_new_ssl = ssl_client_new(_ssl_ctx, reinterpret_cast(this), nullptr, 0, ext); + std::shared_ptr _new_ssl_shared(_new_ssl, _delete_shared_SSL); + _ssl = _new_ssl_shared; + uint32_t t = millis(); - while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) { + while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) { uint8_t* data; - int rc = ssl_read(_ssl, &data); + int rc = ssl_read(_ssl.get(), &data); if (rc < SSL_OK) { ssl_display_error(rc); break; @@ -136,16 +135,23 @@ class SSLContext void connectServer(ClientContext *ctx) { io_ctx = ctx; - _ssl = ssl_server_new(_ssl_ctx, reinterpret_cast(this)); + ctx->ref(); + + // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free + SSL *_new_ssl = ssl_server_new(_ssl_ctx, reinterpret_cast(this)); + std::shared_ptr _new_ssl_shared(_new_ssl, _delete_shared_SSL); + _ssl = _new_ssl_shared; + _isServer = true; uint32_t timeout_ms = 5000; uint32_t t = millis(); - while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) { + while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) { uint8_t* data; - int rc = ssl_read(_ssl, &data); + int rc = ssl_read(_ssl.get(), &data); if (rc < SSL_OK) { + ssl_display_error(rc); break; } } @@ -153,13 +159,16 @@ class SSLContext void stop() { + if (io_ctx) { + io_ctx->unref(); + } io_ctx = nullptr; } bool connected() { if (_isServer) return _ssl != nullptr; - else return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK; + else return _ssl != nullptr && ssl_handshake_status(_ssl.get()) == SSL_OK; } int read(uint8_t* dst, size_t size) @@ -302,7 +311,7 @@ class SSLContext bool verifyCert() { - int rc = ssl_verify_cert(_ssl); + int rc = ssl_verify_cert(_ssl.get()); if (_allowSelfSignedCerts && rc == SSL_X509_ERROR(X509_VFY_ERROR_SELF_SIGNED)) { DEBUGV("Allowing self-signed certificate\n"); return true; @@ -321,12 +330,14 @@ class SSLContext operator SSL*() { - return _ssl; + return _ssl.get(); } static ClientContext* getIOContext(int fd) { - return reinterpret_cast(fd)->io_ctx; + if (!fd) return NULL; + SSLContext *thisSSL = reinterpret_cast(fd); + return thisSSL->io_ctx; } int loadServerX509Cert(const uint8_t *cert, int len) { @@ -347,10 +358,9 @@ class SSLContext optimistic_yield(100); uint8_t* data; - int rc = ssl_read(_ssl, &data); + int rc = ssl_read(_ssl.get(), &data); if (rc <= 0) { if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) { - ssl_free(_ssl); _ssl = nullptr; } return 0; @@ -367,7 +377,7 @@ class SSLContext return 0; } - int rc = ssl_write(_ssl, src, size); + int rc = ssl_write(_ssl.get(), src, size); if (rc >= 0) { return rc; } @@ -410,12 +420,11 @@ class SSLContext { return !_writeBuffers.empty(); } - +public: bool _isServer = false; static SSL_CTX* _ssl_ctx; static int _ssl_ctx_refcnt; - SSL* _ssl = nullptr; - int _refcnt = 0; + std::shared_ptr _ssl = nullptr; const uint8_t* _read_ptr = nullptr; size_t _available = 0; BufferList _writeBuffers; @@ -434,42 +443,28 @@ WiFiClientSecure::WiFiClientSecure() WiFiClientSecure::~WiFiClientSecure() { - if (_ssl) { - _ssl->unref(); - } + _ssl = nullptr; } -WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other) - : WiFiClient(static_cast(other)) -{ - _ssl = other._ssl; - if (_ssl) { - _ssl->ref(); + static void _delete_shared_SSLContext(SSLContext *_to_del) + { + delete _to_del; } -} -WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs) -{ - (WiFiClient&) *this = rhs; - _ssl = rhs._ssl; - if (_ssl) { - _ssl->ref(); - } - return *this; -} // Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen) { + // We've been given the client context from the available() call _client = client; - if (_ssl) { - _ssl->unref(); - _ssl = nullptr; - } + _client->ref(); - _ssl = new SSLContext; - _ssl->ref(); + // Make the "_ssl" SSLContext, in the constructor there should be none yet + SSLContext *_new_ssl = new SSLContext; + std::shared_ptr _new_ssl_shared(_new_ssl, _delete_shared_SSLContext); + _ssl = _new_ssl_shared; + _ssl = std::make_shared(); if (usePMEM) { // When using PMEM based certs, allocate stack and copy from flash to DRAM, call SSL functions to avoid // heap fragmentation that would happen w/malloc() @@ -490,7 +485,6 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui _ssl->loadServerX509Cert(cert, certLen); } } - _client->ref(); _ssl->connectServer(client); } @@ -523,14 +517,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port) int WiFiClientSecure::_connectSSL(const char* hostName) { if (!_ssl) { - _ssl = new SSLContext; - _ssl->ref(); + _ssl = std::make_shared(); } _ssl->connect(_client, hostName, _timeout); auto status = ssl_handshake_status(*_ssl); if (status != SSL_OK) { - _ssl->unref(); _ssl = nullptr; return 0; } @@ -550,7 +542,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) } if (rc != SSL_CLOSE_NOTIFY) { - _ssl->unref(); _ssl = nullptr; } @@ -653,8 +644,6 @@ void WiFiClientSecure::stop() { if (_ssl) { _ssl->stop(); - _ssl->unref(); - _ssl = nullptr; } WiFiClient::stop(); } @@ -772,8 +761,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name) void WiFiClientSecure::_initSSLContext() { if (!_ssl) { - _ssl = new SSLContext; - _ssl->ref(); + _ssl = std::make_shared(); } } diff --git a/libraries/ESP8266WiFi/src/WiFiClientSecure.h b/libraries/ESP8266WiFi/src/WiFiClientSecure.h index 9b7cf8df10..73ec587f1d 100644 --- a/libraries/ESP8266WiFi/src/WiFiClientSecure.h +++ b/libraries/ESP8266WiFi/src/WiFiClientSecure.h @@ -32,8 +32,6 @@ class WiFiClientSecure : public WiFiClient { public: WiFiClientSecure(); ~WiFiClientSecure() override; - WiFiClientSecure(const WiFiClientSecure&); - WiFiClientSecure& operator=(const WiFiClientSecure&); int connect(IPAddress ip, uint16_t port) override; int connect(const String host, uint16_t port) override; @@ -91,7 +89,7 @@ friend class WiFiServerSecure; // Needs access to custom constructor below int _connectSSL(const char* hostName); bool _verifyDN(const char* name); - SSLContext* _ssl = nullptr; + std::shared_ptr _ssl = nullptr; }; #endif //wificlientsecure_h From 5e438aa631170646bdbeddaf890720b4c33821f7 Mon Sep 17 00:00:00 2001 From: "Earle F. Philhower, III" Date: Tue, 6 Feb 2018 12:03:08 -0800 Subject: [PATCH 2/3] Separate client and server SSL_CTX, support both Refactor to use a separate client SSL_CTX and server SSL_CTX. This allows for separate certificates to be installed on each, and means that you can now have both a *single* client and a *single* server running in parallel at the same time, as they'll have separate memory areas. Tested using mqtt_esp8266 SSL client with a client certificate and a WebServerSecure with its own custom certificate and key in parallel. --- .../ESP8266WiFi/src/WiFiClientSecure.cpp | 104 +++++++++--------- 1 file changed, 55 insertions(+), 49 deletions(-) diff --git a/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp b/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp index 4e774df16e..c5b2786877 100644 --- a/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp +++ b/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp @@ -74,24 +74,41 @@ typedef std::list BufferList; class SSLContext { public: - SSLContext() + SSLContext(bool isServer = false) { - if (_ssl_ctx_refcnt == 0) { - _ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0); + _isServer = isServer; + if (!_isServer) { + if (_ssl_client_ctx_refcnt == 0) { + _ssl_client_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0); + } + ++_ssl_client_ctx_refcnt; + } else { + if (_ssl_svr_ctx_refcnt == 0) { + _ssl_svr_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0); + } + ++_ssl_svr_ctx_refcnt; } - ++_ssl_ctx_refcnt; } ~SSLContext() { if (io_ctx) { io_ctx->unref(); - io_ctx = NULL; + io_ctx = nullptr; } _ssl = nullptr; - --_ssl_ctx_refcnt; - if (_ssl_ctx_refcnt == 0) { - ssl_ctx_free(_ssl_ctx); + if (!_isServer) { + --_ssl_client_ctx_refcnt; + if (_ssl_client_ctx_refcnt == 0) { + ssl_ctx_free(_ssl_client_ctx); + _ssl_client_ctx = nullptr; + } + } else { + --_ssl_svr_ctx_refcnt; + if (_ssl_svr_ctx_refcnt == 0) { + ssl_ctx_free(_ssl_svr_ctx); + _ssl_svr_ctx = nullptr; + } } } @@ -117,7 +134,7 @@ class SSLContext ctx->ref(); // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free - SSL *_new_ssl = ssl_client_new(_ssl_ctx, reinterpret_cast(this), nullptr, 0, ext); + SSL *_new_ssl = ssl_client_new(_ssl_client_ctx, reinterpret_cast(this), nullptr, 0, ext); std::shared_ptr _new_ssl_shared(_new_ssl, _delete_shared_SSL); _ssl = _new_ssl_shared; @@ -133,18 +150,16 @@ class SSLContext } } - void connectServer(ClientContext *ctx) { + void connectServer(ClientContext *ctx, uint32_t timeout_ms) + { io_ctx = ctx; ctx->ref(); // Wrap the new SSL with a smart pointer, custom deleter to call ssl_free - SSL *_new_ssl = ssl_server_new(_ssl_ctx, reinterpret_cast(this)); + SSL *_new_ssl = ssl_server_new(_ssl_svr_ctx, reinterpret_cast(this)); std::shared_ptr _new_ssl_shared(_new_ssl, _delete_shared_SSL); _ssl = _new_ssl_shared; - _isServer = true; - - uint32_t timeout_ms = 5000; uint32_t t = millis(); while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) { @@ -301,7 +316,7 @@ class SSLContext bool loadObject(int type, const uint8_t* data, size_t size) { - int rc = ssl_obj_memory_load(_ssl_ctx, type, data, static_cast(size), nullptr); + int rc = ssl_obj_memory_load(_isServer?_ssl_svr_ctx:_ssl_client_ctx, type, data, static_cast(size), nullptr); if (rc != SSL_OK) { DEBUGV("loadObject: ssl_obj_memory_load returned %d\n", rc); return false; @@ -335,19 +350,11 @@ class SSLContext static ClientContext* getIOContext(int fd) { - if (!fd) return NULL; + if (!fd) return nullptr; SSLContext *thisSSL = reinterpret_cast(fd); return thisSSL->io_ctx; } - int loadServerX509Cert(const uint8_t *cert, int len) { - return ssl_obj_memory_load(SSLContext::_ssl_ctx, SSL_OBJ_X509_CERT, cert, len, NULL); - } - - int loadServerRSAKey(const uint8_t *rsakey, int len) { - return ssl_obj_memory_load(SSLContext::_ssl_ctx, SSL_OBJ_RSA_KEY, rsakey, len, NULL); - } - protected: int _readAll() { @@ -420,10 +427,12 @@ class SSLContext { return !_writeBuffers.empty(); } -public: + bool _isServer = false; - static SSL_CTX* _ssl_ctx; - static int _ssl_ctx_refcnt; + static SSL_CTX* _ssl_client_ctx; + static int _ssl_client_ctx_refcnt; + static SSL_CTX* _ssl_svr_ctx; + static int _ssl_svr_ctx_refcnt; std::shared_ptr _ssl = nullptr; const uint8_t* _read_ptr = nullptr; size_t _available = 0; @@ -432,8 +441,10 @@ class SSLContext ClientContext* io_ctx = nullptr; }; -SSL_CTX* SSLContext::_ssl_ctx = nullptr; -int SSLContext::_ssl_ctx_refcnt = 0; +SSL_CTX* SSLContext::_ssl_client_ctx = nullptr; +int SSLContext::_ssl_client_ctx_refcnt = 0; +SSL_CTX* SSLContext::_ssl_svr_ctx = nullptr; +int SSLContext::_ssl_svr_ctx_refcnt = 0; WiFiClientSecure::WiFiClientSecure() { @@ -446,46 +457,41 @@ WiFiClientSecure::~WiFiClientSecure() _ssl = nullptr; } - static void _delete_shared_SSLContext(SSLContext *_to_del) - { - delete _to_del; - } - - // Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning -WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen) +WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, + const uint8_t *rsakey, int rsakeyLen, + const uint8_t *cert, int certLen) { + // TLS handshake may take more than the 5 second default timeout + _timeout = 15000; + // We've been given the client context from the available() call _client = client; _client->ref(); // Make the "_ssl" SSLContext, in the constructor there should be none yet - SSLContext *_new_ssl = new SSLContext; - std::shared_ptr _new_ssl_shared(_new_ssl, _delete_shared_SSLContext); + SSLContext *_new_ssl = new SSLContext(true); + std::shared_ptr _new_ssl_shared(_new_ssl); _ssl = _new_ssl_shared; - _ssl = std::make_shared(); if (usePMEM) { // When using PMEM based certs, allocate stack and copy from flash to DRAM, call SSL functions to avoid // heap fragmentation that would happen w/malloc() - uint8_t *stackData = (uint8_t*)alloca(max(certLen, rsakeyLen)); if (rsakey && rsakeyLen) { - memcpy_P(stackData, rsakey, rsakeyLen); - _ssl->loadServerRSAKey(stackData, rsakeyLen); + _ssl->loadObject_P(SSL_OBJ_RSA_KEY, rsakey, rsakeyLen); } if (cert && certLen) { - memcpy_P(stackData, cert, certLen); - _ssl->loadServerX509Cert(stackData, certLen); + _ssl->loadObject_P(SSL_OBJ_X509_CERT, cert, certLen); } } else { if (rsakey && rsakeyLen) { - _ssl->loadServerRSAKey(rsakey, rsakeyLen); + _ssl->loadObject(SSL_OBJ_RSA_KEY, rsakey, rsakeyLen); } if (cert && certLen) { - _ssl->loadServerX509Cert(cert, certLen); + _ssl->loadObject(SSL_OBJ_X509_CERT, cert, certLen); } } - _ssl->connectServer(client); + _ssl->connectServer(client, _timeout); } int WiFiClientSecure::connect(IPAddress ip, uint16_t port) @@ -725,9 +731,9 @@ bool WiFiClientSecure::_verifyDN(const char* domain_name) String domain_name_str(domain_name); domain_name_str.toLowerCase(); - const char* san = NULL; + const char* san = nullptr; int i = 0; - while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) { + while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != nullptr) { String san_str(san); san_str.toLowerCase(); if (matchName(san_str, domain_name_str)) { From a90a884428a5117c49593f4b9ee73bd9fd5af5ca Mon Sep 17 00:00:00 2001 From: "Earle F. Philhower, III" Date: Tue, 6 Feb 2018 18:58:15 -0800 Subject: [PATCH 3/3] Add brackets around a couple if-else clauses --- libraries/ESP8266WiFi/src/WiFiClientSecure.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp b/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp index 67bfc53a6e..8a7d71e99f 100644 --- a/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp +++ b/libraries/ESP8266WiFi/src/WiFiClientSecure.cpp @@ -182,8 +182,11 @@ class SSLContext bool connected() { - if (_isServer) return _ssl != nullptr; - else return _ssl != nullptr && ssl_handshake_status(_ssl.get()) == SSL_OK; + if (_isServer) { + return _ssl != nullptr; + } else { + return _ssl != nullptr && ssl_handshake_status(_ssl.get()) == SSL_OK; + } } int read(uint8_t* dst, size_t size) @@ -313,7 +316,6 @@ class SSLContext return loadObject(type, buf.get(), size); } - bool loadObject(int type, const uint8_t* data, size_t size) { int rc = ssl_obj_memory_load(_isServer?_ssl_svr_ctx:_ssl_client_ctx, type, data, static_cast(size), nullptr); @@ -350,9 +352,11 @@ class SSLContext static ClientContext* getIOContext(int fd) { - if (!fd) return nullptr; - SSLContext *thisSSL = reinterpret_cast(fd); - return thisSSL->io_ctx; + if (fd) { + SSLContext *thisSSL = reinterpret_cast(fd); + return thisSSL->io_ctx; + } + return nullptr; } protected: