diff --git a/src/System.Net.Security/System.Net.Security.sln b/src/System.Net.Security/System.Net.Security.sln index 03afd91ac1ef..fd6566223585 100644 --- a/src/System.Net.Security/System.Net.Security.sln +++ b/src/System.Net.Security/System.Net.Security.sln @@ -65,22 +65,22 @@ Global {89F37791-6254-4D60-AB96-ACD3CCA0E771}.Windows_Debug|Any CPU.Build.0 = Windows_Debug|Any CPU {89F37791-6254-4D60-AB96-ACD3CCA0E771}.Windows_Release|Any CPU.ActiveCfg = Windows_Release|Any CPU {89F37791-6254-4D60-AB96-ACD3CCA0E771}.Windows_Release|Any CPU.Build.0 = Windows_Release|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Debug|Any CPU.Build.0 = Debug|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Linux_Debug|Any CPU.ActiveCfg = Debug|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Linux_Debug|Any CPU.Build.0 = Debug|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Linux_Release|Any CPU.ActiveCfg = Release|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Linux_Release|Any CPU.Build.0 = Release|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.OSX_Debug|Any CPU.ActiveCfg = Debug|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.OSX_Debug|Any CPU.Build.0 = Debug|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.OSX_Release|Any CPU.ActiveCfg = Release|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.OSX_Release|Any CPU.Build.0 = Release|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Release|Any CPU.ActiveCfg = Release|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Release|Any CPU.Build.0 = Release|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Windows_Debug|Any CPU.ActiveCfg = Debug|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Windows_Debug|Any CPU.Build.0 = Debug|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Windows_Release|Any CPU.ActiveCfg = Release|Any CPU - {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Windows_Release|Any CPU.Build.0 = Release|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Debug|Any CPU.ActiveCfg = Windows_Debug|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Debug|Any CPU.Build.0 = Windows_Debug|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Linux_Debug|Any CPU.ActiveCfg = Linux_Debug|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Linux_Debug|Any CPU.Build.0 = Linux_Debug|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Linux_Release|Any CPU.ActiveCfg = Linux_Release|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Linux_Release|Any CPU.Build.0 = Linux_Release|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.OSX_Debug|Any CPU.ActiveCfg = OSX_Debug|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.OSX_Debug|Any CPU.Build.0 = OSX_Debug|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.OSX_Release|Any CPU.ActiveCfg = OSX_Release|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.OSX_Release|Any CPU.Build.0 = OSX_Release|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Release|Any CPU.ActiveCfg = Windows_Release|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Release|Any CPU.Build.0 = Windows_Release|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Windows_Debug|Any CPU.ActiveCfg = Windows_Debug|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Windows_Debug|Any CPU.Build.0 = Windows_Debug|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Windows_Release|Any CPU.ActiveCfg = Windows_Release|Any CPU + {A55A2B9A-830F-4330-A0E7-02A9FB30ABD2}.Windows_Release|Any CPU.Build.0 = Windows_Release|Any CPU {0D174EA9-9E61-4519-8D31-7BD2331A1982}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {0D174EA9-9E61-4519-8D31-7BD2331A1982}.Debug|Any CPU.Build.0 = Debug|Any CPU {0D174EA9-9E61-4519-8D31-7BD2331A1982}.Linux_Debug|Any CPU.ActiveCfg = Debug|Any CPU diff --git a/src/System.Net.Security/src/System/Net/SecureProtocols/SslStream.cs b/src/System.Net.Security/src/System/Net/SecureProtocols/SslStream.cs index 9074362eaba4..4a7b2f48494c 100644 --- a/src/System.Net.Security/src/System/Net/SecureProtocols/SslStream.cs +++ b/src/System.Net.Security/src/System/Net/SecureProtocols/SslStream.cs @@ -5,6 +5,7 @@ using System.Security.Authentication; using System.Security.Authentication.ExtendedProtection; using System.Security.Cryptography.X509Certificates; +using System.Threading; using System.Threading.Tasks; namespace System.Net.Security @@ -439,5 +440,87 @@ public override void Write(byte[] buffer, int offset, int count) { _sslState.SecureStream.Write(buffer, offset, count); } + + private IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) + { + return _sslState.SecureStream.BeginRead(buffer, offset, count, asyncCallback, asyncState); + } + + private int EndRead(IAsyncResult asyncResult) + { + return _sslState.SecureStream.EndRead(asyncResult); + } + + private IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) + { + return _sslState.SecureStream.BeginWrite(buffer, offset, count, asyncCallback, asyncState); + } + + private void EndWrite(IAsyncResult asyncResult) + { + _sslState.SecureStream.EndWrite(asyncResult); + } + + // ReadAsync - provide async read functionality. + // + // This method provides async read functionality. All we do is + // call through to the Begin/EndRead methods. + // + // Input: + // + // buffer - Buffer to read into. + // offset - Offset into the buffer where we're to read. + // size - Number of bytes to read. + // cancellationtoken - Token used to request cancellation of the operation + // + // Returns: + // + // A Task representing the read. + public override Task ReadAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + return Task.Factory.FromAsync( + (bufferArg, offsetArg, sizeArg, callback, state) => ((SslStream)state).BeginRead(bufferArg, offsetArg, sizeArg, callback, state), + iar => ((SslStream)iar.AsyncState).EndRead(iar), + buffer, + offset, + size, + this); + } + + // WriteAsync - provide async write functionality. + // + // This method provides async write functionality. All we do is + // call through to the Begin/EndWrite methods. + // + // Input: + // + // buffer - Buffer to write into. + // offset - Offset into the buffer where we're to write. + // size - Number of bytes to write. + // cancellationtoken - Token used to request cancellation of the operation + // + // Returns: + // + // A Task representing the write. + public override Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + return Task.Factory.FromAsync( + (bufferArg, offsetArg, sizeArg, callback, state) => ((SslStream)state).BeginWrite(bufferArg, offsetArg, sizeArg, callback, state), + iar => ((SslStream)iar.AsyncState).EndWrite(iar), + buffer, + offset, + size, + this); + } } } diff --git a/src/System.Net.Security/src/System/Net/SecureProtocols/SslStreamInternal.cs b/src/System.Net.Security/src/System/Net/SecureProtocols/SslStreamInternal.cs index 69550be0b142..8d08fd2d4365 100644 --- a/src/System.Net.Security/src/System/Net/SecureProtocols/SslStreamInternal.cs +++ b/src/System.Net.Security/src/System/Net/SecureProtocols/SslStreamInternal.cs @@ -12,26 +12,32 @@ namespace System.Net.Security // internal class SslStreamInternal { + private static readonly AsyncCallback s_writeCallback = new AsyncCallback(WriteCallback); + private static readonly AsyncProtocolCallback s_resumeAsyncWriteCallback = new AsyncProtocolCallback(ResumeAsyncWriteCallback); + private static readonly AsyncProtocolCallback s_resumeAsyncReadCallback = new AsyncProtocolCallback(ResumeAsyncReadCallback); + private static readonly AsyncProtocolCallback s_readHeaderCallback = new AsyncProtocolCallback(ReadHeaderCallback); + private static readonly AsyncProtocolCallback s_readFrameCallback = new AsyncProtocolCallback(ReadFrameCallback); + private const int PinnableReadBufferSize = 4096 * 4 + 32; // We read in 16K chunks + headers. private static PinnableBufferCache s_PinnableReadBufferCache = new PinnableBufferCache("System.Net.SslStream", PinnableReadBufferSize); private const int PinnableWriteBufferSize = 4096 + 1024; // We write in 4K chunks + encryption overhead. private static PinnableBufferCache s_PinnableWriteBufferCache = new PinnableBufferCache("System.Net.SslStream", PinnableWriteBufferSize); - private SslState _SslState; - private int _NestedWrite; - private int _NestedRead; + private SslState _sslState; + private int _nestedWrite; + private int _nestedRead; // Never updated directly, special properties are used. This is the read buffer. - private byte[] _InternalBuffer; - private bool _InternalBufferFromPinnableCache; + private byte[] _internalBuffer; + private bool _internalBufferFromPinnableCache; - private byte[] _PinnableOutputBuffer; // Used for writes when we can do it. - private byte[] _PinnableOutputBufferInUse; // Remembers what UNENCRYPTED buffer is using _PinnableOutputBuffer. + private byte[] _pinnableOutputBuffer; // Used for writes when we can do it. + private byte[] _pinnableOutputBufferInUse; // Remembers what UNENCRYPTED buffer is using _PinnableOutputBuffer. - private int _InternalOffset; - private int _InternalBufferCount; + private int _internalOffset; + private int _internalBufferCount; - private FixedSizeReader _Reader; + private FixedSizeReader _reader; internal SslStreamInternal(SslState sslState) { @@ -40,52 +46,134 @@ internal SslStreamInternal(SslState sslState) PinnableBufferCacheEventSource.Log.DebugMessage1("CTOR: In System.Net._SslStream.SslStream", this.GetHashCode()); } - _SslState = sslState; - _Reader = new FixedSizeReader(_SslState.InnerStream); + _sslState = sslState; + _reader = new FixedSizeReader(_sslState.InnerStream); } // If we have a read buffer from the pinnable cache, return it. private void FreeReadBuffer() { - if (_InternalBufferFromPinnableCache) + if (_internalBufferFromPinnableCache) { - s_PinnableReadBufferCache.FreeBuffer(_InternalBuffer); - _InternalBufferFromPinnableCache = false; + s_PinnableReadBufferCache.FreeBuffer(_internalBuffer); + _internalBufferFromPinnableCache = false; } - _InternalBuffer = null; + _internalBuffer = null; } ~SslStreamInternal() { - if (_InternalBufferFromPinnableCache) + if (_internalBufferFromPinnableCache) { if (PinnableBufferCacheEventSource.Log.IsEnabled()) { - PinnableBufferCacheEventSource.Log.DebugMessage2("DTOR: In System.Net._SslStream.~SslStream Freeing Read Buffer", this.GetHashCode(), PinnableBufferCacheEventSource.AddressOfByteArray(_InternalBuffer)); + PinnableBufferCacheEventSource.Log.DebugMessage2("DTOR: In System.Net._SslStream.~SslStream Freeing Read Buffer", this.GetHashCode(), PinnableBufferCacheEventSource.AddressOfByteArray(_internalBuffer)); } FreeReadBuffer(); } - if (_PinnableOutputBuffer != null) + if (_pinnableOutputBuffer != null) { if (PinnableBufferCacheEventSource.Log.IsEnabled()) { - PinnableBufferCacheEventSource.Log.DebugMessage2("DTOR: In System.Net._SslStream.~SslStream Freeing Write Buffer", this.GetHashCode(), PinnableBufferCacheEventSource.AddressOfByteArray(_PinnableOutputBuffer)); + PinnableBufferCacheEventSource.Log.DebugMessage2("DTOR: In System.Net._SslStream.~SslStream Freeing Write Buffer", this.GetHashCode(), PinnableBufferCacheEventSource.AddressOfByteArray(_pinnableOutputBuffer)); } - s_PinnableWriteBufferCache.FreeBuffer(_PinnableOutputBuffer); + s_PinnableWriteBufferCache.FreeBuffer(_pinnableOutputBuffer); } } internal int Read(byte[] buffer, int offset, int count) { - return ProcessRead(buffer, offset, count); + return ProcessRead(buffer, offset, count, null); } internal void Write(byte[] buffer, int offset, int count) { - ProcessWrite(buffer, offset, count); + ProcessWrite(buffer, offset, count, null); + } + + internal IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) + { + BufferAsyncResult bufferResult = new BufferAsyncResult(this, buffer, offset, count, asyncState, asyncCallback); + AsyncProtocolRequest asyncRequest = new AsyncProtocolRequest(bufferResult); + ProcessRead(buffer, offset, count, asyncRequest); + return bufferResult; + } + + internal int EndRead(IAsyncResult asyncResult) + { + if (asyncResult == null) + { + throw new ArgumentNullException("asyncResult"); + } + + BufferAsyncResult bufferResult = asyncResult as BufferAsyncResult; + if (bufferResult == null) + { + throw new ArgumentException(SR.Format(SR.net_io_async_result, asyncResult.GetType().FullName), "asyncResult"); + } + + if (Interlocked.Exchange(ref _nestedRead, 0) == 0) + { + throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndRead")); + } + + // No "artificial" timeouts implemented so far, InnerStream controls timeout. + bufferResult.InternalWaitForCompletion(); + + if (bufferResult.Result is Exception) + { + if (bufferResult.Result is IOException) + { + throw (Exception)bufferResult.Result; + } + + throw new IOException(SR.net_io_read, (Exception)bufferResult.Result); + } + + return (int)bufferResult.Result; + } + + internal IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) + { + LazyAsyncResult lazyResult = new LazyAsyncResult(this, asyncState, asyncCallback); + AsyncProtocolRequest asyncRequest = new AsyncProtocolRequest(lazyResult); + ProcessWrite(buffer, offset, count, asyncRequest); + return lazyResult; + } + + internal void EndWrite(IAsyncResult asyncResult) + { + if (asyncResult == null) + { + throw new ArgumentNullException("asyncResult"); + } + + LazyAsyncResult lazyResult = asyncResult as LazyAsyncResult; + if (lazyResult == null) + { + throw new ArgumentException(SR.Format(SR.net_io_async_result, asyncResult.GetType().FullName), "asyncResult"); + } + + if (Interlocked.Exchange(ref _nestedWrite, 0) == 0) + { + throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndWrite")); + } + + // No "artificial" timeouts implemented so far, InnerStream controls timeout. + lazyResult.InternalWaitForCompletion(); + + if (lazyResult.Result is Exception) + { + if (lazyResult.Result is IOException) + { + throw (Exception)lazyResult.Result; + } + + throw new IOException(SR.net_io_write, (Exception)lazyResult.Result); + } } internal bool DataAvailable @@ -97,7 +185,7 @@ private byte[] InternalBuffer { get { - return _InternalBuffer; + return _internalBuffer; } } @@ -105,7 +193,7 @@ private int InternalOffset { get { - return _InternalOffset; + return _internalOffset; } } @@ -113,14 +201,14 @@ private int InternalBufferCount { get { - return _InternalBufferCount; + return _internalBufferCount; } } private void SkipBytes(int decrCount) { - _InternalOffset += decrCount; - _InternalBufferCount -= decrCount; + _internalOffset += decrCount; + _internalBufferCount -= decrCount; } // @@ -129,10 +217,10 @@ private void SkipBytes(int decrCount) // private void EnsureInternalBufferSize(int curOffset, int addSize) { - if (_InternalBuffer == null || _InternalBuffer.Length < addSize + curOffset) + if (_internalBuffer == null || _internalBuffer.Length < addSize + curOffset) { - bool wasPinnable = _InternalBufferFromPinnableCache; - byte[] saved = _InternalBuffer; + bool wasPinnable = _internalBufferFromPinnableCache; + byte[] saved = _internalBuffer; int newSize = addSize + curOffset; if (newSize <= PinnableReadBufferSize) @@ -142,8 +230,8 @@ private void EnsureInternalBufferSize(int curOffset, int addSize) PinnableBufferCacheEventSource.Log.DebugMessage2("In System.Net._SslStream.EnsureInternalBufferSize IS pinnable", this.GetHashCode(), newSize); } - _InternalBufferFromPinnableCache = true; - _InternalBuffer = s_PinnableReadBufferCache.AllocateBuffer(); + _internalBufferFromPinnableCache = true; + _internalBuffer = s_PinnableReadBufferCache.AllocateBuffer(); } else { @@ -152,13 +240,13 @@ private void EnsureInternalBufferSize(int curOffset, int addSize) PinnableBufferCacheEventSource.Log.DebugMessage2("In System.Net._SslStream.EnsureInternalBufferSize NOT pinnable", this.GetHashCode(), newSize); } - _InternalBufferFromPinnableCache = false; - _InternalBuffer = new byte[newSize]; + _internalBufferFromPinnableCache = false; + _internalBuffer = new byte[newSize]; } if (saved != null && curOffset != 0) { - Buffer.BlockCopy(saved, 0, _InternalBuffer, 0, curOffset); + Buffer.BlockCopy(saved, 0, _internalBuffer, 0, curOffset); } if (wasPinnable) @@ -166,8 +254,8 @@ private void EnsureInternalBufferSize(int curOffset, int addSize) s_PinnableReadBufferCache.FreeBuffer(saved); } } - _InternalOffset = curOffset; - _InternalBufferCount = curOffset + addSize; + _internalOffset = curOffset; + _internalBufferCount = curOffset + addSize; } // @@ -199,23 +287,26 @@ private void ValidateParameters(byte[] buffer, int offset, int count) // // Sync write method. // - private void ProcessWrite(byte[] buffer, int offset, int count) + private void ProcessWrite(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest) { ValidateParameters(buffer, offset, count); - if (Interlocked.Exchange(ref _NestedWrite, 1) == 1) + if (Interlocked.Exchange(ref _nestedWrite, 1) == 1) { throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "Write", "write")); } + bool failed = false; + try { - StartWriting(buffer, offset, count); + StartWriting(buffer, offset, count, asyncRequest); } catch (Exception e) { - _SslState.FinishWrite(); + _sslState.FinishWrite(); + failed = true; if (e is IOException) { throw; @@ -225,26 +316,34 @@ private void ProcessWrite(byte[] buffer, int offset, int count) } finally { - _NestedWrite = 0; + if (asyncRequest == null || failed) + { + _nestedWrite = 0; + } } } - private void StartWriting(byte[] buffer, int offset, int count) + private void StartWriting(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest) { + if (asyncRequest != null) + { + asyncRequest.SetNextRequest(buffer, offset, count, s_resumeAsyncWriteCallback); + } + // We loop to this method from the callback. // If the last chunk was just completed from async callback (count < 0), we complete user request. - if (count >= 0) + if (count >= 0 ) { byte[] outBuffer = null; - if (_PinnableOutputBufferInUse == null) + if (_pinnableOutputBufferInUse == null) { - if (_PinnableOutputBuffer == null) + if (_pinnableOutputBuffer == null) { - _PinnableOutputBuffer = s_PinnableWriteBufferCache.AllocateBuffer(); + _pinnableOutputBuffer = s_PinnableWriteBufferCache.AllocateBuffer(); } - _PinnableOutputBufferInUse = buffer; - outBuffer = _PinnableOutputBuffer; + _pinnableOutputBufferInUse = buffer; + outBuffer = _pinnableOutputBuffer; if (PinnableBufferCacheEventSource.Log.IsEnabled()) { PinnableBufferCacheEventSource.Log.DebugMessage3("In System.Net._SslStream.StartWriting Trying Pinnable", this.GetHashCode(), count, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); @@ -261,17 +360,18 @@ private void StartWriting(byte[] buffer, int offset, int count) do { // Request a write IO slot. - if (_SslState.CheckEnqueueWrite(null)) + if (_sslState.CheckEnqueueWrite(asyncRequest)) { // Operation is async and has been queued, return. return; } - int chunkBytes = Math.Min(count, _SslState.MaxDataSize); + int chunkBytes = Math.Min(count, _sslState.MaxDataSize); int encryptedBytes; - SecurityStatusPal errorCode = _SslState.EncryptData(buffer, offset, chunkBytes, ref outBuffer, out encryptedBytes); + SecurityStatusPal errorCode = _sslState.EncryptData(buffer, offset, chunkBytes, ref outBuffer, out encryptedBytes); if (errorCode != SecurityStatusPal.OK) { + // Re-handshake status is not supported. ProtocolToken message = new ProtocolToken(null, errorCode); throw new IOException(SR.net_io_encrypt, message.GetException()); } @@ -282,19 +382,41 @@ private void StartWriting(byte[] buffer, int offset, int count) this.GetHashCode(), encryptedBytes, PinnableBufferCacheEventSource.AddressOfByteArray(outBuffer)); } - _SslState.InnerStream.Write(outBuffer, 0, encryptedBytes); + if (asyncRequest != null) + { + // Prepare for the next request. + asyncRequest.SetNextRequest(buffer, offset + chunkBytes, count - chunkBytes, s_resumeAsyncWriteCallback); + IAsyncResult ar = _sslState.InnerStreamAPM.BeginWrite(outBuffer, 0, encryptedBytes, s_writeCallback, asyncRequest); + if (!ar.CompletedSynchronously) + { + return; + } + + _sslState.InnerStreamAPM.EndWrite(ar); + + } + else + { + _sslState.InnerStream.Write(outBuffer, 0, encryptedBytes); + } offset += chunkBytes; count -= chunkBytes; // Release write IO slot. - _SslState.FinishWrite(); + _sslState.FinishWrite(); + } while (count != 0); } - if (buffer == _PinnableOutputBufferInUse) + if (asyncRequest != null) + { + asyncRequest.CompleteUser(); + } + + if (buffer == _pinnableOutputBufferInUse) { - _PinnableOutputBufferInUse = null; + _pinnableOutputBufferInUse = null; if (PinnableBufferCacheEventSource.Log.IsEnabled()) { PinnableBufferCacheEventSource.Log.DebugMessage1("In System.Net._SslStream.StartWriting Freeing buffer.", this.GetHashCode()); @@ -303,17 +425,19 @@ private void StartWriting(byte[] buffer, int offset, int count) } // - // Sync read method. + // Combined sync/async read method. For sync requet asyncRequest==null. // - private int ProcessRead(byte[] buffer, int offset, int count) + private int ProcessRead(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest) { ValidateParameters(buffer, offset, count); - if (Interlocked.Exchange(ref _NestedRead, 1) == 1) + if (Interlocked.Exchange(ref _nestedRead, 1) == 1) { - throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "Read", "read")); + throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, (asyncRequest!=null? "BeginRead":"Read"), "read")); } + bool failed = false; + try { int copyBytes; @@ -325,15 +449,21 @@ private int ProcessRead(byte[] buffer, int offset, int count) Buffer.BlockCopy(InternalBuffer, InternalOffset, buffer, offset, copyBytes); SkipBytes(copyBytes); } - + + if (asyncRequest != null) { + asyncRequest.CompleteUser((object) copyBytes); + } + return copyBytes; } - return StartReading(buffer, offset, count); + return StartReading(buffer, offset, count, asyncRequest); } catch (Exception e) { - _SslState.FinishRead(null); + _sslState.FinishRead(null); + failed = true; + if (e is IOException) { throw; @@ -343,14 +473,17 @@ private int ProcessRead(byte[] buffer, int offset, int count) } finally { - _NestedRead = 0; + if (asyncRequest == null || failed) + { + _nestedRead = 0; + } } } // // To avoid recursion when decrypted 0 bytes this method will loop until a decrypted result at least 1 byte. // - private int StartReading(byte[] buffer, int offset, int count) + private int StartReading(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest) { int result = 0; @@ -360,34 +493,42 @@ private int StartReading(byte[] buffer, int offset, int count) { GlobalLog.AssertFormat("SslStream::StartReading()|Previous frame was not consumed. InternalBufferCount:{0}", InternalBufferCount); } + Debug.Fail("SslStream::StartReading()|Previous frame was not consumed. InternalBufferCount:" + InternalBufferCount); } do { - int copyBytes = _SslState.CheckEnqueueRead(buffer, offset, count, null); + if (asyncRequest != null) + { + asyncRequest.SetNextRequest(buffer, offset, count, s_resumeAsyncReadCallback); + } + + int copyBytes = _sslState.CheckEnqueueRead(buffer, offset, count, asyncRequest); if (copyBytes == 0) { - //Queued but not completed! + // Queued but not completed! return 0; } if (copyBytes != -1) { + if (asyncRequest != null) + { + asyncRequest.CompleteUser((object)copyBytes); + } + return copyBytes; } } // When we read -1 bytes means we have decrypted 0 bytes or rehandshaking, need looping. - while ((result = StartFrameHeader(buffer, offset, count)) == -1); + while ((result = StartFrameHeader(buffer, offset, count, asyncRequest)) == -1); return result; } - // - // Need read frame size first - // - private int StartFrameHeader(byte[] buffer, int offset, int count) + private int StartFrameHeader(byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest) { int readBytes = 0; @@ -399,22 +540,42 @@ private int StartFrameHeader(byte[] buffer, int offset, int count) // Reset internal buffer for a new frame. EnsureInternalBufferSize(0, SecureChannel.ReadHeaderSize); - readBytes = _Reader.ReadPacket(InternalBuffer, 0, SecureChannel.ReadHeaderSize); + if (asyncRequest != null) + { + asyncRequest.SetNextRequest(InternalBuffer, 0, SecureChannel.ReadHeaderSize, s_readHeaderCallback); + _reader.AsyncReadPacket(asyncRequest); + + if (!asyncRequest.MustCompleteSynchronously) + { + return 0; + } - return StartFrameBody(readBytes, buffer, offset, count); + readBytes = asyncRequest.Result; + } + else + { + readBytes = _reader.ReadPacket(InternalBuffer, 0, SecureChannel.ReadHeaderSize); + } + + return StartFrameBody(readBytes, buffer, offset, count, asyncRequest); } - private int StartFrameBody(int readBytes, byte[] buffer, int offset, int count) + private int StartFrameBody(int readBytes, byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest) { if (readBytes == 0) { //EOF : Reset the buffer as we did not read anything into it. SkipBytes(InternalBufferCount); + if (asyncRequest != null) + { + asyncRequest.CompleteUser((object)0); + } + return 0; } // Now readBytes is a payload size. - readBytes = _SslState.GetRemainingFrameSize(InternalBuffer, readBytes); + readBytes = _sslState.GetRemainingFrameSize(InternalBuffer, readBytes); if (readBytes < 0) { @@ -423,15 +584,31 @@ private int StartFrameBody(int readBytes, byte[] buffer, int offset, int count) EnsureInternalBufferSize(SecureChannel.ReadHeaderSize, readBytes); - readBytes = _Reader.ReadPacket(InternalBuffer, SecureChannel.ReadHeaderSize, readBytes); + if (asyncRequest != null) + { + asyncRequest.SetNextRequest(InternalBuffer, SecureChannel.ReadHeaderSize, readBytes, s_readFrameCallback); + + _reader.AsyncReadPacket(asyncRequest); - return ProcessFrameBody(readBytes, buffer, offset, count); + if (!asyncRequest.MustCompleteSynchronously) + { + return 0; + } + + readBytes = asyncRequest.Result; + } + else + { + readBytes = _reader.ReadPacket(InternalBuffer, SecureChannel.ReadHeaderSize, readBytes); + } + + return ProcessFrameBody(readBytes, buffer, offset, count, asyncRequest); } // - // readBytes == SSL Data Payload size on input or 0 on EOF + // readBytes == SSL Data Payload size on input or 0 on EOF. // - private int ProcessFrameBody(int readBytes, byte[] buffer, int offset, int count) + private int ProcessFrameBody(int readBytes, byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest) { if (readBytes == 0) { @@ -445,7 +622,7 @@ private int ProcessFrameBody(int readBytes, byte[] buffer, int offset, int count // Decrypt into internal buffer, change "readBytes" to count now _Decrypted Bytes_. int data_offset = 0; - SecurityStatusPal errorCode = _SslState.DecryptData(InternalBuffer, ref data_offset, ref readBytes); + SecurityStatusPal errorCode = _sslState.DecryptData(InternalBuffer, ref data_offset, ref readBytes); if (errorCode != SecurityStatusPal.OK) { @@ -458,7 +635,7 @@ private int ProcessFrameBody(int readBytes, byte[] buffer, int offset, int count // Reset internal buffer count. SkipBytes(InternalBufferCount); - return ProcessReadErrorCode(errorCode, buffer, offset, count, extraBuffer); + return ProcessReadErrorCode(errorCode, buffer, offset, count, asyncRequest, extraBuffer); } @@ -483,7 +660,11 @@ private int ProcessFrameBody(int readBytes, byte[] buffer, int offset, int count // This will adjust both the remaining internal buffer count and the offset. SkipBytes(readBytes); - _SslState.FinishRead(null); + _sslState.FinishRead(null); + if (asyncRequest != null) + { + asyncRequest.CompleteUser((object)readBytes); + } return readBytes; } @@ -491,9 +672,8 @@ private int ProcessFrameBody(int readBytes, byte[] buffer, int offset, int count // // Only processing SEC_I_RENEGOTIATE. // - private int ProcessReadErrorCode(SecurityStatusPal errorCode, byte[] buffer, int offset, int count, byte[] extraBuffer) + private int ProcessReadErrorCode(SecurityStatusPal errorCode, byte[] buffer, int offset, int count, AsyncProtocolRequest asyncRequest, byte[] extraBuffer) { - // ERROR - examine what kind ProtocolToken message = new ProtocolToken(null, errorCode); if (GlobalLog.IsEnabled) @@ -503,7 +683,7 @@ private int ProcessReadErrorCode(SecurityStatusPal errorCode, byte[] buffer, int if (message.Renegotiate) { - _SslState.ReplyOnReAuthentication(extraBuffer); + _sslState.ReplyOnReAuthentication(extraBuffer); // Loop on read. return -1; @@ -511,11 +691,155 @@ private int ProcessReadErrorCode(SecurityStatusPal errorCode, byte[] buffer, int if (message.CloseConnection) { - _SslState.FinishRead(null); + _sslState.FinishRead(null); + if (asyncRequest != null) + { + asyncRequest.CompleteUser((object)0); + } + return 0; } throw new IOException(SR.net_io_decrypt, message.GetException()); } + + private static void WriteCallback(IAsyncResult transportResult) + { + if (transportResult.CompletedSynchronously) + { + return; + } + + if (!(transportResult.AsyncState is AsyncProtocolRequest)) + { + if (GlobalLog.IsEnabled) + { + GlobalLog.Assert("SslStream::WriteCallback | State type is wrong, expected AsyncProtocolRequest."); + } + + Debug.Fail("SslStream::WriteCallback|State type is wrong, expected AsyncProtocolRequest."); + } + + AsyncProtocolRequest asyncRequest = (AsyncProtocolRequest)transportResult.AsyncState; + + var sslStream = (SslStreamInternal)asyncRequest.AsyncObject; + + try + { + sslStream._sslState.InnerStreamAPM.EndWrite(transportResult); + sslStream._sslState.FinishWrite(); + + if (asyncRequest.Count == 0) + { + // This was the last chunk. + asyncRequest.Count = -1; + } + + sslStream.StartWriting(asyncRequest.Buffer, asyncRequest.Offset, asyncRequest.Count, asyncRequest); + } + catch (Exception e) + { + if (asyncRequest.IsUserCompleted) + { + // This will throw on a worker thread. + throw; + } + + sslStream._sslState.FinishWrite(); + asyncRequest.CompleteWithError(e); + } + } + + // + // This is used in a rare situation when async Read is resumed from completed handshake. + // + private static void ResumeAsyncReadCallback(AsyncProtocolRequest request) + { + try + { + ((SslStreamInternal)request.AsyncObject).StartReading(request.Buffer, request.Offset, request.Count, request); + } + catch (Exception e) + { + if (request.IsUserCompleted) + { + // This will throw on a worker thread. + throw; + } + + ((SslStreamInternal)request.AsyncObject)._sslState.FinishRead(null); + request.CompleteWithError(e); + } + } + + // + // This is used in a rare situation when async Write is resumed from completed handshake. + // + private static void ResumeAsyncWriteCallback(AsyncProtocolRequest asyncRequest) + { + try + { + ((SslStreamInternal)asyncRequest.AsyncObject).StartWriting(asyncRequest.Buffer, asyncRequest.Offset, asyncRequest.Count, asyncRequest); + } + catch (Exception e) + { + if (asyncRequest.IsUserCompleted) + { + // This will throw on a worker thread. + throw; + } + + ((SslStreamInternal)asyncRequest.AsyncObject)._sslState.FinishWrite(); + asyncRequest.CompleteWithError(e); + } + } + + private static void ReadHeaderCallback(AsyncProtocolRequest asyncRequest) + { + try + { + SslStreamInternal sslStream = (SslStreamInternal)asyncRequest.AsyncObject; + BufferAsyncResult bufferResult = (BufferAsyncResult)asyncRequest.UserAsyncResult; + if (-1 == sslStream.StartFrameBody(asyncRequest.Result, bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest)) + { + // in case we decrypted 0 bytes start another reading. + sslStream.StartReading(bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest); + } + } + catch (Exception e) + { + if (asyncRequest.IsUserCompleted) + { + // This will throw on a worker thread. + throw; + } + + asyncRequest.CompleteWithError(e); + } + } + + private static void ReadFrameCallback(AsyncProtocolRequest asyncRequest) + { + try + { + SslStreamInternal sslStream = (SslStreamInternal)asyncRequest.AsyncObject; + BufferAsyncResult bufferResult = (BufferAsyncResult)asyncRequest.UserAsyncResult; + if (-1 == sslStream.ProcessFrameBody(asyncRequest.Result, bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest)) + { + // in case we decrypted 0 bytes start another reading. + sslStream.StartReading(bufferResult.Buffer, bufferResult.Offset, bufferResult.Count, asyncRequest); + } + } + catch (Exception e) + { + if (asyncRequest.IsUserCompleted) + { + // This will throw on a worker thread. + throw; + } + + asyncRequest.CompleteWithError(e); + } + } } } diff --git a/src/System.Net.Security/tests/FunctionalTests/NegotiateStreamStreamToStreamTest.cs b/src/System.Net.Security/tests/FunctionalTests/NegotiateStreamStreamToStreamTest.cs index 45e88d85dd58..651fa908279f 100644 --- a/src/System.Net.Security/tests/FunctionalTests/NegotiateStreamStreamToStreamTest.cs +++ b/src/System.Net.Security/tests/FunctionalTests/NegotiateStreamStreamToStreamTest.cs @@ -152,6 +152,7 @@ public void NegotiateStream_StreamToStream_Authentication_EmptyCredentials_Fails auth[1] = server.AuthenticateAsServerAsync(); bool finished = Task.WaitAll(auth, TestConfiguration.PassingTestTimeoutMilliseconds); + Assert.True(finished, "Handshake completed in the allotted time"); // Expected Client property values: Assert.True(client.IsAuthenticated); diff --git a/src/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs b/src/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs new file mode 100644 index 000000000000..d8cb5b955127 --- /dev/null +++ b/src/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Net.Sockets; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; + +using Xunit; + +namespace System.Net.Security.Tests +{ + public class SslStreamNetworkStreamTest + { + [Fact] + public async void SslStream_SendReceiveOverNetworkStream_Ok() + { + X509Certificate2 serverCertificate = TestConfiguration.GetServerCertificate(); + TcpListener listener = new TcpListener(IPAddress.Any, 0); + + using (TcpClient client = new TcpClient()) + { + listener.Start(); + + Task clientConnectTask = client.ConnectAsync(IPAddress.Loopback, ((IPEndPoint)listener.LocalEndpoint).Port); + Task listenerAcceptTask = listener.AcceptTcpClientAsync(); + + await Task.WhenAll(clientConnectTask, listenerAcceptTask); + + TcpClient server = listenerAcceptTask.Result; + using (SslStream clientStream = new SslStream( + client.GetStream(), + false, + new RemoteCertificateValidationCallback(ValidateServerCertificate), + null, + EncryptionPolicy.RequireEncryption)) + using (SslStream serverStream = new SslStream( + server.GetStream(), + false, + null, + null, + EncryptionPolicy.RequireEncryption)) + { + + Task clientAuthenticationTask = clientStream.AuthenticateAsClientAsync( + serverCertificate.GetNameInfo(X509NameType.SimpleName, false), + null, + SslProtocols.Tls12, + false); + + Task serverAuthenticationTask = serverStream.AuthenticateAsServerAsync( + serverCertificate, + false, + SslProtocols.Tls12, + false); + + await Task.WhenAll(clientAuthenticationTask, serverAuthenticationTask); + + byte[] readBuffer = new byte[256]; + Task readTask = clientStream.ReadAsync(readBuffer, 0, readBuffer.Length); + + byte[] writeBuffer = new byte[256]; + Task writeTask = clientStream.WriteAsync(writeBuffer, 0, writeBuffer.Length); + + bool result = Task.WaitAll( + new Task[1] { writeTask }, + TestConfiguration.PassingTestTimeoutMilliseconds); + + Assert.True(result, "WriteAsync timed-out."); + } + } + } + + private static bool ValidateServerCertificate( + object sender, + X509Certificate retrievedServerPublicCertificate, + X509Chain chain, + SslPolicyErrors sslPolicyErrors) + { + // Accept any certificate. + return true; + } + } +} diff --git a/src/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs b/src/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs index 70906d35065b..20106795d3b2 100644 --- a/src/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs +++ b/src/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs @@ -106,14 +106,14 @@ public void SslStream_StreamToStream_Successive_ClientWrite_Async_Success() Task[] tasks = new Task[2]; - tasks[0] = clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length); - tasks[1] = serverSslStream.ReadAsync(recvBuf, 0, _sampleMsg.Length); + tasks[0] = serverSslStream.ReadAsync(recvBuf, 0, _sampleMsg.Length); + tasks[1] = clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length); bool finished = Task.WaitAll(tasks, TestConfiguration.PassingTestTimeoutMilliseconds); Assert.True(finished, "Send/receive completed in the allotted time"); Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected."); - tasks[0] = clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length); - tasks[1] = serverSslStream.ReadAsync(recvBuf, 0, _sampleMsg.Length); + tasks[0] = serverSslStream.ReadAsync(recvBuf, 0, _sampleMsg.Length); + tasks[1] = clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length); finished = Task.WaitAll(tasks, TestConfiguration.PassingTestTimeoutMilliseconds); Assert.True(finished, "Send/receive completed in the allotted time"); Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify second read data is as expected."); diff --git a/src/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj b/src/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj index ec6598b0c52e..60d84ecaf2aa 100644 --- a/src/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj +++ b/src/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj @@ -21,7 +21,7 @@ unix\project.lock.json - + @@ -47,6 +47,7 @@ + diff --git a/src/System.Net.Security/tests/FunctionalTests/TestConfiguration.cs b/src/System.Net.Security/tests/FunctionalTests/TestConfiguration.cs index 4e490d9101ca..dbb3bbfd745d 100644 --- a/src/System.Net.Security/tests/FunctionalTests/TestConfiguration.cs +++ b/src/System.Net.Security/tests/FunctionalTests/TestConfiguration.cs @@ -38,6 +38,7 @@ public static X509Certificate2Collection GetClientCertificateCollection() { return GetCertificateCollection("testclient1_at_contoso.com.pfx"); } + private static X509Certificate2Collection GetCertificateCollection(string certificateFileName) { var certCollection = new X509Certificate2Collection(); diff --git a/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs b/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs index 9df1682896fb..5a75ec514ee4 100644 --- a/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs +++ b/src/System.Net.Security/tests/UnitTests/Fakes/FakeSslState.cs @@ -243,5 +243,25 @@ public override void Write(byte[] buffer, int offset, int count) { throw new NotImplementedException(); } + + internal IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) + { + throw new NotImplementedException(); + } + + internal int EndRead(IAsyncResult asyncResult) + { + throw new NotImplementedException(); + } + + internal IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback asyncCallback, object asyncState) + { + throw new NotImplementedException(); + } + + internal void EndWrite(IAsyncResult asyncResult) + { + throw new NotImplementedException(); + } } }