diff --git a/README.md b/README.md index 8a11cea5..2d6b743b 100644 --- a/README.md +++ b/README.md @@ -12,3 +12,7 @@ The base classes on which everything else is built. They expose all possible sce ## TLS support Support for TLS is added using mbed TLS, for now only the client part is supported. You can enable this by adding the flag ASYNC_TCP_SSL_ENABLED to your build flags (-DASYNC_TCP_SSL_ENABLED). If you'd like to set a root certificate you can use the setRootCa function on AsyncClient. Feel free to add support for the server side as well :-) + +In addition to the regular certificate based cipher suites there is also support for Pre-Shared Key +cipher suites. Use `setPsk` to define the PSK identifier and PSK itself. The PSK needs to be +provided in the form of a hex string (and easy way to generate a PSK is to use md5sum). diff --git a/src/AsyncTCP.cpp b/src/AsyncTCP.cpp index 1fe71191..aaaeb512 100644 --- a/src/AsyncTCP.cpp +++ b/src/AsyncTCP.cpp @@ -429,6 +429,8 @@ AsyncClient::AsyncClient(tcp_pcb* pcb) , _root_ca_len(0) , _pcb_secure(false) , _handshake_done(true) +, _psk_ident(0) +, _psk(0) #endif // ASYNC_TCP_SSL_ENABLED , _pcb_sent_at(0) , _close_pcb(false) @@ -502,6 +504,11 @@ void AsyncClient::setRootCa(const char* rootca, const size_t len) { _root_ca = (char*)rootca; _root_ca_len = len; } + +void AsyncClient::setPsk(const char* psk_ident, const char* psk) { + _psk_ident = psk_ident; + _psk = psk; +} #endif // ASYNC_TCP_SSL_ENABLED AsyncClient& AsyncClient::operator=(const AsyncClient& other){ @@ -545,11 +552,17 @@ int8_t AsyncClient::_connected(void* pcb, int8_t err){ tcp_poll(_pcb, &_tcp_poll, 1); #if ASYNC_TCP_SSL_ENABLED if(_pcb_secure){ - if(tcp_ssl_new_client(_pcb, _hostname.empty() ? NULL : _hostname.c_str(), _root_ca, _root_ca_len) < 0){ + bool err = false; + if(_root_ca) { + err = tcp_ssl_new_client(_pcb, _hostname.empty() ? NULL : _hostname.c_str(), _root_ca, _root_ca_len) < 0; + } else { + err = tcp_ssl_new_psk_client(_pcb, _psk_ident, _psk) < 0; + } + if (err) { log_e("closing...."); return _close(); } - + tcp_ssl_arg(_pcb, this); tcp_ssl_data(_pcb, &_s_data); tcp_ssl_handshake(_pcb, &_s_handshake); @@ -663,7 +676,7 @@ int8_t AsyncClient::_recv(tcp_pcb* pcb, pbuf* pb, int8_t err) { log_e("_recv err: %d\n", read_bytes); _close(); } - + //return read_bytes; } return ERR_OK; diff --git a/src/AsyncTCP.h b/src/AsyncTCP.h index 117f7ff6..8117a63f 100644 --- a/src/AsyncTCP.h +++ b/src/AsyncTCP.h @@ -77,6 +77,8 @@ class AsyncClient { char* _root_ca; bool _pcb_secure; bool _handshake_done; + const char* _psk_ident; + const char* _psk; #endif // ASYNC_TCP_SSL_ENABLED uint32_t _pcb_sent_at; bool _close_pcb; @@ -122,6 +124,7 @@ class AsyncClient { bool connect(IPAddress ip, uint16_t port, bool secure = false); bool connect(const char* host, uint16_t port, bool secure = false); void setRootCa(const char* rootca, const size_t len); + void setPsk(const char* psk_ident, const char* psk); #else bool connect(IPAddress ip, uint16_t port); bool connect(const char* host, uint16_t port); diff --git a/src/tcp_mbedtls.c b/src/tcp_mbedtls.c index 55ec9cc8..fc06d72b 100644 --- a/src/tcp_mbedtls.c +++ b/src/tcp_mbedtls.c @@ -7,7 +7,7 @@ #include // #define TCP_SSL_DEBUG(...) ets_printf(__VA_ARGS__) -#define TCP_SSL_DEBUG(...) +#define TCP_SSL_DEBUG(...) static const char pers[] = "esp32-tls"; @@ -139,7 +139,7 @@ int tcp_ssl_send(void *ctx, const unsigned char *buf, size_t len) { } else { tcp_len = len; } - + if (tcp_len > 2 * tcp_ssl->tcp->mss) { tcp_len = 2 * tcp_ssl->tcp->mss; } @@ -243,7 +243,7 @@ int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* ro MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) { TCP_SSL_DEBUG("error setting SSL config.\n"); - + tcp_ssl_free(tcp); return -1; } @@ -267,7 +267,7 @@ int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* ro } else { mbedtls_ssl_conf_authmode(&tcp_ssl->ssl_conf, MBEDTLS_SSL_VERIFY_OPTIONAL); } - + if(hostname != NULL) { TCP_SSL_DEBUG("setting the hostname: %s\n", hostname); if((ret = mbedtls_ssl_set_hostname(&tcp_ssl->ssl_ctx, hostname)) != 0){ @@ -282,7 +282,87 @@ int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* ro if ((ret = mbedtls_ssl_setup(&tcp_ssl->ssl_ctx, &tcp_ssl->ssl_conf)) != 0) { tcp_ssl_free(tcp); - + + return handle_error(ret); + } + + mbedtls_ssl_set_bio(&tcp_ssl->ssl_ctx, (void*)tcp_ssl, tcp_ssl_send, tcp_ssl_recv, NULL); + + // Start handshake. + ret = mbedtls_ssl_handshake(&tcp_ssl->ssl_ctx); + if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + TCP_SSL_DEBUG("handshake error!\n"); + return handle_error(ret); + } + + return ERR_OK; +} + +// Open an SSL connection using a PSK (pre-shared-key) cipher suite. +int tcp_ssl_new_psk_client(struct tcp_pcb *tcp, const char* psk_ident, const char* pskey) { + tcp_ssl_t* tcp_ssl; + + if(tcp == NULL) return -1; + if(tcp_ssl_get(tcp) != NULL) return -1; + + tcp_ssl = tcp_ssl_new(tcp); + if(tcp_ssl == NULL) return -1; + + mbedtls_entropy_init(&tcp_ssl->entropy_ctx); + mbedtls_ctr_drbg_init(&tcp_ssl->drbg_ctx); + mbedtls_ssl_init(&tcp_ssl->ssl_ctx); + mbedtls_ssl_config_init(&tcp_ssl->ssl_conf); + + mbedtls_ctr_drbg_seed(&tcp_ssl->drbg_ctx, mbedtls_entropy_func, + &tcp_ssl->entropy_ctx, (const unsigned char*)pers, strlen(pers)); + + if(mbedtls_ssl_config_defaults(&tcp_ssl->ssl_conf, + MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT)) { + TCP_SSL_DEBUG("error setting SSL config.\n"); + + tcp_ssl_free(tcp); + return -1; + } + + int ret = 0; + + TCP_SSL_DEBUG("setting the pre-shared key.\n"); + // convert PSK from hex string to binary + if ((strlen(pskey) & 1) != 0 || strlen(pskey) > 2*MBEDTLS_PSK_MAX_LEN) { + TCP_SSL_DEBUG(" failed\n ! pre-shared key not valid hex or too long\n\n"); + return -1; + } + unsigned char psk[MBEDTLS_PSK_MAX_LEN]; + size_t psk_len = strlen(pskey)/2; + for (int j=0; j= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] = c<<4; + c = pskey[j+1]; + if (c >= '0' && c <= '9') c -= '0'; + else if (c >= 'A' && c <= 'F') c -= 'A' - 10; + else if (c >= 'a' && c <= 'f') c -= 'a' - 10; + else return -1; + psk[j/2] |= c; + } + // set mbedtls config + ret = mbedtls_ssl_conf_psk(&tcp_ssl->ssl_conf, psk, psk_len, + (const unsigned char *)psk_ident, strlen(psk_ident)); + if (ret != 0) { + TCP_SSL_DEBUG(" failed\n ! mbedtls_ssl_conf_psk returned -0x%x\n\n", -ret); + return handle_error(ret); + } + + mbedtls_ssl_conf_rng(&tcp_ssl->ssl_conf, mbedtls_ctr_drbg_random, &tcp_ssl->drbg_ctx); + + if ((ret = mbedtls_ssl_setup(&tcp_ssl->ssl_ctx, &tcp_ssl->ssl_conf)) != 0) { + tcp_ssl_free(tcp); + return handle_error(ret); } @@ -302,7 +382,7 @@ int tcp_ssl_write(struct tcp_pcb *tcp, uint8_t *data, size_t len) { if(tcp == NULL) { return -1; } - + tcp_ssl_t * tcp_ssl = tcp_ssl_get(tcp); if(tcp_ssl == NULL){ @@ -351,7 +431,7 @@ int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p) { tcp_ssl->tcp_pbuf = p; tcp_ssl->pbuf_offset = 0; - + do { if(tcp_ssl->ssl_ctx.state != MBEDTLS_SSL_HANDSHAKE_OVER) { TCP_SSL_DEBUG("start handshake: %d\n", tcp_ssl->ssl_ctx.state); @@ -467,4 +547,4 @@ void tcp_ssl_err(struct tcp_pcb *tcp, tcp_ssl_error_cb_t arg){ } } -#endif // ASYNC_TCP_SSL_ENABLED \ No newline at end of file +#endif // ASYNC_TCP_SSL_ENABLED diff --git a/src/tcp_mbedtls.h b/src/tcp_mbedtls.h index ff0c9a87..492c70bb 100644 --- a/src/tcp_mbedtls.h +++ b/src/tcp_mbedtls.h @@ -31,6 +31,7 @@ typedef void (* tcp_ssl_error_cb_t)(void *arg, struct tcp_pcb *tcp, int8_t error uint8_t tcp_ssl_has_client(); int tcp_ssl_new_client(struct tcp_pcb *tcp, const char* hostname, const char* root_ca, const size_t root_ca_len); +int tcp_ssl_new_psk_client(struct tcp_pcb *tcp, const char* psk_ident, const char* psk); int tcp_ssl_write(struct tcp_pcb *tcp, uint8_t *data, size_t len); int tcp_ssl_read(struct tcp_pcb *tcp, struct pbuf *p); int tcp_ssl_handshake_step(struct tcp_pcb *tcp);