diff --git a/include/envoy/ssl/connection.h b/include/envoy/ssl/connection.h index d586d9fe09a57..53dbf0e4ba819 100644 --- a/include/envoy/ssl/connection.h +++ b/include/envoy/ssl/connection.h @@ -23,6 +23,12 @@ class ConnectionInfo { **/ virtual bool peerCertificatePresented() const PURE; + /** + * @return bool whether the local certificate is requested on the client by the server. This + * flag is always true on the server. + **/ + virtual bool localCertificatePresented() const PURE; + /** * @return std::string the URIs in the SAN field of the local certificate. Returns {} if there is * no local certificate, or no SAN field, or no URI. diff --git a/source/extensions/filters/common/expr/context.cc b/source/extensions/filters/common/expr/context.cc index cf2e3a1b06427..3489a7fddbd60 100644 --- a/source/extensions/filters/common/expr/context.cc +++ b/source/extensions/filters/common/expr/context.cc @@ -112,6 +112,7 @@ absl::optional ConnectionWrapper::operator[](CelValue key) const { auto value = key.StringOrDie().value(); if (value == MTLS) { return CelValue::CreateBool(info_.downstreamSslConnection() != nullptr && + info_.downstreamSslConnection()->localCertificatePresented() && info_.downstreamSslConnection()->peerCertificatePresented()); } else if (value == RequestedServerName) { return CelValue::CreateString(info_.requestedServerName()); @@ -144,6 +145,7 @@ absl::optional UpstreamWrapper::operator[](CelValue key) const { } } else if (value == MTLS) { return CelValue::CreateBool(info_.upstreamSslConnection() != nullptr && + info_.upstreamSslConnection()->localCertificatePresented() && info_.upstreamSslConnection()->peerCertificatePresented()); } diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index 7737d36b160f6..b076a003bec0c 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -48,7 +48,7 @@ SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, ctx_(std::dynamic_pointer_cast(ctx)), state_(SocketState::PreHandshake) { bssl::UniquePtr ssl = ctx_->newSsl(transport_socket_options_.get()); ssl_ = ssl.get(); - info_ = std::make_shared(std::move(ssl)); + info_ = std::make_shared(std::move(ssl), state); if (state == InitialState::Client) { SSL_set_connect_state(ssl_); } else { @@ -301,11 +301,29 @@ void SslSocket::shutdownSsl() { } } +SslSocketInfo::SslSocketInfo(bssl::UniquePtr ssl, InitialState state) : ssl_(std::move(ssl)) { + if (state == InitialState::Client) { + SSL_set_cert_cb( + ssl_.get(), + [](SSL*, void* arg) -> int { + auto info = static_cast(arg); + info->local_cert_presented = true; + return 1; + }, + this); + } else { + ASSERT(state == InitialState::Server); + local_cert_presented = true; + } +} + bool SslSocketInfo::peerCertificatePresented() const { bssl::UniquePtr cert(SSL_get_peer_certificate(ssl_.get())); return cert != nullptr; } +bool SslSocketInfo::localCertificatePresented() const { return local_cert_presented; } + std::vector SslSocketInfo::uriSanLocalCertificate() const { if (!cached_uri_san_local_certificate_.empty()) { return cached_uri_san_local_certificate_; diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index 4f7660717b807..d67f2da0ac5dc 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -43,10 +43,11 @@ enum class SocketState { PreHandshake, HandshakeInProgress, HandshakeComplete, S class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { public: - SslSocketInfo(bssl::UniquePtr ssl) : ssl_(std::move(ssl)) {} + SslSocketInfo(bssl::UniquePtr ssl, InitialState state); // Ssl::ConnectionInfo bool peerCertificatePresented() const override; + bool localCertificatePresented() const override; std::vector uriSanLocalCertificate() const override; const std::string& sha256PeerCertificateDigest() const override; const std::string& serialNumberPeerCertificate() const override; @@ -83,6 +84,7 @@ class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { mutable std::vector cached_dns_san_local_certificate_; mutable std::string cached_session_id_; mutable std::string cached_tls_version_; + bool local_cert_presented = false; }; class SslSocket : public Network::TransportSocket, diff --git a/test/extensions/filters/common/expr/context_test.cc b/test/extensions/filters/common/expr/context_test.cc index 7f11fa75bb55e..827337526895f 100644 --- a/test/extensions/filters/common/expr/context_test.cc +++ b/test/extensions/filters/common/expr/context_test.cc @@ -277,7 +277,9 @@ TEST(Context, ConnectionAttributes) { EXPECT_CALL(info, upstreamHost()).WillRepeatedly(Return(upstream_host)); EXPECT_CALL(info, requestedServerName()).WillRepeatedly(ReturnRef(sni_name)); EXPECT_CALL(*downstream_ssl_info, peerCertificatePresented()).WillRepeatedly(Return(true)); + EXPECT_CALL(*downstream_ssl_info, localCertificatePresented()).WillRepeatedly(Return(true)); EXPECT_CALL(*upstream_ssl_info, peerCertificatePresented()).WillRepeatedly(Return(true)); + EXPECT_CALL(*upstream_ssl_info, localCertificatePresented()).WillRepeatedly(Return(true)); const std::string tls_version = "TLSv1"; EXPECT_CALL(*downstream_ssl_info, tlsVersion()).WillRepeatedly(ReturnRef(tls_version)); EXPECT_CALL(*upstream_host, address()).WillRepeatedly(Return(upstream_address)); diff --git a/test/extensions/transport_sockets/tls/ssl_socket_test.cc b/test/extensions/transport_sockets/tls/ssl_socket_test.cc index aff65278f0877..e48787572dea9 100644 --- a/test/extensions/transport_sockets/tls/ssl_socket_test.cc +++ b/test/extensions/transport_sockets/tls/ssl_socket_test.cc @@ -100,7 +100,7 @@ class TestUtilOptions : public TestUtilOptionsBase { bool expect_success, Network::Address::IpVersion version) : TestUtilOptionsBase(expect_success, version), client_ctx_yaml_(client_ctx_yaml), server_ctx_yaml_(server_ctx_yaml), expect_no_cert_(false), expect_no_cert_chain_(false), - expect_private_key_method_(false), + expect_local_cert_(false), expect_private_key_method_(false), expected_server_close_event_(Network::ConnectionEvent::RemoteClose) { if (expect_success) { setExpectedServerStats("ssl.handshake"); @@ -131,6 +131,13 @@ class TestUtilOptions : public TestUtilOptionsBase { return *this; } + bool expectLocalCertPresented() const { return expect_local_cert_; } + + TestUtilOptions& setExpectLocalCertPresented() { + expect_local_cert_ = true; + return *this; + } + TestUtilOptions& setExpectedClientCertUri(const std::string& expected_client_cert_uri) { TestUtilOptionsBase::setExpectedClientCertUri(expected_client_cert_uri); return *this; @@ -230,6 +237,7 @@ class TestUtilOptions : public TestUtilOptionsBase { bool expect_no_cert_; bool expect_no_cert_chain_; + bool expect_local_cert_; bool expect_private_key_method_; Network::ConnectionEvent expected_server_close_event_; std::string expected_digest_; @@ -343,6 +351,7 @@ void testUtil(const TestUtilOptions& options) { server_connection->ssl()->subjectPeerCertificate()); } if (!options.expectedLocalSubject().empty()) { + EXPECT_TRUE(server_connection->ssl()->localCertificatePresented()); EXPECT_EQ(options.expectedLocalSubject(), server_connection->ssl()->subjectLocalCertificate()); } @@ -381,10 +390,18 @@ void testUtil(const TestUtilOptions& options) { EXPECT_EQ(EMPTY_STRING, server_connection->ssl()->subjectPeerCertificate()); EXPECT_EQ(std::vector{}, server_connection->ssl()->dnsSansPeerCertificate()); } + if (options.expectNoCertChain()) { EXPECT_EQ(EMPTY_STRING, server_connection->ssl()->urlEncodedPemEncodedPeerCertificateChain()); } + + if (options.expectLocalCertPresented()) { + EXPECT_TRUE(client_connection->ssl()->localCertificatePresented()); + } else { + EXPECT_FALSE(client_connection->ssl()->localCertificatePresented()); + } + // By default, the session is not created with session resumption. The // client should see a session ID but the server should not. EXPECT_EQ(EMPTY_STRING, server_connection->ssl()->sessionId()); @@ -791,6 +808,7 @@ TEST_P(SslSocketTest, GetCertDigest) { TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); testUtil(test_options.setExpectedDigest(TEST_NO_SAN_CERT_HASH) + .setExpectLocalCertPresented() .setExpectedSerialNumber(TEST_NO_SAN_CERT_SERIAL)); } @@ -862,6 +880,7 @@ TEST_P(SslSocketTest, GetCertDigestServerCertWithIntermediateCA) { TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); testUtil(test_options.setExpectedDigest(TEST_NO_SAN_CERT_HASH) + .setExpectLocalCertPresented() .setExpectedSerialNumber(TEST_NO_SAN_CERT_SERIAL)); } @@ -889,6 +908,7 @@ TEST_P(SslSocketTest, GetCertDigestServerCertWithoutCommonName) { TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); testUtil(test_options.setExpectedDigest(TEST_NO_SAN_CERT_HASH) + .setExpectLocalCertPresented() .setExpectedSerialNumber(TEST_NO_SAN_CERT_SERIAL)); } @@ -917,6 +937,7 @@ TEST_P(SslSocketTest, GetUriWithUriSan) { TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); testUtil(test_options.setExpectedClientCertUri("spiffe://lyft.com/test-team") + .setExpectLocalCertPresented() .setExpectedSerialNumber(TEST_SAN_URI_CERT_SERIAL)); } @@ -990,7 +1011,8 @@ TEST_P(SslSocketTest, GetNoUriWithDnsSan) { // The SAN field only has DNS, expect "" for uriSanPeerCertificate(). TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); - testUtil(test_options.setExpectedSerialNumber(TEST_SAN_DNS_CERT_SERIAL)); + testUtil( + test_options.setExpectedSerialNumber(TEST_SAN_DNS_CERT_SERIAL).setExpectLocalCertPresented()); } TEST_P(SslSocketTest, NoCert) { @@ -1070,6 +1092,7 @@ TEST_P(SslSocketTest, GetUriWithLocalUriSan) { TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); testUtil(test_options.setExpectedLocalUri("spiffe://lyft.com/test-team") + .setExpectLocalCertPresented() .setExpectedSerialNumber(TEST_NO_SAN_CERT_SERIAL)); } @@ -1098,6 +1121,7 @@ TEST_P(SslSocketTest, GetSubjectsWithBothCerts) { TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); testUtil(test_options.setExpectedSerialNumber(TEST_NO_SAN_CERT_SERIAL) + .setExpectLocalCertPresented() .setExpectedPeerIssuer( "CN=Test CA,OU=Lyft Engineering,O=Lyft,L=San Francisco,ST=California,C=US") .setExpectedPeerSubject( @@ -1134,6 +1158,7 @@ TEST_P(SslSocketTest, GetPeerCert) { TestEnvironment::readFileToStringForTest(TestEnvironment::substitute( "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/no_san_cert.pem")); testUtil(test_options.setExpectedSerialNumber(TEST_NO_SAN_CERT_SERIAL) + .setExpectLocalCertPresented() .setExpectedPeerIssuer( "CN=Test CA,OU=Lyft Engineering,O=Lyft,L=San Francisco,ST=California,C=US") .setExpectedPeerSubject( @@ -1172,6 +1197,7 @@ TEST_P(SslSocketTest, GetPeerCertChain) { "{{ test_rundir " "}}/test/extensions/transport_sockets/tls/test_data/no_san_chain.pem")); testUtil(test_options.setExpectedSerialNumber(TEST_NO_SAN_CERT_SERIAL) + .setExpectLocalCertPresented() .setExpectedPeerCertChain(expected_peer_cert_chain)); } @@ -1199,6 +1225,7 @@ TEST_P(SslSocketTest, GetIssueExpireTimesPeerCert) { )EOF"; TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); testUtil(test_options.setExpectedSerialNumber(TEST_NO_SAN_CERT_SERIAL) + .setExpectLocalCertPresented() .setExpectedValidFromTimePeerCert(TEST_NO_SAN_CERT_NOT_BEFORE) .setExpectedExpirationTimePeerCert(TEST_NO_SAN_CERT_NOT_AFTER)); } @@ -1421,6 +1448,7 @@ TEST_P(SslSocketTest, ClientCertificateHashVerification) { TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); testUtil(test_options.setExpectedClientCertUri("spiffe://lyft.com/test-team") + .setExpectLocalCertPresented() .setExpectedSerialNumber(TEST_SAN_URI_CERT_SERIAL)); } @@ -1447,6 +1475,7 @@ TEST_P(SslSocketTest, ClientCertificateHashVerificationNoCA) { TestUtilOptions test_options(client_ctx_yaml, server_ctx_yaml, true, GetParam()); testUtil(test_options.setExpectedClientCertUri("spiffe://lyft.com/test-team") + .setExpectLocalCertPresented() .setExpectedSerialNumber(TEST_SAN_URI_CERT_SERIAL)); } @@ -3694,7 +3723,8 @@ TEST_P(SslSocketTest, RevokedCertificate) { TestUtilOptions successful_test_options(successful_client_ctx_yaml, server_ctx_yaml, true, GetParam()); - testUtil(successful_test_options.setExpectedSerialNumber(TEST_SAN_DNS2_CERT_SERIAL)); + testUtil(successful_test_options.setExpectedSerialNumber(TEST_SAN_DNS2_CERT_SERIAL) + .setExpectLocalCertPresented()); } TEST_P(SslSocketTest, RevokedCertificateCRLInTrustedCA) { @@ -3735,7 +3765,8 @@ TEST_P(SslSocketTest, RevokedCertificateCRLInTrustedCA) { )EOF"; TestUtilOptions successful_test_options(successful_client_ctx_yaml, server_ctx_yaml, true, GetParam()); - testUtil(successful_test_options.setExpectedSerialNumber(TEST_SAN_DNS2_CERT_SERIAL)); + testUtil(successful_test_options.setExpectedSerialNumber(TEST_SAN_DNS2_CERT_SERIAL) + .setExpectLocalCertPresented()); } TEST_P(SslSocketTest, GetRequestedServerName) { @@ -4220,7 +4251,7 @@ TEST_P(SslSocketTest, RsaPrivateKeyProviderAsyncSignSuccess) { TestUtilOptions successful_test_options(successful_client_ctx_yaml, server_ctx_yaml, true, GetParam()); - testUtil(successful_test_options.setPrivateKeyMethodExpected(true)); + testUtil(successful_test_options.setPrivateKeyMethodExpected(true).setExpectLocalCertPresented()); } // Test asynchronous decryption (RSA). @@ -4252,7 +4283,7 @@ TEST_P(SslSocketTest, RsaPrivateKeyProviderAsyncDecryptSuccess) { TestUtilOptions successful_test_options(successful_client_ctx_yaml, server_ctx_yaml, true, GetParam()); - testUtil(successful_test_options.setPrivateKeyMethodExpected(true)); + testUtil(successful_test_options.setPrivateKeyMethodExpected(true).setExpectLocalCertPresented()); } // Test synchronous signing (ECDHE). @@ -4284,7 +4315,7 @@ TEST_P(SslSocketTest, RsaPrivateKeyProviderSyncSignSuccess) { TestUtilOptions successful_test_options(successful_client_ctx_yaml, server_ctx_yaml, true, GetParam()); - testUtil(successful_test_options.setPrivateKeyMethodExpected(true)); + testUtil(successful_test_options.setPrivateKeyMethodExpected(true).setExpectLocalCertPresented()); } // Test synchronous decryption (RSA). @@ -4316,7 +4347,7 @@ TEST_P(SslSocketTest, RsaPrivateKeyProviderSyncDecryptSuccess) { TestUtilOptions successful_test_options(successful_client_ctx_yaml, server_ctx_yaml, true, GetParam()); - testUtil(successful_test_options.setPrivateKeyMethodExpected(true)); + testUtil(successful_test_options.setPrivateKeyMethodExpected(true).setExpectLocalCertPresented()); } // Test asynchronous signing (ECDHE) failure (invalid signature). diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 041888aa99c9a..490865f8d0793 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -38,6 +38,7 @@ class MockConnectionInfo : public ConnectionInfo { ~MockConnectionInfo() override; MOCK_CONST_METHOD0(peerCertificatePresented, bool()); + MOCK_CONST_METHOD0(localCertificatePresented, bool()); MOCK_CONST_METHOD0(uriSanLocalCertificate, std::vector()); MOCK_CONST_METHOD0(sha256PeerCertificateDigest, const std::string&()); MOCK_CONST_METHOD0(serialNumberPeerCertificate, const std::string&());