Unverified Commit 48072ee0 authored by Dirk-Willem van Gulik's avatar Dirk-Willem van Gulik Committed by GitHub

Support for STARTLS/STARTSSL in-band transport upgrades/renegotation (#9100)

* Split start_ssl_client into two phases; to allow the implementation of protocols that use some sort of in-band STARTTLS or STARTSSL signal to upgrade a plaint text connection to SSL/TLS. Examples of these protocols are XMPP, SMTP and various database TCP connections.

* Remove removed setTimeout that was accidentally included (was removed for IDF >=5), bring timeout inline with the other timeouts (ints), fix cert/key checks to look if there is actually something there (all issues caught by the CI/CD on windows-latest

* Quell compiler warning; use the right timeout

* Newer versions of MBEDTLS make the client key struct private (and most of the x509 struct too), so absent of a non-null pointer we cannot check wether it is populated. Solve this by looking at the version (as 0 is not a valid x509 version).

* Fix another \(rightfull\) compiler warning iwth the version pointer

* Quell CI/CD runs on non-WiFi supporting hardare

* Quell CI/CD runs on non-WiFi supporting hardare

* Fix typo in directory name

* Apply suggestions from code review
Co-authored-by: default avatarJan Procházka <90197375+P-R-O-C-H-Y@users.noreply.github.com>

* Rename Files

* Remove leftover file

---------
Co-authored-by: default avatarMe No Dev <me-no-dev@users.noreply.github.com>
Co-authored-by: default avatarJan Procházka <90197375+P-R-O-C-H-Y@users.noreply.github.com>
Co-authored-by: default avatarLucas Saavedra Vaz <32426024+lucasssvaz@users.noreply.github.com>
parent 13fac087
/* STARTSSL example
Inline upgrading from a clear-text connection to an SSL/TLS connection.
Some protocols such as SMTP, XMPP, Mysql, Postgress and others allow, or require,
that you start the connection without encryption; and then send a command to switch
over to encryption.
E.g. a typical SMTP submission would entail a dialogue such as this:
1. client connects to server in the clear
2. server says hello
3. client sents a EHLO
4. server tells the client that it supports SSL/TLS
5. client sends a 'STARTTLS' to make use of this faciltiy
6. client/server negiotiate a SSL or TLS connection.
7. client sends another EHLO
8. server now tells the client what (else) is supported; such as additional authentication options.
... conversation continues encrypted.
This can be enabled in WiFiClientSecure by telling it to start in plaintext:
client.setPlainStart();
and client is than a plain, TCP, connection (just as WiFiClient would be); until the client calls
the method:
client.startTLS(); // returns zero on error; non zero on success.
After which things switch to TLS/SSL.
*/
#include <WiFiClientSecure.h>
#ifndef WIFI_NETWORK
#define WIFI_NETWORK "YOUR Wifi SSID"
#endif
#ifndef WIFI_PASSWD
#define WIFI_PASSWD "your-secret-password"
#endif
#ifndef SMTP_HOST
#define SMTP_HOST "smtp.gmail.com"
#endif
#ifndef SMTP_PORT
#define SMTP_PORT (587) // Standard (plaintext) submission port
#endif
const char* ssid = WIFI_NETWORK; // your network SSID (name of wifi network)
const char* password = WIFI_PASSWD; // your network password
const char* server = SMTP_HOST; // Server URL
const int submission_port = SMTP_PORT; // submission port.
WiFiClientSecure client;
static bool readAllSMTPLines();
void setup() {
int ret;
//Initialize serial and wait for port to open:
Serial.begin(115200);
delay(100);
Serial.print("Attempting to connect to SSID: ");
Serial.print(ssid);
WiFi.begin(ssid, password);
// attempt to connect to Wifi network:
while (WiFi.status() != WL_CONNECTED) {
Serial.print(".");
// wait 1 second for re-trying
delay(1000);
}
Serial.print("Connected to ");
Serial.println(ssid);
Serial.printf("\nStarting connection to server: %s:%d\n", server, submission_port);
// skip verification for this demo. In production one should at the very least
// enable TOFU; or ideally hardcode a (CA) certificate that is trusted.
client.setInsecure();
// Enable a plain-test start.
client.setPlainStart();
if (!client.connect(server, SMTP_PORT)) {
Serial.println("Connection failed!");
return;
};
Serial.println("Connected to server (in the clear, in plaintest)");
if (!readAllSMTPLines()) goto err;
Serial.println("Sending : EHLO\t\tin the clear");
client.print("EHLO there\r\n");
if (!readAllSMTPLines()) goto err;
Serial.println("Sending : STARTTLS\t\tin the clear");
client.print("STARTTLS\r\n");
if (!readAllSMTPLines()) goto err;
Serial.println("Upgrading connection to TLS");
if ((ret=client.startTLS()) <= 0) {
Serial.printf("Upgrade connection failed: err %d\n", ret);
goto err;
}
Serial.println("Sending : EHLO again\t\tover the now encrypted connection");
client.print("EHLO again\r\n");
if (!readAllSMTPLines()) goto err;
// normally, as this point - we'd be authenticating and then be submitting
// an email. This has been left out of this example.
Serial.println("Sending : QUIT\t\t\tover the now encrypted connection");
client.print("QUIT\r\n");
if (!readAllSMTPLines()) goto err;
Serial.println("Completed OK\n");
err:
Serial.println("Closing connection");
client.stop();
}
// SMTP command repsponse start with three digits and a space;
// or, for continuation, with three digits and a '-'.
static bool readAllSMTPLines() {
String s = "";
int i;
// blocking read; we cannot rely on a timeout
// of a WiFiClientSecure read; as it is non
// blocking.
const unsigned long timeout = 15 * 1000;
unsigned long start = millis(); // the timeout is for the entire CMD block response; not per character/line.
while (1) {
while ((i = client.available()) == 0 && millis() - start < timeout) {
/* .. wait */
};
if (i == 0) {
Serial.println("Timeout reading SMTP response");
return false;
};
if (i < 0)
break;
i = client.read();
if (i < 0)
break;
if (i > 31 && i < 128) s += (char)i;
if (i == 0x0A) {
Serial.print("Receiving: ");
Serial.println(s);
if (s.charAt(3) == ' ')
return true;
s = "";
}
}
Serial.printf("Error reading SMTP command response line: %d\n", i);
return false;
}
void loop() {
// do nothing
}
......@@ -140,9 +140,16 @@ int WiFiClientSecure::connect(const char *host, uint16_t port, const char *CA_ce
int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *host, const char *CA_cert, const char *cert, const char *private_key)
{
int ret = start_ssl_client(sslclient, ip, port, host, _timeout, CA_cert, _use_ca_bundle, cert, private_key, NULL, NULL, _use_insecure, _alpn_protos);
if (ret >=0 && ! _stillinPlainStart)
ret = ssl_starttls_handshake(sslclient);
else
log_i("Actual TLS start posponed.");
_lastError = ret;
if (ret < 0) {
log_e("start_ssl_client: %d", ret);
log_e("start_ssl_client: connect failed: %d", ret);
stop();
return 0;
}
......@@ -150,6 +157,23 @@ int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *host, con
return 1;
}
int WiFiClientSecure::startTLS()
{
int ret = 1;
if (_stillinPlainStart) {
log_i("startTLS: starting TLS/SSL on this dplain connection");
ret = ssl_starttls_handshake(sslclient);
if (ret < 0) {
log_e("startTLS: %d", ret);
stop();
return 0;
};
_stillinPlainStart = false;
} else
log_i("startTLS: ignoring StartTLS - as we should be secure already");
return 1;
}
int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey) {
return connect(ip.toString().c_str(), port, pskIdent, psKey);
}
......@@ -164,7 +188,7 @@ int WiFiClientSecure::connect(const char *host, uint16_t port, const char *pskId
int ret = start_ssl_client(sslclient, address, port, host, _timeout, NULL, false, NULL, NULL, pskIdent, psKey, _use_insecure, _alpn_protos);
_lastError = ret;
if (ret < 0) {
log_e("start_ssl_client: %d", ret);
log_e("start_ssl_client: connect failed %d", ret);
stop();
return 0;
}
......@@ -189,10 +213,7 @@ int WiFiClientSecure::read()
{
uint8_t data = -1;
int res = read(&data, 1);
if (res < 0) {
return res;
}
return data;
return res < 0 ? res: data;
}
size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
......@@ -200,6 +221,10 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
if (!_connected) {
return 0;
}
if (_stillinPlainStart)
return send_net_data(sslclient, buf, size);
if(_lastWriteTimeout != _timeout){
struct timeval timeout_tv;
timeout_tv.tv_sec = _timeout / 1000;
......@@ -209,9 +234,9 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
_lastWriteTimeout = _timeout;
}
}
int res = send_ssl_data(sslclient, buf, size);
if (res < 0) {
log_e("Closing connection on failed write");
stop();
res = 0;
}
......@@ -220,6 +245,9 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
int WiFiClientSecure::read(uint8_t *buf, size_t size)
{
if(_stillinPlainStart)
return get_net_receive(sslclient, buf, size);
if(_lastReadTimeout != _timeout){
if(fd() >= 0){
struct timeval timeout_tv;
......@@ -232,7 +260,7 @@ int WiFiClientSecure::read(uint8_t *buf, size_t size)
}
}
int peeked = 0;
int peeked = 0, res = -1;
int avail = available();
if ((!buf && size) || avail <= 0) {
return -1;
......@@ -251,9 +279,10 @@ int WiFiClientSecure::read(uint8_t *buf, size_t size)
buf++;
peeked = 1;
}
res = get_ssl_receive(sslclient, buf, size);
int res = get_ssl_receive(sslclient, buf, size);
if (res < 0) {
log_e("Closing connection on failed read");
stop();
return peeked?peeked:res;
}
......@@ -262,12 +291,17 @@ int WiFiClientSecure::read(uint8_t *buf, size_t size)
int WiFiClientSecure::available()
{
int peeked = (_peek >= 0);
if (_stillinPlainStart)
return peek_net_receive(sslclient,0);
int peeked = (_peek >= 0), res = -1;
if (!_connected) {
return peeked;
}
int res = data_to_read(sslclient);
if (res < 0) {
res = data_to_read(sslclient);
if (res < 0 && !_stillinPlainStart) {
log_e("Closing connection on failed available check");
stop();
return peeked?peeked:res;
}
......@@ -403,3 +437,4 @@ int WiFiClientSecure::fd() const
{
return sslclient->socket;
}
......@@ -34,6 +34,7 @@ protected:
int _peek = -1;
int _timeout;
bool _use_insecure;
bool _stillinPlainStart = false;
const char *_CA_cert;
const char *_cert;
const char *_private_key;
......@@ -78,6 +79,17 @@ public:
bool verify(const char* fingerprint, const char* domain_name);
void setHandshakeTimeout(unsigned long handshake_timeout);
void setAlpnProtocols(const char **alpn_protos);
// Certain protocols start in plain-text; and then have the client
// give some STARTSSL command to `upgrade' the connection to TLS
// or SSL. Setting PlainStart to true (the default is false) enables
// this. It is up to the application code to then call 'startTLS()'
// at the right point to initialise the SSL or TLS upgrade.
void setPlainStart() { _stillinPlainStart = true; };
bool stillInPlainStart() { return _stillinPlainStart; };
int startTLS();
const mbedtls_x509_crt* getPeerCertificate() { return mbedtls_ssl_get_peer_cert(&sslclient->ssl_ctx); };
bool getFingerprintSHA256(uint8_t sha256_result[32]) { return get_peer_fingerprint(sslclient, sha256_result); };
int fd() const;
......
......@@ -55,8 +55,7 @@ void ssl_init(sslclient_context *ssl_client)
int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_t port, const char* hostname, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos)
{
char buf[512];
int ret, flags;
int ret;
int enable = 1;
log_v("Free internal heap before TLS %u", ESP.getFreeHeap());
......@@ -226,6 +225,9 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_
return -1;
}
// Note - this check for BOTH key and cert is relied on
// later during cleanup.
if (!insecure && cli_cert != NULL && cli_key != NULL) {
mbedtls_x509_crt_init(&ssl_client->client_cert);
mbedtls_pk_init(&ssl_client->client_key);
......@@ -267,6 +269,13 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_
}
mbedtls_ssl_set_bio(&ssl_client->ssl_ctx, &ssl_client->socket, mbedtls_net_send, mbedtls_net_recv, NULL );
return ssl_client->socket;
}
int ssl_starttls_handshake(sslclient_context *ssl_client)
{
char buf[512];
int ret, flags;
log_v("Performing the SSL/TLS handshake...");
unsigned long handshake_start_time=millis();
......@@ -280,7 +289,7 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_
}
if (cli_cert != NULL && cli_key != NULL) {
if (ssl_client->client_cert.version) {
log_d("Protocol is %s Ciphersuite is %s", mbedtls_ssl_get_version(&ssl_client->ssl_ctx), mbedtls_ssl_get_ciphersuite(&ssl_client->ssl_ctx));
if ((ret = mbedtls_ssl_get_record_expansion(&ssl_client->ssl_ctx)) >= 0) {
log_d("Record expansion is %d", ret);
......@@ -300,15 +309,16 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_
log_v("Certificate verified.");
}
if (rootCABuff != NULL) {
if (ssl_client->ca_cert.version) {
mbedtls_x509_crt_free(&ssl_client->ca_cert);
}
if (cli_cert != NULL) {
// We know that we always have a client cert/key pair -- and we
// cannot look into the private client_key pk struct for newer
// versions of mbedtls. So rely on a public field of the cert
// and infer that there is a key too.
if (ssl_client->client_cert.version) {
mbedtls_x509_crt_free(&ssl_client->client_cert);
}
if (cli_key != NULL) {
mbedtls_pk_free(&ssl_client->client_key);
}
......@@ -317,7 +327,6 @@ int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_
return ssl_client->socket;
}
void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key)
{
log_v("Cleaning SSL connection.");
......@@ -328,13 +337,13 @@ void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, cons
}
// avoid memory leak if ssl connection attempt failed
//if (ssl_client->ssl_conf.ca_chain != NULL) {
// if (ssl_client->ssl_conf.ca_chain != NULL) {
mbedtls_x509_crt_free(&ssl_client->ca_cert);
//}
//if (ssl_client->ssl_conf.key_cert != NULL) {
// }
// if (ssl_client->ssl_conf.key_cert != NULL) {
mbedtls_x509_crt_free(&ssl_client->client_cert);
mbedtls_pk_free(&ssl_client->client_key);
//}
// }
mbedtls_ssl_free(&ssl_client->ssl_ctx);
mbedtls_ssl_config_free(&ssl_client->ssl_conf);
mbedtls_ctr_drbg_free(&ssl_client->drbg_ctx);
......@@ -368,10 +377,8 @@ int data_to_read(sslclient_context *ssl_client)
int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len)
{
log_v("Writing HTTP request with %d bytes...", len); //for low level debug
int ret = -1;
unsigned long write_start_time=millis();
int ret = -1;
while ((ret = mbedtls_ssl_write(&ssl_client->ssl_ctx, data, len)) <= 0) {
if((millis()-write_start_time)>ssl_client->socket_timeout) {
......@@ -391,14 +398,60 @@ int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len
return ret;
}
int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length)
// Some protocols, such as SMTP, XMPP, MySQL/Posgress and various others
// do a 'in-line' upgrade from plaintext to SSL or TLS (usually with some
// sort of 'STARTTLS' textual command from client to sever). For this
// we need to have access to the 'raw' socket; i.e. without TLS/SSL state
// handling before the handshake starts; but after setting up the TLS
// connection.
//
int peek_net_receive(sslclient_context *ssl_client, int timeout) {
#if MBEDTLS_FIXED_LINKING_NET_POLL
int ret = mbedtls_net_poll((mbedtls_net_context*)ssl_client, MBEDTLS_NET_POLL_READ, timeout);
ret == MBEDTLS_NET_POLL_READ ? 1 : ret;
#else
// We should be using mbedtls_net_poll(); which is part of mbedtls and
// included in the EspressifSDK. Unfortunately - it did not make it into
// the statically linked library file. So, for now, we replace it by
// substancially similar code.
//
struct timeval tv = { .tv_sec = timeout / 1000, .tv_usec = (timeout % 1000) * 1000 };
fd_set fdset;
FD_SET(ssl_client->socket, &fdset);
int ret = select(ssl_client->socket + 1, &fdset, nullptr, nullptr, timeout<0 ? nullptr : &tv);
if (ret < 0) {
log_e("select on read fd %d, errno: %d, \"%s\"", ssl_client->socket, errno, strerror(errno));
lwip_close(ssl_client->socket);
ssl_client->socket = -1;
return -1;
};
#endif
return ret;
};
int get_net_receive(sslclient_context *ssl_client, uint8_t *data, int length)
{
//log_d( "Reading HTTP response..."); //for low level debug
int ret = -1;
int ret = peek_net_receive(ssl_client,ssl_client->socket_timeout);
if (ret > 0)
ret = mbedtls_net_recv(ssl_client, data, length);
ret = mbedtls_ssl_read(&ssl_client->ssl_ctx, data, length);
// log_v( "%d bytes NET read of %d", ret, length); //for low level debug
return ret;
}
//log_v( "%d bytes read", ret); //for low level debug
int send_net_data(sslclient_context *ssl_client, const uint8_t *data, size_t len) {
int ret = mbedtls_net_send(ssl_client, data, len);
// log_v("Net sending %d btes->ret %d", len, ret); //for low level debug
return ret;
}
int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length)
{
int ret = mbedtls_ssl_read(&ssl_client->ssl_ctx, data, length);
// log_v( "%d bytes SSL read", ret); //for low level debug
return ret;
}
......
......@@ -31,10 +31,14 @@ typedef struct sslclient_context {
void ssl_init(sslclient_context *ssl_client);
int start_ssl_client(sslclient_context *ssl_client, const IPAddress& ip, uint32_t port, const char* hostname, int timeout, const char *rootCABuff, bool useRootCABundle, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey, bool insecure, const char **alpn_protos);
int ssl_starttls_handshake(sslclient_context *ssl_client);
void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int data_to_read(sslclient_context *ssl_client);
int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, size_t len);
int get_ssl_receive(sslclient_context *ssl_client, uint8_t *data, int length);
int send_net_data(sslclient_context *ssl_client, const uint8_t *data, size_t len);
int get_net_receive(sslclient_context *ssl_client, uint8_t *data, int length);
int peek_net_receive(sslclient_context *ssl_client, int timeout);
bool verify_ssl_fingerprint(sslclient_context *ssl_client, const char* fp, const char* domain_name);
bool verify_ssl_dn(sslclient_context *ssl_client, const char* domain_name);
bool get_peer_fingerprint(sslclient_context *ssl_client, uint8_t sha256[32]);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment