Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ The base classes on which everything else is built. They expose all possible sce

## TLS support
Support for TLS is added using mbed TLS, for now only the client part is supported. You can enable this by adding the flag ASYNC_TCP_SSL_ENABLED to your build flags (-DASYNC_TCP_SSL_ENABLED). If you'd like to set a root certificate you can use the setRootCa function on AsyncClient. Feel free to add support for the server side as well :-)

In addition to the regular certificate based cipher suites there is also support for Pre-Shared Key
cipher suites. Use `setPsk` to define the PSK identifier and PSK itself. The PSK needs to be
provided in the form of a hex string (and easy way to generate a PSK is to use md5sum).
19 changes: 16 additions & 3 deletions src/AsyncTCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ AsyncClient::AsyncClient(tcp_pcb* pcb)
, _root_ca_len(0)
, _pcb_secure(false)
, _handshake_done(true)
, _psk_ident(0)
, _psk(0)
#endif // ASYNC_TCP_SSL_ENABLED
, _pcb_sent_at(0)
, _close_pcb(false)
Expand Down Expand Up @@ -502,6 +504,11 @@ void AsyncClient::setRootCa(const char* rootca, const size_t len) {
_root_ca = (char*)rootca;
_root_ca_len = len;
}

void AsyncClient::setPsk(const char* psk_ident, const char* psk) {
_psk_ident = psk_ident;
_psk = psk;
}
#endif // ASYNC_TCP_SSL_ENABLED

AsyncClient& AsyncClient::operator=(const AsyncClient& other){
Expand Down Expand Up @@ -545,11 +552,17 @@ int8_t AsyncClient::_connected(void* pcb, int8_t err){
tcp_poll(_pcb, &_tcp_poll, 1);
#if ASYNC_TCP_SSL_ENABLED
if(_pcb_secure){
if(tcp_ssl_new_client(_pcb, _hostname.empty() ? NULL : _hostname.c_str(), _root_ca, _root_ca_len) < 0){
bool err = false;
if(_root_ca) {
err = tcp_ssl_new_client(_pcb, _hostname.empty() ? NULL : _hostname.c_str(), _root_ca, _root_ca_len) < 0;
} else {
err = tcp_ssl_new_psk_client(_pcb, _psk_ident, _psk) < 0;
}
if (err) {
log_e("closing....");
return _close();
}

tcp_ssl_arg(_pcb, this);
tcp_ssl_data(_pcb, &_s_data);
tcp_ssl_handshake(_pcb, &_s_handshake);
Expand Down Expand Up @@ -663,7 +676,7 @@ int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) {
log_e("_recv err: %d\n", read_bytes);
_close();
}

//return read_bytes;
}
return ERR_OK;
Expand Down
3 changes: 3 additions & 0 deletions src/AsyncTCP.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class AsyncClient {
char* _root_ca;
bool _pcb_secure;
bool _handshake_done;
const char* _psk_ident;
const char* _psk;
#endif // ASYNC_TCP_SSL_ENABLED
uint32_t _pcb_sent_at;
bool _close_pcb;
Expand Down Expand Up @@ -122,6 +124,7 @@ class AsyncClient {
bool connect(IPAddress ip, uint16_t port, bool secure = false);
bool connect(const char* host, uint16_t port, bool secure = false);
void setRootCa(const char* rootca, const size_t len);
void setPsk(const char* psk_ident, const char* psk);
#else
bool connect(IPAddress ip, uint16_t port);
bool connect(const char* host, uint16_t port);
Expand Down
96 changes: 88 additions & 8 deletions src/tcp_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <string.h>

// #define TCP_SSL_DEBUG(...) ets_printf(__VA_ARGS__)
#define TCP_SSL_DEBUG(...)
#define TCP_SSL_DEBUG(...)

static const char pers[] = "esp32-tls";

Expand Down Expand Up @@ -139,7 +139,7 @@ int tcp_ssl_send(void *ctx, const unsigned char *buf, size_t len) {
} else {
tcp_len = len;
}

if (tcp_len > 2 * tcp_ssl->tcp->mss) {
tcp_len = 2 * tcp_ssl->tcp->mss;
}
Expand Down Expand Up @@ -243,7 +243,7 @@ int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* ro
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) {
TCP_SSL_DEBUG("error setting SSL config.\n");

tcp_ssl_free(tcp);
return -1;
}
Expand All @@ -267,7 +267,7 @@ int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* ro
} else {
mbedtls_ssl_conf_authmode(&tcp_ssl->ssl_conf, MBEDTLS_SSL_VERIFY_OPTIONAL);
}

if(hostname != NULL) {
TCP_SSL_DEBUG("setting the hostname: %s\n", hostname);
if((ret = mbedtls_ssl_set_hostname(&tcp_ssl->ssl_ctx, hostname)) != 0){
Expand All @@ -282,7 +282,87 @@ int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* ro

if ((ret = mbedtls_ssl_setup(&tcp_ssl->ssl_ctx, &tcp_ssl->ssl_conf)) != 0) {
tcp_ssl_free(tcp);


return handle_error(ret);
}

mbedtls_ssl_set_bio(&tcp_ssl->ssl_ctx, (void*)tcp_ssl, tcp_ssl_send, tcp_ssl_recv, NULL);

// Start handshake.
ret = mbedtls_ssl_handshake(&tcp_ssl->ssl_ctx);
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
TCP_SSL_DEBUG("handshake error!\n");
return handle_error(ret);
}

return ERR_OK;
}

// Open an SSL connection using a PSK (pre-shared-key) cipher suite.
int tcp_ssl_new_psk_client(struct tcp_pcb *tcp, const char* psk_ident, const char* pskey) {
tcp_ssl_t* tcp_ssl;

if(tcp == NULL) return -1;
if(tcp_ssl_get(tcp) != NULL) return -1;

tcp_ssl = tcp_ssl_new(tcp);
if(tcp_ssl == NULL) return -1;

mbedtls_entropy_init(&tcp_ssl->entropy_ctx);
mbedtls_ctr_drbg_init(&tcp_ssl->drbg_ctx);
mbedtls_ssl_init(&tcp_ssl->ssl_ctx);
mbedtls_ssl_config_init(&tcp_ssl->ssl_conf);

mbedtls_ctr_drbg_seed(&tcp_ssl->drbg_ctx, mbedtls_entropy_func,
&tcp_ssl->entropy_ctx, (const unsigned char*)pers, strlen(pers));

if(mbedtls_ssl_config_defaults(&tcp_ssl->ssl_conf,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) {
TCP_SSL_DEBUG("error setting SSL config.\n");

tcp_ssl_free(tcp);
return -1;
}

int ret = 0;

TCP_SSL_DEBUG("setting the pre-shared key.\n");
// convert PSK from hex string to binary
if ((strlen(pskey) & 1) != 0 || strlen(pskey) > 2*MBEDTLS_PSK_MAX_LEN) {
TCP_SSL_DEBUG(" failed\n ! pre-shared key not valid hex or too long\n\n");
return -1;
}
unsigned char psk[MBEDTLS_PSK_MAX_LEN];
size_t psk_len = strlen(pskey)/2;
for (int j=0; j<strlen(pskey); j+= 2) {
char c = pskey[j];
if (c >= '0' && c <= '9') c -= '0';
else if (c >= 'A' && c <= 'F') c -= 'A' - 10;
else if (c >= 'a' && c <= 'f') c -= 'a' - 10;
else return -1;
psk[j/2] = c<<4;
c = pskey[j+1];
if (c >= '0' && c <= '9') c -= '0';
else if (c >= 'A' && c <= 'F') c -= 'A' - 10;
else if (c >= 'a' && c <= 'f') c -= 'a' - 10;
else return -1;
psk[j/2] |= c;
}
// set mbedtls config
ret = mbedtls_ssl_conf_psk(&tcp_ssl->ssl_conf, psk, psk_len,
(const unsigned char *)psk_ident, strlen(psk_ident));
if (ret != 0) {
TCP_SSL_DEBUG(" failed\n ! mbedtls_ssl_conf_psk returned -0x%x\n\n", -ret);
return handle_error(ret);
}

mbedtls_ssl_conf_rng(&tcp_ssl->ssl_conf, mbedtls_ctr_drbg_random, &tcp_ssl->drbg_ctx);

if ((ret = mbedtls_ssl_setup(&tcp_ssl->ssl_ctx, &tcp_ssl->ssl_conf)) != 0) {
tcp_ssl_free(tcp);

return handle_error(ret);
}

Expand All @@ -302,7 +382,7 @@ int tcp_ssl_write(struct tcp_pcb *tcp, uint8_t *data, size_t len) {
if(tcp == NULL) {
return -1;
}

tcp_ssl_t * tcp_ssl = tcp_ssl_get(tcp);

if(tcp_ssl == NULL){
Expand Down Expand Up @@ -351,7 +431,7 @@ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) {

tcp_ssl->tcp_pbuf = p;
tcp_ssl->pbuf_offset = 0;

do {
if(tcp_ssl->ssl_ctx.state != MBEDTLS_SSL_HANDSHAKE_OVER) {
TCP_SSL_DEBUG("start handshake: %d\n", tcp_ssl->ssl_ctx.state);
Expand Down Expand Up @@ -467,4 +547,4 @@ void tcp_ssl_err(struct tcp_pcb *tcp, tcp_ssl_error_cb_t arg){
}
}

#endif // ASYNC_TCP_SSL_ENABLED
#endif // ASYNC_TCP_SSL_ENABLED
1 change: 1 addition & 0 deletions src/tcp_mbedtls.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ typedef void (* tcp_ssl_error_cb_t)(void *arg, struct tcp_pcb *tcp, int8_t error

uint8_t tcp_ssl_has_client();
int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* root_ca, const size_t root_ca_len);
int tcp_ssl_new_psk_client(struct tcp_pcb *tcp, const char* psk_ident, const char* psk);
int tcp_ssl_write(struct tcp_pcb *tcp, uint8_t *data, size_t len);
int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p);
int tcp_ssl_handshake_step(struct tcp_pcb *tcp);
Expand Down