diff --git a/src/System.Net.Security/src/System/Net/SecureProtocols/SslStream.cs b/src/System.Net.Security/src/System/Net/SecureProtocols/SslStream.cs index 9074362eaba4..7f93fa9831a3 100644 --- a/src/System.Net.Security/src/System/Net/SecureProtocols/SslStream.cs +++ b/src/System.Net.Security/src/System/Net/SecureProtocols/SslStream.cs @@ -5,6 +5,7 @@ using System.Security.Authentication; using System.Security.Authentication.ExtendedProtection; using System.Security.Cryptography.X509Certificates; +using System.Threading; using System.Threading.Tasks; namespace System.Net.Security @@ -439,5 +440,10 @@ public override void Write(byte[] buffer, int offset, int count) { _sslState.SecureStream.Write(buffer, offset, count); } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _sslState.SecureStream.WriteAsync(buffer, offset, count, cancellationToken); + } } } diff --git a/src/System.Net.Security/src/System/Net/SecureProtocols/SslStreamInternal.cs b/src/System.Net.Security/src/System/Net/SecureProtocols/SslStreamInternal.cs index 69550be0b142..3b51654d23db 100644 --- a/src/System.Net.Security/src/System/Net/SecureProtocols/SslStreamInternal.cs +++ b/src/System.Net.Security/src/System/Net/SecureProtocols/SslStreamInternal.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.IO; using System.Threading; +using System.Threading.Tasks; namespace System.Net.Security { @@ -85,9 +86,28 @@ internal int Read(byte[] buffer, int offset, int count) internal void Write(byte[] buffer, int offset, int count) { + ValidateParameters(buffer, offset, count); + + if (Interlocked.Exchange(ref _NestedWrite, 1) == 1) + { + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "Write", "write")); + } + ProcessWrite(buffer, offset, count); } + internal Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateParameters(buffer, offset, count); + + if (Interlocked.Exchange(ref _NestedWrite, 1) == 1) + { + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "Write", "write")); + } + + return ProcessWriteAsync(buffer, offset, count, cancellationToken); + } + internal bool DataAvailable { get { return InternalBufferCount != 0; } @@ -201,16 +221,32 @@ private void ValidateParameters(byte[] buffer, int offset, int count) // private void ProcessWrite(byte[] buffer, int offset, int count) { - ValidateParameters(buffer, offset, count); + try + { + StartWriting(buffer, offset, count); + } + catch (Exception e) + { + _SslState.FinishWrite(); - if (Interlocked.Exchange(ref _NestedWrite, 1) == 1) + if (e is IOException) + { + throw; + } + + throw new IOException(SR.net_io_write, e); + } + finally { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "Write", "write")); + _NestedWrite = 0; } + } + private async Task ProcessWriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { try { - StartWriting(buffer, offset, count); + await StartWritingAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); } catch (Exception e) { @@ -229,6 +265,50 @@ private void ProcessWrite(byte[] buffer, int offset, int count) } } + private void PrepareWritingBuffer(byte[] buffer, ref byte[] outBuffer, int count) + { + if (_PinnableOutputBufferInUse == null) + { + if (_PinnableOutputBuffer == null) + { + _PinnableOutputBuffer = s_PinnableWriteBufferCache.AllocateBuffer(); + } + + _PinnableOutputBufferInUse = buffer; + outBuffer = _PinnableOutputBuffer; + if (PinnableBufferCacheEventSource.Log.IsEnabled()) + { + PinnableBufferCacheEventSource.Log.DebugMessage3("In System.Net._SslStream.StartWriting Trying Pinnable", this.GetHashCode(), count, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); + } + } + else + { + if (PinnableBufferCacheEventSource.Log.IsEnabled()) + { + PinnableBufferCacheEventSource.Log.DebugMessage2("In System.Net._SslStream.StartWriting BufferInUse", this.GetHashCode(), count); + } + } + } + + private int EncryptWritingBuffer(byte[] buffer, int offset, byte[] outBuffer, int chunkBytes) + { + int encryptedBytes; + SecurityStatusPal errorCode = _SslState.EncryptData(buffer, offset, chunkBytes, ref outBuffer, out encryptedBytes); + if (errorCode != SecurityStatusPal.OK) + { + ProtocolToken message = new ProtocolToken(null, errorCode); + throw new IOException(SR.net_io_encrypt, message.GetException()); + } + + if (PinnableBufferCacheEventSource.Log.IsEnabled()) + { + PinnableBufferCacheEventSource.Log.DebugMessage3("In System.Net._SslStream.StartWriting Got Encrypted Buffer", + this.GetHashCode(), encryptedBytes, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); + } + + return encryptedBytes; + } + private void StartWriting(byte[] buffer, int offset, int count) { // We loop to this method from the callback. @@ -236,27 +316,48 @@ private void StartWriting(byte[] buffer, int offset, int count) if (count >= 0) { byte[] outBuffer = null; - if (_PinnableOutputBufferInUse == null) + PrepareWritingBuffer(buffer, ref outBuffer, count); + + do { - if (_PinnableOutputBuffer == null) + // Request a write IO slot. + if (_SslState.CheckEnqueueWrite(null)) { - _PinnableOutputBuffer = s_PinnableWriteBufferCache.AllocateBuffer(); + // Operation is async and has been queued, return. + return; } - _PinnableOutputBufferInUse = buffer; - outBuffer = _PinnableOutputBuffer; - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage3("In System.Net._SslStream.StartWriting Trying Pinnable", this.GetHashCode(), count, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); - } - } - else + int chunkBytes = Math.Min(count, _SslState.MaxDataSize); + int encryptedBytes = EncryptWritingBuffer(buffer, offset, outBuffer, chunkBytes); + + _SslState.InnerStream.Write(outBuffer, 0, encryptedBytes); + + offset += chunkBytes; + count -= chunkBytes; + + // Release write IO slot. + _SslState.FinishWrite(); + } while (count != 0); + } + + if (buffer == _PinnableOutputBufferInUse) + { + _PinnableOutputBufferInUse = null; + if (PinnableBufferCacheEventSource.Log.IsEnabled()) { - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage2("In System.Net._SslStream.StartWriting BufferInUse", this.GetHashCode(), count); - } + PinnableBufferCacheEventSource.Log.DebugMessage1("In System.Net._SslStream.StartWriting Freeing buffer.", this.GetHashCode()); } + } + } + + private async Task StartWritingAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + // We loop to this method from the callback. + // If the last chunk was just completed from async callback (count < 0), we complete user request. + if (count >= 0) + { + byte[] outBuffer = null; + PrepareWritingBuffer(buffer, ref outBuffer, count); do { @@ -268,21 +369,9 @@ private void StartWriting(byte[] buffer, int offset, int count) } int chunkBytes = Math.Min(count, _SslState.MaxDataSize); - int encryptedBytes; - SecurityStatusPal errorCode = _SslState.EncryptData(buffer, offset, chunkBytes, ref outBuffer, out encryptedBytes); - if (errorCode != SecurityStatusPal.OK) - { - ProtocolToken message = new ProtocolToken(null, errorCode); - throw new IOException(SR.net_io_encrypt, message.GetException()); - } + int encryptedBytes = EncryptWritingBuffer(buffer, offset, outBuffer, chunkBytes); - if (PinnableBufferCacheEventSource.Log.IsEnabled()) - { - PinnableBufferCacheEventSource.Log.DebugMessage3("In System.Net._SslStream.StartWriting Got Encrypted Buffer", - this.GetHashCode(), encryptedBytes, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); - } - - _SslState.InnerStream.Write(outBuffer, 0, encryptedBytes); + await _SslState.InnerStream.WriteAsync(outBuffer, 0, encryptedBytes, cancellationToken).ConfigureAwait(false); offset += chunkBytes; count -= chunkBytes;