From f400259e59d59bf83033bcf3c778f72a377b794e Mon Sep 17 00:00:00 2001 From: wfurt Date: Mon, 8 Jun 2020 00:20:10 -0700 Subject: [PATCH] allow access to crypto props from RemoteCertificateValidationCallback --- .../src/System/Net/Security/SslStream.cs | 27 +++++++++++++------ .../ClientAsyncAuthenticateTest.cs | 25 +++++++++++++++-- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index f3638799d77324..132ad45963c593 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -484,7 +484,7 @@ public virtual SslProtocols SslProtocol { get { - ThrowIfExceptionalOrNotAuthenticated(); + ThrowIfExceptionalOrNotHandshake(); SslConnectionInfo? info = _context!.ConnectionInfo; if (info == null) { @@ -560,7 +560,7 @@ public virtual TlsCipherSuite NegotiatedCipherSuite { get { - ThrowIfExceptionalOrNotAuthenticated(); + ThrowIfExceptionalOrNotHandshake(); return _context!.ConnectionInfo?.TlsCipherSuite ?? default(TlsCipherSuite); } } @@ -569,7 +569,7 @@ public virtual CipherAlgorithmType CipherAlgorithm { get { - ThrowIfExceptionalOrNotAuthenticated(); + ThrowIfExceptionalOrNotHandshake(); SslConnectionInfo? info = _context!.ConnectionInfo; if (info == null) { @@ -583,7 +583,7 @@ public virtual int CipherStrength { get { - ThrowIfExceptionalOrNotAuthenticated(); + ThrowIfExceptionalOrNotHandshake(); SslConnectionInfo? info = _context!.ConnectionInfo; if (info == null) { @@ -598,7 +598,7 @@ public virtual HashAlgorithmType HashAlgorithm { get { - ThrowIfExceptionalOrNotAuthenticated(); + ThrowIfExceptionalOrNotHandshake(); SslConnectionInfo? info = _context!.ConnectionInfo; if (info == null) { @@ -612,7 +612,7 @@ public virtual int HashStrength { get { - ThrowIfExceptionalOrNotAuthenticated(); + ThrowIfExceptionalOrNotHandshake(); SslConnectionInfo? info = _context!.ConnectionInfo; if (info == null) { @@ -627,7 +627,7 @@ public virtual ExchangeAlgorithmType KeyExchangeAlgorithm { get { - ThrowIfExceptionalOrNotAuthenticated(); + ThrowIfExceptionalOrNotHandshake(); SslConnectionInfo? info = _context!.ConnectionInfo; if (info == null) { @@ -642,7 +642,7 @@ public virtual int KeyExchangeStrength { get { - ThrowIfExceptionalOrNotAuthenticated(); + ThrowIfExceptionalOrNotHandshake(); SslConnectionInfo? info = _context!.ConnectionInfo; if (info == null) { @@ -863,6 +863,17 @@ private void ThrowIfExceptionalOrNotAuthenticated() } } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ThrowIfExceptionalOrNotHandshake() + { + ThrowIfExceptional(); + + if (!IsAuthenticated && _context?.ConnectionInfo == null) + { + ThrowNotAuthenticated(); + } + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private void ThrowIfExceptionalOrNotAuthenticatedOrShutdown() { diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/ClientAsyncAuthenticateTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/ClientAsyncAuthenticateTest.cs index a8dba5b6b0b6b0..9926e08102dc90 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/ClientAsyncAuthenticateTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/ClientAsyncAuthenticateTest.cs @@ -40,6 +40,12 @@ public async Task ClientAsyncAuthenticate_ServerRequireEncryption_ConnectWithEnc await ClientAsyncSslHelper(EncryptionPolicy.RequireEncryption); } + [Fact] + public async Task ClientAsyncAuthenticate_ConnectionInfoInCallback_DoesNotThrow() + { + await ClientAsyncSslHelper(EncryptionPolicy.RequireEncryption, SslProtocols.Tls12, SslProtocolSupport.DefaultSslProtocols, AllowAnyServerCertificateAndVerifyConnectionInfo); + } + [Fact] public async Task ClientAsyncAuthenticate_ServerNoEncryption_NoConnect() { @@ -139,7 +145,8 @@ private Task ClientAsyncSslHelper(SslProtocols clientSslProtocols, SslProtocols private async Task ClientAsyncSslHelper( EncryptionPolicy encryptionPolicy, SslProtocols clientSslProtocols, - SslProtocols serverSslProtocols) + SslProtocols serverSslProtocols, + RemoteCertificateValidationCallback certificateCallback = null) { _log.WriteLine("Server: " + serverSslProtocols + "; Client: " + clientSslProtocols); @@ -150,7 +157,7 @@ private async Task ClientAsyncSslHelper( { server.SslProtocols = serverSslProtocols; await client.ConnectAsync(server.RemoteEndPoint.Address, server.RemoteEndPoint.Port); - using (SslStream sslStream = new SslStream(client.GetStream(), false, AllowAnyServerCertificate, null)) + using (SslStream sslStream = new SslStream(client.GetStream(), false, certificateCallback != null ? certificateCallback : AllowAnyServerCertificate, null)) { Task clientAuthTask = sslStream.AuthenticateAsClientAsync("localhost", null, clientSslProtocols, false); await clientAuthTask.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds); @@ -173,6 +180,20 @@ private bool AllowAnyServerCertificate( return true; // allow everything } + private bool AllowAnyServerCertificateAndVerifyConnectionInfo( + object sender, + X509Certificate certificate, + X509Chain chain, + SslPolicyErrors sslPolicyErrors) + { + SslStream stream = (SslStream)sender; + + Assert.NotEqual(SslProtocols.None, stream.SslProtocol); + Assert.NotEqual(CipherAlgorithmType.None, stream.CipherAlgorithm); + + return true; // allow everything + } + #endregion Helpers } }