diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Connection.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Connection.cs index 52790ec65..0708796ec 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Connection.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Connection.cs @@ -109,7 +109,7 @@ public void Start() } } - public void Abort() + public virtual void Abort() { if (_frame != null) { diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs index d5225447d..61f09dbbb 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs @@ -56,7 +56,7 @@ public abstract partial class Frame : FrameContext, IFrameControl private bool _requestProcessingStarted; private Task _requestProcessingTask; protected volatile bool _requestProcessingStopping; // volatile, see: https://msdn.microsoft.com/en-us/library/x13ttww7.aspx - protected volatile bool _requestAborted; + protected int _requestAborted; protected CancellationTokenSource _abortedCts; protected CancellationToken? _manuallySetRequestAbortToken; @@ -167,7 +167,7 @@ public CancellationToken RequestAborted var cts = _abortedCts; return cts != null ? cts.Token : - _requestAborted ? new CancellationToken(true) : + (Volatile.Read(ref _requestAborted) == 1) ? new CancellationToken(true) : RequestAbortedSource.Token; } set @@ -185,7 +185,7 @@ private CancellationTokenSource RequestAbortedSource // Get the abort token, lazily-initializing it if necessary. // Make sure it's canceled if an abort request already came in. var cts = LazyInitializer.EnsureInitialized(ref _abortedCts, () => new CancellationTokenSource()); - if (_requestAborted) + if (Volatile.Read(ref _requestAborted) == 1) { cts.Cancel(); } @@ -288,24 +288,31 @@ public Task Stop() /// public void Abort() { - _requestProcessingStopping = true; - _requestAborted = true; + if (Interlocked.CompareExchange(ref _requestAborted, 1, 0) == 0) + { + _requestProcessingStopping = true; - _requestBody?.Abort(); - _responseBody?.Abort(); + _requestBody?.Abort(); + _responseBody?.Abort(); - try - { - ConnectionControl.End(ProduceEndType.SocketDisconnect); - SocketInput.AbortAwaiting(); - RequestAbortedSource.Cancel(); - } - catch (Exception ex) - { - Log.LogError("Abort", ex); - } - finally - { + try + { + ConnectionControl.End(ProduceEndType.SocketDisconnect); + SocketInput.AbortAwaiting(); + } + catch (Exception ex) + { + Log.LogError("Abort", ex); + } + + try + { + RequestAbortedSource.Cancel(); + } + catch (Exception ex) + { + Log.LogError("Abort", ex); + } _abortedCts = null; } } @@ -461,12 +468,12 @@ public async Task WriteAsyncAwaited(ArraySegment data, CancellationToken c private void WriteChunked(ArraySegment data) { - SocketOutput.Write(data, immediate: false, chunk: true); + SocketOutput.Write(data, immediate: true, chunk: true); } private Task WriteChunkedAsync(ArraySegment data, CancellationToken cancellationToken) { - return SocketOutput.WriteAsync(data, immediate: false, chunk: true, cancellationToken: cancellationToken); + return SocketOutput.WriteAsync(data, immediate: true, chunk: true, cancellationToken: cancellationToken); } private Task WriteChunkedResponseSuffix() diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs index 0f3d8d48d..e9498b1d4 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs @@ -3,6 +3,7 @@ using System; using System.Net; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Http.Features; @@ -111,7 +112,7 @@ public override async Task RequestProcessingAsync() _application.DisposeContext(context, _applicationException); // If _requestAbort is set, the connection has already been closed. - if (!_requestAborted) + if (Volatile.Read(ref _requestAborted) == 0) { _responseBody.ResumeAcceptingWrites(); await ProduceEnd(); @@ -148,7 +149,7 @@ public override async Task RequestProcessingAsync() _abortedCts = null; // If _requestAborted is set, the connection has already been closed. - if (!_requestAborted) + if (Volatile.Read(ref _requestAborted) == 0) { // Inform client no more data will ever arrive ConnectionControl.End(ProduceEndType.SocketShutdownSend); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameRequestStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameRequestStream.cs index dc70690b3..efd366910 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameRequestStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameRequestStream.cs @@ -5,6 +5,7 @@ using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNet.Server.Kestrel.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Http { @@ -51,8 +52,6 @@ public override void SetLength(long value) public override int Read(byte[] buffer, int offset, int count) { - ValidateState(); - // ValueTask uses .GetAwaiter().GetResult() if necessary return ReadAsync(buffer, offset, count).Result; } @@ -60,7 +59,7 @@ public override int Read(byte[] buffer, int offset, int count) #if NET451 public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { - ValidateState(); + ValidateState(CancellationToken.None); var task = ReadAsync(buffer, offset, count, CancellationToken.None, state); if (callback != null) @@ -77,7 +76,7 @@ public override int EndRead(IAsyncResult asyncResult) private Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) { - ValidateState(); + ValidateState(cancellationToken); var tcs = new TaskCompletionSource(state); var task = _body.ReadAsync(new ArraySegment(buffer, offset, count), cancellationToken); @@ -103,10 +102,13 @@ private Task ReadAsync(byte[] buffer, int offset, int count, CancellationTo public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - ValidateState(); - - // Needs .AsTask to match Stream's Async method return types - return _body.ReadAsync(new ArraySegment(buffer, offset, count), cancellationToken).AsTask(); + var task = ValidateState(cancellationToken); + if (task == null) + { + // Needs .AsTask to match Stream's Async method return types + return _body.ReadAsync(new ArraySegment(buffer, offset, count), cancellationToken).AsTask(); + } + return task; } public override void Write(byte[] buffer, int offset, int count) @@ -149,24 +151,29 @@ public void StopAcceptingReads() public void Abort() { // We don't want to throw an ODE until the app func actually completes. - // If the request is aborted, we throw an IOException instead. + // If the request is aborted, we throw an TaskCanceledException instead. if (_state != FrameStreamState.Closed) { _state = FrameStreamState.Aborted; } } - private void ValidateState() + private Task ValidateState(CancellationToken cancellationToken) { switch (_state) { case FrameStreamState.Open: - return; + if (cancellationToken.IsCancellationRequested) + { + return TaskUtilities.GetCancelledZeroTask(); + } + break; case FrameStreamState.Closed: throw new ObjectDisposedException(nameof(FrameRequestStream)); case FrameStreamState.Aborted: - throw new IOException("The request has been aborted."); + return TaskUtilities.GetCancelledZeroTask(); } + return null; } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameResponseStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameResponseStream.cs index f45470f55..606c112d3 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameResponseStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameResponseStream.cs @@ -5,6 +5,7 @@ using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNet.Server.Kestrel.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Http { @@ -37,16 +38,19 @@ public override long Length public override void Flush() { - ValidateState(); + ValidateState(CancellationToken.None); _context.FrameControl.Flush(); } public override Task FlushAsync(CancellationToken cancellationToken) { - ValidateState(); - - return _context.FrameControl.FlushAsync(cancellationToken); + var task = ValidateState(cancellationToken); + if (task == null) + { + return _context.FrameControl.FlushAsync(cancellationToken); + } + return task; } public override long Seek(long offset, SeekOrigin origin) @@ -66,16 +70,19 @@ public override int Read(byte[] buffer, int offset, int count) public override void Write(byte[] buffer, int offset, int count) { - ValidateState(); + ValidateState(CancellationToken.None); _context.FrameControl.Write(new ArraySegment(buffer, offset, count)); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - ValidateState(); - - return _context.FrameControl.WriteAsync(new ArraySegment(buffer, offset, count), cancellationToken); + var task = ValidateState(cancellationToken); + if (task == null) + { + return _context.FrameControl.WriteAsync(new ArraySegment(buffer, offset, count), cancellationToken); + } + return task; } public Stream StartAcceptingWrites() @@ -112,24 +119,33 @@ public void StopAcceptingWrites() public void Abort() { // We don't want to throw an ODE until the app func actually completes. - // If the request is aborted, we throw an IOException instead. if (_state != FrameStreamState.Closed) { _state = FrameStreamState.Aborted; } } - private void ValidateState() + private Task ValidateState(CancellationToken cancellationToken) { switch (_state) { case FrameStreamState.Open: - return; + if (cancellationToken.IsCancellationRequested) + { + return TaskUtilities.GetCancelledTask(cancellationToken); + } + break; case FrameStreamState.Closed: throw new ObjectDisposedException(nameof(FrameResponseStream)); case FrameStreamState.Aborted: - throw new IOException("The request has been aborted."); + if (cancellationToken.IsCancellationRequested) + { + // Aborted state only throws on write if cancellationToken requests it + return TaskUtilities.GetCancelledTask(cancellationToken); + } + break; } + return null; } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs index ce9f4aa9c..daeeeaa9e 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs @@ -2,11 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Server.Kestrel.Http { diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs index 94e335b34..83213cbea 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs @@ -5,6 +5,7 @@ using System.IO; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Http @@ -184,7 +185,7 @@ public void ConsumingComplete( public void AbortAwaiting() { - _awaitableError = new ObjectDisposedException(nameof(SocketInput), "The request was aborted"); + _awaitableError = new TaskCanceledException("The request was aborted"); Complete(); } @@ -238,6 +239,10 @@ public void GetResult() var error = _awaitableError; if (error != null) { + if (error is TaskCanceledException || error is InvalidOperationException) + { + throw error; + } throw new IOException(error.Message, error); } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketOutput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketOutput.cs index 15a56f9fc..eed0230e3 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketOutput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketOutput.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.IO; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Infrastructure; @@ -22,6 +21,7 @@ public class SocketOutput : ISocketOutput private const int _maxPooledWriteContexts = 32; private static readonly WaitCallback _returnBlocks = (state) => ReturnBlocks((MemoryPoolBlock2)state); + private static readonly Action _connectionCancellation = (state) => ((SocketOutput)state).CancellationTriggered(); private readonly KestrelThread _thread; private readonly UvStreamHandle _socket; @@ -45,6 +45,7 @@ public class SocketOutput : ISocketOutput // The number of write operations that have been scheduled so far // but have not completed. private bool _writePending = false; + private bool _cancelled = false; private int _numBytesPreCompleted = 0; private Exception _lastWriteError; private WriteContext _nextWriteContext; @@ -78,6 +79,7 @@ public SocketOutput( public Task WriteAsync( ArraySegment buffer, + CancellationToken cancellationToken, bool immediate = true, bool chunk = false, bool socketShutdownSend = false, @@ -89,9 +91,21 @@ public Task WriteAsync( lock (_contextLock) { + if (_socket.IsClosed) + { + _log.ConnectionDisconnectedWrite(_connectionId, buffer.Count, _lastWriteError); + + return TaskUtilities.CompletedTask; + } + if (buffer.Count > 0) { var tail = ProducingStart(); + if (tail.IsDefault) + { + return TaskUtilities.CompletedTask; + } + if (chunk) { _numBytesPreCompleted += ChunkWriter.WriteBeginChunkBytes(ref tail, buffer.Count); @@ -146,13 +160,36 @@ public Task WriteAsync( } else { - // immediate write, which is not eligable for instant completion above - tcs = new TaskCompletionSource(buffer.Count); - _tasksPending.Enqueue(new WaitingTask() { - CompletionSource = tcs, - BytesToWrite = buffer.Count, - IsSync = isSync - }); + if (cancellationToken.CanBeCanceled) + { + if (cancellationToken.IsCancellationRequested) + { + _connection.Abort(); + _cancelled = true; + return TaskUtilities.GetCancelledTask(cancellationToken); + } + else + { + // immediate write, which is not eligable for instant completion above + tcs = new TaskCompletionSource(); + _tasksPending.Enqueue(new WaitingTask() + { + CancellationToken = cancellationToken, + CancellationRegistration = cancellationToken.Register(_connectionCancellation, this), + BytesToWrite = buffer.Count, + CompletionSource = tcs + }); + } + } + else + { + tcs = new TaskCompletionSource(); + _tasksPending.Enqueue(new WaitingTask() { + IsSync = isSync, + BytesToWrite = buffer.Count, + CompletionSource = tcs + }); + } } if (!_writePending && immediate) @@ -177,12 +214,14 @@ public void End(ProduceEndType endType) { case ProduceEndType.SocketShutdownSend: WriteAsync(default(ArraySegment), + default(CancellationToken), immediate: true, socketShutdownSend: true, socketDisconnect: false); break; case ProduceEndType.SocketDisconnect: WriteAsync(default(ArraySegment), + default(CancellationToken), immediate: true, socketShutdownSend: false, socketDisconnect: true); @@ -198,7 +237,7 @@ public MemoryPoolIterator2 ProducingStart() if (_tail == null) { - throw new IOException("The socket has been closed."); + return default(MemoryPoolIterator2); } _lastStart = new MemoryPoolIterator2(_tail, _tail.End); @@ -251,6 +290,24 @@ private void ProducingCompleteNoPreComplete(MemoryPoolIterator2 end) } } + private void CancellationTriggered() + { + lock (_contextLock) + { + if (!_cancelled) + { + // Abort the connection for any failed write + // Queued on threadpool so get it in as first op. + _connection.Abort(); + _cancelled = true; + + CompleteAllWrites(); + + _log.ConnectionError(_connectionId, new TaskCanceledException("Write operation canceled. Aborting connection.")); + } + } + } + private static void ReturnBlocks(MemoryPoolBlock2 block) { while (block != null) @@ -296,19 +353,20 @@ private void WriteAllPending() } // This may called on the libuv event loop - // This is always called with the _contextLock already acquired private void OnWriteCompleted(WriteContext writeContext) { + // Called inside _contextLock var bytesWritten = writeContext.ByteCount; var status = writeContext.WriteStatus; var error = writeContext.WriteError; if (error != null) { - _lastWriteError = new IOException(error.Message, error); - - // Abort the connection for any failed write. + // Abort the connection for any failed write + // Queued on threadpool so get it in as first op. _connection.Abort(); + _cancelled = true; + _lastWriteError = error; } PoolWriteContext(writeContext); @@ -317,43 +375,75 @@ private void OnWriteCompleted(WriteContext writeContext) // completed writes that we haven't triggered callbacks for yet. _numBytesPreCompleted -= bytesWritten; - // bytesLeftToBuffer can be greater than _maxBytesPreCompleted - // This allows large writes to complete once they've actually finished. - var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; - while (_tasksPending.Count > 0 && - (_tasksPending.Peek().BytesToWrite) <= bytesLeftToBuffer) + if (error == null) { - var waitingTask = _tasksPending.Dequeue(); - var bytesToWrite = waitingTask.BytesToWrite; + CompleteFinishedWrites(status); + _log.ConnectionWriteCallback(_connectionId, status); + } + else + { + CompleteAllWrites(); + _log.ConnectionError(_connectionId, error); + } + } - _numBytesPreCompleted += bytesToWrite; - bytesLeftToBuffer -= bytesToWrite; + private void CompleteNextWrite(ref int bytesLeftToBuffer) + { + // Called inside _contextLock + var waitingTask = _tasksPending.Dequeue(); + var bytesToWrite = waitingTask.BytesToWrite; - if (_lastWriteError == null) + _numBytesPreCompleted += bytesToWrite; + bytesLeftToBuffer -= bytesToWrite; + + // Dispose registration if there is one + waitingTask.CancellationRegistration?.Dispose(); + + if (waitingTask.CancellationToken.IsCancellationRequested) + { + if (waitingTask.IsSync) { - if (waitingTask.IsSync) - { - waitingTask.CompletionSource.TrySetResult(null); - } - else - { - _threadPool.Complete(waitingTask.CompletionSource); - } + waitingTask.CompletionSource.TrySetCanceled(); } else { - if (waitingTask.IsSync) - { - waitingTask.CompletionSource.TrySetException(_lastWriteError); - } - else - { - _threadPool.Error(waitingTask.CompletionSource, _lastWriteError); - } + _threadPool.Cancel(waitingTask.CompletionSource); } } + else + { + if (waitingTask.IsSync) + { + waitingTask.CompletionSource.TrySetResult(null); + } + else + { + _threadPool.Complete(waitingTask.CompletionSource); + } + } + } + + private void CompleteFinishedWrites(int status) + { + // Called inside _contextLock + // bytesLeftToBuffer can be greater than _maxBytesPreCompleted + // This allows large writes to complete once they've actually finished. + var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; + while (_tasksPending.Count > 0 && + (_tasksPending.Peek().BytesToWrite) <= bytesLeftToBuffer) + { + CompleteNextWrite(ref bytesLeftToBuffer); + } + } - _log.ConnectionWriteCallback(_connectionId, status); + private void CompleteAllWrites() + { + // Called inside _contextLock + var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; + while (_tasksPending.Count > 0) + { + CompleteNextWrite(ref bytesLeftToBuffer); + } } // This is called on the libuv event loop @@ -383,7 +473,7 @@ private void ReturnAllBlocks() private void PoolWriteContext(WriteContext writeContext) { - // called inside _contextLock + // Called inside _contextLock if (_writeContextPool.Count < _maxPooledWriteContexts) { writeContext.Reset(); @@ -393,12 +483,23 @@ private void PoolWriteContext(WriteContext writeContext) void ISocketOutput.Write(ArraySegment buffer, bool immediate, bool chunk) { - WriteAsync(buffer, immediate, chunk, isSync: true).GetAwaiter().GetResult(); + WriteAsync(buffer, CancellationToken.None, immediate, chunk, isSync: true).GetAwaiter().GetResult(); } Task ISocketOutput.WriteAsync(ArraySegment buffer, bool immediate, bool chunk, CancellationToken cancellationToken) { - return WriteAsync(buffer, immediate, chunk); + if (cancellationToken.IsCancellationRequested) + { + _connection.Abort(); + _cancelled = true; + return TaskUtilities.GetCancelledTask(cancellationToken); + } + else if (_cancelled) + { + return TaskUtilities.CompletedTask; + } + + return WriteAsync(buffer, cancellationToken, immediate, chunk); } private static void BytesBetween(MemoryPoolIterator2 start, MemoryPoolIterator2 end, out int bytes, out int buffers) @@ -649,6 +750,8 @@ private struct WaitingTask { public bool IsSync; public int BytesToWrite; + public CancellationToken CancellationToken; + public IDisposable CancellationRegistration; public TaskCompletionSource CompletionSource; } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs index 1cab42524..0ed3c6565 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs @@ -29,6 +29,10 @@ public interface IKestrelTrace : ILogger void ConnectionWriteCallback(long connectionId, int status); + void ConnectionError(long connectionId, Exception ex); + + void ConnectionDisconnectedWrite(long connectionId, int count, Exception ex); + void ApplicationError(Exception ex); } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IThreadPool.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IThreadPool.cs index 404bc01a5..f9217bd99 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IThreadPool.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IThreadPool.cs @@ -9,6 +9,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure public interface IThreadPool { void Complete(TaskCompletionSource tcs); + void Cancel(TaskCompletionSource tcs); void Error(TaskCompletionSource tcs, Exception ex); void Run(Action action); } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs index f9b82ce25..5e8a2ee28 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs @@ -21,6 +21,8 @@ public class KestrelTrace : IKestrelTrace private static readonly Action _connectionWroteFin; private static readonly Action _connectionKeepAlive; private static readonly Action _connectionDisconnect; + private static readonly Action _connectionError; + private static readonly Action _connectionDisconnectedWrite; protected readonly ILogger _logger; @@ -39,6 +41,8 @@ static KestrelTrace() // ConnectionWrite: Reserved: 11 // ConnectionWriteCallback: Reserved: 12 // ApplicationError: Reserved: 13 - LoggerMessage.Define overload not present + _connectionError = LoggerMessage.Define(LogLevel.Information, 14, @"Connection id ""{ConnectionId}"" communication error"); + _connectionDisconnectedWrite = LoggerMessage.Define(LogLevel.Debug, 15, @"Connection id ""{ConnectionId}"" write of ""{count}"" bytes to disconnected client."); } public KestrelTrace(ILogger logger) @@ -114,6 +118,16 @@ public virtual void ApplicationError(Exception ex) _logger.LogError(13, "An unhandled exception was thrown by the application.", ex); } + public virtual void ConnectionError(long connectionId, Exception ex) + { + _connectionError(_logger, connectionId, ex); + } + + public virtual void ConnectionDisconnectedWrite(long connectionId, int count, Exception ex) + { + _connectionDisconnectedWrite(_logger, connectionId, count, ex); + } + public virtual void Log(LogLevel logLevel, int eventId, object state, Exception exception, Func formatter) { _logger.Log(logLevel, eventId, state, exception, formatter); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/LoggingThreadPool.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/LoggingThreadPool.cs index 70f142f53..a5f41987d 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/LoggingThreadPool.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/LoggingThreadPool.cs @@ -12,6 +12,7 @@ public class LoggingThreadPool : IThreadPool private readonly IKestrelTrace _log; private readonly WaitCallback _runAction; + private readonly WaitCallback _cancelTcs; private readonly WaitCallback _completeTcs; public LoggingThreadPool(IKestrelTrace log) @@ -42,6 +43,18 @@ public LoggingThreadPool(IKestrelTrace log) _log.ApplicationError(e); } }; + + _cancelTcs = (o) => + { + try + { + ((TaskCompletionSource)o).TrySetCanceled(); + } + catch (Exception e) + { + _log.ApplicationError(e); + } + }; } public void Run(Action action) @@ -54,6 +67,11 @@ public void Complete(TaskCompletionSource tcs) ThreadPool.QueueUserWorkItem(_completeTcs, tcs); } + public void Cancel(TaskCompletionSource tcs) + { + ThreadPool.QueueUserWorkItem(_cancelTcs, tcs); + } + public void Error(TaskCompletionSource tcs, Exception ex) { // ex ang _log are closure captured diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs index 62b19ced9..8e22ee401 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs @@ -724,6 +724,11 @@ public void CopyFrom(ArraySegment buffer) public void CopyFrom(byte[] data, int offset, int count) { + if (IsDefault) + { + return; + } + Debug.Assert(_block != null); Debug.Assert(_block.Next == null); Debug.Assert(_block.End == _index); @@ -766,6 +771,11 @@ public void CopyFrom(byte[] data, int offset, int count) public unsafe void CopyFromAscii(string data) { + if (IsDefault) + { + return; + } + Debug.Assert(_block != null); Debug.Assert(_block.Next == null); Debug.Assert(_block.End == _index); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs index a59713eea..5e52222d3 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure @@ -13,5 +14,24 @@ public static class TaskUtilities public static Task CompletedTask = Task.FromResult(null); #endif public static Task ZeroTask = Task.FromResult(0); + + public static Task GetCancelledTask(CancellationToken cancellationToken) + { +#if DOTNET5_4 + return Task.FromCanceled(cancellationToken); +#else + var tcs = new TaskCompletionSource(); + tcs.TrySetCanceled(); + return tcs.Task; +#endif + } + + public static Task GetCancelledZeroTask() + { + // Task.FromCanceled doesn't return Task + var tcs = new TaskCompletionSource(); + tcs.TrySetCanceled(); + return tcs.Task; + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Networking/Libuv.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Networking/Libuv.cs index 037f77eaa..792e92887 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Networking/Libuv.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Networking/Libuv.cs @@ -98,6 +98,11 @@ public Libuv() } } + // Second ctor that doesn't set any fields only to be used by MockLibuv + internal Libuv(bool onlyForTesting) + { + } + public readonly bool IsWindows; public int Check(int statusCode) diff --git a/test/Microsoft.AspNet.Server.KestrelTests/TestHelpers/MockConnection.cs b/test/Microsoft.AspNet.Server.KestrelTests/TestHelpers/MockConnection.cs new file mode 100644 index 000000000..20a937213 --- /dev/null +++ b/test/Microsoft.AspNet.Server.KestrelTests/TestHelpers/MockConnection.cs @@ -0,0 +1,25 @@ +using System.Threading; +using Microsoft.AspNet.Server.Kestrel.Http; +using Microsoft.AspNet.Server.Kestrel.Networking; + +namespace Microsoft.AspNet.Server.KestrelTests.TestHelpers +{ + public class MockConnection : Connection + { + public MockConnection(UvStreamHandle socket) + : base (new ListenerContext(), socket) + { + + } + + public override void Abort() + { + if (RequestAbortedSource != null) + { + RequestAbortedSource.Cancel(); + } + } + + public CancellationTokenSource RequestAbortedSource { get; set; } + } +} diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs index a44f09612..811cf121a 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs @@ -1084,7 +1084,7 @@ await connection.ReceiveEnd( } } - await Assert.ThrowsAsync(async () => await readTcs.Task); + await Assert.ThrowsAsync(async () => await readTcs.Task); // The cancellation token for only the last request should be triggered. var abortedRequestId = await registrationTcs.Task; @@ -1096,6 +1096,12 @@ await connection.ReceiveEnd( [FrameworkSkipCondition(RuntimeFrameworks.Mono, SkipReason = "Test hangs after execution on Mono.")] public async Task FailedWritesResultInAbortedRequest(ServiceContext testContext) { + const int resetEventTimeout = 2000; + // This should match _maxBytesPreCompleted in SocketOutput + const int maxBytesPreCompleted = 65536; + // Ensure string is long enough to disable write-behind buffering + var largeString = new string('a', maxBytesPreCompleted + 1); + var writeTcs = new TaskCompletionSource(); var registrationWh = new ManualResetEventSlim(); var connectionCloseWh = new ManualResetEventSlim(); @@ -1112,27 +1118,22 @@ public async Task FailedWritesResultInAbortedRequest(ServiceContext testContext) connectionCloseWh.Wait(); response.Headers.Clear(); - response.Headers["Content-Length"] = new[] { "5" }; try { // Ensure write is long enough to disable write-behind buffering - for (int i = 0; i < 10; i++) + for (int i = 0; i < 100; i++) { - await response.WriteAsync(new string('a', 65537)); + await response.WriteAsync(largeString, lifetime.RequestAborted).ConfigureAwait(false); } } catch (Exception ex) { writeTcs.SetException(ex); - - // Give a chance for RequestAborted to trip before the app completes - registrationWh.Wait(1000); - throw; } - writeTcs.SetCanceled(); + writeTcs.SetException(new Exception("This shouldn't be reached.")); }, testContext)) { using (var connection = new TestConnection()) @@ -1141,16 +1142,16 @@ await connection.Send( "POST / HTTP/1.1", "Content-Length: 5", "", - "Hello"); + "Hello").ConfigureAwait(false); // Don't wait to receive the response. Just close the socket. } connectionCloseWh.Set(); // Write failed - await Assert.ThrowsAsync(async () => await writeTcs.Task); + await Assert.ThrowsAsync(async () => await writeTcs.Task); // RequestAborted tripped - Assert.True(registrationWh.Wait(200)); + Assert.True(registrationWh.Wait(resetEventTimeout)); } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs index f9d1f7ac2..47d9eb2a8 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs @@ -102,7 +102,7 @@ public void ConsumingOutOfOrderFailsGracefully() private static void TestConcurrentFaultedTask(Task t) { Assert.True(t.IsFaulted); - Assert.IsType(typeof(System.IO.IOException), t.Exception.InnerException); + Assert.IsType(typeof(System.InvalidOperationException), t.Exception.InnerException); Assert.Equal(t.Exception.InnerException.Message, "Concurrent reads are not supported."); } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs index 01c42c5f1..974829e93 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs @@ -50,7 +50,7 @@ public void CanWrite1MB() var completedWh = new ManualResetEventSlim(); // Act - socketOutput.WriteAsync(buffer).ContinueWith( + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith( (t) => { Assert.Null(t.Exception); @@ -101,14 +101,14 @@ public void WritesDontCompleteImmediatelyWhenTooManyBytesAreAlreadyPreCompleted( }; // Act - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // The first write should pre-complete since it is <= _maxBytesPreCompleted. Assert.True(completedWh.Wait(1000)); // Arrange completedWh.Reset(); // Act - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // Too many bytes are already pre-completed for the second write to pre-complete. Assert.False(completedWh.Wait(1000)); @@ -162,28 +162,28 @@ public void WritesDontCompleteImmediatelyWhenTooManyBytesIncludingNonImmediateAr }; // Act - socketOutput.WriteAsync(halfBuffer, false).ContinueWith(onCompleted); + socketOutput.WriteAsync(halfBuffer, default(CancellationToken), false).ContinueWith(onCompleted); // Assert // The first write should pre-complete since it is not immediate. Assert.True(completedWh.Wait(1000)); // Arrange completedWh.Reset(); // Act - socketOutput.WriteAsync(halfBuffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(halfBuffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // The second write should pre-complete since it is <= _maxBytesPreCompleted. Assert.True(completedWh.Wait(1000)); // Arrange completedWh.Reset(); // Act - socketOutput.WriteAsync(halfBuffer, false).ContinueWith(onCompleted); + socketOutput.WriteAsync(halfBuffer, default(CancellationToken), false).ContinueWith(onCompleted); // Assert // The third write should pre-complete since it is not immediate, even though too many. Assert.True(completedWh.Wait(1000)); // Arrange completedWh.Reset(); // Act - socketOutput.WriteAsync(halfBuffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(halfBuffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // Too many bytes are already pre-completed for the fourth write to pre-complete. Assert.False(completedWh.Wait(1000)); @@ -198,6 +198,199 @@ public void WritesDontCompleteImmediatelyWhenTooManyBytesIncludingNonImmediateAr } } + [Fact] + public async Task OnlyWritesRequestingCancellationAreErroredOnCancellation() + { + // This should match _maxBytesPreCompleted in SocketOutput + var maxBytesPreCompleted = 65536; + var completeQueue = new Queue>(); + + // Arrange + var mockLibuv = new MockLibuv + { + OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + } + }; + + using (var kestrelEngine = new KestrelEngine(mockLibuv, new TestServiceContext())) + using (var memory = new MemoryPool2()) + { + kestrelEngine.Start(count: 1); + + var kestrelThread = kestrelEngine.Threads[0]; + var socket = new MockSocket(kestrelThread.Loop.ThreadId, new TestKestrelTrace()); + var trace = new KestrelTrace(new TestKestrelTrace()); + var ltp = new LoggingThreadPool(trace); + ISocketOutput socketOutput = new SocketOutput(kestrelThread, socket, memory, new MockConnection(socket), 0, trace, ltp, new Queue()); + + var bufferSize = maxBytesPreCompleted; + + var data = new byte[bufferSize]; + var fullBuffer = new ArraySegment(data, 0, bufferSize); + + var cts = new CancellationTokenSource(); + + // Act + var task1Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token); + // task1 should complete successfully as < _maxBytesPreCompleted + + // First task is completed and successful + Assert.True(task1Success.IsCompleted); + Assert.False(task1Success.IsCanceled); + Assert.False(task1Success.IsFaulted); + + task1Success.GetAwaiter().GetResult(); + + // following tasks should wait. + + var task2Throw = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token); + var task3Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: default(CancellationToken)); + + // Give time for tasks to percolate + await Task.Delay(1000).ConfigureAwait(false); + + // Second task is not completed + Assert.False(task2Throw.IsCompleted); + Assert.False(task2Throw.IsCanceled); + Assert.False(task2Throw.IsFaulted); + + // Third task is not completed + Assert.False(task3Success.IsCompleted); + Assert.False(task3Success.IsCanceled); + Assert.False(task3Success.IsFaulted); + + cts.Cancel(); + + // Give time for tasks to percolate + await Task.Delay(1000).ConfigureAwait(false); + + // Second task is now canceled + Assert.True(task2Throw.IsCompleted); + Assert.True(task2Throw.IsCanceled); + Assert.False(task2Throw.IsFaulted); + + // Third task is now completed + Assert.True(task3Success.IsCompleted); + Assert.False(task3Success.IsCanceled); + Assert.False(task3Success.IsFaulted); + + // Fourth task immediately cancels as the token is canceled + var task4Throw = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token); + + Assert.True(task4Throw.IsCompleted); + Assert.True(task4Throw.IsCanceled); + Assert.False(task4Throw.IsFaulted); + + Assert.Throws(() => task4Throw.GetAwaiter().GetResult()); + + var task5Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: default(CancellationToken)); + // task5 should complete immediately + + Assert.True(task5Success.IsCompleted); + Assert.False(task5Success.IsCanceled); + Assert.False(task5Success.IsFaulted); + + cts = new CancellationTokenSource(); + + var task6Throw = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token); + // task6 should complete immediately but not cancel as its cancellation token isn't set + + Assert.True(task6Throw.IsCompleted); + Assert.False(task6Throw.IsCanceled); + Assert.False(task6Throw.IsFaulted); + + task6Throw.GetAwaiter().GetResult(); + + Assert.True(true); + } + } + + [Fact] + public async Task FailedWriteCompletesOrCancelsAllPendingTasks() + { + // This should match _maxBytesPreCompleted in SocketOutput + var maxBytesPreCompleted = 65536; + var completeQueue = new Queue>(); + + // Arrange + var mockLibuv = new MockLibuv + { + OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + } + }; + + using (var kestrelEngine = new KestrelEngine(mockLibuv, new TestServiceContext())) + using (var memory = new MemoryPool2()) + using (var abortedSource = new CancellationTokenSource()) + { + kestrelEngine.Start(count: 1); + + var kestrelThread = kestrelEngine.Threads[0]; + var socket = new MockSocket(kestrelThread.Loop.ThreadId, new TestKestrelTrace()); + var trace = new KestrelTrace(new TestKestrelTrace()); + var ltp = new LoggingThreadPool(trace); + + var mockConnection = new MockConnection(socket); + mockConnection.RequestAbortedSource = abortedSource; + ISocketOutput socketOutput = new SocketOutput(kestrelThread, socket, memory, mockConnection, 0, trace, ltp, new Queue()); + + var bufferSize = maxBytesPreCompleted; + + var data = new byte[bufferSize]; + var fullBuffer = new ArraySegment(data, 0, bufferSize); + + // Act + var task1Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: abortedSource.Token); + // task1 should complete successfully as < _maxBytesPreCompleted + + // First task is completed and successful + Assert.True(task1Success.IsCompleted); + Assert.False(task1Success.IsCanceled); + Assert.False(task1Success.IsFaulted); + + task1Success.GetAwaiter().GetResult(); + + // following tasks should wait. + var task2Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: CancellationToken.None); + var task3Canceled = socketOutput.WriteAsync(fullBuffer, cancellationToken: abortedSource.Token); + + // Give time for tasks to percolate + await Task.Delay(1000).ConfigureAwait(false); + + // Second task is not completed + Assert.False(task2Success.IsCompleted); + Assert.False(task2Success.IsCanceled); + Assert.False(task2Success.IsFaulted); + + // Third task is not completed + Assert.False(task3Canceled.IsCompleted); + Assert.False(task3Canceled.IsCanceled); + Assert.False(task3Canceled.IsFaulted); + + // Cause the first write to fail. + completeQueue.Dequeue()(-1); + + // Give time for tasks to percolate + await Task.Delay(1000).ConfigureAwait(false); + + // Second task is now completed + Assert.True(task2Success.IsCompleted); + Assert.False(task2Success.IsCanceled); + Assert.False(task2Success.IsFaulted); + + // Third task is now canceled + Assert.True(task3Canceled.IsCompleted); + Assert.True(task3Canceled.IsCanceled); + Assert.False(task3Canceled.IsFaulted); + } + } + [Fact] public void WritesDontGetCompletedTooQuickly() { @@ -247,7 +440,7 @@ public void WritesDontGetCompletedTooQuickly() }; // Act (Pre-complete the maximum number of bytes in preparation for the rest of the test) - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // The first write should pre-complete since it is <= _maxBytesPreCompleted. Assert.True(completedWh.Wait(1000)); @@ -257,8 +450,8 @@ public void WritesDontGetCompletedTooQuickly() onWriteWh.Reset(); // Act - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted); - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted2); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted2); Assert.True(onWriteWh.Wait(1000)); completeQueue.Dequeue()(0); @@ -320,7 +513,7 @@ public void ProducingStartAndProducingCompleteCanBeUsedDirectly() socketOutput.ProducingComplete(end); // A call to Write is required to ensure a write is scheduled - socketOutput.WriteAsync(default(ArraySegment)); + socketOutput.WriteAsync(default(ArraySegment), default(CancellationToken)); Assert.True(nBufferWh.Wait(1000)); Assert.Equal(2, nBuffers); diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/TestHelpers/MockLibuv.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/TestHelpers/MockLibuv.cs index 4cc9549d7..2f2f2f110 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/TestHelpers/MockLibuv.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/TestHelpers/MockLibuv.cs @@ -15,6 +15,7 @@ public class MockLibuv : Libuv private Func, int> _onWrite; unsafe public MockLibuv() + : base(onlyForTesting: true) { _uv_write = UvWrite; @@ -66,6 +67,8 @@ unsafe public MockLibuv() _uv_close = (handle, callback) => callback(handle); _uv_loop_close = handle => 0; _uv_walk = (loop, callback, ignore) => 0; + _uv_err_name = errno => IntPtr.Zero; + _uv_strerror = errno => IntPtr.Zero; } public Func, int> OnWrite