Commit 582e6433 authored by me-no-dev's avatar me-no-dev

Add proper timeout handling to WiFiClientSecure

parent ef07a84a
...@@ -48,6 +48,7 @@ WiFiClientSecure::WiFiClientSecure() ...@@ -48,6 +48,7 @@ WiFiClientSecure::WiFiClientSecure()
WiFiClientSecure::WiFiClientSecure(int sock) WiFiClientSecure::WiFiClientSecure(int sock)
{ {
_connected = false; _connected = false;
_timeout = 0;
sslclient = new sslclient_context; sslclient = new sslclient_context;
ssl_init(sslclient); ssl_init(sslclient);
...@@ -98,6 +99,11 @@ int WiFiClientSecure::connect(IPAddress ip, uint16_t port) ...@@ -98,6 +99,11 @@ int WiFiClientSecure::connect(IPAddress ip, uint16_t port)
return connect(ip, port, _CA_cert, _cert, _private_key); return connect(ip, port, _CA_cert, _cert, _private_key);
} }
int WiFiClientSecure::connect(IPAddress ip, uint16_t port, int32_t timeout){
_timeout = timeout;
return connect(ip, port);
}
int WiFiClientSecure::connect(const char *host, uint16_t port) int WiFiClientSecure::connect(const char *host, uint16_t port)
{ {
if (_pskIdent && _psKey) if (_pskIdent && _psKey)
...@@ -105,6 +111,11 @@ int WiFiClientSecure::connect(const char *host, uint16_t port) ...@@ -105,6 +111,11 @@ int WiFiClientSecure::connect(const char *host, uint16_t port)
return connect(host, port, _CA_cert, _cert, _private_key); return connect(host, port, _CA_cert, _cert, _private_key);
} }
int WiFiClientSecure::connect(const char *host, uint16_t port, int32_t timeout){
_timeout = timeout;
return connect(host, port);
}
int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *_CA_cert, const char *_cert, const char *_private_key) int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *_CA_cert, const char *_cert, const char *_private_key)
{ {
return connect(ip.toString().c_str(), port, _CA_cert, _cert, _private_key); return connect(ip.toString().c_str(), port, _CA_cert, _cert, _private_key);
...@@ -112,7 +123,10 @@ int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *_CA_cert, ...@@ -112,7 +123,10 @@ int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *_CA_cert,
int WiFiClientSecure::connect(const char *host, uint16_t port, const char *_CA_cert, const char *_cert, const char *_private_key) int WiFiClientSecure::connect(const char *host, uint16_t port, const char *_CA_cert, const char *_cert, const char *_private_key)
{ {
int ret = start_ssl_client(sslclient, host, port, _CA_cert, _cert, _private_key, NULL, NULL); if(_timeout > 0){
sslclient->handshake_timeout = _timeout * 1000;
}
int ret = start_ssl_client(sslclient, host, port, _timeout, _CA_cert, _cert, _private_key, NULL, NULL);
_lastError = ret; _lastError = ret;
if (ret < 0) { if (ret < 0) {
log_e("start_ssl_client: %d", ret); log_e("start_ssl_client: %d", ret);
...@@ -129,7 +143,10 @@ int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *pskIdent, ...@@ -129,7 +143,10 @@ int WiFiClientSecure::connect(IPAddress ip, uint16_t port, const char *pskIdent,
int WiFiClientSecure::connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey) { int WiFiClientSecure::connect(const char *host, uint16_t port, const char *pskIdent, const char *psKey) {
log_v("start_ssl_client with PSK"); log_v("start_ssl_client with PSK");
int ret = start_ssl_client(sslclient, host, port, NULL, NULL, NULL, _pskIdent, _psKey); if(_timeout > 0){
sslclient->handshake_timeout = _timeout * 1000;
}
int ret = start_ssl_client(sslclient, host, port, _timeout, NULL, NULL, NULL, _pskIdent, _psKey);
_lastError = ret; _lastError = ret;
if (ret < 0) { if (ret < 0) {
log_e("start_ssl_client: %d", ret); log_e("start_ssl_client: %d", ret);
......
...@@ -32,6 +32,7 @@ protected: ...@@ -32,6 +32,7 @@ protected:
int _lastError = 0; int _lastError = 0;
int _peek = -1; int _peek = -1;
int _timeout = 0;
const char *_CA_cert; const char *_CA_cert;
const char *_cert; const char *_cert;
const char *_private_key; const char *_private_key;
...@@ -44,7 +45,9 @@ public: ...@@ -44,7 +45,9 @@ public:
WiFiClientSecure(int socket); WiFiClientSecure(int socket);
~WiFiClientSecure(); ~WiFiClientSecure();
int connect(IPAddress ip, uint16_t port); int connect(IPAddress ip, uint16_t port);
int connect(IPAddress ip, uint16_t port, int32_t timeout);
int connect(const char *host, uint16_t port); int connect(const char *host, uint16_t port);
int connect(const char *host, uint16_t port, int32_t timeout);
int connect(IPAddress ip, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key); int connect(IPAddress ip, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key); int connect(const char *host, uint16_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey); int connect(IPAddress ip, uint16_t port, const char *pskIdent, const char *psKey);
......
...@@ -45,10 +45,10 @@ void ssl_init(sslclient_context *ssl_client) ...@@ -45,10 +45,10 @@ void ssl_init(sslclient_context *ssl_client)
} }
int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey) int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey)
{ {
char buf[512]; char buf[512];
int ret, flags, timeout; int ret, flags;
int enable = 1; int enable = 1;
log_v("Free internal heap before TLS %u", ESP.getFreeHeap()); log_v("Free internal heap before TLS %u", ESP.getFreeHeap());
...@@ -73,7 +73,10 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p ...@@ -73,7 +73,10 @@ int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t p
serv_addr.sin_port = htons(port); serv_addr.sin_port = htons(port);
if (lwip_connect(ssl_client->socket, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) == 0) { if (lwip_connect(ssl_client->socket, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) == 0) {
timeout = 30000; if(timeout <= 0){
timeout = 30;
}
timeout *= 1000;//to milliseconds
lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)); lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout));
lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)); lwip_setsockopt(ssl_client->socket, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout));
lwip_setsockopt(ssl_client->socket, IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(enable)); lwip_setsockopt(ssl_client->socket, IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(enable));
......
...@@ -29,7 +29,7 @@ typedef struct sslclient_context { ...@@ -29,7 +29,7 @@ typedef struct sslclient_context {
void ssl_init(sslclient_context *ssl_client); void ssl_init(sslclient_context *ssl_client);
int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey); int start_ssl_client(sslclient_context *ssl_client, const char *host, uint32_t port, int timeout, const char *rootCABuff, const char *cli_cert, const char *cli_key, const char *pskIdent, const char *psKey);
void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key); void stop_ssl_socket(sslclient_context *ssl_client, const char *rootCABuff, const char *cli_cert, const char *cli_key);
int data_to_read(sslclient_context *ssl_client); int data_to_read(sslclient_context *ssl_client);
int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, uint16_t len); int send_ssl_data(sslclient_context *ssl_client, const uint8_t *data, uint16_t len);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment