Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,13 @@ private bool AcquireServerCredentials(ref byte[]? thumbPrint)
_sslAuthenticationOptions.CertificateContext = SslStreamCertificateContext.Create(selectedCert);
}

Debug.Assert(_sslAuthenticationOptions.CertificateContext != null);
//
// Note selectedCert is a safe ref possibly cloned from the user passed Cert object
//
byte[] guessedThumbPrint = selectedCert.GetCertHash();
SafeFreeCredentials? cachedCredentialHandle = SslSessionsCache.TryCachedCredential(guessedThumbPrint, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.IsServer, _sslAuthenticationOptions.EncryptionPolicy);
bool sendTrustedList = _sslAuthenticationOptions.CertificateContext!.Trust?._sendTrustInHandshake ?? false;
SafeFreeCredentials? cachedCredentialHandle = SslSessionsCache.TryCachedCredential(guessedThumbPrint, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.IsServer, _sslAuthenticationOptions.EncryptionPolicy, sendTrustedList);

if (cachedCredentialHandle != null)
{
Expand Down Expand Up @@ -763,6 +765,7 @@ private SecurityStatusPal GenerateToken(ReadOnlySpan<byte> inputBuffer, ref byte
byte[]? result = Array.Empty<byte>();
SecurityStatusPal status = default;
bool cachedCreds = false;
bool sendTrustList = false;
byte[]? thumbPrint = null;

//
Expand All @@ -779,6 +782,11 @@ private SecurityStatusPal GenerateToken(ReadOnlySpan<byte> inputBuffer, ref byte
cachedCreds = _sslAuthenticationOptions.IsServer
? AcquireServerCredentials(ref thumbPrint)
: AcquireClientCredentials(ref thumbPrint);

if (cachedCreds && _sslAuthenticationOptions.IsServer)
{
sendTrustList = _sslAuthenticationOptions.CertificateContext?.Trust?._sendTrustInHandshake ?? false;
}
}

if (_sslAuthenticationOptions.IsServer)
Expand Down Expand Up @@ -820,7 +828,7 @@ private SecurityStatusPal GenerateToken(ReadOnlySpan<byte> inputBuffer, ref byte
//
if (!cachedCreds && _securityContext != null && !_securityContext.IsInvalid && _credentialsHandle != null && !_credentialsHandle.IsInvalid)
{
SslSessionsCache.CacheCredential(_credentialsHandle, thumbPrint, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.IsServer, _sslAuthenticationOptions.EncryptionPolicy);
SslSessionsCache.CacheCredential(_credentialsHandle, thumbPrint, _sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.IsServer, _sslAuthenticationOptions.EncryptionPolicy, sendTrustList);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,20 @@ internal static class SslSessionsCache
private readonly int _allowedProtocols;
private readonly EncryptionPolicy _encryptionPolicy;
private readonly bool _isServerMode;
private readonly bool _sendTrustList;

//
// SECURITY: X509Certificate.GetCertHash() is virtual hence before going here,
// the caller of this ctor has to ensure that a user cert object was inspected and
// optionally cloned.
//
internal SslCredKey(byte[]? thumbPrint, int allowedProtocols, bool isServerMode, EncryptionPolicy encryptionPolicy)
internal SslCredKey(byte[]? thumbPrint, int allowedProtocols, bool isServerMode, EncryptionPolicy encryptionPolicy, bool sendTrustList)
{
_thumbPrint = thumbPrint ?? Array.Empty<byte>();
_allowedProtocols = allowedProtocols;
_encryptionPolicy = encryptionPolicy;
_isServerMode = isServerMode;
_sendTrustList = sendTrustList;
}

public override int GetHashCode()
Expand Down Expand Up @@ -65,6 +67,7 @@ public override int GetHashCode()
hashCode ^= _allowedProtocols;
hashCode ^= (int)_encryptionPolicy;
hashCode ^= _isServerMode ? 0x10000 : 0x20000;
hashCode ^= _sendTrustList ? 0x40000 : 0x80000;

return hashCode;
}
Expand Down Expand Up @@ -96,6 +99,11 @@ public bool Equals(SslCredKey other)
return false;
}

if (_sendTrustList != other._sendTrustList)
{
return false;
}

for (int i = 0; i < thumbPrint.Length; ++i)
{
if (thumbPrint[i] != otherThumbPrint[i])
Expand All @@ -114,15 +122,15 @@ public bool Equals(SslCredKey other)
// ATTN: The returned handle can be invalid, the callers of InitializeSecurityContext and AcceptSecurityContext
// must be prepared to execute a back-out code if the call fails.
//
internal static SafeFreeCredentials? TryCachedCredential(byte[]? thumbPrint, SslProtocols sslProtocols, bool isServer, EncryptionPolicy encryptionPolicy)
internal static SafeFreeCredentials? TryCachedCredential(byte[]? thumbPrint, SslProtocols sslProtocols, bool isServer, EncryptionPolicy encryptionPolicy, bool sendTrustList = false)
{
if (s_cachedCreds.IsEmpty)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"Not found, Current Cache Count = {s_cachedCreds.Count}");
return null;
}

var key = new SslCredKey(thumbPrint, (int)sslProtocols, isServer, encryptionPolicy);
var key = new SslCredKey(thumbPrint, (int)sslProtocols, isServer, encryptionPolicy, sendTrustList);

//SafeCredentialReference? cached;
SafeFreeCredentials? credentials = GetCachedCredential(key);
Expand All @@ -147,7 +155,7 @@ public bool Equals(SslCredKey other)
//
// ATTN: The thumbPrint must be from inspected and possibly cloned user Cert object or we get a security hole in SslCredKey ctor.
//
internal static void CacheCredential(SafeFreeCredentials creds, byte[]? thumbPrint, SslProtocols sslProtocols, bool isServer, EncryptionPolicy encryptionPolicy)
internal static void CacheCredential(SafeFreeCredentials creds, byte[]? thumbPrint, SslProtocols sslProtocols, bool isServer, EncryptionPolicy encryptionPolicy, bool sendTrustList = false)
{
Debug.Assert(creds != null, "creds == null");

Expand All @@ -157,7 +165,7 @@ internal static void CacheCredential(SafeFreeCredentials creds, byte[]? thumbPri
return;
}

SslCredKey key = new SslCredKey(thumbPrint, (int)sslProtocols, isServer, encryptionPolicy);
SslCredKey key = new SslCredKey(thumbPrint, (int)sslProtocols, isServer, encryptionPolicy, sendTrustList);

SafeFreeCredentials? credentials = GetCachedCredential(key);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ public static SafeFreeCredentials AcquireCredentialsHandle(SslStreamCertificateC
{
// New crypto API supports TLS1.3 but it does not allow to force NULL encryption.
SafeFreeCredentials cred = !UseNewCryptoApi || policy == EncryptionPolicy.NoEncryption ?
AcquireCredentialsHandleSchannelCred(certificateContext?.Certificate, protocols, policy, isServer) :
AcquireCredentialsHandleSchCredentials(certificateContext?.Certificate, protocols, policy, isServer);
AcquireCredentialsHandleSchannelCred(certificateContext, protocols, policy, isServer) :
AcquireCredentialsHandleSchCredentials(certificateContext, protocols, policy, isServer);
if (certificateContext != null && certificateContext.Trust != null && certificateContext.Trust._sendTrustInHandshake)
{
AttachCertificateStore(cred, certificateContext.Trust._store!);
Expand Down Expand Up @@ -157,8 +157,9 @@ private static unsafe void AttachCertificateStore(SafeFreeCredentials cred, X509

// This is legacy crypto API used on .NET Framework and older Windows versions.
// It only supports TLS up to 1.2
public static unsafe SafeFreeCredentials AcquireCredentialsHandleSchannelCred(X509Certificate2? certificate, SslProtocols protocols, EncryptionPolicy policy, bool isServer)
public static unsafe SafeFreeCredentials AcquireCredentialsHandleSchannelCred(SslStreamCertificateContext? certificateContext, SslProtocols protocols, EncryptionPolicy policy, bool isServer)
{
X509Certificate2? certificate = certificateContext?.Certificate;
int protocolFlags = GetProtocolFlagsFromSslProtocols(protocols, isServer);
Interop.SspiCli.SCHANNEL_CRED.Flags flags;
Interop.SspiCli.CredentialUse direction;
Expand All @@ -183,6 +184,10 @@ public static unsafe SafeFreeCredentials AcquireCredentialsHandleSchannelCred(X5
{
direction = Interop.SspiCli.CredentialUse.SECPKG_CRED_INBOUND;
flags = Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_SEND_AUX_RECORD;
if (certificateContext?.Trust?._sendTrustInHandshake == true)
{
flags |= Interop.SspiCli.SCHANNEL_CRED.Flags.SCH_CRED_NO_SYSTEM_MAPPER;
}
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Info($"flags=({flags}), ProtocolFlags=({protocolFlags}), EncryptionPolicy={policy}");
Expand All @@ -203,15 +208,20 @@ public static unsafe SafeFreeCredentials AcquireCredentialsHandleSchannelCred(X5
}

// This function uses new crypto API to support TLS 1.3 and beyond.
public static unsafe SafeFreeCredentials AcquireCredentialsHandleSchCredentials(X509Certificate2? certificate, SslProtocols protocols, EncryptionPolicy policy, bool isServer)
public static unsafe SafeFreeCredentials AcquireCredentialsHandleSchCredentials(SslStreamCertificateContext? certificateContext, SslProtocols protocols, EncryptionPolicy policy, bool isServer)
{
X509Certificate2? certificate = certificateContext?.Certificate;
int protocolFlags = GetProtocolFlagsFromSslProtocols(protocols, isServer);
Interop.SspiCli.SCH_CREDENTIALS.Flags flags;
Interop.SspiCli.CredentialUse direction;
if (isServer)
{
direction = Interop.SspiCli.CredentialUse.SECPKG_CRED_INBOUND;
flags = Interop.SspiCli.SCH_CREDENTIALS.Flags.SCH_SEND_AUX_RECORD;
if (certificateContext?.Trust?._sendTrustInHandshake == true)
{
flags |= Interop.SspiCli.SCH_CREDENTIALS.Flags.SCH_CRED_NO_SYSTEM_MAPPER;
}
}
else
{
Expand Down