Skip to content

Fix mem leak in SSL server, allow for concurrent client and server connections w/o interference #4305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 8, 2018
183 changes: 96 additions & 87 deletions libraries/ESP8266WiFi/src/WiFiClientSecure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,37 +74,47 @@ typedef std::list<BufferItem> BufferList;
class SSLContext
{
public:
SSLContext()
SSLContext(bool isServer = false)
{
if (_ssl_ctx_refcnt == 0) {
_ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
_isServer = isServer;
if (!_isServer) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These if(server) else client statements hint at separating SSLContext into SSLServerContext and SSLClientContext. However, I suspect that doing that would require even more changes, and right now we need to improve stability. So let's handle that later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be straightforward to make the SSLServerContext from the unmodified SSLContext, because there's only 3 or spots where you'll see this check or a different codepath dependent on this. That said, I'm not sure it helps understanding by breaking it into another subclass...frankly most of my time was spent going through the inherited classes to see WTH was going on instead of just being able to peek at one file/class and see the inner workings. We can revisit if axtls sticks around. BearSSL doesn't need this kind of abstraction at all, so the code there is simpler.

As for stability, this code spent 5 hours being beaten by a while(true) mosquiit_pub sending it data, it mqtt publishing to a SSL mosquitto server using a Client cert every 5 seconds, and a while(true) wget https://esp8266 loop to make it serve web pages. Not even a hiccup noted...

Wish I could package this setup as a test, but it needs a mosquitto server, fixed IPs (for the certs), and a couple monitors to make sure the mqtt mesages and the wget continue working. I wouldn't know where to start...

if (_ssl_client_ctx_refcnt == 0) {
_ssl_client_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
}
++_ssl_client_ctx_refcnt;
} else {
if (_ssl_svr_ctx_refcnt == 0) {
_ssl_svr_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0);
}
++_ssl_svr_ctx_refcnt;
}
++_ssl_ctx_refcnt;
}

~SSLContext()
{
if (_ssl) {
ssl_free(_ssl);
_ssl = nullptr;
if (io_ctx) {
io_ctx->unref();
io_ctx = nullptr;
}

--_ssl_ctx_refcnt;
if (_ssl_ctx_refcnt == 0) {
ssl_ctx_free(_ssl_ctx);
_ssl = nullptr;
if (!_isServer) {
--_ssl_client_ctx_refcnt;
if (_ssl_client_ctx_refcnt == 0) {
ssl_ctx_free(_ssl_client_ctx);
_ssl_client_ctx = nullptr;
}
} else {
--_ssl_svr_ctx_refcnt;
if (_ssl_svr_ctx_refcnt == 0) {
ssl_ctx_free(_ssl_svr_ctx);
_ssl_svr_ctx = nullptr;
}
}
}

void ref()
{
++_refcnt;
}

void unref()
static void _delete_shared_SSL(SSL *_to_del)
{
if (--_refcnt == 0) {
delete this;
}
ssl_free(_to_del);
}

void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms)
Expand All @@ -116,50 +126,67 @@ class SSLContext
ssl_free will want to send a close notify alert, but the old TCP connection
is already gone at this point, so reset io_ctx. */
io_ctx = nullptr;
ssl_free(_ssl);
_ssl = nullptr;
_available = 0;
_read_ptr = nullptr;
}
io_ctx = ctx;
_ssl = ssl_client_new(_ssl_ctx, reinterpret_cast<int>(this), nullptr, 0, ext);
ctx->ref();

// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
SSL *_new_ssl = ssl_client_new(_ssl_client_ctx, reinterpret_cast<int>(this), nullptr, 0, ext);
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL);
_ssl = _new_ssl_shared;

uint32_t t = millis();

while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) {
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) {
uint8_t* data;
int rc = ssl_read(_ssl, &data);
int rc = ssl_read(_ssl.get(), &data);
if (rc < SSL_OK) {
ssl_display_error(rc);
break;
}
}
}

