diff --git a/src/Microsoft.AspNet.Server.Kestrel/Filter/StreamSocketOutput.cs b/src/Microsoft.AspNet.Server.Kestrel/Filter/StreamSocketOutput.cs index f5c9795af..6ee20338f 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Filter/StreamSocketOutput.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Filter/StreamSocketOutput.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNet.Server.Kestrel.Http; @@ -12,27 +13,44 @@ namespace Microsoft.AspNet.Server.Kestrel.Filter { public class StreamSocketOutput : ISocketOutput { + private static readonly byte[] _endChunkBytes = Encoding.ASCII.GetBytes("\r\n"); private static readonly byte[] _nullBuffer = new byte[0]; private readonly Stream _outputStream; private readonly MemoryPool2 _memory; private MemoryPoolBlock2 _producingBlock; + private object _writeLock = new object(); + public StreamSocketOutput(Stream outputStream, MemoryPool2 memory) { _outputStream = outputStream; _memory = memory; } - void ISocketOutput.Write(ArraySegment buffer, bool immediate) + public void Write(ArraySegment buffer, bool immediate, bool chunk) { - _outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count); + lock (_writeLock) + { + if (chunk && buffer.Array != null) + { + var beginChunkBytes = ChunkWriter.BeginChunkBytes(buffer.Count); + _outputStream.Write(beginChunkBytes.Array, beginChunkBytes.Offset, beginChunkBytes.Count); + } + + _outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count); + + if (chunk && buffer.Array != null) + { + _outputStream.Write(_endChunkBytes, 0, _endChunkBytes.Length); + } + } } - Task ISocketOutput.WriteAsync(ArraySegment buffer, bool immediate, CancellationToken cancellationToken) + public Task WriteAsync(ArraySegment buffer, bool immediate, bool chunk, CancellationToken cancellationToken) { // TODO: Use _outputStream.WriteAsync - _outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count); + Write(buffer, immediate, chunk); return TaskUtilities.CompletedTask; } diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/ChunkWriter.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/ChunkWriter.cs new file mode 100644 index 000000000..2d3885700 --- /dev/null +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/ChunkWriter.cs @@ -0,0 +1,62 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Text; +using Microsoft.AspNet.Server.Kestrel.Infrastructure; + +namespace Microsoft.AspNet.Server.Kestrel.Http +{ + public static class ChunkWriter + { + private static readonly ArraySegment _endChunkBytes = CreateAsciiByteArraySegment("\r\n"); + private static readonly byte[] _hex = Encoding.ASCII.GetBytes("0123456789abcdef"); + + private static ArraySegment CreateAsciiByteArraySegment(string text) + { + var bytes = Encoding.ASCII.GetBytes(text); + return new ArraySegment(bytes); + } + + public static ArraySegment BeginChunkBytes(int dataCount) + { + var bytes = new byte[10] + { + _hex[((dataCount >> 0x1c) & 0x0f)], + _hex[((dataCount >> 0x18) & 0x0f)], + _hex[((dataCount >> 0x14) & 0x0f)], + _hex[((dataCount >> 0x10) & 0x0f)], + _hex[((dataCount >> 0x0c) & 0x0f)], + _hex[((dataCount >> 0x08) & 0x0f)], + _hex[((dataCount >> 0x04) & 0x0f)], + _hex[((dataCount >> 0x00) & 0x0f)], + (byte)'\r', + (byte)'\n', + }; + + // Determine the most-significant non-zero nibble + int total, shift; + total = (dataCount > 0xffff) ? 0x10 : 0x00; + dataCount >>= total; + shift = (dataCount > 0x00ff) ? 0x08 : 0x00; + dataCount >>= shift; + total |= shift; + total |= (dataCount > 0x000f) ? 0x04 : 0x00; + + var offset = 7 - (total >> 2); + return new ArraySegment(bytes, offset, 10 - offset); + } + + public static int WriteBeginChunkBytes(ref MemoryPoolIterator2 start, int dataCount) + { + var chunkSegment = BeginChunkBytes(dataCount); + start.CopyFrom(chunkSegment); + return chunkSegment.Count; + } + + public static void WriteEndChunkBytes(ref MemoryPoolIterator2 start) + { + start.CopyFrom(_endChunkBytes); + } + } +} diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs index 8b6dd828a..f315bb702 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs @@ -23,11 +23,9 @@ namespace Microsoft.AspNet.Server.Kestrel.Http public abstract partial class Frame : FrameContext, IFrameControl { private static readonly Encoding _ascii = Encoding.ASCII; - private static readonly ArraySegment _endChunkBytes = CreateAsciiByteArraySegment("\r\n"); private static readonly ArraySegment _endChunkedResponseBytes = CreateAsciiByteArraySegment("0\r\n\r\n"); private static readonly ArraySegment _continueBytes = CreateAsciiByteArraySegment("HTTP/1.1 100 Continue\r\n\r\n"); private static readonly ArraySegment _emptyData = new ArraySegment(new byte[0]); - private static readonly byte[] _hex = Encoding.ASCII.GetBytes("0123456789abcdef"); private static readonly byte[] _bytesConnectionClose = Encoding.ASCII.GetBytes("\r\nConnection: close"); private static readonly byte[] _bytesConnectionKeepAlive = Encoding.ASCII.GetBytes("\r\nConnection: keep-alive"); @@ -472,45 +470,12 @@ public async Task WriteAsyncAwaited(ArraySegment data, CancellationToken c private void WriteChunked(ArraySegment data) { - SocketOutput.Write(BeginChunkBytes(data.Count), immediate: false); - SocketOutput.Write(data, immediate: false); - SocketOutput.Write(_endChunkBytes, immediate: true); + SocketOutput.Write(data, immediate: false, chunk: true); } private async Task WriteChunkedAsync(ArraySegment data, CancellationToken cancellationToken) { - await SocketOutput.WriteAsync(BeginChunkBytes(data.Count), immediate: false, cancellationToken: cancellationToken); - await SocketOutput.WriteAsync(data, immediate: false, cancellationToken: cancellationToken); - await SocketOutput.WriteAsync(_endChunkBytes, immediate: true, cancellationToken: cancellationToken); - } - - public static ArraySegment BeginChunkBytes(int dataCount) - { - var bytes = new byte[10] - { - _hex[((dataCount >> 0x1c) & 0x0f)], - _hex[((dataCount >> 0x18) & 0x0f)], - _hex[((dataCount >> 0x14) & 0x0f)], - _hex[((dataCount >> 0x10) & 0x0f)], - _hex[((dataCount >> 0x0c) & 0x0f)], - _hex[((dataCount >> 0x08) & 0x0f)], - _hex[((dataCount >> 0x04) & 0x0f)], - _hex[((dataCount >> 0x00) & 0x0f)], - (byte)'\r', - (byte)'\n', - }; - - // Determine the most-significant non-zero nibble - int total, shift; - total = (dataCount > 0xffff) ? 0x10 : 0x00; - dataCount >>= total; - shift = (dataCount > 0x00ff) ? 0x08 : 0x00; - dataCount >>= shift; - total |= shift; - total |= (dataCount > 0x000f) ? 0x04 : 0x00; - - var offset = 7 - (total >> 2); - return new ArraySegment(bytes, offset, 10 - offset); + await SocketOutput.WriteAsync(data, immediate: false, chunk: true, cancellationToken: cancellationToken); } private void WriteChunkedResponseSuffix() diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/ISocketOutput.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/ISocketOutput.cs index 597326ca8..4d9a08357 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/ISocketOutput.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/ISocketOutput.cs @@ -13,8 +13,8 @@ namespace Microsoft.AspNet.Server.Kestrel.Http /// public interface ISocketOutput { - void Write(ArraySegment buffer, bool immediate = true); - Task WriteAsync(ArraySegment buffer, bool immediate = true, CancellationToken cancellationToken = default(CancellationToken)); + void Write(ArraySegment buffer, bool immediate = true, bool chunk = false); + Task WriteAsync(ArraySegment buffer, bool immediate = true, bool chunk = false, CancellationToken cancellationToken = default(CancellationToken)); /// /// Returns an iterator pointing to the tail of the response buffer. Response data can be appended diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketOutput.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketOutput.cs index f3e8adc25..cf3f8dc52 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketOutput.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketOutput.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using Microsoft.AspNet.Server.Kestrel.Infrastructure; using Microsoft.AspNet.Server.Kestrel.Networking; +using Microsoft.Extensions.Logging; namespace Microsoft.AspNet.Server.Kestrel.Http { @@ -29,7 +30,7 @@ public class SocketOutput : ISocketOutput private readonly IKestrelTrace _log; private readonly IThreadPool _threadPool; - // This locks all access to _tail, _isProducing and _returnFromOnProducingComplete. + // This locks all access to _tail and _lastStart. // _head does not require a lock, since it is only used in the ctor and uv thread. private readonly object _returnLock = new object(); @@ -48,7 +49,6 @@ public class SocketOutput : ISocketOutput private Exception _lastWriteError; private WriteContext _nextWriteContext; private readonly Queue> _tasksPending; - private readonly Queue> _tasksCompleted; private readonly Queue _writeContextPool; private readonly Queue _writeReqPool; @@ -69,7 +69,6 @@ public SocketOutput( _log = log; _threadPool = threadPool; _tasksPending = new Queue>(_initialTaskQueues); - _tasksCompleted = new Queue>(_initialTaskQueues); _writeContextPool = new Queue(_maxPooledWriteContexts); _writeReqPool = writeReqPool; @@ -80,22 +79,35 @@ public SocketOutput( public Task WriteAsync( ArraySegment buffer, bool immediate = true, + bool chunk = false, bool socketShutdownSend = false, bool socketDisconnect = false) { - if (buffer.Count > 0) - { - var tail = ProducingStart(); - tail.CopyFrom(buffer); - // We do our own accounting below - ProducingCompleteNoPreComplete(tail); - } TaskCompletionSource tcs = null; - var scheduleWrite = false; lock (_contextLock) { + if (buffer.Count > 0) + { + var tail = ProducingStart(); + if (chunk) + { + _numBytesPreCompleted += ChunkWriter.WriteBeginChunkBytes(ref tail, buffer.Count); + } + + tail.CopyFrom(buffer); + + if (chunk) + { + ChunkWriter.WriteEndChunkBytes(ref tail); + _numBytesPreCompleted += 2; + } + + // We do our own accounting below + ProducingCompleteNoPreComplete(tail); + } + if (_nextWriteContext == null) { if (_writeContextPool.Count > 0) @@ -253,9 +265,9 @@ private void ScheduleWrite() // This is called on the libuv event loop private void WriteAllPending() { - WriteContext writingContext; + WriteContext writingContext = null; - lock (_contextLock) + if (Monitor.TryEnter(_contextLock)) { _writePending = false; @@ -264,23 +276,28 @@ private void WriteAllPending() writingContext = _nextWriteContext; _nextWriteContext = null; } - else - { - return; - } + + Monitor.Exit(_contextLock); + } + else + { + ScheduleWrite(); } - writingContext.DoWriteIfNeeded(); + if (writingContext != null) + { + writingContext.DoWriteIfNeeded(); + } } - // This is called on the libuv event loop + // This may called on the libuv event loop + // This is always called with the _contextLock already acquired private void OnWriteCompleted(WriteContext writeContext) { var bytesWritten = writeContext.ByteCount; var status = writeContext.WriteStatus; var error = writeContext.WriteError; - if (error != null) { _lastWriteError = new IOException(error.Message, error); @@ -289,33 +306,24 @@ private void OnWriteCompleted(WriteContext writeContext) _connection.Abort(); } - lock (_contextLock) - { - PoolWriteContext(writeContext); + PoolWriteContext(writeContext); - // _numBytesPreCompleted can temporarily go negative in the event there are - // completed writes that we haven't triggered callbacks for yet. - _numBytesPreCompleted -= bytesWritten; - - // bytesLeftToBuffer can be greater than _maxBytesPreCompleted - // This allows large writes to complete once they've actually finished. - var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; - while (_tasksPending.Count > 0 && - (int)(_tasksPending.Peek().Task.AsyncState) <= bytesLeftToBuffer) - { - var tcs = _tasksPending.Dequeue(); - var bytesToWrite = (int)tcs.Task.AsyncState; + // _numBytesPreCompleted can temporarily go negative in the event there are + // completed writes that we haven't triggered callbacks for yet. + _numBytesPreCompleted -= bytesWritten; - _numBytesPreCompleted += bytesToWrite; - bytesLeftToBuffer -= bytesToWrite; + // bytesLeftToBuffer can be greater than _maxBytesPreCompleted + // This allows large writes to complete once they've actually finished. + var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; + while (_tasksPending.Count > 0 && + (int)(_tasksPending.Peek().Task.AsyncState) <= bytesLeftToBuffer) + { + var tcs = _tasksPending.Dequeue(); + var bytesToWrite = (int)tcs.Task.AsyncState; - _tasksCompleted.Enqueue(tcs); - } - } + _numBytesPreCompleted += bytesToWrite; + bytesLeftToBuffer -= bytesToWrite; - while (_tasksCompleted.Count > 0) - { - var tcs = _tasksCompleted.Dequeue(); if (_lastWriteError == null) { _threadPool.Complete(tcs); @@ -327,7 +335,6 @@ private void OnWriteCompleted(WriteContext writeContext) } _log.ConnectionWriteCallback(_connectionId, status); - _tasksCompleted.Clear(); } // This is called on the libuv event loop @@ -365,9 +372,9 @@ private void PoolWriteContext(WriteContext writeContext) } } - void ISocketOutput.Write(ArraySegment buffer, bool immediate) + void ISocketOutput.Write(ArraySegment buffer, bool immediate, bool chunk) { - var task = WriteAsync(buffer, immediate); + var task = WriteAsync(buffer, immediate, chunk); if (task.Status == TaskStatus.RanToCompletion) { @@ -379,9 +386,9 @@ void ISocketOutput.Write(ArraySegment buffer, bool immediate) } } - Task ISocketOutput.WriteAsync(ArraySegment buffer, bool immediate, CancellationToken cancellationToken) + Task ISocketOutput.WriteAsync(ArraySegment buffer, bool immediate, bool chunk, CancellationToken cancellationToken) { - return WriteAsync(buffer, immediate); + return WriteAsync(buffer, immediate, chunk); } private static void BytesBetween(MemoryPoolIterator2 start, MemoryPoolIterator2 end, out int bytes, out int buffers) @@ -409,6 +416,7 @@ private static void BytesBetween(MemoryPoolIterator2 start, MemoryPoolIterator2 private class WriteContext { private static WaitCallback _returnWrittenBlocks = (state) => ReturnWrittenBlocks((MemoryPoolBlock2)state); + private static WaitCallback _completeWrite = (state) => ((WriteContext)state).CompleteOnThreadPool(); private SocketOutput Self; private UvWriteReq _writeReq; @@ -504,19 +512,48 @@ public void DoDisconnectIfNeeded() { if (SocketDisconnect == false || Self._socket.IsClosed) { - Complete(); + CompleteOnUvThread(); return; } Self._socket.Dispose(); Self.ReturnAllBlocks(); Self._log.ConnectionStop(Self._connectionId); - Complete(); + CompleteOnUvThread(); + } + + public void CompleteOnUvThread() + { + if (Monitor.TryEnter(Self._contextLock)) + { + try + { + Self.OnWriteCompleted(this); + } + finally + { + Monitor.Exit(Self._contextLock); + } + } + else + { + ThreadPool.QueueUserWorkItem(_completeWrite, this); + } } - public void Complete() + public void CompleteOnThreadPool() { - Self.OnWriteCompleted(this); + lock (Self._contextLock) + { + try + { + Self.OnWriteCompleted(this); + } + catch (Exception ex) + { + Self._log.LogError("SocketOutput.OnWriteCompleted", ex); + } + } } private void PoolWriteReq(UvWriteReq writeReq) diff --git a/test/Microsoft.AspNet.Server.KestrelTests/ChunkWriterTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/ChunkWriterTests.cs new file mode 100644 index 000000000..3a2dd7103 --- /dev/null +++ b/test/Microsoft.AspNet.Server.KestrelTests/ChunkWriterTests.cs @@ -0,0 +1,38 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq; +using System.Text; +using Microsoft.AspNet.Server.Kestrel.Http; +using Xunit; + +namespace Microsoft.AspNet.Server.KestrelTests +{ + public class ChunkWriterTests + { + [Theory] + [InlineData(1, "1\r\n")] + [InlineData(10, "a\r\n")] + [InlineData(0x08, "8\r\n")] + [InlineData(0x10, "10\r\n")] + [InlineData(0x080, "80\r\n")] + [InlineData(0x100, "100\r\n")] + [InlineData(0x0800, "800\r\n")] + [InlineData(0x1000, "1000\r\n")] + [InlineData(0x08000, "8000\r\n")] + [InlineData(0x10000, "10000\r\n")] + [InlineData(0x080000, "80000\r\n")] + [InlineData(0x100000, "100000\r\n")] + [InlineData(0x0800000, "800000\r\n")] + [InlineData(0x1000000, "1000000\r\n")] + [InlineData(0x08000000, "8000000\r\n")] + [InlineData(0x10000000, "10000000\r\n")] + [InlineData(0x7fffffffL, "7fffffff\r\n")] + public void ChunkedPrefixMustBeHexCrLfWithoutLeadingZeros(int dataCount, string expected) + { + var beginChunkBytes = ChunkWriter.BeginChunkBytes(dataCount); + + Assert.Equal(Encoding.ASCII.GetBytes(expected), beginChunkBytes.ToArray()); + } + } +} diff --git a/test/Microsoft.AspNet.Server.KestrelTests/FrameTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/FrameTests.cs index 793084ef5..5960510ef 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/FrameTests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/FrameTests.cs @@ -13,31 +13,6 @@ namespace Microsoft.AspNet.Server.KestrelTests { public class FrameTests { - [Theory] - [InlineData(1, "1\r\n")] - [InlineData(10, "a\r\n")] - [InlineData(0x08, "8\r\n")] - [InlineData(0x10, "10\r\n")] - [InlineData(0x080, "80\r\n")] - [InlineData(0x100, "100\r\n")] - [InlineData(0x0800, "800\r\n")] - [InlineData(0x1000, "1000\r\n")] - [InlineData(0x08000, "8000\r\n")] - [InlineData(0x10000, "10000\r\n")] - [InlineData(0x080000, "80000\r\n")] - [InlineData(0x100000, "100000\r\n")] - [InlineData(0x0800000, "800000\r\n")] - [InlineData(0x1000000, "1000000\r\n")] - [InlineData(0x08000000, "8000000\r\n")] - [InlineData(0x10000000, "10000000\r\n")] - [InlineData(0x7fffffffL, "7fffffff\r\n")] - public void ChunkedPrefixMustBeHexCrLfWithoutLeadingZeros(int dataCount, string expected) - { - var beginChunkBytes = Frame.BeginChunkBytes(dataCount); - - Assert.Equal(Encoding.ASCII.GetBytes(expected), beginChunkBytes.ToArray()); - } - [Theory] [InlineData("Cookie: \r\n\r\n", 1)] [InlineData("Cookie:\r\n\r\n", 1)]