diff --git a/iocore/net/P_SSLNetVConnection.h b/iocore/net/P_SSLNetVConnection.h index 7c8b94ae1b6..d38955251e2 100644 --- a/iocore/net/P_SSLNetVConnection.h +++ b/iocore/net/P_SSLNetVConnection.h @@ -331,6 +331,8 @@ class SSLNetVConnection : public UnixNetVConnection, ink_hrtime sslLastWriteTime = 0; int64_t sslTotalBytesSent = 0; + SSL_SESSION *client_sess = nullptr; + // The serverName is either a pointer to the (null-terminated) name fetched from the // SSL object or the empty string. const char * diff --git a/iocore/net/SSLClientUtils.cc b/iocore/net/SSLClientUtils.cc index 4d14569545a..231010bd079 100644 --- a/iocore/net/SSLClientUtils.cc +++ b/iocore/net/SSLClientUtils.cc @@ -165,13 +165,15 @@ ssl_new_session_callback(SSL *ssl, SSL_SESSION *sess) std::string lookup_key; ts::bwprint(lookup_key, "{}:{}:{}", sni_addr.c_str(), SSL_get_SSL_CTX(ssl), get_verify_str(ssl)); origin_sess_cache->insert_session(lookup_key, sess, ssl); - return 1; } else { if (is_debug_tag_set("ssl.origin_session_cache")) { Debug("ssl.origin_session_cache", "Failed to fetch SNI/IP."); } - return 0; } + + // return 0 here since we're converting the sessions using i2d_SSL_SESSION, + // meaning if we return 1, openssl will keep an extra refcount on the session. + return 0; } SSL_CTX * diff --git a/iocore/net/SSLNetVConnection.cc b/iocore/net/SSLNetVConnection.cc index da1b473f1b7..74e4e10f37b 100644 --- a/iocore/net/SSLNetVConnection.cc +++ b/iocore/net/SSLNetVConnection.cc @@ -946,10 +946,23 @@ SSLNetVConnection::clear() _ca_cert_file.reset(); _ca_cert_dir.reset(); + // SSL_SESSION_free() must only be called for SSL_SESSION objects, + // for which the reference count was explicitly incremented (e.g. + // by calling SSL_get1_session(), see SSL_get_session(3)) or when + // the SSL_SESSION object was generated outside a TLS handshake + // operation, e.g. by using d2i_SSL_SESSION(3). It must not be called + // on other SSL_SESSION objects, as this would cause incorrect + // reference counts and therefore program failures. + if (client_sess != nullptr) { + SSL_SESSION_free(client_sess); + client_sess = nullptr; + } + if (ssl != nullptr) { SSL_free(ssl); ssl = nullptr; } + ALPNSupport::clear(); TLSBasicSupport::clear(); TLSSessionResumptionSupport::clear(); @@ -2078,7 +2091,14 @@ SSLNetVConnection::_ssl_connect() sess = this->getOriginSession(ssl, lookup_key); if (sess) { - SSL_set_session(ssl, sess); + if (SSL_set_session(ssl, sess) == 0) { + SSL_SESSION_free(sess); + } else { + if (this->client_sess) { + SSL_SESSION_free(this->client_sess); + } + this->client_sess = sess; + } } } } diff --git a/iocore/net/SSLSessionCache.cc b/iocore/net/SSLSessionCache.cc index 93051a35c7f..1bcc52f41c7 100644 --- a/iocore/net/SSLSessionCache.cc +++ b/iocore/net/SSLSessionCache.cc @@ -326,8 +326,25 @@ SSLOriginSessionCache::insert_session(const std::string &lookup_key, SSL_SESSION Debug("ssl.origin_session_cache", "insert session: %s = %p", lookup_key.c_str(), sess); } + size_t len = i2d_SSL_SESSION(sess, nullptr); // make sure we're not going to need more than SSL_MAX_ORIG_SESSION_SIZE bytes + + /* do not cache a session that's too big. */ + if (len > static_cast(SSL_MAX_ORIG_SESSION_SIZE)) { + Debug("ssl.origin_session_cache", "Unable to save SSL session because size of %zd exceeds the max of %d", len, + SSL_MAX_ORIG_SESSION_SIZE); + return; + } else if (len == 0) { + Debug("ssl.origin_session_cache", "Unable to save SSL session because size is 0"); + return; + } + + Ptr buf; + buf = new_IOBufferData(buffer_size_to_index(len, MAX_BUFFER_SIZE_INDEX), MEMALIGNED); + ink_release_assert(static_cast(buf->block_size()) >= len); + unsigned char *loc = reinterpret_cast(buf->data()); + i2d_SSL_SESSION(sess, &loc); ssl_curve_id curve = (ssl == nullptr) ? 0 : SSLGetCurveNID(ssl); - ats_scoped_obj ssl_orig_session(new SSLOriginSession(lookup_key, sess, curve)); + ats_scoped_obj ssl_orig_session(new SSLOriginSession(lookup_key, buf, len, curve)); auto new_node = ssl_orig_session.release(); std::unique_lock lock(mutex); @@ -358,7 +375,8 @@ SSLOriginSessionCache::get_session(const std::string &lookup_key, SSL_SESSION ** return false; } - *sess = entry->second->session; + const unsigned char *loc = reinterpret_cast(entry->second->asn1_data->data()); + *sess = d2i_SSL_SESSION(nullptr, &loc, entry->second->len_asn1_data); if (curve != nullptr) { *curve = entry->second->curve_id; } @@ -380,3 +398,23 @@ SSLOriginSessionCache::remove_oldest_session(const std::unique_locksecond; + orig_sess_que.remove(node); + orig_sess_map.erase(entry); + delete node; + } + + return; +} diff --git a/iocore/net/SSLSessionCache.h b/iocore/net/SSLSessionCache.h index fdf99660867..2a099d36c1e 100644 --- a/iocore/net/SSLSessionCache.h +++ b/iocore/net/SSLSessionCache.h @@ -33,6 +33,7 @@ #include #define SSL_MAX_SESSION_SIZE 256 +#define SSL_MAX_ORIG_SESSION_SIZE 4096 struct ssl_session_cache_exdata { ssl_curve_id curve = 0; @@ -187,16 +188,15 @@ class SSLOriginSession { public: std::string key; - SSL_SESSION *session; + Ptr asn1_data; /* this is the ASN1 representation of the SSL_CTX */ + size_t len_asn1_data; ssl_curve_id curve_id; - SSLOriginSession(const std::string &lookup_key, SSL_SESSION *sess, ssl_curve_id curve) - : key(lookup_key), session(sess), curve_id(curve) + SSLOriginSession(const std::string &lookup_key, const Ptr &asn1, size_t len_asn1, ssl_curve_id curve) + : key(lookup_key), asn1_data(asn1), len_asn1_data(len_asn1), curve_id(curve) { } - ~SSLOriginSession() { SSL_SESSION_free(session); } - LINK(SSLOriginSession, link); }; @@ -208,6 +208,7 @@ class SSLOriginSessionCache void insert_session(const std::string &lookup_key, SSL_SESSION *sess, SSL *ssl); bool get_session(const std::string &lookup_key, SSL_SESSION **sess, ssl_curve_id *curve); + void remove_session(const std::string &lookup_key); private: void remove_oldest_session(const std::unique_lock &lock); diff --git a/iocore/net/TLSSessionResumptionSupport.cc b/iocore/net/TLSSessionResumptionSupport.cc index e8e981e7639..2ae41eb5a08 100644 --- a/iocore/net/TLSSessionResumptionSupport.cc +++ b/iocore/net/TLSSessionResumptionSupport.cc @@ -183,6 +183,8 @@ TLSSessionResumptionSupport::getOriginSession(SSL *ssl, const std::string &looku // Double check the timeout if (is_ssl_session_timed_out(session)) { SSL_INCREMENT_DYN_STAT(ssl_origin_session_cache_miss); + origin_sess_cache->remove_session(lookup_key); + SSL_SESSION_free(session); session = nullptr; } else { SSL_INCREMENT_DYN_STAT(ssl_origin_session_cache_hit);