void connectServer(ClientContext *ctx) {
void connectServer(ClientContext *ctx, uint32_t timeout_ms)
{
io_ctx = ctx;
_ssl = ssl_server_new(_ssl_ctx, reinterpret_cast<int>(this));
_isServer = true;
ctx->ref();

// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
SSL *_new_ssl = ssl_server_new(_ssl_svr_ctx, reinterpret_cast<int>(this));
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL);
_ssl = _new_ssl_shared;

uint32_t timeout_ms = 5000;
uint32_t t = millis();

while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) {
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) {
uint8_t* data;
int rc = ssl_read(_ssl, &data);
int rc = ssl_read(_ssl.get(), &data);
if (rc < SSL_OK) {
ssl_display_error(rc);
break;
}
}
}

void stop()
{
if (io_ctx) {
io_ctx->unref();
}
io_ctx = nullptr;
}

bool connected()
{
if (_isServer) return _ssl != nullptr;
else return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK;
if (_isServer) {
return _ssl != nullptr;
} else {
return _ssl != nullptr && ssl_handshake_status(_ssl.get()) == SSL_OK;
}
}

int read(uint8_t* dst, size_t size)
Expand Down Expand Up @@ -289,10 +316,9 @@ class SSLContext
return loadObject(type, buf.get(), size);
}


bool loadObject(int type, const uint8_t* data, size_t size)
{
int rc = ssl_obj_memory_load(_ssl_ctx, type, data, static_cast<int>(size), nullptr);
int rc = ssl_obj_memory_load(_isServer?_ssl_svr_ctx:_ssl_client_ctx, type, data, static_cast<int>(size), nullptr);
if (rc != SSL_OK) {
DEBUGV("loadObject: ssl_obj_memory_load returned %d\n", rc);
return false;
Expand All @@ -302,7 +328,7 @@ class SSLContext

bool verifyCert()
{
int rc = ssl_verify_cert(_ssl);
int rc = ssl_verify_cert(_ssl.get());
if (_allowSelfSignedCerts && rc == SSL_X509_ERROR(X509_VFY_ERROR_SELF_SIGNED)) {
DEBUGV("Allowing self-signed certificate\n");
return true;
Expand All @@ -321,12 +347,16 @@ class SSLContext

operator SSL*()
{
return _ssl;
return _ssl.get();
}

static ClientContext* getIOContext(int fd)
{
return reinterpret_cast<SSLContext*>(fd)->io_ctx;
if (fd) {
SSLContext *thisSSL = reinterpret_cast<SSLContext*>(fd);
return thisSSL->io_ctx;
}
return nullptr;
}

protected:
Expand All @@ -339,10 +369,9 @@ class SSLContext
optimistic_yield(100);

uint8_t* data;
int rc = ssl_read(_ssl, &data);
int rc = ssl_read(_ssl.get(), &data);
if (rc <= 0) {
if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) {
ssl_free(_ssl);
_ssl = nullptr;
}
return 0;
Expand All @@ -359,7 +388,7 @@ class SSLContext
return 0;
}

int rc = ssl_write(_ssl, src, size);
int rc = ssl_write(_ssl.get(), src, size);
if (rc >= 0) {
return rc;
}
Expand Down Expand Up @@ -404,19 +433,22 @@ class SSLContext
}

bool _isServer = false;
static SSL_CTX* _ssl_ctx;
static int _ssl_ctx_refcnt;
SSL* _ssl = nullptr;
int _refcnt = 0;
static SSL_CTX* _ssl_client_ctx;
static int _ssl_client_ctx_refcnt;
static SSL_CTX* _ssl_svr_ctx;
static int _ssl_svr_ctx_refcnt;
std::shared_ptr<SSL> _ssl = nullptr;
const uint8_t* _read_ptr = nullptr;
size_t _available = 0;
BufferList _writeBuffers;
bool _allowSelfSignedCerts = false;
ClientContext* io_ctx = nullptr;
};

SSL_CTX* SSLContext::_ssl_ctx = nullptr;
int SSLContext::_ssl_ctx_refcnt = 0;
SSL_CTX* SSLContext::_ssl_client_ctx = nullptr;
int SSLContext::_ssl_client_ctx_refcnt = 0;
SSL_CTX* SSLContext::_ssl_svr_ctx = nullptr;
int SSLContext::_ssl_svr_ctx_refcnt = 0;

WiFiClientSecure::WiFiClientSecure()
{
Expand All @@ -426,41 +458,25 @@ WiFiClientSecure::WiFiClientSecure()

WiFiClientSecure::~WiFiClientSecure()
{
if (_ssl) {
_ssl->unref();
}
}

WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other)
: WiFiClient(static_cast<const WiFiClient&>(other))
{
_ssl = other._ssl;
if (_ssl) {
_ssl->ref();
}
}

WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs)
{
(WiFiClient&) *this = rhs;
_ssl = rhs._ssl;
if (_ssl) {
_ssl->ref();
}
return *this;
_ssl = nullptr;
}

// Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen)
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM,
const uint8_t *rsakey, int rsakeyLen,
const uint8_t *cert, int certLen)
{
// TLS handshake may take more than the 5 second default timeout
_timeout = 15000;

// We've been given the client context from the available() call
_client = client;
if (_ssl) {
_ssl->unref();
_ssl = nullptr;
}
_client->ref();

_ssl = new SSLContext;
_ssl->ref();
// Make the "_ssl" SSLContext, in the constructor there should be none yet
SSLContext *_new_ssl = new SSLContext(true);
std::shared_ptr<SSLContext> _new_ssl_shared(_new_ssl);
_ssl = _new_ssl_shared;

if (usePMEM) {
if (rsakey && rsakeyLen) {
Expand All @@ -477,8 +493,7 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui
_ssl->loadObject(SSL_OBJ_X509_CERT, cert, certLen);
}
}
_client->ref();
_ssl->connectServer(client);
_ssl->connectServer(client, _timeout);
}

int WiFiClientSecure::connect(IPAddress ip, uint16_t port)
Expand Down Expand Up @@ -510,14 +525,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port)
int WiFiClientSecure::_connectSSL(const char* hostName)
{
if (!_ssl) {
_ssl = new SSLContext;
_ssl->ref();
_ssl = std::make_shared<SSLContext>();
}
_ssl->connect(_client, hostName, _timeout);

auto status = ssl_handshake_status(*_ssl);
if (status != SSL_OK) {
_ssl->unref();
_ssl = nullptr;
return 0;
}
Expand All @@ -537,7 +550,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
}

if (rc != SSL_CLOSE_NOTIFY) {
_ssl->unref();
_ssl = nullptr;
}

Expand Down Expand Up @@ -640,8 +652,6 @@ void WiFiClientSecure::stop()
{
if (_ssl) {
_ssl->stop();
_ssl->unref();
_ssl = nullptr;
}
WiFiClient::stop();
}
Expand Down Expand Up @@ -723,9 +733,9 @@ bool WiFiClientSecure::_verifyDN(const char* domain_name)
String domain_name_str(domain_name);
domain_name_str.toLowerCase();

const char* san = NULL;
const char* san = nullptr;
int i = 0;
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) {
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != nullptr) {
String san_str(san);
san_str.toLowerCase();
if (matchName(san_str, domain_name_str)) {
Expand Down Expand Up @@ -759,8 +769,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name)
void WiFiClientSecure::_initSSLContext()
{
if (!_ssl) {
_ssl = new SSLContext;
_ssl->ref();
_ssl = std::make_shared<SSLContext>();
}
}

Expand Down
4 changes: 1 addition & 3 deletions libraries/ESP8266WiFi/src/WiFiClientSecure.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class WiFiClientSecure : public WiFiClient {
public:
WiFiClientSecure();
~WiFiClientSecure() override;
WiFiClientSecure(const WiFiClientSecure&);
WiFiClientSecure& operator=(const WiFiClientSecure&);

int connect(IPAddress ip, uint16_t port) override;
int connect(const String host, uint16_t port) override;
Expand Down Expand Up @@ -91,7 +89,7 @@ friend class WiFiServerSecure; // Needs access to custom constructor below
int _connectSSL(const char* hostName);
bool _verifyDN(const char* name);

SSLContext* _ssl = nullptr;
std::shared_ptr<SSLContext> _ssl = nullptr;
};

#endif //wificlientsecure_h