diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs index 3f8daec8964724..e522e7ee6dff62 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs @@ -333,14 +333,11 @@ internal static int Encrypt(SafeSslHandle context, ReadOnlySpan input, ref int retVal; Exception? innerError = null; - lock (context) - { - retVal = Ssl.SslWrite(context, ref MemoryMarshal.GetReference(input), input.Length); + retVal = Ssl.SslWrite(context, ref MemoryMarshal.GetReference(input), input.Length); - if (retVal != input.Length) - { - errorCode = GetSslError(context, retVal, out innerError); - } + if (retVal != input.Length) + { + errorCode = GetSslError(context, retVal, out innerError); } if (retVal != input.Length) @@ -390,30 +387,27 @@ internal static int Decrypt(SafeSslHandle context, byte[] outBuffer, int offset, int retVal = BioWrite(context.InputBio!, outBuffer, offset, count); Exception? innerError = null; - lock (context) + if (retVal == count) { - if (retVal == count) + unsafe { - unsafe + fixed (byte* fixedBuffer = outBuffer) { - fixed (byte* fixedBuffer = outBuffer) - { - retVal = Ssl.SslRead(context, fixedBuffer + offset, outBuffer.Length); - } - } - - if (retVal > 0) - { - count = retVal; + retVal = Ssl.SslRead(context, fixedBuffer + offset, outBuffer.Length); } } - if (retVal != count) + if (retVal > 0) { - errorCode = GetSslError(context, retVal, out innerError); + count = retVal; } } + if (retVal != count) + { + errorCode = GetSslError(context, retVal, out innerError); + } + if (retVal != count) { retVal = 0; diff --git a/src/libraries/Common/src/System/Net/SecurityStatusPal.cs b/src/libraries/Common/src/System/Net/SecurityStatusPal.cs index 595aec0970aba1..d0f7b26c038d5e 100644 --- a/src/libraries/Common/src/System/Net/SecurityStatusPal.cs +++ b/src/libraries/Common/src/System/Net/SecurityStatusPal.cs @@ -35,6 +35,7 @@ internal enum SecurityStatusPalErrorCode ContextExpired, CredentialsNeeded, Renegotiate, + TryAgain, // Errors OutOfMemory, diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs index 982837e7fc93d9..4995f3fd09bf0e 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs @@ -12,9 +12,8 @@ public partial class SslStream private interface ISslIOAdapter { ValueTask ReadAsync(Memory buffer); - ValueTask ReadLockAsync(Memory buffer); - Task WriteLockAsync(); ValueTask WriteAsync(byte[] buffer, int offset, int count); + Task WaitAsync(TaskCompletionSource waiter); CancellationToken CancellationToken { get; } } @@ -31,12 +30,10 @@ public AsyncSslIOAdapter(SslStream sslStream, CancellationToken cancellationToke public ValueTask ReadAsync(Memory buffer) => _sslStream.InnerStream.ReadAsync(buffer, _cancellationToken); - public ValueTask ReadLockAsync(Memory buffer) => _sslStream.CheckEnqueueReadAsync(buffer); - - public Task WriteLockAsync() => _sslStream.CheckEnqueueWriteAsync(); - public ValueTask WriteAsync(byte[] buffer, int offset, int count) => _sslStream.InnerStream.WriteAsync(new ReadOnlyMemory(buffer, offset, count), _cancellationToken); + public Task WaitAsync(TaskCompletionSource waiter) => waiter.Task; + public CancellationToken CancellationToken => _cancellationToken; } @@ -48,17 +45,15 @@ public AsyncSslIOAdapter(SslStream sslStream, CancellationToken cancellationToke public ValueTask ReadAsync(Memory buffer) => new ValueTask(_sslStream.InnerStream.Read(buffer.Span)); - public ValueTask ReadLockAsync(Memory buffer) => new ValueTask(_sslStream.CheckEnqueueRead(buffer)); - public ValueTask WriteAsync(byte[] buffer, int offset, int count) { _sslStream.InnerStream.Write(buffer, offset, count); return default; } - public Task WriteLockAsync() + public Task WaitAsync(TaskCompletionSource waiter) { - _sslStream.CheckEnqueueWrite(); + waiter.Task.Wait(); return Task.CompletedTask; } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs index 5bd0bdb1bfb8c2..9c714992182952 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs @@ -43,24 +43,14 @@ private enum FrameType : byte AppData = 23 } - // - // This block is used to rule the >>re-handshakes<< that are concurrent with read/write I/O requests. - // - private const int LockNone = 0; - private const int LockWrite = 1; - private const int LockHandshake = 2; - private const int LockPendingWrite = 3; - private const int LockRead = 4; - private const int LockPendingRead = 6; + private readonly object _handshakeLock = new object(); + private volatile TaskCompletionSource? _handshakeWaiter; private const int FrameOverhead = 32; private const int ReadBufferSize = 4096 * 4 + FrameOverhead; // We read in 16K chunks + headers. private const int InitialHandshakeBufferSize = 4096 + FrameOverhead; // try to fit at least 4K ServerCertificate private ArrayBuffer _handshakeBuffer; - private int _lockWriteState; - private int _lockReadState; - private void ValidateCreateContext(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertValidationCallback remoteCallback, LocalCertSelectionCallback? localCallback) { ThrowIfExceptional(); @@ -175,7 +165,18 @@ private void CloseInternal() private SecurityStatusPal EncryptData(ReadOnlyMemory buffer, ref byte[] outBuffer, out int outSize) { ThrowIfExceptionalOrNotAuthenticated(); - return _context!.Encrypt(buffer, ref outBuffer, out outSize); + + lock (_handshakeLock) + { + if (_handshakeWaiter != null) + { + outSize = 0; + // avoid waiting under lock. + return new SecurityStatusPal(SecurityStatusPalErrorCode.TryAgain); + } + + return _context!.Encrypt(buffer, ref outBuffer, out outSize); + } } private SecurityStatusPal DecryptData() @@ -218,14 +219,15 @@ private SecurityStatusPal PrivateDecryptData(byte[]? buffer, ref int offset, ref private async Task ReplyOnReAuthenticationAsync(TIOAdapter adapter, byte[]? buffer) where TIOAdapter : ISslIOAdapter { - lock (SyncLock!) + try { - // Note we are already inside the read, so checking for already going concurrent handshake. - _lockReadState = LockHandshake; + await ForceAuthenticationAsync(adapter, receiveFirst: false, buffer).ConfigureAwait(false); + } + finally + { + _handshakeWaiter!.SetResult(true); + _handshakeWaiter = null; } - - await ForceAuthenticationAsync(adapter, receiveFirst: false, buffer).ConfigureAwait(false); - FinishHandshakeRead(LockNone); } // reAuthenticationData is only used on Windows in case of renegotiation. @@ -396,169 +398,6 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken) return true; } - private void FinishHandshakeRead(int newState) - { - lock (SyncLock!) - { - // Lock is redundant here. Included for clarity. - int lockState = Interlocked.Exchange(ref _lockReadState, newState); - - if (lockState != LockPendingRead) - { - return; - } - - _lockReadState = LockRead; - } - } - - // Returns: - // -1 - proceed - // 0 - queued - // X - some bytes are ready, no need for IO - private int CheckEnqueueRead(Memory buffer) - { - ThrowIfExceptionalOrNotAuthenticated(); - - int lockState = Interlocked.CompareExchange(ref _lockReadState, LockRead, LockNone); - if (lockState != LockHandshake) - { - // Proceed, no concurrent handshake is ongoing so no need for a lock. - return -1; - } - - LazyAsyncResult? lazyResult = null; - lock (SyncLock!) - { - // Check again under lock. - if (_lockReadState != LockHandshake) - { - // The other thread has finished before we grabbed the lock. - _lockReadState = LockRead; - return -1; - } - - _lockReadState = LockPendingRead; - } - // Need to exit from lock before waiting. - lazyResult!.InternalWaitForCompletion(); - ThrowIfExceptionalOrNotAuthenticated(); - return -1; - } - - private ValueTask CheckEnqueueReadAsync(Memory buffer) - { - ThrowIfExceptionalOrNotAuthenticated(); - - int lockState = Interlocked.CompareExchange(ref _lockReadState, LockRead, LockNone); - if (lockState != LockHandshake) - { - // Proceed, no concurrent handshake is ongoing so no need for a lock. - return new ValueTask(-1); - } - - lock (SyncLock!) - { - // Check again under lock. - if (_lockReadState != LockHandshake) - { - // The other thread has finished before we grabbed the lock. - _lockReadState = LockRead; - return new ValueTask(-1); - } - - _lockReadState = LockPendingRead; - TaskCompletionSource taskCompletionSource = new TaskCompletionSource(buffer, TaskCreationOptions.RunContinuationsAsynchronously); - return new ValueTask(taskCompletionSource.Task); - } - } - - private Task CheckEnqueueWriteAsync() - { - ThrowIfExceptionalOrNotAuthenticated(); - - // Clear previous request. - int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockWrite, LockNone); - if (lockState != LockHandshake) - { - return Task.CompletedTask; - } - - lock (SyncLock!) - { - if (_lockWriteState != LockHandshake) - { - ThrowIfExceptionalOrNotAuthenticated(); - return Task.CompletedTask; - } - - _lockWriteState = LockPendingWrite; - TaskCompletionSource completionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - return completionSource.Task; - } - } - - private void CheckEnqueueWrite() - { - // Clear previous request. - int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockWrite, LockNone); - if (lockState != LockHandshake) - { - // Proceed with write. - return; - } - - LazyAsyncResult? lazyResult = null; - lock (SyncLock!) - { - if (_lockWriteState != LockHandshake) - { - // Handshake has completed before we grabbed the lock. - ThrowIfExceptionalOrNotAuthenticated(); - return; - } - - _lockWriteState = LockPendingWrite; - } - - // Need to exit from lock before waiting. - lazyResult!.InternalWaitForCompletion(); - ThrowIfExceptionalOrNotAuthenticated(); - return; - } - - private void FinishWrite() - { - int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockNone, LockWrite); - if (lockState != LockHandshake) - { - return; - } - } - - private void FinishHandshake(Exception e) - { - lock (SyncLock!) - { - if (e != null) - { - SetException(e); - } - - // Release read if any. - FinishHandshakeRead(LockNone); - - // If there is a pending write we want to keep it's lock state. - int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockNone, LockHandshake); - if (lockState != LockPendingWrite) - { - return; - } - - _lockWriteState = LockWrite; - } - } - private async ValueTask WriteAsyncChunked(TIOAdapter writeAdapter, ReadOnlyMemory buffer) where TIOAdapter : struct, ISslIOAdapter { @@ -573,22 +412,38 @@ private async ValueTask WriteAsyncChunked(TIOAdapter writeAdapter, R private ValueTask WriteSingleChunk(TIOAdapter writeAdapter, ReadOnlyMemory buffer) where TIOAdapter : struct, ISslIOAdapter { - // Request a write IO slot. - Task ioSlot = writeAdapter.WriteLockAsync(); - if (!ioSlot.IsCompletedSuccessfully) - { - // Operation is async and has been queued, return. - return WaitForWriteIOSlot(writeAdapter, ioSlot, buffer); - } - byte[] rentedBuffer = ArrayPool.Shared.Rent(buffer.Length + FrameOverhead); byte[] outBuffer = rentedBuffer; - SecurityStatusPal status = EncryptData(buffer, ref outBuffer, out int encryptedBytes); + SecurityStatusPal status; + int encryptedBytes; + while (true) + { + status = EncryptData(buffer, ref outBuffer, out encryptedBytes); + + // TryAgain should be rare, when renegotiation happens exactly when we want to write. + if (status.ErrorCode != SecurityStatusPalErrorCode.TryAgain) + { + break; + } + + TaskCompletionSource? waiter = _handshakeWaiter; + if (waiter != null) + { + Task waiterTask = writeAdapter.WaitAsync(waiter); + // We finished synchronously waiting for renegotiation. We can try again immediately. + if (waiterTask.IsCompletedSuccessfully) + { + continue; + } + + // We need to wait asynchronously as well as for the write when EncryptData is finished. + return WaitAndWriteAsync(writeAdapter, buffer, waiterTask, rentedBuffer); + } + } if (status.ErrorCode != SecurityStatusPalErrorCode.OK) { - // Re-handshake status is not supported. ArrayPool.Shared.Return(rentedBuffer); return new ValueTask(Task.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new IOException(SR.net_io_encrypt, SslStreamPal.GetException(status))))); } @@ -597,21 +452,51 @@ private ValueTask WriteSingleChunk(TIOAdapter writeAdapter, ReadOnly if (t.IsCompletedSuccessfully) { ArrayPool.Shared.Return(rentedBuffer); - FinishWrite(); return t; } else { - return CompleteAsync(t, rentedBuffer); + return CompleteWriteAsync(t, rentedBuffer); } - async ValueTask WaitForWriteIOSlot(TIOAdapter wAdapter, Task lockTask, ReadOnlyMemory buff) + async ValueTask WaitAndWriteAsync(TIOAdapter writeAdapter, ReadOnlyMemory buffer, Task waitTask, byte[] rentedBuffer) { - await lockTask.ConfigureAwait(false); - await WriteSingleChunk(wAdapter, buff).ConfigureAwait(false); + byte[]? bufferToReturn = rentedBuffer; + byte[] outBuffer = rentedBuffer; + try + { + // Wait for renegotiation to finish. + await waitTask.ConfigureAwait(false); + + SecurityStatusPal status = EncryptData(buffer, ref outBuffer, out int encryptedBytes); + if (status.ErrorCode == SecurityStatusPalErrorCode.TryAgain) + { + // No need to hold on the buffer any more. + ArrayPool.Shared.Return(bufferToReturn); + bufferToReturn = null; + // Call WriteSingleChunk() recursively to avoid code duplication. + // This should be extremely rare in cases when second renegotiation happens concurrently with Write. + await WriteSingleChunk(writeAdapter, buffer).ConfigureAwait(false); + } + else if (status.ErrorCode == SecurityStatusPalErrorCode.OK) + { + await writeAdapter.WriteAsync(outBuffer, 0, encryptedBytes).ConfigureAwait(false); + } + else + { + throw new IOException(SR.net_io_encrypt, SslStreamPal.GetException(status)); + } + } + finally + { + if (bufferToReturn != null) + { + ArrayPool.Shared.Return(bufferToReturn); + } + } } - async ValueTask CompleteAsync(ValueTask writeTask, byte[] bufferToReturn) + async ValueTask CompleteWriteAsync(ValueTask writeTask, byte[] bufferToReturn) { try { @@ -620,7 +505,6 @@ async ValueTask CompleteAsync(ValueTask writeTask, byte[] bufferToReturn) finally { ArrayPool.Shared.Return(bufferToReturn); - FinishWrite(); } } } @@ -688,12 +572,6 @@ private async ValueTask ReadAsyncInternal(TIOAdapter adapter, M return CopyDecryptedData(buffer); } - int copyBytes = await adapter.ReadLockAsync(buffer).ConfigureAwait(false); - if (copyBytes > 0) - { - return copyBytes; - } - ResetReadBuffer(); // Read the next frame header. @@ -735,7 +613,27 @@ private async ValueTask ReadAsyncInternal(TIOAdapter adapter, M // DecryptData will decrypt in-place and modify these to point to the actual decrypted data, which may be smaller. _decryptedBytesOffset = _internalOffset; _decryptedBytesCount = payloadBytes; - SecurityStatusPal status = DecryptData(); + + SecurityStatusPal status; + lock (_handshakeLock) + { + status = DecryptData(); + if (status.ErrorCode == SecurityStatusPalErrorCode.Renegotiate) + { + // The status indicates that peer wants to renegotiate. (Windows only) + // In practice, there can be some other reasons too - like TLS1.3 session creation + // of alert handling. We need to pass the data to lsass and it is not safe to do parallel + // write any more as that can change TLS state and the EncryptData() can fail in strange ways. + + // To handle this we call DecryptData() under lock and we create TCS waiter. + // EncryptData() checks that under same lock and if it exist it will not call low-level crypto. + // Instead it will wait synchronously or asynchronously and it will try again after the wait. + // The result will be set when ReplyOnReAuthenticationAsync() is finished e.g. lsass business is over + // or if we bail to continue. If either one happen before EncryptData(), _handshakeWaiter will be set to null + // and EncryptData() will work normally e.g. no waiting, just exclusion with DecryptData() + _handshakeWaiter = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + } // Treat the bytes we just decrypted as consumed // Note, we won't do another buffer read until the decrypted bytes are processed @@ -759,6 +657,9 @@ private async ValueTask ReadAsyncInternal(TIOAdapter adapter, M { if (!_sslAuthenticationOptions!.AllowRenegotiation) { + _handshakeWaiter!.SetResult(false); + _handshakeWaiter = null; + if (NetEventSource.IsEnabled) NetEventSource.Fail(this, "Renegotiation was requested but it is disallowed"); throw new IOException(SR.net_ssl_io_renego); } @@ -888,8 +789,6 @@ private async ValueTask WriteAsyncInternal(TIOAdapter writeAdapter, } catch (Exception e) { - FinishWrite(); - if (e is IOException || (e is OperationCanceledException && writeAdapter.CancellationToken.IsCancellationRequested)) { throw;