diff --git a/include/envoy/network/transport_socket.h b/include/envoy/network/transport_socket.h index a4eb7d64938d9..cd43c0deb003d 100644 --- a/include/envoy/network/transport_socket.h +++ b/include/envoy/network/transport_socket.h @@ -4,6 +4,8 @@ #include "envoy/common/pure.h" #include "envoy/ssl/connection.h" +#include "absl/types/optional.h" + namespace Envoy { namespace Network { @@ -136,6 +138,32 @@ class TransportSocket { typedef std::unique_ptr TransportSocketPtr; +/** + * Options for creating transport sockets. + */ +class TransportSocketOptions { +public: + virtual ~TransportSocketOptions() {} + + /** + * @return the const optional server name to set in the transport socket, for example SNI for + * SSL, regardless of the upstream cluster configuration. Filters that influence + * upstream connection selection, such as tcp_proxy, should take this option into account + * and should pass it through to the connection pool to ensure the correct endpoints are + * selected and the upstream connection is set up accordingly. + */ + virtual const absl::optional& serverNameOverride() const PURE; + + /** + * @param vector of bytes to which the option should append hash key data that will be used + * to separate connections based on the option. Any data already in the key vector must + * not be modified. + */ + virtual void hashKey(std::vector& key) const PURE; +}; + +typedef std::shared_ptr TransportSocketOptionsSharedPtr; + /** * A factory for creating transport socket. It will be associated to filter chains and clusters. */ @@ -149,9 +177,11 @@ class TransportSocketFactory { virtual bool implementsSecureTransport() const PURE; /** + * @param options for creating the transport socket * @return Network::TransportSocketPtr a transport socket to be passed to connection. */ - virtual TransportSocketPtr createTransportSocket() const PURE; + virtual TransportSocketPtr + createTransportSocket(TransportSocketOptionsSharedPtr options) const PURE; }; typedef std::unique_ptr TransportSocketFactoryPtr; diff --git a/include/envoy/upstream/cluster_manager.h b/include/envoy/upstream/cluster_manager.h index c7e7aba180af7..78392bb818c6d 100644 --- a/include/envoy/upstream/cluster_manager.h +++ b/include/envoy/upstream/cluster_manager.h @@ -131,9 +131,10 @@ class ClusterManager { * Can return nullptr if there is no host available in the cluster or if the cluster does not * exist. */ - virtual Tcp::ConnectionPool::Instance* tcpConnPoolForCluster(const std::string& cluster, - ResourcePriority priority, - LoadBalancerContext* context) PURE; + virtual Tcp::ConnectionPool::Instance* + tcpConnPoolForCluster(const std::string& cluster, ResourcePriority priority, + LoadBalancerContext* context, + Network::TransportSocketOptionsSharedPtr transport_socket_options) PURE; /** * Allocate a load balanced TCP connection for a cluster. The created connection is already @@ -143,8 +144,9 @@ class ClusterManager { * Returns both a connection and the host that backs the connection. Both can be nullptr if there * is no host available in the cluster. */ - virtual Host::CreateConnectionData tcpConnForCluster(const std::string& cluster, - LoadBalancerContext* context) PURE; + virtual Host::CreateConnectionData + tcpConnForCluster(const std::string& cluster, LoadBalancerContext* context, + Network::TransportSocketOptionsSharedPtr transport_socket_options) PURE; /** * Returns a client that can be used to make async HTTP calls against the given cluster. The @@ -271,7 +273,8 @@ class ClusterManagerFactory { virtual Tcp::ConnectionPool::InstancePtr allocateTcpConnPool(Event::Dispatcher& dispatcher, HostConstSharedPtr host, ResourcePriority priority, - const Network::ConnectionSocket::OptionsSharedPtr& options) PURE; + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options) PURE; /** * Allocate a cluster from configuration proto. diff --git a/include/envoy/upstream/upstream.h b/include/envoy/upstream/upstream.h index 3fa19d34dee53..6c4c0241e1e82 100644 --- a/include/envoy/upstream/upstream.h +++ b/include/envoy/upstream/upstream.h @@ -73,7 +73,8 @@ class Host : virtual public HostDescription { */ virtual CreateConnectionData createConnection(Event::Dispatcher& dispatcher, - const Network::ConnectionSocket::OptionsSharedPtr& options) const PURE; + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options) const PURE; /** * Create a health check connection for this host. diff --git a/source/common/common/BUILD b/source/common/common/BUILD index adc49e7607f49..a7274d4fc0262 100644 --- a/source/common/common/BUILD +++ b/source/common/common/BUILD @@ -273,6 +273,11 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "scalar_to_byte_vector_lib", + hdrs = ["scalar_to_byte_vector.h"], +) + envoy_cc_library( name = "token_bucket_impl_lib", srcs = ["token_bucket_impl.cc"], diff --git a/source/common/common/scalar_to_byte_vector.h b/source/common/common/scalar_to_byte_vector.h new file mode 100644 index 0000000000000..9db11f90e56f1 --- /dev/null +++ b/source/common/common/scalar_to_byte_vector.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +#include + +namespace Envoy { +template void pushScalarToByteVector(T val, std::vector& bytes) { + uint8_t* byte_ptr = reinterpret_cast(&val); + for (uint32_t byte_index = 0; byte_index < sizeof val; byte_index++) { + bytes.push_back(*byte_ptr++); + } +} +} // namespace Envoy diff --git a/source/common/http/http1/conn_pool.cc b/source/common/http/http1/conn_pool.cc index cfcde98c7e12e..e7265d56a63d1 100644 --- a/source/common/http/http1/conn_pool.cc +++ b/source/common/http/http1/conn_pool.cc @@ -285,7 +285,7 @@ ConnPoolImpl::ActiveClient::ActiveClient(ConnPoolImpl& parent) parent_.conn_connect_ms_ = std::make_unique( parent_.host_->cluster().stats().upstream_cx_connect_ms_, parent_.dispatcher_.timeSystem()); Upstream::Host::CreateConnectionData data = - parent_.host_->createConnection(parent_.dispatcher_, parent_.socket_options_); + parent_.host_->createConnection(parent_.dispatcher_, parent_.socket_options_, nullptr); real_host_description_ = data.host_description_; codec_client_ = parent_.createCodecClient(data); codec_client_->addConnectionCallbacks(*this); diff --git a/source/common/http/http2/conn_pool.cc b/source/common/http/http2/conn_pool.cc index 1fd97c83bd7d3..df2d8ce4c42c8 100644 --- a/source/common/http/http2/conn_pool.cc +++ b/source/common/http/http2/conn_pool.cc @@ -262,7 +262,7 @@ ConnPoolImpl::ActiveClient::ActiveClient(ConnPoolImpl& parent) parent_.conn_connect_ms_ = std::make_unique( parent_.host_->cluster().stats().upstream_cx_connect_ms_, parent_.dispatcher_.timeSystem()); Upstream::Host::CreateConnectionData data = - parent_.host_->createConnection(parent_.dispatcher_, parent_.socket_options_); + parent_.host_->createConnection(parent_.dispatcher_, parent_.socket_options_, nullptr); real_host_description_ = data.host_description_; client_ = parent_.createCodecClient(data); client_->addConnectionCallbacks(*this); diff --git a/source/common/network/BUILD b/source/common/network/BUILD index fcb0f6c7d6e81..d88ad8cabae62 100644 --- a/source/common/network/BUILD +++ b/source/common/network/BUILD @@ -240,3 +240,24 @@ envoy_cc_library( "@envoy_api//envoy/api/v2/core:base_cc", ], ) + +envoy_cc_library( + name = "transport_socket_options_lib", + srcs = ["transport_socket_options_impl.cc"], + hdrs = ["transport_socket_options_impl.h"], + deps = [ + "//include/envoy/network:transport_socket_interface", + "//source/common/common:scalar_to_byte_vector_lib", + "//source/common/common:utility_lib", + ], +) + +envoy_cc_library( + name = "upstream_server_name_lib", + srcs = ["upstream_server_name.cc"], + hdrs = ["upstream_server_name.h"], + deps = [ + "//include/envoy/stream_info:filter_state_interface", + "//source/common/common:macros", + ], +) diff --git a/source/common/network/raw_buffer_socket.cc b/source/common/network/raw_buffer_socket.cc index 6987797052ecc..1d5f2fb240d51 100644 --- a/source/common/network/raw_buffer_socket.cc +++ b/source/common/network/raw_buffer_socket.cc @@ -82,7 +82,8 @@ std::string RawBufferSocket::protocol() const { return EMPTY_STRING; } void RawBufferSocket::onConnected() { callbacks_->raiseEvent(ConnectionEvent::Connected); } -TransportSocketPtr RawBufferSocketFactory::createTransportSocket() const { +TransportSocketPtr +RawBufferSocketFactory::createTransportSocket(TransportSocketOptionsSharedPtr) const { return std::make_unique(); } diff --git a/source/common/network/raw_buffer_socket.h b/source/common/network/raw_buffer_socket.h index 8b8b205ce38f2..aeb48825e949f 100644 --- a/source/common/network/raw_buffer_socket.h +++ b/source/common/network/raw_buffer_socket.h @@ -29,7 +29,7 @@ class RawBufferSocket : public TransportSocket, protected Logger::Loggable& key) const { + if (!override_server_name_.has_value()) { + return; + } + + pushScalarToByteVector(StringUtil::CaseInsensitiveHash()(override_server_name_.value()), key); +} +} // namespace Network +} // namespace Envoy diff --git a/source/common/network/transport_socket_options_impl.h b/source/common/network/transport_socket_options_impl.h new file mode 100644 index 0000000000000..ba544a67d7656 --- /dev/null +++ b/source/common/network/transport_socket_options_impl.h @@ -0,0 +1,26 @@ +#pragma once + +#include "envoy/network/transport_socket.h" + +namespace Envoy { +namespace Network { + +class TransportSocketOptionsImpl : public TransportSocketOptions { +public: + TransportSocketOptionsImpl(absl::string_view override_server_name = "") + : override_server_name_(override_server_name.empty() + ? absl::nullopt + : absl::optional(override_server_name)) {} + + // Network::TransportSocketOptions + const absl::optional& serverNameOverride() const override { + return override_server_name_; + } + void hashKey(std::vector& key) const override; + +private: + const absl::optional override_server_name_; +}; + +} // namespace Network +} // namespace Envoy diff --git a/source/common/network/upstream_server_name.cc b/source/common/network/upstream_server_name.cc new file mode 100644 index 0000000000000..d4bd93a94dc15 --- /dev/null +++ b/source/common/network/upstream_server_name.cc @@ -0,0 +1,12 @@ +#include "common/network/upstream_server_name.h" + +#include "common/common/macros.h" + +namespace Envoy { +namespace Network { + +const std::string& UpstreamServerName::key() { + CONSTRUCT_ON_FIRST_USE(std::string, "envoy.network.upstream_server_name"); +} +} // namespace Network +} // namespace Envoy diff --git a/source/common/network/upstream_server_name.h b/source/common/network/upstream_server_name.h new file mode 100644 index 0000000000000..e50bb906af5c6 --- /dev/null +++ b/source/common/network/upstream_server_name.h @@ -0,0 +1,26 @@ +#pragma once + +#include "envoy/stream_info/filter_state.h" + +#include "absl/strings/string_view.h" + +namespace Envoy { +namespace Network { + +/** + * Server name to set in the upstream connection. The filters like tcp_proxy should use this + * value to override the server name specified in the upstream cluster, for example to override + * the SNI value in the upstream TLS context. + */ +class UpstreamServerName : public StreamInfo::FilterState::Object { +public: + UpstreamServerName(absl::string_view server_name) : server_name_(server_name) {} + const std::string& value() const { return server_name_; } + static const std::string& key(); + +private: + const std::string server_name_; +}; + +} // namespace Network +} // namespace Envoy diff --git a/source/common/ssl/context_impl.cc b/source/common/ssl/context_impl.cc index 7929cadf2a102..2fd2eee4ea1a0 100644 --- a/source/common/ssl/context_impl.cc +++ b/source/common/ssl/context_impl.cc @@ -275,7 +275,7 @@ std::vector ContextImpl::parseAlpnProtocols(const std::string& alpn_pro return out; } -bssl::UniquePtr ContextImpl::newSsl() const { +bssl::UniquePtr ContextImpl::newSsl(absl::optional) const { return bssl::UniquePtr(SSL_new(ctx_.get())); } @@ -498,11 +498,15 @@ ClientContextImpl::ClientContextImpl(Stats::Scope& scope, const ClientContextCon } } -bssl::UniquePtr ClientContextImpl::newSsl() const { - bssl::UniquePtr ssl_con(ContextImpl::newSsl()); +bssl::UniquePtr +ClientContextImpl::newSsl(absl::optional override_server_name) const { + bssl::UniquePtr ssl_con(ContextImpl::newSsl(absl::nullopt)); - if (!server_name_indication_.empty()) { - int rc = SSL_set_tlsext_host_name(ssl_con.get(), server_name_indication_.c_str()); + std::string server_name_indication = + override_server_name.has_value() ? override_server_name.value() : server_name_indication_; + + if (!server_name_indication.empty()) { + int rc = SSL_set_tlsext_host_name(ssl_con.get(), server_name_indication.c_str()); RELEASE_ASSERT(rc, ""); } diff --git a/source/common/ssl/context_impl.h b/source/common/ssl/context_impl.h index f1cf16d118c3b..4fb733df025be 100644 --- a/source/common/ssl/context_impl.h +++ b/source/common/ssl/context_impl.h @@ -11,6 +11,7 @@ #include "common/ssl/context_manager_impl.h" +#include "absl/types/optional.h" #include "openssl/ssl.h" namespace Envoy { @@ -41,7 +42,7 @@ struct SslStats { class ContextImpl : public virtual Context { public: - virtual bssl::UniquePtr newSsl() const; + virtual bssl::UniquePtr newSsl(absl::optional override_server_name) const; /** * Logs successful TLS handshake and updates stats. @@ -142,7 +143,7 @@ class ClientContextImpl : public ContextImpl, public ClientContext { ClientContextImpl(Stats::Scope& scope, const ClientContextConfig& config, TimeSource& time_source); - bssl::UniquePtr newSsl() const override; + bssl::UniquePtr newSsl(absl::optional override_server_name) const override; private: const std::string server_name_indication_; diff --git a/source/common/ssl/ssl_socket.cc b/source/common/ssl/ssl_socket.cc index 08e4b257efcb1..983f1b586ed0a 100644 --- a/source/common/ssl/ssl_socket.cc +++ b/source/common/ssl/ssl_socket.cc @@ -35,8 +35,12 @@ class NotReadySslSocket : public Network::TransportSocket { }; } // namespace -SslSocket::SslSocket(ContextSharedPtr ctx, InitialState state) - : ctx_(std::dynamic_pointer_cast(ctx)), ssl_(ctx_->newSsl()) { +SslSocket::SslSocket(ContextSharedPtr ctx, InitialState state, + Network::TransportSocketOptionsSharedPtr transport_socket_options) + : ctx_(std::dynamic_pointer_cast(ctx)), + ssl_(ctx_->newSsl(transport_socket_options != nullptr + ? transport_socket_options->serverNameOverride() + : absl::nullopt)) { if (state == InitialState::Client) { SSL_set_connect_state(ssl_.get()); } else { @@ -370,7 +374,8 @@ ClientSslSocketFactory::ClientSslSocketFactory(ClientContextConfigPtr config, config_->setSecretUpdateCallback([this]() { onAddOrUpdateSecret(); }); } -Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket() const { +Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket( + Network::TransportSocketOptionsSharedPtr transport_socket_options) const { // onAddOrUpdateSecret() could be invoked in the middle of checking the existence of ssl_ctx and // creating SslSocket using ssl_ctx. Capture ssl_ctx_ into a local variable so that we check and // use the same ssl_ctx to create SslSocket. @@ -380,7 +385,8 @@ Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket() cons ssl_ctx = ssl_ctx_; } if (ssl_ctx) { - return std::make_unique(std::move(ssl_ctx), Ssl::InitialState::Client); + return std::make_unique(std::move(ssl_ctx), Ssl::InitialState::Client, + transport_socket_options); } else { ENVOY_LOG(debug, "Create NotReadySslSocket"); stats_.upstream_context_secrets_not_ready_.inc(); @@ -409,7 +415,8 @@ ServerSslSocketFactory::ServerSslSocketFactory(ServerContextConfigPtr config, config_->setSecretUpdateCallback([this]() { onAddOrUpdateSecret(); }); } -Network::TransportSocketPtr ServerSslSocketFactory::createTransportSocket() const { +Network::TransportSocketPtr +ServerSslSocketFactory::createTransportSocket(Network::TransportSocketOptionsSharedPtr) const { // onAddOrUpdateSecret() could be invoked in the middle of checking the existence of ssl_ctx and // creating SslSocket using ssl_ctx. Capture ssl_ctx_ into a local variable so that we check and // use the same ssl_ctx to create SslSocket. @@ -419,7 +426,7 @@ Network::TransportSocketPtr ServerSslSocketFactory::createTransportSocket() cons ssl_ctx = ssl_ctx_; } if (ssl_ctx) { - return std::make_unique(std::move(ssl_ctx), Ssl::InitialState::Server); + return std::make_unique(std::move(ssl_ctx), Ssl::InitialState::Server, nullptr); } else { ENVOY_LOG(debug, "Create NotReadySslSocket"); stats_.downstream_context_secrets_not_ready_.inc(); diff --git a/source/common/ssl/ssl_socket.h b/source/common/ssl/ssl_socket.h index babcd7ac1ca69..d9888ba228913 100644 --- a/source/common/ssl/ssl_socket.h +++ b/source/common/ssl/ssl_socket.h @@ -39,7 +39,8 @@ class SslSocket : public Network::TransportSocket, public Connection, protected Logger::Loggable { public: - SslSocket(ContextSharedPtr ctx, InitialState state); + SslSocket(ContextSharedPtr ctx, InitialState state, + Network::TransportSocketOptionsSharedPtr transport_socket_options); // Ssl::Connection bool peerCertificatePresented() const override; @@ -87,7 +88,8 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory, ClientSslSocketFactory(ClientContextConfigPtr config, Ssl::ContextManager& manager, Stats::Scope& stats_scope); - Network::TransportSocketPtr createTransportSocket() const override; + Network::TransportSocketPtr + createTransportSocket(Network::TransportSocketOptionsSharedPtr options) const override; bool implementsSecureTransport() const override; // Secret::SecretCallbacks @@ -109,7 +111,8 @@ class ServerSslSocketFactory : public Network::TransportSocketFactory, ServerSslSocketFactory(ServerContextConfigPtr config, Ssl::ContextManager& manager, Stats::Scope& stats_scope, const std::vector& server_names); - Network::TransportSocketPtr createTransportSocket() const override; + Network::TransportSocketPtr + createTransportSocket(Network::TransportSocketOptionsSharedPtr options) const override; bool implementsSecureTransport() const override; // Secret::SecretCallbacks diff --git a/source/common/tcp/conn_pool.cc b/source/common/tcp/conn_pool.cc index 4ab2c6e5151f6..e753d9ba0e75d 100644 --- a/source/common/tcp/conn_pool.cc +++ b/source/common/tcp/conn_pool.cc @@ -11,8 +11,10 @@ namespace Tcp { ConnPoolImpl::ConnPoolImpl(Event::Dispatcher& dispatcher, Upstream::HostConstSharedPtr host, Upstream::ResourcePriority priority, - const Network::ConnectionSocket::OptionsSharedPtr& options) + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options) : dispatcher_(dispatcher), host_(host), priority_(priority), socket_options_(options), + transport_socket_options_(transport_socket_options), upstream_ready_timer_(dispatcher_.createTimer([this]() { onUpstreamReady(); })) {} ConnPoolImpl::~ConnPoolImpl() { @@ -356,8 +358,8 @@ ConnPoolImpl::ActiveConn::ActiveConn(ConnPoolImpl& parent) parent_.conn_connect_ms_ = std::make_unique( parent_.host_->cluster().stats().upstream_cx_connect_ms_, parent_.dispatcher_.timeSystem()); - Upstream::Host::CreateConnectionData data = - parent_.host_->createConnection(parent_.dispatcher_, parent_.socket_options_); + Upstream::Host::CreateConnectionData data = parent_.host_->createConnection( + parent_.dispatcher_, parent_.socket_options_, parent_.transport_socket_options_); real_host_description_ = data.host_description_; conn_ = std::move(data.connection_); diff --git a/source/common/tcp/conn_pool.h b/source/common/tcp/conn_pool.h index e484eef80e1f4..faf206726a55c 100644 --- a/source/common/tcp/conn_pool.h +++ b/source/common/tcp/conn_pool.h @@ -22,7 +22,8 @@ class ConnPoolImpl : Logger::Loggable, public ConnectionPool:: public: ConnPoolImpl(Event::Dispatcher& dispatcher, Upstream::HostConstSharedPtr host, Upstream::ResourcePriority priority, - const Network::ConnectionSocket::OptionsSharedPtr& options); + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options); ~ConnPoolImpl(); @@ -148,6 +149,7 @@ class ConnPoolImpl : Logger::Loggable, public ConnectionPool:: Upstream::HostConstSharedPtr host_; Upstream::ResourcePriority priority_; const Network::ConnectionSocket::OptionsSharedPtr socket_options_; + Network::TransportSocketOptionsSharedPtr transport_socket_options_; std::list pending_conns_; // conns awaiting connected event std::list ready_conns_; // conns ready for assignment diff --git a/source/common/tcp_proxy/BUILD b/source/common/tcp_proxy/BUILD index 0d51ab033a0d2..f28529e05dbbf 100644 --- a/source/common/tcp_proxy/BUILD +++ b/source/common/tcp_proxy/BUILD @@ -31,9 +31,12 @@ envoy_cc_library( "//source/common/access_log:access_log_lib", "//source/common/common:assert_lib", "//source/common/common:empty_string", + "//source/common/common:macros", "//source/common/common:minimal_logger_lib", "//source/common/network:cidr_range_lib", "//source/common/network:filter_lib", + "//source/common/network:transport_socket_options_lib", + "//source/common/network:upstream_server_name_lib", "//source/common/network:utility_lib", "//source/common/router:metadatamatchcriteria_lib", "//source/common/stream_info:stream_info_lib", diff --git a/source/common/tcp_proxy/tcp_proxy.cc b/source/common/tcp_proxy/tcp_proxy.cc index 6bb77c64e7f42..a871f48939c30 100644 --- a/source/common/tcp_proxy/tcp_proxy.cc +++ b/source/common/tcp_proxy/tcp_proxy.cc @@ -15,14 +15,21 @@ #include "common/common/assert.h" #include "common/common/empty_string.h" #include "common/common/fmt.h" +#include "common/common/macros.h" #include "common/common/utility.h" #include "common/config/well_known_names.h" +#include "common/network/transport_socket_options_impl.h" +#include "common/network/upstream_server_name.h" #include "common/router/metadatamatchcriteria_impl.h" namespace Envoy { namespace TcpProxy { -const std::string PerConnectionCluster::Key = "envoy.tcp_proxy.cluster"; +using ::Envoy::Network::UpstreamServerName; + +const std::string& PerConnectionCluster::key() { + CONSTRUCT_ON_FIRST_USE(std::string, "envoy.tcp_proxy.cluster"); +} Config::Route::Route( const envoy::config::filter::network::tcp_proxy::v2::TcpProxy::DeprecatedV1::TCPRoute& config) { @@ -112,10 +119,10 @@ Config::Config(const envoy::config::filter::network::tcp_proxy::v2::TcpProxy& co const std::string& Config::getRegularRouteFromEntries(Network::Connection& connection) { // First check if the per-connection state to see if we need to route to a pre-selected cluster if (connection.streamInfo().filterState().hasData( - PerConnectionCluster::Key)) { + PerConnectionCluster::key())) { const PerConnectionCluster& per_connection_cluster = connection.streamInfo().filterState().getDataReadOnly( - PerConnectionCluster::Key); + PerConnectionCluster::key()); return per_connection_cluster.value(); } @@ -358,8 +365,20 @@ Network::FilterStatus Filter::initializeUpstreamConnection() { return Network::FilterStatus::StopIteration; } + Network::TransportSocketOptionsSharedPtr transport_socket_options; + + if (downstreamConnection() && + downstreamConnection()->streamInfo().filterState().hasData( + UpstreamServerName::key())) { + const auto& original_requested_server_name = + downstreamConnection()->streamInfo().filterState().getDataReadOnly( + UpstreamServerName::key()); + transport_socket_options = std::make_shared( + original_requested_server_name.value()); + } + Tcp::ConnectionPool::Instance* conn_pool = cluster_manager_.tcpConnPoolForCluster( - cluster_name, Upstream::ResourcePriority::Default, this); + cluster_name, Upstream::ResourcePriority::Default, this, transport_socket_options); if (!conn_pool) { // Either cluster is unknown or there are no healthy hosts. tcpConnPoolForCluster() increments // cluster->stats().upstream_cx_none_healthy in the latter case. diff --git a/source/common/tcp_proxy/tcp_proxy.h b/source/common/tcp_proxy/tcp_proxy.h index bba6f6faad9ca..405b6cd2fbc62 100644 --- a/source/common/tcp_proxy/tcp_proxy.h +++ b/source/common/tcp_proxy/tcp_proxy.h @@ -162,7 +162,7 @@ class PerConnectionCluster : public StreamInfo::FilterState::Object { public: PerConnectionCluster(absl::string_view cluster) : cluster_(cluster) {} const std::string& value() const { return cluster_; } - static const std::string Key; + static const std::string& key(); private: const std::string cluster_; diff --git a/source/common/upstream/cluster_manager_impl.cc b/source/common/upstream/cluster_manager_impl.cc index 3cc51594e0629..d495cd7644de3 100644 --- a/source/common/upstream/cluster_manager_impl.cc +++ b/source/common/upstream/cluster_manager_impl.cc @@ -665,9 +665,9 @@ ClusterManagerImpl::httpConnPoolForCluster(const std::string& cluster, ResourceP return entry->second->connPool(priority, protocol, context); } -Tcp::ConnectionPool::Instance* -ClusterManagerImpl::tcpConnPoolForCluster(const std::string& cluster, ResourcePriority priority, - LoadBalancerContext* context) { +Tcp::ConnectionPool::Instance* ClusterManagerImpl::tcpConnPoolForCluster( + const std::string& cluster, ResourcePriority priority, LoadBalancerContext* context, + Network::TransportSocketOptionsSharedPtr transport_socket_options) { ThreadLocalClusterManagerImpl& cluster_manager = tls_->getTyped(); auto entry = cluster_manager.thread_local_clusters_.find(cluster); @@ -676,7 +676,7 @@ ClusterManagerImpl::tcpConnPoolForCluster(const std::string& cluster, ResourcePr } // Select a host and create a connection pool for it if it does not already exist. - return entry->second->tcpConnPool(priority, context); + return entry->second->tcpConnPool(priority, context, transport_socket_options); } void ClusterManagerImpl::postThreadLocalClusterUpdate(const Cluster& cluster, uint32_t priority, @@ -706,8 +706,9 @@ void ClusterManagerImpl::postThreadLocalHealthFailure(const HostSharedPtr& host) [this, host] { ThreadLocalClusterManagerImpl::onHostHealthFailure(host, *tls_); }); } -Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster(const std::string& cluster, - LoadBalancerContext* context) { +Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster( + const std::string& cluster, LoadBalancerContext* context, + Network::TransportSocketOptionsSharedPtr transport_socket_options) { ThreadLocalClusterManagerImpl& cluster_manager = tls_->getTyped(); auto entry = cluster_manager.thread_local_clusters_.find(cluster); @@ -717,8 +718,8 @@ Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster(const std::stri HostConstSharedPtr logical_host = entry->second->lb_->chooseHost(context); if (logical_host) { - auto conn_info = - logical_host->createConnection(cluster_manager.thread_local_dispatcher_, nullptr); + auto conn_info = logical_host->createConnection(cluster_manager.thread_local_dispatcher_, + nullptr, transport_socket_options); if ((entry->second->cluster_info_->features() & ClusterInfo::Features::CLOSE_CONNECTIONS_ON_HOST_HEALTH_FAILURE) && conn_info.connection_ != nullptr) { @@ -1130,7 +1131,8 @@ ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::connPool( Tcp::ConnectionPool::Instance* ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::tcpConnPool( - ResourcePriority priority, LoadBalancerContext* context) { + ResourcePriority priority, LoadBalancerContext* context, + Network::TransportSocketOptionsSharedPtr transport_socket_options) { HostConstSharedPtr host = lb_->chooseHost(context); if (!host) { ENVOY_LOG(debug, "no healthy host for TCP connection pool"); @@ -1156,11 +1158,16 @@ ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::tcpConnPool( } } + if (transport_socket_options != nullptr) { + transport_socket_options->hashKey(hash_key); + } + TcpConnPoolsContainer& container = parent_.host_tcp_conn_pool_map_[host]; if (!container.pools_[hash_key]) { container.pools_[hash_key] = parent_.parent_.factory_.allocateTcpConnPool( parent_.thread_local_dispatcher_, host, priority, - have_options ? context->downstreamConnection()->socketOptions() : nullptr); + have_options ? context->downstreamConnection()->socketOptions() : nullptr, + transport_socket_options); } return container.pools_[hash_key].get(); @@ -1191,9 +1198,10 @@ Http::ConnectionPool::InstancePtr ProdClusterManagerFactory::allocateConnPool( Tcp::ConnectionPool::InstancePtr ProdClusterManagerFactory::allocateTcpConnPool( Event::Dispatcher& dispatcher, HostConstSharedPtr host, ResourcePriority priority, - const Network::ConnectionSocket::OptionsSharedPtr& options) { + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options) { return Tcp::ConnectionPool::InstancePtr{ - new Tcp::ConnPoolImpl(dispatcher, host, priority, options)}; + new Tcp::ConnPoolImpl(dispatcher, host, priority, options, transport_socket_options)}; } ClusterSharedPtr ProdClusterManagerFactory::clusterFromProto( diff --git a/source/common/upstream/cluster_manager_impl.h b/source/common/upstream/cluster_manager_impl.h index 65befbaaf1641..42c3f9caba066 100644 --- a/source/common/upstream/cluster_manager_impl.h +++ b/source/common/upstream/cluster_manager_impl.h @@ -58,7 +58,8 @@ class ProdClusterManagerFactory : public ClusterManagerFactory { Tcp::ConnectionPool::InstancePtr allocateTcpConnPool(Event::Dispatcher& dispatcher, HostConstSharedPtr host, ResourcePriority priority, - const Network::ConnectionSocket::OptionsSharedPtr& options) override; + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options) override; ClusterSharedPtr clusterFromProto(const envoy::api::v2::Cluster& cluster, ClusterManager& cm, Outlier::EventLoggerSharedPtr outlier_event_logger, AccessLog::AccessLogManager& log_manager, @@ -190,11 +191,13 @@ class ClusterManagerImpl : public ClusterManager, Logger::LoggablegetTyped(); ASSERT(data.current_resolved_address_); return {HostImpl::createConnection(dispatcher, *parent_.info_, data.current_resolved_address_, - options), + options, transport_socket_options), HostDescriptionConstSharedPtr{ new RealHostDescription(data.current_resolved_address_, parent_.localityLbEndpoint(), parent_.lbEndpoint(), shared_from_this())}}; diff --git a/source/common/upstream/logical_dns_cluster.h b/source/common/upstream/logical_dns_cluster.h index be27f3f8443f0..0c25c5ff8255e 100644 --- a/source/common/upstream/logical_dns_cluster.h +++ b/source/common/upstream/logical_dns_cluster.h @@ -51,9 +51,9 @@ class LogicalDnsCluster : public ClusterImplBase { parent_(parent) {} // Upstream::Host - CreateConnectionData - createConnection(Event::Dispatcher& dispatcher, - const Network::ConnectionSocket::OptionsSharedPtr& options) const override; + CreateConnectionData createConnection( + Event::Dispatcher& dispatcher, const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options) const override; // Upstream::HostDescription // Override setting health check address, since for logical DNS the registered host has 0.0.0.0 diff --git a/source/common/upstream/upstream_impl.cc b/source/common/upstream/upstream_impl.cc index b1d1271fe9b49..ee5eb0831d632 100644 --- a/source/common/upstream/upstream_impl.cc +++ b/source/common/upstream/upstream_impl.cc @@ -146,22 +146,24 @@ parseExtensionProtocolOptions(const envoy::api::v2::Cluster& config) { } // namespace -Host::CreateConnectionData -HostImpl::createConnection(Event::Dispatcher& dispatcher, - const Network::ConnectionSocket::OptionsSharedPtr& options) const { - return {createConnection(dispatcher, *cluster_, address_, options), shared_from_this()}; +Host::CreateConnectionData HostImpl::createConnection( + Event::Dispatcher& dispatcher, const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options) const { + return {createConnection(dispatcher, *cluster_, address_, options, transport_socket_options), + shared_from_this()}; } Host::CreateConnectionData HostImpl::createHealthCheckConnection(Event::Dispatcher& dispatcher) const { - return {createConnection(dispatcher, *cluster_, healthCheckAddress(), nullptr), + return {createConnection(dispatcher, *cluster_, healthCheckAddress(), nullptr, nullptr), shared_from_this()}; } Network::ClientConnectionPtr HostImpl::createConnection(Event::Dispatcher& dispatcher, const ClusterInfo& cluster, Network::Address::InstanceConstSharedPtr address, - const Network::ConnectionSocket::OptionsSharedPtr& options) { + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options) { Network::ConnectionSocket::OptionsSharedPtr connection_options; if (cluster.clusterSocketOptions() != nullptr) { if (options) { @@ -177,7 +179,8 @@ HostImpl::createConnection(Event::Dispatcher& dispatcher, const ClusterInfo& clu } Network::ClientConnectionPtr connection = dispatcher.createClientConnection( - address, cluster.sourceAddress(), cluster.transportSocketFactory().createTransportSocket(), + address, cluster.sourceAddress(), + cluster.transportSocketFactory().createTransportSocket(transport_socket_options), connection_options); connection->setBufferLimits(cluster.perConnectionBufferLimitBytes()); return connection; diff --git a/source/common/upstream/upstream_impl.h b/source/common/upstream/upstream_impl.h index 290091d5cde6b..bc1b48cb1c1b8 100644 --- a/source/common/upstream/upstream_impl.h +++ b/source/common/upstream/upstream_impl.h @@ -171,9 +171,9 @@ class HostImpl : public HostDescriptionImpl, // Upstream::Host std::vector counters() const override { return stats_store_.counters(); } - CreateConnectionData - createConnection(Event::Dispatcher& dispatcher, - const Network::ConnectionSocket::OptionsSharedPtr& options) const override; + CreateConnectionData createConnection( + Event::Dispatcher& dispatcher, const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options) const override; CreateConnectionData createHealthCheckConnection(Event::Dispatcher& dispatcher) const override; std::vector gauges() const override { return stats_store_.gauges(); } void healthFlagClear(HealthFlag flag) override { health_flags_ &= ~enumToInt(flag); } @@ -203,7 +203,8 @@ class HostImpl : public HostDescriptionImpl, static Network::ClientConnectionPtr createConnection(Event::Dispatcher& dispatcher, const ClusterInfo& cluster, Network::Address::InstanceConstSharedPtr address, - const Network::ConnectionSocket::OptionsSharedPtr& options); + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr transport_socket_options); private: std::atomic health_flags_{}; diff --git a/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc b/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc index a44b08fc8f9c3..6996187f22d63 100644 --- a/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc +++ b/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc @@ -23,7 +23,7 @@ ClientPtr ClientImpl::create(Upstream::HostConstSharedPtr host, Event::Dispatche std::unique_ptr client( new ClientImpl(host, dispatcher, std::move(encoder), decoder_factory, config)); - client->connection_ = host->createConnection(dispatcher, nullptr).connection_; + client->connection_ = host->createConnection(dispatcher, nullptr, nullptr).connection_; client->connection_->addConnectionCallbacks(*client); client->connection_->addReadFilter(Network::ReadFilterSharedPtr{new UpstreamReadFilter(*client)}); client->connection_->connect(); diff --git a/source/extensions/filters/network/sni_cluster/sni_cluster.cc b/source/extensions/filters/network/sni_cluster/sni_cluster.cc index 62eb7143ce865..2b403b586a198 100644 --- a/source/extensions/filters/network/sni_cluster/sni_cluster.cc +++ b/source/extensions/filters/network/sni_cluster/sni_cluster.cc @@ -19,7 +19,8 @@ Network::FilterStatus SniClusterFilter::onNewConnection() { // Set the tcp_proxy cluster to the same value as SNI. The data is mutable to allow // other filters to change it. read_callbacks_->connection().streamInfo().filterState().setData( - TcpProxy::PerConnectionCluster::Key, std::make_unique(sni), + TcpProxy::PerConnectionCluster::key(), + std::make_unique(sni), StreamInfo::FilterState::StateType::Mutable); } diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc index 1cdd87c8bcb1a..32a6be1ac19b3 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -246,7 +246,7 @@ FilterStatus Router::messageBegin(MessageMetadataSharedPtr metadata) { ASSERT(protocol != ProtocolType::Auto); Tcp::ConnectionPool::Instance* conn_pool = cluster_manager_.tcpConnPoolForCluster( - route_entry_->clusterName(), Upstream::ResourcePriority::Default, this); + route_entry_->clusterName(), Upstream::ResourcePriority::Default, this, nullptr); if (!conn_pool) { callbacks_->sendLocalReply( AppException(AppExceptionType::InternalError, diff --git a/source/extensions/stat_sinks/common/statsd/statsd.cc b/source/extensions/stat_sinks/common/statsd/statsd.cc index ba848e4aac26d..b96653b47ae8b 100644 --- a/source/extensions/stat_sinks/common/statsd/statsd.cc +++ b/source/extensions/stat_sinks/common/statsd/statsd.cc @@ -233,7 +233,7 @@ void TcpStatsdSink::TlsSink::write(Buffer::Instance& buffer) { if (!connection_) { Upstream::Host::CreateConnectionData info = - parent_.cluster_manager_.tcpConnForCluster(parent_.cluster_info_->name(), nullptr); + parent_.cluster_manager_.tcpConnForCluster(parent_.cluster_info_->name(), nullptr, nullptr); if (!info.connection_) { return; } diff --git a/source/extensions/transport_sockets/alts/tsi_socket.cc b/source/extensions/transport_sockets/alts/tsi_socket.cc index 562d3c9b674cd..65d655fbc578f 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.cc +++ b/source/extensions/transport_sockets/alts/tsi_socket.cc @@ -249,7 +249,8 @@ TsiSocketFactory::TsiSocketFactory(HandshakerFactory handshaker_factory, bool TsiSocketFactory::implementsSecureTransport() const { return true; } -Network::TransportSocketPtr TsiSocketFactory::createTransportSocket() const { +Network::TransportSocketPtr +TsiSocketFactory::createTransportSocket(Network::TransportSocketOptionsSharedPtr) const { return std::make_unique(handshaker_factory_, handshake_validator_); } diff --git a/source/extensions/transport_sockets/alts/tsi_socket.h b/source/extensions/transport_sockets/alts/tsi_socket.h index 70f8a1d7aeff0..8e3ee5e954438 100644 --- a/source/extensions/transport_sockets/alts/tsi_socket.h +++ b/source/extensions/transport_sockets/alts/tsi_socket.h @@ -98,7 +98,8 @@ class TsiSocketFactory : public Network::TransportSocketFactory { TsiSocketFactory(HandshakerFactory handshaker_factory, HandshakeValidator handshake_validator); bool implementsSecureTransport() const override; - Network::TransportSocketPtr createTransportSocket() const override; + Network::TransportSocketPtr + createTransportSocket(Network::TransportSocketOptionsSharedPtr options) const override; private: HandshakerFactory handshaker_factory_; diff --git a/source/extensions/transport_sockets/capture/capture.cc b/source/extensions/transport_sockets/capture/capture.cc index a1a5df79b39bc..fe4688e93ffe5 100644 --- a/source/extensions/transport_sockets/capture/capture.cc +++ b/source/extensions/transport_sockets/capture/capture.cc @@ -99,9 +99,11 @@ CaptureSocketFactory::CaptureSocketFactory( : path_prefix_(path_prefix), format_(format), transport_socket_factory_(std::move(transport_socket_factory)), time_system_(time_system) {} -Network::TransportSocketPtr CaptureSocketFactory::createTransportSocket() const { - return std::make_unique( - path_prefix_, format_, transport_socket_factory_->createTransportSocket(), time_system_); +Network::TransportSocketPtr +CaptureSocketFactory::createTransportSocket(Network::TransportSocketOptionsSharedPtr) const { + return std::make_unique(path_prefix_, format_, + transport_socket_factory_->createTransportSocket(nullptr), + time_system_); } bool CaptureSocketFactory::implementsSecureTransport() const { diff --git a/source/extensions/transport_sockets/capture/capture.h b/source/extensions/transport_sockets/capture/capture.h index f7031146718ca..b1419a1b95a38 100644 --- a/source/extensions/transport_sockets/capture/capture.h +++ b/source/extensions/transport_sockets/capture/capture.h @@ -49,7 +49,8 @@ class CaptureSocketFactory : public Network::TransportSocketFactory { Event::TimeSystem& time_system); // Network::TransportSocketFactory - Network::TransportSocketPtr createTransportSocket() const override; + Network::TransportSocketPtr + createTransportSocket(Network::TransportSocketOptionsSharedPtr options) const override; bool implementsSecureTransport() const override; private: diff --git a/source/server/config_validation/cluster_manager.cc b/source/server/config_validation/cluster_manager.cc index 0c6d5c1d3dfdc..22ad6fb2c84d8 100644 --- a/source/server/config_validation/cluster_manager.cc +++ b/source/server/config_validation/cluster_manager.cc @@ -48,8 +48,9 @@ ValidationClusterManager::httpConnPoolForCluster(const std::string&, ResourcePri return nullptr; } -Host::CreateConnectionData ValidationClusterManager::tcpConnForCluster(const std::string&, - LoadBalancerContext*) { +Host::CreateConnectionData +ValidationClusterManager::tcpConnForCluster(const std::string&, LoadBalancerContext*, + Network::TransportSocketOptionsSharedPtr) { return Host::CreateConnectionData{nullptr, nullptr}; } diff --git a/source/server/config_validation/cluster_manager.h b/source/server/config_validation/cluster_manager.h index cde44eff40053..74b60532cad74 100644 --- a/source/server/config_validation/cluster_manager.h +++ b/source/server/config_validation/cluster_manager.h @@ -52,7 +52,8 @@ class ValidationClusterManager : public ClusterManagerImpl { Http::ConnectionPool::Instance* httpConnPoolForCluster(const std::string&, ResourcePriority, Http::Protocol, LoadBalancerContext*) override; - Host::CreateConnectionData tcpConnForCluster(const std::string&, LoadBalancerContext*) override; + Host::CreateConnectionData tcpConnForCluster(const std::string&, LoadBalancerContext*, + Network::TransportSocketOptionsSharedPtr) override; Http::AsyncClient& httpAsyncClientForCluster(const std::string&) override; private: diff --git a/source/server/connection_handler_impl.cc b/source/server/connection_handler_impl.cc index 156aa7742c6f6..9557c0f47d36b 100644 --- a/source/server/connection_handler_impl.cc +++ b/source/server/connection_handler_impl.cc @@ -213,7 +213,7 @@ void ConnectionHandlerImpl::ActiveListener::newConnection(Network::ConnectionSoc return; } - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); Network::ConnectionPtr new_connection = parent_.dispatcher_.createServerConnection(std::move(socket), std::move(transport_socket)); new_connection->setBufferLimits(config_.perConnectionBufferLimitBytes()); diff --git a/test/common/grpc/grpc_client_integration_test_harness.h b/test/common/grpc/grpc_client_integration_test_harness.h index 9efe45e965fc8..8aed68eeacb77 100644 --- a/test/common/grpc/grpc_client_integration_test_harness.h +++ b/test/common/grpc/grpc_client_integration_test_harness.h @@ -476,7 +476,7 @@ class GrpcSslClientIntegrationTest : public GrpcClientIntegrationTest { ON_CALL(*mock_cluster_info_, transportSocketFactory()) .WillByDefault(ReturnRef(*mock_cluster_info_->transport_socket_factory_)); async_client_transport_socket_ = - mock_cluster_info_->transport_socket_factory_->createTransportSocket(); + mock_cluster_info_->transport_socket_factory_->createTransportSocket(nullptr); fake_upstream_ = std::make_unique(createUpstreamSslContext(), 0, FakeHttpConnection::Type::HTTP2, ipVersion(), test_time_.timeSystem()); diff --git a/test/common/network/filter_manager_impl_test.cc b/test/common/network/filter_manager_impl_test.cc index b339e162b75b2..03df366dd33ed 100644 --- a/test/common/network/filter_manager_impl_test.cc +++ b/test/common/network/filter_manager_impl_test.cc @@ -255,7 +255,7 @@ TEST_F(NetworkFilterManagerTest, RateLimitAndTcpProxy) { EXPECT_EQ(manager.initializeReadFilters(), true); - EXPECT_CALL(factory_context.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + EXPECT_CALL(factory_context.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _, _)) .WillOnce(Return(&conn_pool)); request_callbacks->complete(RateLimit::LimitStatus::OK, nullptr); diff --git a/test/common/ssl/BUILD b/test/common/ssl/BUILD index 3ef019b7367c8..dc9f5193d48e5 100644 --- a/test/common/ssl/BUILD +++ b/test/common/ssl/BUILD @@ -28,6 +28,7 @@ envoy_cc_test( "//source/common/event:dispatcher_lib", "//source/common/json:json_loader_lib", "//source/common/network:listen_socket_lib", + "//source/common/network:transport_socket_options_lib", "//source/common/network:utility_lib", "//source/common/ssl:context_config_lib", "//source/common/ssl:context_lib", diff --git a/test/common/ssl/ssl_socket_test.cc b/test/common/ssl/ssl_socket_test.cc index de940af11bf37..cb2f43c761a0d 100644 --- a/test/common/ssl/ssl_socket_test.cc +++ b/test/common/ssl/ssl_socket_test.cc @@ -10,6 +10,7 @@ #include "common/json/json_loader.h" #include "common/network/address_impl.h" #include "common/network/listen_socket_impl.h" +#include "common/network/transport_socket_options_impl.h" #include "common/network/utility.h" #include "common/ssl/context_config_impl.h" #include "common/ssl/context_impl.h" @@ -45,6 +46,7 @@ namespace Ssl { namespace { +// TODO replace the long parameter list with an options object void testUtil(const std::string& client_ctx_yaml, const std::string& server_ctx_yaml, const std::string& expected_digest, const std::string& expected_uri, const std::string& expected_local_uri, const std::string& expected_serial_number, @@ -79,13 +81,13 @@ void testUtil(const std::string& client_ctx_yaml, const std::string& server_ctx_ client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), - client_ssl_socket_factory.createTransportSocket(), nullptr); + client_ssl_socket_factory.createTransportSocket(nullptr), nullptr); Network::ConnectionPtr server_connection; Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { Network::ConnectionPtr new_connection = dispatcher.createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket()); + std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr)); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -156,6 +158,7 @@ void testUtil(const std::string& client_ctx_yaml, const std::string& server_ctx_ } } +// TODO replace the long parameter list with an options object const std::string testUtilV2( const envoy::api::v2::Listener& server_proto, const envoy::api::v2::auth::UpstreamTlsContext& client_ctx_proto, @@ -163,7 +166,8 @@ const std::string testUtilV2( const std::string& expected_protocol_version, const std::string& expected_server_cert_digest, const std::string& expected_client_cert_uri, const std::string& expected_requested_server_name, const std::string& expected_alpn_protocol, const std::string& expected_server_stats, - const std::string& expected_client_stats, const Network::Address::IpVersion version) { + const std::string& expected_client_stats, const Network::Address::IpVersion version, + Network::TransportSocketOptionsSharedPtr transport_socket_options) { Event::SimulatedTimeSystem time_system; testing::NiceMock factory_context; ContextManagerImpl manager(time_system); @@ -194,7 +198,7 @@ const std::string testUtilV2( client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), - client_ssl_socket_factory.createTransportSocket(), nullptr); + client_ssl_socket_factory.createTransportSocket(transport_socket_options), nullptr); if (!client_session.empty()) { const Ssl::SslSocket* ssl_socket = @@ -213,9 +217,13 @@ const std::string testUtilV2( Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { - socket->setRequestedServerName(client_ctx_proto.sni()); + std::string sni = transport_socket_options != NULL && + transport_socket_options->serverNameOverride().has_value() + ? transport_socket_options->serverNameOverride().value() + : client_ctx_proto.sni(); + socket->setRequestedServerName(sni); Network::ConnectionPtr new_connection = dispatcher.createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket()); + std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr)); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -245,6 +253,23 @@ const std::string testUtilV2( if (!expected_protocol_version.empty()) { EXPECT_EQ(expected_protocol_version, SSL_get_version(client_ssl_socket)); } + + absl::optional server_ssl_requested_server_name; + const Ssl::SslSocket* server_ssl_socket = + dynamic_cast(server_connection->ssl()); + SSL* server_ssl = server_ssl_socket->rawSslForTest(); + auto requested_server_name = SSL_get_servername(server_ssl, TLSEXT_NAMETYPE_host_name); + if (requested_server_name != nullptr) { + server_ssl_requested_server_name = std::string(requested_server_name); + } + + if (!expected_requested_server_name.empty()) { + EXPECT_TRUE(server_ssl_requested_server_name.has_value()); + EXPECT_EQ(expected_requested_server_name, server_ssl_requested_server_name.value()); + } else { + EXPECT_FALSE(server_ssl_requested_server_name.has_value()); + } + SSL_SESSION* client_ssl_session = SSL_get_session(client_ssl_socket); EXPECT_TRUE(SSL_SESSION_is_resumable(client_ssl_session)); uint8_t* session_data; @@ -354,7 +379,7 @@ TEST_P(SslSocketTest, GetCertDigest) { filename: "{{ test_rundir }}/test/common/ssl/test_data/no_san_key.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" )EOF"; testUtil(client_ctx_yaml, server_ctx_yaml, @@ -496,7 +521,8 @@ TEST_P(SslSocketTest, GetCertDigestInline) { testUtilV2(listener, client_ctx, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); } TEST_P(SslSocketTest, GetCertDigestServerCertWithIntermediateCA) { @@ -518,7 +544,7 @@ TEST_P(SslSocketTest, GetCertDigestServerCertWithIntermediateCA) { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key3.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" )EOF"; testUtil(client_ctx_yaml, server_ctx_yaml, @@ -545,7 +571,7 @@ TEST_P(SslSocketTest, GetCertDigestServerCertWithoutCommonName) { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_only_dns_key.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" )EOF"; testUtil(client_ctx_yaml, server_ctx_yaml, @@ -572,7 +598,7 @@ TEST_P(SslSocketTest, GetUriWithUriSan) { filename: "{{ test_tmpdir }}/unittestkey.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" verify_subject_alt_name: "spiffe://lyft.com/test-team" )EOF"; @@ -599,7 +625,7 @@ TEST_P(SslSocketTest, GetNoUriWithDnsSan) { filename: "{{ test_tmpdir }}/unittestkey.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" )EOF"; // The SAN field only has DNS, expect "" for uriSanPeerCertificate(). @@ -644,7 +670,7 @@ TEST_P(SslSocketTest, GetUriWithLocalUriSan) { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" )EOF"; testUtil(client_ctx_yaml, server_ctx_yaml, "", "", "spiffe://lyft.com/test-team", @@ -670,7 +696,7 @@ TEST_P(SslSocketTest, GetSubjectsWithBothCerts) { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" require_client_certificate: true )EOF"; @@ -699,7 +725,7 @@ TEST_P(SslSocketTest, GetPeerCert) { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" require_client_certificate: true )EOF"; @@ -741,7 +767,7 @@ TEST_P(SslSocketTest, FailedClientAuthCaVerificationNoClientCert) { filename: "{{ test_tmpdir }}/unittestkey.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" require_client_certificate: true )EOF"; @@ -768,7 +794,7 @@ TEST_P(SslSocketTest, FailedClientAuthCaVerification) { filename: "{{ test_tmpdir }}/unittestkey.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" )EOF"; testUtil(client_ctx_yaml, server_ctx_yaml, "", "", "", "", "", "", "", "ssl.fail_verify_error", @@ -789,7 +815,7 @@ TEST_P(SslSocketTest, FailedClientAuthSanVerificationNoClientCert) { filename: "{{ test_tmpdir }}/unittestkey.pem" validation_context: trusted_ca: - filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" + filename: "{{ test_rundir }}/test/common/ssl/test_data/ca_cert.pem" verify_subject_alt_name: "example.com" )EOF"; @@ -832,7 +858,7 @@ TEST_P(SslSocketTest, FailedClientCertificateDefaultExpirationVerification) { configureServerAndExpiredClientCertificate(listener, client); testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", "", - "ssl.fail_verify_error", "ssl.connection_error", GetParam()); + "ssl.fail_verify_error", "ssl.connection_error", GetParam(), nullptr); } // Expired certificates will not be accepted when explicitly disallowed via @@ -850,7 +876,7 @@ TEST_P(SslSocketTest, FailedClientCertificateExpirationVerification) { ->set_allow_expired_certificate(false); testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", "", - "ssl.fail_verify_error", "ssl.connection_error", GetParam()); + "ssl.fail_verify_error", "ssl.connection_error", GetParam(), nullptr); } // Expired certificates will be accepted when explicitly allowed via allow_expired_certificate. @@ -867,7 +893,7 @@ TEST_P(SslSocketTest, ClientCertificateExpirationAllowedVerification) { ->set_allow_expired_certificate(true); testUtilV2(listener, client, "", true, "", "", "spiffe://lyft.com/test-team", "", "", - "ssl.handshake", "ssl.handshake", GetParam()); + "ssl.handshake", "ssl.handshake", GetParam(), nullptr); } // Allow expired certificates, but add a certificate hash requirement so it still fails. @@ -888,7 +914,7 @@ TEST_P(SslSocketTest, FailedClientCertAllowExpiredBadHashVerification) { "0000000000000000000000000000000000000000000000000000000000000000"); testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", "", - "ssl.fail_verify_cert_hash", "ssl.connection_error", GetParam()); + "ssl.fail_verify_cert_hash", "ssl.connection_error", GetParam(), nullptr); } // Allow expired certificatess, but use the wrong CA so it should fail still. @@ -911,7 +937,7 @@ TEST_P(SslSocketTest, FailedClientCertAllowServerExpiredWrongCAVerification) { TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/fake_ca_cert.pem")); testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", "", - "ssl.fail_verify_error", "ssl.connection_error", GetParam()); + "ssl.fail_verify_error", "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, ClientCertificateHashVerification) { @@ -996,13 +1022,15 @@ TEST_P(SslSocketTest, ClientCertificateHashListVerification) { testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); // Works even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); } TEST_P(SslSocketTest, ClientCertificateHashListVerificationNoCA) { @@ -1033,13 +1061,15 @@ TEST_P(SslSocketTest, ClientCertificateHashListVerificationNoCA) { testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); // Works even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); } TEST_P(SslSocketTest, FailedClientCertificateHashVerificationNoClientCertificate) { @@ -1193,13 +1223,15 @@ TEST_P(SslSocketTest, ClientCertificateSpkiVerification) { testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); // Works even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); } TEST_P(SslSocketTest, ClientCertificateSpkiVerificationNoCA) { @@ -1230,13 +1262,15 @@ TEST_P(SslSocketTest, ClientCertificateSpkiVerificationNoCA) { testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); // Works even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); } TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoClientCertificate) { @@ -1262,12 +1296,12 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoClientCertificate envoy::api::v2::auth::UpstreamTlsContext client; testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoCANoClientCertificate) { @@ -1291,12 +1325,12 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoCANoClientCertifi envoy::api::v2::auth::UpstreamTlsContext client; testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationWrongClientCertificate) { @@ -1328,12 +1362,12 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationWrongClientCertific TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/no_san_key.pem")); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoCAWrongClientCertificate) { @@ -1363,12 +1397,12 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoCAWrongClientCert TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/no_san_key.pem")); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationWrongCA) { @@ -1400,12 +1434,12 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationWrongCA) { TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem")); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_error", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_error", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, ClientCertificateHashAndSpkiVerification) { @@ -1440,13 +1474,15 @@ TEST_P(SslSocketTest, ClientCertificateHashAndSpkiVerification) { testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); // Works even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); } TEST_P(SslSocketTest, ClientCertificateHashAndSpkiVerificationNoCA) { @@ -1479,13 +1515,15 @@ TEST_P(SslSocketTest, ClientCertificateHashAndSpkiVerificationNoCA) { testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); // Works even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", "ssl.handshake", GetParam(), + nullptr); } TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoClientCertificate) { @@ -1511,12 +1549,12 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoClientCert envoy::api::v2::auth::UpstreamTlsContext client; testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoCANoClientCertificate) { @@ -1540,12 +1578,12 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoCANoClient envoy::api::v2::auth::UpstreamTlsContext client; testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationWrongClientCertificate) { @@ -1577,12 +1615,12 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationWrongClientC TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/no_san_key.pem")); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoCAWrongClientCertificate) { @@ -1612,12 +1650,12 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoCAWrongCli TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/no_san_key.pem")); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationWrongCA) { @@ -1649,12 +1687,12 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationWrongCA) { TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem")); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_error", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Fails even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_error", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); } // Make sure that we do not flush code and do an immediate close if we have not completed the @@ -1699,7 +1737,7 @@ TEST_P(SslSocketTest, FlushCloseDuringHandshake) { EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket()); + std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr)); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -1763,7 +1801,7 @@ TEST_P(SslSocketTest, HalfClose) { client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), - client_ssl_socket_factory.createTransportSocket(), nullptr); + client_ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection->enableHalfClose(true); client_connection->addReadFilter(client_read_filter); client_connection->connect(); @@ -1775,7 +1813,7 @@ TEST_P(SslSocketTest, HalfClose) { EXPECT_CALL(listener_callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket()); + std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr)); listener_callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(listener_callbacks, onNewConnection_(_)) @@ -1854,7 +1892,7 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { ClientSslSocketFactory ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), - ssl_socket_factory.createTransportSocket(), nullptr); + ssl_socket_factory.createTransportSocket(nullptr), nullptr); // Verify that server sent list with 2 acceptable client certificate CA names. const Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); @@ -1874,7 +1912,7 @@ TEST_P(SslSocketTest, ClientAuthMultipleCAs) { EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket()); + std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr)); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -1942,7 +1980,7 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, ClientSslSocketFactory ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket1.localAddress(), Network::Address::InstanceConstSharedPtr(), - ssl_socket_factory.createTransportSocket(), nullptr); + ssl_socket_factory.createTransportSocket(nullptr), nullptr); Network::MockConnectionCallbacks client_connection_callbacks; client_connection->addConnectionCallbacks(client_connection_callbacks); @@ -1955,8 +1993,8 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, Network::TransportSocketFactory& tsf = socket->localAddress() == socket1.localAddress() ? server_ssl_socket_factory1 : server_ssl_socket_factory2; - Network::ConnectionPtr new_connection = - dispatcher.createServerConnection(std::move(socket), tsf.createTransportSocket()); + Network::ConnectionPtr new_connection = dispatcher.createServerConnection( + std::move(socket), tsf.createTransportSocket(nullptr)); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -1982,7 +2020,7 @@ void testTicketSessionResumption(const std::string& server_ctx_yaml1, client_connection = dispatcher.createClientConnection( socket2.localAddress(), Network::Address::InstanceConstSharedPtr(), - ssl_socket_factory.createTransportSocket(), nullptr); + ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection->addConnectionCallbacks(client_connection_callbacks); const Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); SSL_set_session(ssl_socket->rawSslForTest(), ssl_session); @@ -2353,7 +2391,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { ClientSslSocketFactory ssl_socket_factory(std::move(client_cfg), manager, client_stats_store); Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), - ssl_socket_factory.createTransportSocket(), nullptr); + ssl_socket_factory.createTransportSocket(nullptr), nullptr); Network::MockConnectionCallbacks client_connection_callbacks; client_connection->addConnectionCallbacks(client_connection_callbacks); @@ -2369,7 +2407,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { : server2_ssl_socket_factory; Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( - std::move(accepted_socket), tsf.createTransportSocket()); + std::move(accepted_socket), tsf.createTransportSocket(nullptr)); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -2399,7 +2437,7 @@ TEST_P(SslSocketTest, ClientAuthCrossListenerSessionResumption) { client_connection = dispatcher_->createClientConnection( socket2.localAddress(), Network::Address::InstanceConstSharedPtr(), - ssl_socket_factory.createTransportSocket(), nullptr); + ssl_socket_factory.createTransportSocket(nullptr), nullptr); client_connection->addConnectionCallbacks(client_connection_callbacks); const Ssl::SslSocket* ssl_socket = dynamic_cast(client_connection->ssl()); SSL_set_session(ssl_socket->rawSslForTest(), ssl_session); @@ -2464,7 +2502,7 @@ TEST_P(SslSocketTest, SslError) { EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory.createTransportSocket()); + std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr)); callbacks.onNewConnection(std::move(new_connection)); })); EXPECT_CALL(callbacks, onNewConnection_(_)) @@ -2502,44 +2540,44 @@ TEST_P(SslSocketTest, ProtocolVersions) { // Connection using defaults (client & server) succeeds, negotiating TLSv1.2. testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "", "ssl.handshake", - "ssl.handshake", GetParam()); + "ssl.handshake", GetParam(), nullptr); // Connection using defaults (client & server) succeeds, negotiating TLSv1.2, // even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "", "ssl.handshake", - "ssl.handshake", GetParam()); + "ssl.handshake", GetParam(), nullptr); client.set_allow_renegotiation(false); // Connection using TLSv1.0 (client) and defaults (server) succeeds. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); testUtilV2(listener, client, "", true, "TLSv1", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); // Connection using TLSv1.1 (client) and defaults (server) succeeds. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_1); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_1); testUtilV2(listener, client, "", true, "TLSv1.1", "", "", "", "", "ssl.handshake", - "ssl.handshake", GetParam()); + "ssl.handshake", GetParam(), nullptr); // Connection using TLSv1.2 (client) and defaults (server) succeeds. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_2); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_2); testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "", "ssl.handshake", - "ssl.handshake", GetParam()); + "ssl.handshake", GetParam(), nullptr); // Connection using TLSv1.3 (client) and defaults (server) fails. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.connection_error", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Connection using TLSv1.3 (client) and TLSv1.0-1.3 (server) succeeds. server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); testUtilV2(listener, client, "", true, "TLSv1.3", "", "", "", "", "ssl.handshake", - "ssl.handshake", GetParam()); + "ssl.handshake", GetParam(), nullptr); // Connection using defaults (client) and TLSv1.0 (server) succeeds. client_params->clear_tls_minimum_protocol_version(); @@ -2547,31 +2585,31 @@ TEST_P(SslSocketTest, ProtocolVersions) { server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); testUtilV2(listener, client, "", true, "TLSv1", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); // Connection using defaults (client) and TLSv1.1 (server) succeeds. server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_1); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_1); testUtilV2(listener, client, "", true, "TLSv1.1", "", "", "", "", "ssl.handshake", - "ssl.handshake", GetParam()); + "ssl.handshake", GetParam(), nullptr); // Connection using defaults (client) and TLSv1.2 (server) succeeds. server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_2); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_2); testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "", "ssl.handshake", - "ssl.handshake", GetParam()); + "ssl.handshake", GetParam(), nullptr); // Connection using defaults (client) and TLSv1.3 (server) fails. server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.connection_error", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); // Connection using TLSv1.0-TLSv1.3 (client) and TLSv1.3 (server) succeeds. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); testUtilV2(listener, client, "", true, "TLSv1.3", "", "", "", "", "ssl.handshake", - "ssl.handshake", GetParam()); + "ssl.handshake", GetParam(), nullptr); } TEST_P(SslSocketTest, ALPN) { @@ -2591,32 +2629,32 @@ TEST_P(SslSocketTest, ALPN) { // Connection using defaults (client & server) succeeds, no ALPN is negotiated. testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); // Connection using defaults (client & server) succeeds, no ALPN is negotiated, // even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); client.set_allow_renegotiation(false); // Client connects without ALPN to a server with "test" ALPN, no ALPN is negotiated. server_ctx->add_alpn_protocols("test"); testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); server_ctx->clear_alpn_protocols(); // Client connects with "test" ALPN to a server without ALPN, no ALPN is negotiated. client_ctx->add_alpn_protocols("test"); testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); client_ctx->clear_alpn_protocols(); // Client connects with "test" ALPN to a server with "test" ALPN, "test" ALPN is negotiated. client_ctx->add_alpn_protocols("test"); server_ctx->add_alpn_protocols("test"); testUtilV2(listener, client, "", true, "", "", "", "", "test", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); client_ctx->clear_alpn_protocols(); server_ctx->clear_alpn_protocols(); @@ -2626,7 +2664,7 @@ TEST_P(SslSocketTest, ALPN) { client_ctx->add_alpn_protocols("test"); server_ctx->add_alpn_protocols("test"); testUtilV2(listener, client, "", true, "", "", "", "", "test", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); client.set_allow_renegotiation(false); client_ctx->clear_alpn_protocols(); server_ctx->clear_alpn_protocols(); @@ -2635,7 +2673,7 @@ TEST_P(SslSocketTest, ALPN) { client_ctx->add_alpn_protocols("test"); server_ctx->add_alpn_protocols("test2"); testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); client_ctx->clear_alpn_protocols(); server_ctx->clear_alpn_protocols(); } @@ -2658,12 +2696,12 @@ TEST_P(SslSocketTest, CipherSuites) { // Connection using defaults (client & server) succeeds. testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); // Connection using defaults (client & server) succeeds, even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); client.set_allow_renegotiation(false); // Client connects with one of the supported cipher suites, connection succeeds. @@ -2671,7 +2709,7 @@ TEST_P(SslSocketTest, CipherSuites) { server_params->add_cipher_suites("ECDHE-RSA-CHACHA20-POLY1305"); server_params->add_cipher_suites("ECDHE-RSA-AES128-GCM-SHA256"); testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); client_params->clear_cipher_suites(); server_params->clear_cipher_suites(); @@ -2679,7 +2717,7 @@ TEST_P(SslSocketTest, CipherSuites) { client_params->add_cipher_suites("ECDHE-RSA-AES128-GCM-SHA256"); server_params->add_cipher_suites("ECDHE-RSA-CHACHA20-POLY1305"); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.connection_error", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); client_params->clear_cipher_suites(); server_params->clear_cipher_suites(); } @@ -2702,12 +2740,12 @@ TEST_P(SslSocketTest, EcdhCurves) { // Connection using defaults (client & server) succeeds. testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); // Connection using defaults (client & server) succeeds, even with client renegotiation. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); client.set_allow_renegotiation(false); // Client connects with one of the supported ECDH curves, connection succeeds. @@ -2716,7 +2754,7 @@ TEST_P(SslSocketTest, EcdhCurves) { server_params->add_ecdh_curves("P-256"); server_params->add_cipher_suites("ECDHE-RSA-AES128-GCM-SHA256"); testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", "ssl.handshake", - GetParam()); + GetParam(), nullptr); client_params->clear_ecdh_curves(); server_params->clear_ecdh_curves(); server_params->clear_cipher_suites(); @@ -2726,7 +2764,7 @@ TEST_P(SslSocketTest, EcdhCurves) { server_params->add_ecdh_curves("P-256"); server_params->add_cipher_suites("ECDHE-RSA-AES128-GCM-SHA256"); testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.connection_error", - "ssl.connection_error", GetParam()); + "ssl.connection_error", GetParam(), nullptr); client_params->clear_ecdh_curves(); server_params->clear_ecdh_curves(); server_params->clear_cipher_suites(); @@ -2830,7 +2868,45 @@ TEST_P(SslSocketTest, GetRequestedServerName) { client.set_sni("lyft.com"); testUtilV2(listener, client, "", true, "", "", "", "lyft.com", "", "ssl.handshake", - "ssl.handshake", GetParam()); + "ssl.handshake", GetParam(), nullptr); +} + +TEST_P(SslSocketTest, OverrideRequestedServerName) { + envoy::api::v2::Listener listener; + envoy::api::v2::listener::FilterChain* filter_chain = listener.add_filter_chains(); + envoy::api::v2::auth::TlsCertificate* server_cert = + filter_chain->mutable_tls_context()->mutable_common_tls_context()->add_tls_certificates(); + server_cert->mutable_certificate_chain()->set_filename( + TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem")); + server_cert->mutable_private_key()->set_filename( + TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem")); + + envoy::api::v2::auth::UpstreamTlsContext client; + client.set_sni("lyft.com"); + + Network::TransportSocketOptionsSharedPtr transport_socket_options( + new Network::TransportSocketOptionsImpl("example.com")); + + testUtilV2(listener, client, "", true, "", "", "", "example.com", "", "ssl.handshake", + "ssl.handshake", GetParam(), transport_socket_options); +} + +TEST_P(SslSocketTest, OverrideRequestedServerNameWithoutSniInUpstreamTlsContext) { + envoy::api::v2::Listener listener; + envoy::api::v2::listener::FilterChain* filter_chain = listener.add_filter_chains(); + envoy::api::v2::auth::TlsCertificate* server_cert = + filter_chain->mutable_tls_context()->mutable_common_tls_context()->add_tls_certificates(); + server_cert->mutable_certificate_chain()->set_filename( + TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem")); + server_cert->mutable_private_key()->set_filename( + TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem")); + + envoy::api::v2::auth::UpstreamTlsContext client; + + Network::TransportSocketOptionsSharedPtr transport_socket_options( + new Network::TransportSocketOptionsImpl("example.com")); + testUtilV2(listener, client, "", true, "", "", "", "example.com", "", "ssl.handshake", + "ssl.handshake", GetParam(), transport_socket_options); } // Validate that if downstream secrets are not yet downloaded from SDS server, Envoy creates @@ -2863,7 +2939,7 @@ TEST_P(SslSocketTest, DownstreamNotReadySslSocket) { ContextManagerImpl manager(time_system); Ssl::ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, stats_store, std::vector{}); - auto transport_socket = server_ssl_socket_factory.createTransportSocket(); + auto transport_socket = server_ssl_socket_factory.createTransportSocket(nullptr); EXPECT_EQ(EMPTY_STRING, transport_socket->protocol()); EXPECT_EQ(nullptr, transport_socket->ssl()); Buffer::OwnedImpl buffer; @@ -2903,7 +2979,7 @@ TEST_P(SslSocketTest, UpstreamNotReadySslSocket) { ContextManagerImpl manager(time_system); Ssl::ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager, stats_store); - auto transport_socket = client_ssl_socket_factory.createTransportSocket(); + auto transport_socket = client_ssl_socket_factory.createTransportSocket(nullptr); EXPECT_EQ(EMPTY_STRING, transport_socket->protocol()); EXPECT_EQ(nullptr, transport_socket->ssl()); Buffer::OwnedImpl buffer; @@ -2935,7 +3011,7 @@ class SslReadBufferLimitTest : public SslSocketTest { std::move(client_cfg), *manager_, client_stats_store_); client_connection_ = dispatcher_->createClientConnection( socket_.localAddress(), source_address_, - client_ssl_socket_factory_->createTransportSocket(), nullptr); + client_ssl_socket_factory_->createTransportSocket(nullptr), nullptr); client_connection_->addConnectionCallbacks(client_callbacks_); client_connection_->connect(); read_filter_.reset(new Network::MockReadFilter()); @@ -2948,7 +3024,7 @@ class SslReadBufferLimitTest : public SslSocketTest { EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory_->createTransportSocket()); + std::move(socket), server_ssl_socket_factory_->createTransportSocket(nullptr)); new_connection->setBufferLimits(read_buffer_limit); listener_callbacks_.onNewConnection(std::move(new_connection)); })); @@ -3033,7 +3109,7 @@ class SslReadBufferLimitTest : public SslSocketTest { EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory_->createTransportSocket()); + std::move(socket), server_ssl_socket_factory_->createTransportSocket(nullptr)); new_connection->setBufferLimits(read_buffer_limit); listener_callbacks_.onNewConnection(std::move(new_connection)); })); @@ -3157,7 +3233,7 @@ TEST_P(SslReadBufferLimitTest, TestBind) { EXPECT_CALL(listener_callbacks_, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { Network::ConnectionPtr new_connection = dispatcher_->createServerConnection( - std::move(socket), server_ssl_socket_factory_->createTransportSocket()); + std::move(socket), server_ssl_socket_factory_->createTransportSocket(nullptr)); new_connection->setBufferLimits(0); listener_callbacks_.onNewConnection(std::move(new_connection)); })); diff --git a/test/common/tcp/conn_pool_test.cc b/test/common/tcp/conn_pool_test.cc index d30918d7ea65c..b3f26916eabad 100644 --- a/test/common/tcp/conn_pool_test.cc +++ b/test/common/tcp/conn_pool_test.cc @@ -76,7 +76,7 @@ class ConnPoolImplForTest : public ConnPoolImpl { Upstream::ClusterInfoConstSharedPtr cluster, NiceMock* upstream_ready_timer) : ConnPoolImpl(dispatcher, Upstream::makeTestHost(cluster, "tcp://127.0.0.1:9000"), - Upstream::ResourcePriority::Default, nullptr), + Upstream::ResourcePriority::Default, nullptr, nullptr), mock_dispatcher_(dispatcher), mock_upstream_ready_timer_(upstream_ready_timer) {} ~ConnPoolImplForTest() { @@ -181,7 +181,7 @@ class TcpConnPoolImplDestructorTest : public testing::Test { : upstream_ready_timer_(new NiceMock(&dispatcher_)), conn_pool_{new ConnPoolImpl(dispatcher_, Upstream::makeTestHost(cluster_, "tcp://127.0.0.1:9000"), - Upstream::ResourcePriority::Default, nullptr)} {} + Upstream::ResourcePriority::Default, nullptr, nullptr)} {} ~TcpConnPoolImplDestructorTest() {} diff --git a/test/common/tcp_proxy/BUILD b/test/common/tcp_proxy/BUILD index 5252736d157b4..4563a193ec43b 100644 --- a/test/common/tcp_proxy/BUILD +++ b/test/common/tcp_proxy/BUILD @@ -16,6 +16,8 @@ envoy_cc_test( "//source/common/config:filter_json_lib", "//source/common/event:dispatcher_lib", "//source/common/network:address_lib", + "//source/common/network:transport_socket_options_lib", + "//source/common/network:upstream_server_name_lib", "//source/common/stats:stats_lib", "//source/common/tcp_proxy", "//source/common/upstream:upstream_includes", diff --git a/test/common/tcp_proxy/tcp_proxy_test.cc b/test/common/tcp_proxy/tcp_proxy_test.cc index 157e3f09cbba2..704b68c8bc4f6 100644 --- a/test/common/tcp_proxy/tcp_proxy_test.cc +++ b/test/common/tcp_proxy/tcp_proxy_test.cc @@ -7,6 +7,8 @@ #include "common/buffer/buffer_impl.h" #include "common/config/filter_json.h" #include "common/network/address_impl.h" +#include "common/network/transport_socket_options_impl.h" +#include "common/network/upstream_server_name.h" #include "common/router/metadatamatchcriteria_impl.h" #include "common/tcp_proxy/tcp_proxy.h" #include "common/upstream/upstream_impl.h" @@ -40,6 +42,8 @@ using testing::SaveArg; namespace Envoy { namespace TcpProxy { +using ::Envoy::Network::UpstreamServerName; + namespace { Config constructConfigFromJson(const Json::Object& json, Server::Configuration::FactoryContext& context) { @@ -413,7 +417,8 @@ class TcpProxyTest : public testing::Test { { testing::InSequence sequence; for (uint32_t i = 0; i < connections; i++) { - EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + EXPECT_CALL(factory_context_.cluster_manager_, + tcpConnPoolForCluster("fake_cluster", _, _, _)) .WillOnce(Return(&conn_pool_)) .RetiresOnSaturation(); EXPECT_CALL(conn_pool_, newConnection(_)) @@ -424,7 +429,7 @@ class TcpProxyTest : public testing::Test { })) .RetiresOnSaturation(); } - EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _, _)) .WillRepeatedly(Return(nullptr)); } @@ -1134,7 +1139,7 @@ TEST_F(TcpProxyRoutingTest, RoutableConnection) { connection_.local_address_ = std::make_shared("1.2.3.4", 9999); // Expect filter to try to open a connection to specified cluster. - EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _, _)) .WillOnce(Return(nullptr)); filter_->onNewConnection(); @@ -1156,11 +1161,43 @@ TEST_F(TcpProxyRoutingTest, UseClusterFromPerConnectionCluster) { // Expect filter to try to open a connection to specified cluster. EXPECT_CALL(factory_context_.cluster_manager_, - tcpConnPoolForCluster("filter_state_cluster", _, _)) + tcpConnPoolForCluster("filter_state_cluster", _, _, _)) .WillOnce(Return(nullptr)); filter_->onNewConnection(); } +// Test that the tcp proxy forwards the requested server name from FilterState if set +TEST_F(TcpProxyRoutingTest, UpstreamServerName) { + setup(); + + NiceMock stream_info; + stream_info.filterState().setData("envoy.network.upstream_server_name", + std::make_unique("www.example.com"), + StreamInfo::FilterState::StateType::ReadOnly); + + ON_CALL(connection_, streamInfo()).WillByDefault(ReturnRef(stream_info)); + EXPECT_CALL(Const(connection_), streamInfo()).WillRepeatedly(ReturnRef(stream_info)); + + // Expect filter to try to open a connection to a cluster with the transport socket options with + // override-server-name + EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster(_, _, _, _)) + .WillOnce(Invoke([](const std::string& cluster, Upstream::ResourcePriority, + Upstream::LoadBalancerContext*, + Network::TransportSocketOptionsSharedPtr transport_socket_options) + -> Tcp::ConnectionPool::Instance* { + EXPECT_EQ(cluster, "fake_cluster"); + EXPECT_NE(transport_socket_options, nullptr); + EXPECT_TRUE(transport_socket_options->serverNameOverride().has_value()); + EXPECT_EQ(transport_socket_options->serverNameOverride().value(), "www.example.com"); + return nullptr; + })); + + // Port 9999 is within the specified destination port range. + connection_.local_address_ = std::make_shared("1.2.3.4", 9999); + + filter_->onNewConnection(); +} + } // namespace TcpProxy } // namespace Envoy diff --git a/test/common/upstream/BUILD b/test/common/upstream/BUILD index 084e342e78630..2b0fd97f7a317 100644 --- a/test/common/upstream/BUILD +++ b/test/common/upstream/BUILD @@ -37,6 +37,7 @@ envoy_cc_test( "//source/common/config:utility_lib", "//source/common/event:dispatcher_lib", "//source/common/network:socket_option_lib", + "//source/common/network:transport_socket_options_lib", "//source/common/network:utility_lib", "//source/common/ssl:context_lib", "//source/common/stats:stats_lib", diff --git a/test/common/upstream/cluster_manager_impl_test.cc b/test/common/upstream/cluster_manager_impl_test.cc index 55a3c809c0f0a..c960b057966e8 100644 --- a/test/common/upstream/cluster_manager_impl_test.cc +++ b/test/common/upstream/cluster_manager_impl_test.cc @@ -9,6 +9,7 @@ #include "common/config/bootstrap_json.h" #include "common/config/utility.h" #include "common/network/socket_option_impl.h" +#include "common/network/transport_socket_options_impl.h" #include "common/network/utility.h" #include "common/ssl/context_manager_impl.h" #include "common/upstream/cluster_manager_impl.h" @@ -72,7 +73,8 @@ class TestClusterManagerFactory : public ClusterManagerFactory { Tcp::ConnectionPool::InstancePtr allocateTcpConnPool(Event::Dispatcher&, HostConstSharedPtr host, ResourcePriority, - const Network::ConnectionSocket::OptionsSharedPtr&) override { + const Network::ConnectionSocket::OptionsSharedPtr&, + Network::TransportSocketOptionsSharedPtr) override { return Tcp::ConnectionPool::InstancePtr{allocateTcpConnPool_(host)}; } @@ -713,9 +715,18 @@ TEST_F(ClusterManagerImplTest, UnknownCluster) { EXPECT_EQ(nullptr, cluster_manager_->get("hello")); EXPECT_EQ(nullptr, cluster_manager_->httpConnPoolForCluster("hello", ResourcePriority::Default, Http::Protocol::Http2, nullptr)); - EXPECT_EQ(nullptr, - cluster_manager_->tcpConnPoolForCluster("hello", ResourcePriority::Default, nullptr)); - EXPECT_THROW(cluster_manager_->tcpConnForCluster("hello", nullptr), EnvoyException); + Network::TransportSocketOptionsSharedPtr transport_socket_options; + EXPECT_EQ(nullptr, cluster_manager_->tcpConnPoolForCluster("hello", ResourcePriority::Default, + nullptr, transport_socket_options)); + EXPECT_THROW(cluster_manager_->tcpConnForCluster("hello", nullptr, transport_socket_options), + EnvoyException); + + transport_socket_options = std::make_shared("example.com"); + EXPECT_EQ(nullptr, cluster_manager_->tcpConnPoolForCluster("hello", ResourcePriority::Default, + nullptr, transport_socket_options)); + EXPECT_THROW(cluster_manager_->tcpConnForCluster("hello", nullptr, transport_socket_options), + EnvoyException); + EXPECT_THROW(cluster_manager_->httpAsyncClientForCluster("hello"), EnvoyException); factory_.tls_.shutdownThread(); } @@ -743,7 +754,7 @@ TEST_F(ClusterManagerImplTest, VerifyBufferLimits) { EXPECT_CALL(*connection, setBufferLimits(8192)); EXPECT_CALL(factory_.tls_.dispatcher_, createClientConnection_(_, _, _, _)) .WillOnce(Return(connection)); - auto conn_data = cluster_manager_->tcpConnForCluster("cluster_1", nullptr); + auto conn_data = cluster_manager_->tcpConnForCluster("cluster_1", nullptr, nullptr); EXPECT_EQ(connection, conn_data.connection_.get()); factory_.tls_.shutdownThread(); } @@ -1096,7 +1107,7 @@ TEST_F(ClusterManagerImplTest, DynamicAddRemove) { Tcp::ConnectionPool::MockInstance* cp2 = new Tcp::ConnectionPool::MockInstance(); EXPECT_CALL(factory_, allocateTcpConnPool_(_)).WillOnce(Return(cp2)); EXPECT_EQ(cp2, cluster_manager_->tcpConnPoolForCluster("fake_cluster", ResourcePriority::Default, - nullptr)); + nullptr, nullptr)); Network::MockClientConnection* connection = new Network::MockClientConnection(); ON_CALL(*cluster2->info_, features()) @@ -1105,7 +1116,7 @@ TEST_F(ClusterManagerImplTest, DynamicAddRemove) { .WillOnce(Return(connection)); EXPECT_CALL(*connection, setBufferLimits(_)); EXPECT_CALL(*connection, addConnectionCallbacks(_)); - auto conn_info = cluster_manager_->tcpConnForCluster("fake_cluster", nullptr); + auto conn_info = cluster_manager_->tcpConnForCluster("fake_cluster", nullptr, nullptr); EXPECT_EQ(conn_info.connection_.get(), connection); // Now remove the cluster. This should drain the connection pools, but not affect @@ -1259,7 +1270,8 @@ TEST_F(ClusterManagerImplTest, CloseTcpConnectionPoolsOnHealthFailure) { create(parseBootstrapFromJson(json)); EXPECT_CALL(factory_, allocateTcpConnPool_(_)).WillOnce(Return(cp1)); - cluster_manager_->tcpConnPoolForCluster("some_cluster", ResourcePriority::Default, nullptr); + cluster_manager_->tcpConnPoolForCluster("some_cluster", ResourcePriority::Default, nullptr, + nullptr); outlier_detector.runCallbacks(test_host); health_checker.runCallbacks(test_host, HealthTransition::Unchanged); @@ -1269,7 +1281,8 @@ TEST_F(ClusterManagerImplTest, CloseTcpConnectionPoolsOnHealthFailure) { outlier_detector.runCallbacks(test_host); EXPECT_CALL(factory_, allocateTcpConnPool_(_)).WillOnce(Return(cp2)); - cluster_manager_->tcpConnPoolForCluster("some_cluster", ResourcePriority::High, nullptr); + cluster_manager_->tcpConnPoolForCluster("some_cluster", ResourcePriority::High, nullptr, + nullptr); } // Order of these calls is implementation dependent, so can't sequence them! @@ -1330,7 +1343,7 @@ TEST_F(ClusterManagerImplTest, CloseTcpConnectionsOnHealthFailure) { EXPECT_CALL(factory_.tls_.dispatcher_, createClientConnection_(_, _, _, _)) .WillOnce(Return(connection1)); - conn_info1 = cluster_manager_->tcpConnForCluster("some_cluster", nullptr); + conn_info1 = cluster_manager_->tcpConnForCluster("some_cluster", nullptr, nullptr); outlier_detector.runCallbacks(test_host); health_checker.runCallbacks(test_host, HealthTransition::Unchanged); @@ -1342,11 +1355,11 @@ TEST_F(ClusterManagerImplTest, CloseTcpConnectionsOnHealthFailure) { connection1 = new NiceMock(); EXPECT_CALL(factory_.tls_.dispatcher_, createClientConnection_(_, _, _, _)) .WillOnce(Return(connection1)); - conn_info1 = cluster_manager_->tcpConnForCluster("some_cluster", nullptr); + conn_info1 = cluster_manager_->tcpConnForCluster("some_cluster", nullptr, nullptr); EXPECT_CALL(factory_.tls_.dispatcher_, createClientConnection_(_, _, _, _)) .WillOnce(Return(connection2)); - conn_info2 = cluster_manager_->tcpConnForCluster("some_cluster", nullptr); + conn_info2 = cluster_manager_->tcpConnForCluster("some_cluster", nullptr, nullptr); } // Order of these calls is implementation dependent, so can't sequence them! @@ -1402,7 +1415,7 @@ TEST_F(ClusterManagerImplTest, DoNotCloseTcpConnectionsOnHealthFailure) { EXPECT_CALL(factory_.tls_.dispatcher_, createClientConnection_(_, _, _, _)) .WillOnce(Return(connection1)); - conn_info1 = cluster_manager_->tcpConnForCluster("some_cluster", nullptr); + conn_info1 = cluster_manager_->tcpConnForCluster("some_cluster", nullptr, nullptr); outlier_detector.runCallbacks(test_host); health_checker.runCallbacks(test_host, HealthTransition::Unchanged); @@ -1444,8 +1457,9 @@ TEST_F(ClusterManagerImplTest, DynamicHostRemove) { EXPECT_EQ(nullptr, cluster_manager_->httpConnPoolForCluster( "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); EXPECT_EQ(nullptr, cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, - nullptr)); - EXPECT_EQ(nullptr, cluster_manager_->tcpConnForCluster("cluster_1", nullptr).connection_); + nullptr, nullptr)); + EXPECT_EQ(nullptr, + cluster_manager_->tcpConnForCluster("cluster_1", nullptr, nullptr).connection_); EXPECT_EQ(3UL, factory_.stats_.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); // Set up for an initialize callback. @@ -1492,14 +1506,18 @@ TEST_F(ClusterManagerImplTest, DynamicHostRemove) { .WillRepeatedly(ReturnNew()); // This should provide us a CP for each of the above hosts. - Tcp::ConnectionPool::MockInstance* tcp1 = dynamic_cast( - cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); - Tcp::ConnectionPool::MockInstance* tcp2 = dynamic_cast( - cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); - Tcp::ConnectionPool::MockInstance* tcp1_high = dynamic_cast( - cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::High, nullptr)); - Tcp::ConnectionPool::MockInstance* tcp2_high = dynamic_cast( - cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::High, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp1 = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp2 = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp1_high = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::High, nullptr, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp2_high = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::High, nullptr, nullptr)); EXPECT_NE(tcp1, tcp2); EXPECT_NE(tcp1_high, tcp2_high); @@ -1533,10 +1551,12 @@ TEST_F(ClusterManagerImplTest, DynamicHostRemove) { EXPECT_EQ(cp2, cp3); EXPECT_EQ(cp2_high, cp3_high); - Tcp::ConnectionPool::MockInstance* tcp3 = dynamic_cast( - cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); - Tcp::ConnectionPool::MockInstance* tcp3_high = dynamic_cast( - cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::High, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp3 = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp3_high = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::High, nullptr, nullptr)); EXPECT_EQ(tcp2, tcp3); EXPECT_EQ(tcp2_high, tcp3_high); @@ -1548,6 +1568,233 @@ TEST_F(ClusterManagerImplTest, DynamicHostRemove) { factory_.tls_.shutdownThread(); } +TEST_F(ClusterManagerImplTest, DynamicHostRemoveWithTls) { + const std::string yaml = R"EOF( + static_resources: + clusters: + - name: cluster_1 + connect_timeout: 0.250s + type: STRICT_DNS + dns_resolvers: + - socket_address: + address: 1.2.3.4 + port_value: 80 + lb_policy: ROUND_ROBIN + hosts: + - socket_address: + address: localhost + port_value: 11001 + )EOF"; + + std::shared_ptr dns_resolver(new Network::MockDnsResolver()); + EXPECT_CALL(factory_.dispatcher_, createDnsResolver(_)).WillOnce(Return(dns_resolver)); + + Network::DnsResolver::ResolveCb dns_callback; + Event::MockTimer* dns_timer_ = new NiceMock(&factory_.dispatcher_); + Network::MockActiveDnsQuery active_dns_query; + EXPECT_CALL(*dns_resolver, resolve(_, _, _)) + .WillRepeatedly(DoAll(SaveArg<2>(&dns_callback), Return(&active_dns_query))); + create(parseBootstrapFromV2Yaml(yaml)); + EXPECT_FALSE(cluster_manager_->get("cluster_1")->info()->addedViaApi()); + + Network::TransportSocketOptionsSharedPtr transport_socket_options_example_com( + new Network::TransportSocketOptionsImpl("example.com")); + Network::TransportSocketOptionsSharedPtr transport_socket_options_ibm_com( + new Network::TransportSocketOptionsImpl("ibm.com")); + + // Test for no hosts returning the correct values before we have hosts. + EXPECT_EQ(nullptr, cluster_manager_->httpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); + EXPECT_EQ(nullptr, cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, + nullptr, nullptr)); + EXPECT_EQ(nullptr, + cluster_manager_->tcpConnForCluster("cluster_1", nullptr, nullptr).connection_); + + EXPECT_EQ(nullptr, + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr, + transport_socket_options_example_com)); + EXPECT_EQ(nullptr, + cluster_manager_ + ->tcpConnForCluster("cluster_1", nullptr, transport_socket_options_example_com) + .connection_); + + EXPECT_EQ(nullptr, + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr, + transport_socket_options_ibm_com)); + EXPECT_EQ(nullptr, cluster_manager_ + ->tcpConnForCluster("cluster_1", nullptr, transport_socket_options_ibm_com) + .connection_); + + EXPECT_EQ(7UL, factory_.stats_.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); + + // Set up for an initialize callback. + ReadyWatcher initialized; + cluster_manager_->setInitializedCb([&]() -> void { initialized.ready(); }); + EXPECT_CALL(initialized, ready()); + + dns_callback(TestUtility::makeDnsResponse({"127.0.0.1", "127.0.0.2"})); + + // After we are initialized, we should immediately get called back if someone asks for an + // initialize callback. + EXPECT_CALL(initialized, ready()); + cluster_manager_->setInitializedCb([&]() -> void { initialized.ready(); }); + + EXPECT_CALL(factory_, allocateConnPool_(_)) + .Times(4) + .WillRepeatedly(ReturnNew()); + + // This should provide us a CP for each of the above hosts. + Http::ConnectionPool::MockInstance* cp1 = + dynamic_cast(cluster_manager_->httpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); + Http::ConnectionPool::MockInstance* cp2 = + dynamic_cast(cluster_manager_->httpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); + Http::ConnectionPool::MockInstance* cp1_high = + dynamic_cast(cluster_manager_->httpConnPoolForCluster( + "cluster_1", ResourcePriority::High, Http::Protocol::Http11, nullptr)); + Http::ConnectionPool::MockInstance* cp2_high = + dynamic_cast(cluster_manager_->httpConnPoolForCluster( + "cluster_1", ResourcePriority::High, Http::Protocol::Http11, nullptr)); + + EXPECT_NE(cp1, cp2); + EXPECT_NE(cp1_high, cp2_high); + EXPECT_NE(cp1, cp1_high); + + Http::ConnectionPool::Instance::DrainedCb drained_cb; + EXPECT_CALL(*cp1, addDrainedCallback(_)).WillOnce(SaveArg<0>(&drained_cb)); + Http::ConnectionPool::Instance::DrainedCb drained_cb_high; + EXPECT_CALL(*cp1_high, addDrainedCallback(_)).WillOnce(SaveArg<0>(&drained_cb_high)); + + EXPECT_CALL(factory_, allocateTcpConnPool_(_)) + .Times(8) + .WillRepeatedly(ReturnNew()); + + // This should provide us a CP for each of the above hosts, and for different SNIs + Tcp::ConnectionPool::MockInstance* tcp1 = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp2 = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp1_high = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::High, nullptr, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp2_high = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::High, nullptr, nullptr)); + + Tcp::ConnectionPool::MockInstance* tcp1_example_com = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, transport_socket_options_example_com)); + Tcp::ConnectionPool::MockInstance* tcp2_example_com = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, transport_socket_options_example_com)); + + Tcp::ConnectionPool::MockInstance* tcp1_ibm_com = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, transport_socket_options_ibm_com)); + Tcp::ConnectionPool::MockInstance* tcp2_ibm_com = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, transport_socket_options_ibm_com)); + + EXPECT_NE(tcp1, tcp2); + EXPECT_NE(tcp1_high, tcp2_high); + EXPECT_NE(tcp1, tcp1_high); + + EXPECT_NE(tcp1_ibm_com, tcp2_ibm_com); + EXPECT_NE(tcp1_ibm_com, tcp1); + EXPECT_NE(tcp1_ibm_com, tcp2); + EXPECT_NE(tcp1_ibm_com, tcp1_high); + EXPECT_NE(tcp1_ibm_com, tcp2_high); + EXPECT_NE(tcp1_ibm_com, tcp1_example_com); + EXPECT_NE(tcp1_ibm_com, tcp2_example_com); + + EXPECT_NE(tcp2_ibm_com, tcp1); + EXPECT_NE(tcp2_ibm_com, tcp2); + EXPECT_NE(tcp2_ibm_com, tcp1_high); + EXPECT_NE(tcp2_ibm_com, tcp2_high); + EXPECT_NE(tcp2_ibm_com, tcp1_example_com); + EXPECT_NE(tcp2_ibm_com, tcp2_example_com); + + EXPECT_NE(tcp1_example_com, tcp1); + EXPECT_NE(tcp1_example_com, tcp2); + EXPECT_NE(tcp1_example_com, tcp1_high); + EXPECT_NE(tcp1_example_com, tcp2_high); + EXPECT_NE(tcp1_example_com, tcp2_example_com); + + EXPECT_NE(tcp2_example_com, tcp1); + EXPECT_NE(tcp2_example_com, tcp2); + EXPECT_NE(tcp2_example_com, tcp1_high); + EXPECT_NE(tcp2_example_com, tcp2_high); + + EXPECT_CALL(factory_.tls_.dispatcher_, deferredDelete_(_)).Times(6); + + Tcp::ConnectionPool::Instance::DrainedCb tcp_drained_cb; + EXPECT_CALL(*tcp1, addDrainedCallback(_)).WillOnce(SaveArg<0>(&tcp_drained_cb)); + Tcp::ConnectionPool::Instance::DrainedCb tcp_drained_cb_high; + EXPECT_CALL(*tcp1_high, addDrainedCallback(_)).WillOnce(SaveArg<0>(&tcp_drained_cb_high)); + + Tcp::ConnectionPool::Instance::DrainedCb tcp_drained_cb_example_com; + EXPECT_CALL(*tcp1_example_com, addDrainedCallback(_)) + .WillOnce(SaveArg<0>(&tcp_drained_cb_example_com)); + Tcp::ConnectionPool::Instance::DrainedCb tcp_drained_cb_ibm_com; + EXPECT_CALL(*tcp1_ibm_com, addDrainedCallback(_)).WillOnce(SaveArg<0>(&tcp_drained_cb_ibm_com)); + + // Remove the first host, this should lead to the first cp being drained. + dns_timer_->callback_(); + dns_callback(TestUtility::makeDnsResponse({"127.0.0.2"})); + drained_cb(); + drained_cb = nullptr; + tcp_drained_cb(); + tcp_drained_cb = nullptr; + drained_cb_high(); + drained_cb_high = nullptr; + tcp_drained_cb_high(); + tcp_drained_cb_high = nullptr; + tcp_drained_cb_example_com(); + tcp_drained_cb_example_com = nullptr; + tcp_drained_cb_ibm_com(); + tcp_drained_cb_ibm_com = nullptr; + + // Make sure we get back the same connection pool for the 2nd host as we did before the change. + Http::ConnectionPool::MockInstance* cp3 = + dynamic_cast(cluster_manager_->httpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); + Http::ConnectionPool::MockInstance* cp3_high = + dynamic_cast(cluster_manager_->httpConnPoolForCluster( + "cluster_1", ResourcePriority::High, Http::Protocol::Http11, nullptr)); + EXPECT_EQ(cp2, cp3); + EXPECT_EQ(cp2_high, cp3_high); + + Tcp::ConnectionPool::MockInstance* tcp3 = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp3_high = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::High, nullptr, nullptr)); + + Tcp::ConnectionPool::MockInstance* tcp3_example_com = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, transport_socket_options_example_com)); + Tcp::ConnectionPool::MockInstance* tcp3_ibm_com = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, transport_socket_options_ibm_com)); + + EXPECT_EQ(tcp2, tcp3); + EXPECT_EQ(tcp2_high, tcp3_high); + + EXPECT_EQ(tcp2_example_com, tcp3_example_com); + EXPECT_EQ(tcp2_ibm_com, tcp3_ibm_com); + + // Now add and remove a host that we never have a conn pool to. This should not lead to any + // drain callbacks, etc. + dns_timer_->callback_(); + dns_callback(TestUtility::makeDnsResponse({"127.0.0.2", "127.0.0.3"})); + + factory_.tls_.shutdownThread(); +} + // This is a regression test for a use-after-free in // ClusterManagerImpl::ThreadLocalClusterManagerImpl::drainConnPools(), where a removal at one // priority from the ConnPoolsContainer would delete the ConnPoolsContainer mid-iteration over the @@ -1593,8 +1840,9 @@ TEST_F(ClusterManagerImplTest, DynamicHostRemoveDefaultPriority) { dynamic_cast(cluster_manager_->httpConnPoolForCluster( "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); - Tcp::ConnectionPool::MockInstance* tcp = dynamic_cast( - cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, nullptr)); // Immediate drain, since this can happen with the HTTP codecs. EXPECT_CALL(*cp, addDrainedCallback(_)) @@ -1668,8 +1916,9 @@ TEST_F(ClusterManagerImplTest, ConnPoolDestroyWithDraining) { dynamic_cast(cluster_manager_->httpConnPoolForCluster( "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); - Tcp::ConnectionPool::MockInstance* tcp = dynamic_cast( - cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp = + dynamic_cast(cluster_manager_->tcpConnPoolForCluster( + "cluster_1", ResourcePriority::Default, nullptr, nullptr)); // Remove the first host, this should lead to the cp being drained. Http::ConnectionPool::Instance::DrainedCb drained_cb; @@ -1712,8 +1961,9 @@ TEST_F(ClusterManagerImplTest, OriginalDstInitialization) { EXPECT_EQ(nullptr, cluster_manager_->httpConnPoolForCluster( "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); EXPECT_EQ(nullptr, cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, - nullptr)); - EXPECT_EQ(nullptr, cluster_manager_->tcpConnForCluster("cluster_1", nullptr).connection_); + nullptr, nullptr)); + EXPECT_EQ(nullptr, + cluster_manager_->tcpConnForCluster("cluster_1", nullptr, nullptr).connection_); EXPECT_EQ(3UL, factory_.stats_.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); factory_.tls_.shutdownThread(); @@ -2213,7 +2463,7 @@ class SockoptsTest : public ClusterManagerImplTest { } return connection_; })); - cluster_manager_->tcpConnForCluster("SockoptsCluster", nullptr); + cluster_manager_->tcpConnForCluster("SockoptsCluster", nullptr, nullptr); } void expectSetsockoptFreebind() { @@ -2232,7 +2482,7 @@ class SockoptsTest : public ClusterManagerImplTest { EXPECT_EQ(nullptr, options.get()); return connection_; })); - auto conn_data = cluster_manager_->tcpConnForCluster("SockoptsCluster", nullptr); + auto conn_data = cluster_manager_->tcpConnForCluster("SockoptsCluster", nullptr, nullptr); EXPECT_EQ(connection_, conn_data.connection_.get()); } @@ -2419,7 +2669,7 @@ class TcpKeepaliveTest : public ClusterManagerImplTest { options, socket, envoy::api::v2::core::SocketOption::STATE_PREBIND))); return connection_; })); - cluster_manager_->tcpConnForCluster("TcpKeepaliveCluster", nullptr); + cluster_manager_->tcpConnForCluster("TcpKeepaliveCluster", nullptr, nullptr); return; } NiceMock os_sys_calls; @@ -2471,7 +2721,7 @@ class TcpKeepaliveTest : public ClusterManagerImplTest { return 0; })); } - auto conn_data = cluster_manager_->tcpConnForCluster("TcpKeepaliveCluster", nullptr); + auto conn_data = cluster_manager_->tcpConnForCluster("TcpKeepaliveCluster", nullptr, nullptr); EXPECT_EQ(connection_, conn_data.connection_.get()); } @@ -2485,7 +2735,7 @@ class TcpKeepaliveTest : public ClusterManagerImplTest { EXPECT_EQ(nullptr, options.get()); return connection_; })); - auto conn_data = cluster_manager_->tcpConnForCluster("TcpKeepaliveCluster", nullptr); + auto conn_data = cluster_manager_->tcpConnForCluster("TcpKeepaliveCluster", nullptr, nullptr); EXPECT_EQ(connection_, conn_data.connection_.get()); } diff --git a/test/common/upstream/logical_dns_cluster_test.cc b/test/common/upstream/logical_dns_cluster_test.cc index 44bf2ebe260f1..1b694483be9d8 100644 --- a/test/common/upstream/logical_dns_cluster_test.cc +++ b/test/common/upstream/logical_dns_cluster_test.cc @@ -116,7 +116,7 @@ class LogicalDnsClusterTest : public testing::Test { createClientConnection_( PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.1:443")), _, _, _)) .WillOnce(Return(new NiceMock())); - logical_host->createConnection(dispatcher_, nullptr); + logical_host->createConnection(dispatcher_, nullptr, nullptr); logical_host->outlierDetector().putHttpResponseCode(200); expectResolve(Network::DnsLookupFamily::V4Only, expected_address); @@ -135,7 +135,7 @@ class LogicalDnsClusterTest : public testing::Test { createClientConnection_( PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.1:443")), _, _, _)) .WillOnce(Return(new NiceMock())); - Host::CreateConnectionData data = logical_host->createConnection(dispatcher_, nullptr); + Host::CreateConnectionData data = logical_host->createConnection(dispatcher_, nullptr, nullptr); EXPECT_FALSE(data.host_description_->canary()); EXPECT_EQ(&cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]->cluster(), &data.host_description_->cluster()); @@ -167,7 +167,7 @@ class LogicalDnsClusterTest : public testing::Test { createClientConnection_( PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.3:443")), _, _, _)) .WillOnce(Return(new NiceMock())); - logical_host->createConnection(dispatcher_, nullptr); + logical_host->createConnection(dispatcher_, nullptr, nullptr); expectResolve(Network::DnsLookupFamily::V4Only, expected_address); resolve_timer_->callback_(); @@ -181,7 +181,7 @@ class LogicalDnsClusterTest : public testing::Test { createClientConnection_( PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.3:443")), _, _, _)) .WillOnce(Return(new NiceMock())); - logical_host->createConnection(dispatcher_, nullptr); + logical_host->createConnection(dispatcher_, nullptr, nullptr); // Make sure we cancel. EXPECT_CALL(active_dns_query_, cancel()); diff --git a/test/common/upstream/original_dst_cluster_test.cc b/test/common/upstream/original_dst_cluster_test.cc index 8659d3d10494e..4c8ea8c0d2cc0 100644 --- a/test/common/upstream/original_dst_cluster_test.cc +++ b/test/common/upstream/original_dst_cluster_test.cc @@ -428,7 +428,7 @@ TEST_F(OriginalDstClusterTest, Connection) { EXPECT_CALL(dispatcher_, createClientConnection_(PointeesEq(connection.local_address_), _, _, _)) .WillOnce(Return(new NiceMock())); - host->createConnection(dispatcher_, nullptr); + host->createConnection(dispatcher_, nullptr, nullptr); } TEST_F(OriginalDstClusterTest, MultipleClusters) { diff --git a/test/extensions/filters/network/sni_cluster/sni_cluster_test.cc b/test/extensions/filters/network/sni_cluster/sni_cluster_test.cc index 74a8ae6474ba0..9cc39bd6d2e61 100644 --- a/test/extensions/filters/network/sni_cluster/sni_cluster_test.cc +++ b/test/extensions/filters/network/sni_cluster/sni_cluster_test.cc @@ -51,7 +51,7 @@ TEST(SniCluster, SetTcpProxyClusterOnlyIfSniIsPresent) { filter.onNewConnection(); EXPECT_FALSE(stream_info.filterState().hasData( - TcpProxy::PerConnectionCluster::Key)); + TcpProxy::PerConnectionCluster::key())); } // with sni @@ -61,11 +61,11 @@ TEST(SniCluster, SetTcpProxyClusterOnlyIfSniIsPresent) { filter.onNewConnection(); EXPECT_TRUE(stream_info.filterState().hasData( - TcpProxy::PerConnectionCluster::Key)); + TcpProxy::PerConnectionCluster::key())); auto per_connection_cluster = stream_info.filterState().getDataReadOnly( - TcpProxy::PerConnectionCluster::Key); + TcpProxy::PerConnectionCluster::key()); EXPECT_EQ(per_connection_cluster.value(), "filter_state_cluster"); } } diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc index d35809d97f678..3566e96809e8e 100644 --- a/test/extensions/filters/network/thrift_proxy/router_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -483,7 +483,7 @@ TEST_F(ThriftRouterTest, NoHealthyHosts) { EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); - EXPECT_CALL(context_.cluster_manager_, tcpConnPoolForCluster(cluster_name_, _, _)) + EXPECT_CALL(context_.cluster_manager_, tcpConnPoolForCluster(cluster_name_, _, _, _)) .WillOnce(Return(nullptr)); EXPECT_CALL(callbacks_, sendLocalReply(_, _)) diff --git a/test/extensions/transport_sockets/alts/alts_integration_test.cc b/test/extensions/transport_sockets/alts/alts_integration_test.cc index 336c3d36d56af..14f08ba382e96 100644 --- a/test/extensions/transport_sockets/alts/alts_integration_test.cc +++ b/test/extensions/transport_sockets/alts/alts_integration_test.cc @@ -98,7 +98,8 @@ class AltsIntegrationTestBase : public HttpIntegrationTest, Network::ClientConnectionPtr makeAltsConnection() { Network::Address::InstanceConstSharedPtr address = getAddress(version_, lookupPort("http")); return dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(), - client_alts_->createTransportSocket(), nullptr); + client_alts_->createTransportSocket(nullptr), + nullptr); } std::string fakeHandshakerServerAddress(bool connect_to_handshaker) { diff --git a/test/extensions/transport_sockets/alts/tsi_socket_test.cc b/test/extensions/transport_sockets/alts/tsi_socket_test.cc index 3a443115d7160..f3e42b231f952 100644 --- a/test/extensions/transport_sockets/alts/tsi_socket_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_socket_test.cc @@ -399,7 +399,7 @@ class TsiSocketFactoryTest : public testing::Test { }; TEST_F(TsiSocketFactoryTest, CreateTransportSocket) { - EXPECT_NE(nullptr, socket_factory_->createTransportSocket()); + EXPECT_NE(nullptr, socket_factory_->createTransportSocket(nullptr)); } TEST_F(TsiSocketFactoryTest, ImplementsSecureTransport) { diff --git a/test/integration/sds_dynamic_integration_test.cc b/test/integration/sds_dynamic_integration_test.cc index 0fa56d186a276..25a78a904d4dd 100644 --- a/test/integration/sds_dynamic_integration_test.cc +++ b/test/integration/sds_dynamic_integration_test.cc @@ -205,7 +205,8 @@ class SdsDynamicDownstreamIntegrationTest : public SdsDynamicIntegrationBaseTest Network::ClientConnectionPtr makeSslClientConnection() { Network::Address::InstanceConstSharedPtr address = getSslAddress(version_, lookupPort("http")); return dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(), - client_ssl_ctx_->createTransportSocket(), nullptr); + client_ssl_ctx_->createTransportSocket(nullptr), + nullptr); } protected: diff --git a/test/integration/sds_static_integration_test.cc b/test/integration/sds_static_integration_test.cc index 38e5e2e194d7a..139f68c8fb688 100644 --- a/test/integration/sds_static_integration_test.cc +++ b/test/integration/sds_static_integration_test.cc @@ -85,7 +85,8 @@ class SdsStaticDownstreamIntegrationTest Network::ClientConnectionPtr makeSslClientConnection() { Network::Address::InstanceConstSharedPtr address = getSslAddress(version_, lookupPort("http")); return dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(), - client_ssl_ctx_->createTransportSocket(), nullptr); + client_ssl_ctx_->createTransportSocket(nullptr), + nullptr); } private: diff --git a/test/integration/ssl_integration_test.cc b/test/integration/ssl_integration_test.cc index 4fac5833b27d9..2b0856ba0dce4 100644 --- a/test/integration/ssl_integration_test.cc +++ b/test/integration/ssl_integration_test.cc @@ -55,14 +55,15 @@ Network::ClientConnectionPtr SslIntegrationTest::makeSslClientConnection(bool al if (alpn) { return dispatcher_->createClientConnection( address, Network::Address::InstanceConstSharedPtr(), - san ? client_ssl_ctx_alpn_san_->createTransportSocket() - : client_ssl_ctx_alpn_->createTransportSocket(), + san ? client_ssl_ctx_alpn_san_->createTransportSocket(nullptr) + : client_ssl_ctx_alpn_->createTransportSocket(nullptr), nullptr); } else { - return dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(), - san ? client_ssl_ctx_san_->createTransportSocket() - : client_ssl_ctx_plain_->createTransportSocket(), - nullptr); + return dispatcher_->createClientConnection( + address, Network::Address::InstanceConstSharedPtr(), + san ? client_ssl_ctx_san_->createTransportSocket(nullptr) + : client_ssl_ctx_plain_->createTransportSocket(nullptr), + nullptr); } } diff --git a/test/integration/tcp_conn_pool_integration_test.cc b/test/integration/tcp_conn_pool_integration_test.cc index e4905962c959f..0d5f8f0452435 100644 --- a/test/integration/tcp_conn_pool_integration_test.cc +++ b/test/integration/tcp_conn_pool_integration_test.cc @@ -26,7 +26,7 @@ class TestFilter : public Network::ReadFilter { UNREFERENCED_PARAMETER(end_stream); Tcp::ConnectionPool::Instance* pool = cluster_manager_.tcpConnPoolForCluster( - "cluster_0", Upstream::ResourcePriority::Default, nullptr); + "cluster_0", Upstream::ResourcePriority::Default, nullptr, nullptr); ASSERT(pool != nullptr); requests_.emplace_back(*this, data); diff --git a/test/integration/tcp_proxy_integration_test.cc b/test/integration/tcp_proxy_integration_test.cc index 23fc78173d849..c4bcefb4a8b02 100644 --- a/test/integration/tcp_proxy_integration_test.cc +++ b/test/integration/tcp_proxy_integration_test.cc @@ -392,7 +392,7 @@ void TcpProxySslIntegrationTest::setupConnections() { context_ = Ssl::createClientSslTransportSocketFactory(false, false, *context_manager_); ssl_client_ = dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(), - context_->createTransportSocket(), nullptr); + context_->createTransportSocket(nullptr), nullptr); // Perform the SSL handshake. Loopback is whitelisted in tcp_proxy.json for the ssl_auth // filter so there will be no pause waiting on auth data. diff --git a/test/integration/xfcc_integration_test.cc b/test/integration/xfcc_integration_test.cc index ce919d7bae0f4..c70e946c1d09b 100644 --- a/test/integration/xfcc_integration_test.cc +++ b/test/integration/xfcc_integration_test.cc @@ -95,7 +95,7 @@ Network::ClientConnectionPtr XfccIntegrationTest::makeMtlsClientConnection() { Network::Utility::resolveUrl("tcp://" + Network::Test::getLoopbackAddressUrlString(version_) + ":" + std::to_string(lookupPort("http"))); return dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(), - client_mtls_ssl_ctx_->createTransportSocket(), + client_mtls_ssl_ctx_->createTransportSocket(nullptr), nullptr); } diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 4413113bb67ae..5699044b11748 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -471,7 +471,7 @@ class MockTransportSocketFactory : public TransportSocketFactory { ~MockTransportSocketFactory(); MOCK_CONST_METHOD0(implementsSecureTransport, bool()); - MOCK_CONST_METHOD0(createTransportSocket, TransportSocketPtr()); + MOCK_CONST_METHOD1(createTransportSocket, TransportSocketPtr(TransportSocketOptionsSharedPtr)); }; class MockTransportSocketCallbacks : public TransportSocketCallbacks { diff --git a/test/mocks/stream_info/mocks.cc b/test/mocks/stream_info/mocks.cc index 04861ceae97ee..b8a15b15b466a 100644 --- a/test/mocks/stream_info/mocks.cc +++ b/test/mocks/stream_info/mocks.cc @@ -6,6 +6,7 @@ #include "gtest/gtest.h" using testing::_; +using testing::Const; using testing::Invoke; using testing::Return; using testing::ReturnPointee; @@ -65,6 +66,7 @@ MockStreamInfo::MockStreamInfo() ON_CALL(*this, bytesSent()).WillByDefault(ReturnPointee(&bytes_sent_)); ON_CALL(*this, dynamicMetadata()).WillByDefault(ReturnRef(metadata_)); ON_CALL(*this, filterState()).WillByDefault(ReturnRef(filter_state_)); + ON_CALL(Const(*this), filterState()).WillByDefault(ReturnRef(filter_state_)); ON_CALL(*this, setRequestedServerName(_)) .WillByDefault(Invoke([this](const absl::string_view requested_server_name) { requested_server_name_ = std::string(requested_server_name); diff --git a/test/mocks/upstream/host.h b/test/mocks/upstream/host.h index b62a9e1aa4a22..14600932fad49 100644 --- a/test/mocks/upstream/host.h +++ b/test/mocks/upstream/host.h @@ -108,9 +108,9 @@ class MockHost : public Host { MockHost(); ~MockHost(); - CreateConnectionData - createConnection(Event::Dispatcher& dispatcher, - const Network::ConnectionSocket::OptionsSharedPtr& options) const override { + CreateConnectionData createConnection(Event::Dispatcher& dispatcher, + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr) const override { MockCreateConnectionData data = createConnection_(dispatcher, options); return {Network::ClientConnectionPtr{data.connection_}, data.host_description_}; } diff --git a/test/mocks/upstream/mocks.cc b/test/mocks/upstream/mocks.cc index f3d13cabb6aab..beabb3d66868c 100644 --- a/test/mocks/upstream/mocks.cc +++ b/test/mocks/upstream/mocks.cc @@ -104,7 +104,7 @@ MockClusterManager::MockClusterManager(TimeSource&) : MockClusterManager() {} MockClusterManager::MockClusterManager() { ON_CALL(*this, httpConnPoolForCluster(_, _, _, _)).WillByDefault(Return(&conn_pool_)); - ON_CALL(*this, tcpConnPoolForCluster(_, _, _)).WillByDefault(Return(&tcp_conn_pool_)); + ON_CALL(*this, tcpConnPoolForCluster(_, _, _, _)).WillByDefault(Return(&tcp_conn_pool_)); ON_CALL(*this, httpAsyncClientForCluster(_)).WillByDefault(ReturnRef(async_client_)); ON_CALL(*this, httpAsyncClientForCluster(_)).WillByDefault((ReturnRef(async_client_))); ON_CALL(*this, bindConfig()).WillByDefault(ReturnRef(bind_config_)); diff --git a/test/mocks/upstream/mocks.h b/test/mocks/upstream/mocks.h index e574b9702d6f6..2e971ec67c27b 100644 --- a/test/mocks/upstream/mocks.h +++ b/test/mocks/upstream/mocks.h @@ -213,11 +213,11 @@ class MockClusterManagerFactory : public ClusterManagerFactory { ResourcePriority priority, Http::Protocol protocol, const Network::ConnectionSocket::OptionsSharedPtr& options)); - MOCK_METHOD4( - allocateTcpConnPool, - Tcp::ConnectionPool::InstancePtr(Event::Dispatcher& dispatcher, HostConstSharedPtr host, - ResourcePriority priority, - const Network::ConnectionSocket::OptionsSharedPtr& options)); + MOCK_METHOD5(allocateTcpConnPool, Tcp::ConnectionPool::InstancePtr( + Event::Dispatcher& dispatcher, HostConstSharedPtr host, + ResourcePriority priority, + const Network::ConnectionSocket::OptionsSharedPtr& options, + Network::TransportSocketOptionsSharedPtr)); MOCK_METHOD5(clusterFromProto, ClusterSharedPtr(const envoy::api::v2::Cluster& cluster, ClusterManager& cm, @@ -251,7 +251,8 @@ class MockClusterManager : public ClusterManager { } Host::CreateConnectionData tcpConnForCluster(const std::string& cluster, - LoadBalancerContext* context) override { + LoadBalancerContext* context, + Network::TransportSocketOptionsSharedPtr) override { MockHost::MockCreateConnectionData data = tcpConnForCluster_(cluster, context); return {Network::ClientConnectionPtr{data.connection_}, data.host_description_}; } @@ -268,9 +269,11 @@ class MockClusterManager : public ClusterManager { Http::ConnectionPool::Instance*(const std::string& cluster, ResourcePriority priority, Http::Protocol protocol, LoadBalancerContext* context)); - MOCK_METHOD3(tcpConnPoolForCluster, - Tcp::ConnectionPool::Instance*(const std::string& cluster, ResourcePriority priority, - LoadBalancerContext* context)); + MOCK_METHOD4(tcpConnPoolForCluster, + Tcp::ConnectionPool::Instance*( + const std::string& cluster, ResourcePriority priority, + LoadBalancerContext* context, + Network::TransportSocketOptionsSharedPtr transport_socket_options)); MOCK_METHOD2(tcpConnForCluster_, MockHost::MockCreateConnectionData(const std::string& cluster, LoadBalancerContext* context)); diff --git a/test/server/config_validation/cluster_manager_test.cc b/test/server/config_validation/cluster_manager_test.cc index 8bfca0a806ebd..058676ce7f606 100644 --- a/test/server/config_validation/cluster_manager_test.cc +++ b/test/server/config_validation/cluster_manager_test.cc @@ -46,7 +46,7 @@ TEST(ValidationClusterManagerTest, MockedMethods) { bootstrap, stats, tls, runtime, random, local_info, log_manager, admin); EXPECT_EQ(nullptr, cluster_manager->httpConnPoolForCluster("cluster", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); - Host::CreateConnectionData data = cluster_manager->tcpConnForCluster("cluster", nullptr); + Host::CreateConnectionData data = cluster_manager->tcpConnForCluster("cluster", nullptr, nullptr); EXPECT_EQ(nullptr, data.connection_); EXPECT_EQ(nullptr, data.host_description_); diff --git a/test/server/listener_manager_impl_test.cc b/test/server/listener_manager_impl_test.cc index 1278a2268e3d5..072e99a604c48 100644 --- a/test/server/listener_manager_impl_test.cc +++ b/test/server/listener_manager_impl_test.cc @@ -1154,7 +1154,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDestinationP filter_chain = findFilterChain(8080, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1196,7 +1196,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDestinationI filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1244,7 +1244,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithServerNamesM findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1283,7 +1283,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithTransportPro filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1322,7 +1322,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithApplicationP findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1370,7 +1370,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto uri = ssl_socket->uriSanLocalCertificate(); EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); @@ -1379,7 +1379,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(8080, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1389,7 +1389,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(8081, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -1399,7 +1399,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(0, true, "/tmp/test.sock", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); ssl_socket = dynamic_cast(transport_socket.get()); uri = ssl_socket->uriSanLocalCertificate(); EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); @@ -1446,7 +1446,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto uri = ssl_socket->uriSanLocalCertificate(); EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); @@ -1455,7 +1455,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(1234, true, "192.168.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1465,7 +1465,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(1234, true, "192.168.1.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -1475,7 +1475,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinati filter_chain = findFilterChain(0, true, "/tmp/test.sock", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); ssl_socket = dynamic_cast(transport_socket.get()); uri = ssl_socket->uriSanLocalCertificate(); EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); @@ -1531,7 +1531,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto uri = ssl_socket->uriSanLocalCertificate(); EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); @@ -1541,7 +1541,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1552,7 +1552,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam findFilterChain(1234, true, "127.0.0.1", true, "server2.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -1563,7 +1563,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam findFilterChain(1234, true, "127.0.0.1", true, "www.wildcard.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); ssl_socket = dynamic_cast(transport_socket.get()); server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 2); @@ -1605,7 +1605,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithTransport filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1647,7 +1647,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithApplicati findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1); @@ -1702,7 +1702,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithMultipleR true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(nullptr); auto ssl_socket = dynamic_cast(transport_socket.get()); auto server_names = ssl_socket->dnsSansLocalCertificate(); EXPECT_EQ(server_names.size(), 1);