From fb57ac1b7e7aac1b53b7235c31586288ae234353 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Sat, 7 Apr 2018 20:12:54 -0700 Subject: [PATCH 1/5] WIP --- src/Common/MemoryBufferWriter.cs | 92 +++++++++- ...crosoft.AspNetCore.Http.Connections.csproj | 1 + .../Formatters/TextMessageFormatter.cs | 5 - .../Protocol/LimitArrayPoolWriteStream.cs | 163 ------------------ .../Protocol/MessagePackHubProtocol.cs | 16 +- ...spNetCore.SignalR.Protocols.MsgPack.csproj | 4 + ...oft.AspNetCore.SignalR.Client.Tests.csproj | 1 + .../TestConnection.cs | 2 +- .../Formatters/TextMessageFormatterTests.cs | 2 +- .../Internal/Protocol/JsonHubProtocolTests.cs | 2 +- 10 files changed, 104 insertions(+), 184 deletions(-) delete mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/LimitArrayPoolWriteStream.cs diff --git a/src/Common/MemoryBufferWriter.cs b/src/Common/MemoryBufferWriter.cs index 0362196faf..5a327429dc 100644 --- a/src/Common/MemoryBufferWriter.cs +++ b/src/Common/MemoryBufferWriter.cs @@ -5,11 +5,12 @@ using System.Buffers; using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNetCore.Internal { - internal sealed class MemoryBufferWriter : IBufferWriter + internal sealed class MemoryBufferWriter : Stream, IBufferWriter { [ThreadStatic] private static MemoryBufferWriter _cachedInstance; @@ -30,7 +31,15 @@ public MemoryBufferWriter(int segmentSize = 2048) _segmentSize = segmentSize; } - public int Length => _bytesWritten; + public override long Length => _bytesWritten; + + public override bool CanRead => false; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } public static MemoryBufferWriter Get() { @@ -118,29 +127,43 @@ public Span GetSpan(int sizeHint = 0) return GetMemory(sizeHint).Span; } - public Task CopyToAsync(Stream stream) + public void CopyTo(IBufferWriter destination) + { + if (_fullSegments != null) + { + // Copy full segments + for (var i = 0; i < _fullSegments.Count - 1; i++) + { + destination.Write(_fullSegments[i].AsSpan(0, _segmentSize)); + } + } + + destination.Write(_currentSegment.AsSpan(0, _position)); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) { if (_fullSegments == null) { // There is only one segment so write without async - return stream.WriteAsync(_currentSegment, 0, _position); + return destination.WriteAsync(_currentSegment, 0, _position); } - return CopyToSlowAsync(stream); + return CopyToSlowAsync(destination); } - private async Task CopyToSlowAsync(Stream stream) + private async Task CopyToSlowAsync(Stream destination) { if (_fullSegments != null) { // Copy full segments for (var i = 0; i < _fullSegments.Count - 1; i++) { - await stream.WriteAsync(_fullSegments[i], 0, _segmentSize); + await destination.WriteAsync(_fullSegments[i], 0, _segmentSize); } } - await stream.WriteAsync(_currentSegment, 0, _position); + await destination.WriteAsync(_currentSegment, 0, _position); } public byte[] ToArray() @@ -170,5 +193,58 @@ public byte[] ToArray() return result; } + + public override void Flush() + { + + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public unsafe override void WriteByte(byte value) + { + if (_currentSegment != null && _position < _segmentSize) + { + _currentSegment[_position] = value; + _bytesWritten += 1; + _position += 1; + } + else + { + BuffersExtensions.Write(this, new ReadOnlySpan(&value, 1)); + } + } + + public override void Write(byte[] buffer, int offset, int count) + { + BuffersExtensions.Write(this, buffer.AsSpan(offset, count)); + } + +#if NETCOREAPP2_1 + public override void Write(ReadOnlySpan span) + { + BuffersExtensions.Write(this, span); + } +#endif + protected override void Dispose(bool disposing) + { + if (disposing) + { + Reset(); + } + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Http.Connections/Microsoft.AspNetCore.Http.Connections.csproj b/src/Microsoft.AspNetCore.Http.Connections/Microsoft.AspNetCore.Http.Connections.csproj index cd36dcd2e0..e7bc7ccdae 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Microsoft.AspNetCore.Http.Connections.csproj +++ b/src/Microsoft.AspNetCore.Http.Connections/Microsoft.AspNetCore.Http.Connections.csproj @@ -3,6 +3,7 @@ Components for providing real-time bi-directional communication across the Web. netstandard2.0;netcoreapp2.1 + true diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageFormatter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageFormatter.cs index 223fd59d6a..058172984e 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageFormatter.cs @@ -12,11 +12,6 @@ public static class TextMessageFormatter // will not occur (is not a valid character) and therefore it is safe to not escape it public static readonly byte RecordSeparator = 0x1e; - public static void WriteRecordSeparator(Stream output) - { - output.WriteByte(RecordSeparator); - } - public static void WriteRecordSeparator(IBufferWriter output) { var buffer = output.GetSpan(1); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/LimitArrayPoolWriteStream.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/LimitArrayPoolWriteStream.cs deleted file mode 100644 index 257ee1ca6b..0000000000 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/LimitArrayPoolWriteStream.cs +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Buffers; -using System.Diagnostics; -using System.IO; -using System.Threading; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.SignalR.Internal.Protocol -{ - public sealed class LimitArrayPoolWriteStream : Stream - { - private const int MaxByteArrayLength = 0x7FFFFFC7; - private const int InitialLength = 256; - - private readonly int _maxBufferSize; - private byte[] _buffer; - private int _length; - - public LimitArrayPoolWriteStream() : this(MaxByteArrayLength) { } - - public LimitArrayPoolWriteStream(int maxBufferSize) : this(maxBufferSize, InitialLength) { } - - public LimitArrayPoolWriteStream(int maxBufferSize, long capacity) - { - if (capacity < InitialLength) - { - capacity = InitialLength; - } - else if (capacity > maxBufferSize) - { - throw CreateOverCapacityException(maxBufferSize); - } - - _maxBufferSize = maxBufferSize; - _buffer = ArrayPool.Shared.Rent((int)capacity); - } - - protected override void Dispose(bool disposing) - { - if (_buffer != null) - { - ArrayPool.Shared.Return(_buffer); - _buffer = null; - } - - base.Dispose(disposing); - } - - public ArraySegment GetBuffer() => new ArraySegment(_buffer, 0, _length); - - public byte[] ToArray() - { - var arr = new byte[_length]; - Buffer.BlockCopy(_buffer, 0, arr, 0, _length); - return arr; - } - - private void EnsureCapacity(int value) - { - if ((uint)value > (uint)_maxBufferSize) // value cast handles overflow to negative as well - { - throw CreateOverCapacityException(_maxBufferSize); - } - else if (value > _buffer.Length) - { - Grow(value); - } - } - - private void Grow(int value) - { - Debug.Assert(value > _buffer.Length); - - // Extract the current buffer to be replaced. - var currentBuffer = _buffer; - _buffer = null; - - // Determine the capacity to request for the new buffer. It should be - // at least twice as long as the current one, if not more if the requested - // value is more than that. If the new value would put it longer than the max - // allowed byte array, than shrink to that (and if the required length is actually - // longer than that, we'll let the runtime throw). - var twiceLength = 2 * (uint)currentBuffer.Length; - var newCapacity = twiceLength > MaxByteArrayLength ? - (value > MaxByteArrayLength ? value : MaxByteArrayLength) : - Math.Max(value, (int)twiceLength); - - // Get a new buffer, copy the current one to it, return the current one, and - // set the new buffer as current. - var newBuffer = ArrayPool.Shared.Rent(newCapacity); - Buffer.BlockCopy(currentBuffer, 0, newBuffer, 0, _length); - ArrayPool.Shared.Return(currentBuffer); - _buffer = newBuffer; - } - - public override void Write(byte[] buffer, int offset, int count) - { - Debug.Assert(buffer != null); - Debug.Assert(offset >= 0); - Debug.Assert(count >= 0); - - EnsureCapacity(_length + count); - Buffer.BlockCopy(buffer, offset, _buffer, _length, count); - _length += count; - } - -#if NETCOREAPP2_1 - public override void Write(ReadOnlySpan source) - { - EnsureCapacity(_length + source.Length); - source.CopyTo(new Span(_buffer, _length, source.Length)); - _length += source.Length; - } -#endif - - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - Write(buffer, offset, count); - return Task.CompletedTask; - } - -#if NETCOREAPP2_1 - public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) - { - Write(source.Span); - return default; - } -#endif - - public override void WriteByte(byte value) - { - var newLength = _length + 1; - EnsureCapacity(newLength); - _buffer[_length] = value; - _length = newLength; - } - - public override void Flush() { } - public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; - - public override long Length => _length; - public override bool CanWrite => true; - public override bool CanRead => false; - public override bool CanSeek => false; - - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - public override int Read(byte[] buffer, int offset, int count) { throw new NotSupportedException(); } - public override long Seek(long offset, SeekOrigin origin) { throw new NotSupportedException(); } - public override void SetLength(long value) { throw new NotSupportedException(); } - - private static Exception CreateOverCapacityException(int maxBufferSize) - { - return new InvalidOperationException($"Buffer size of {maxBufferSize} exceeded."); - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs index 4789637006..dd2ea32dc3 100644 --- a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs @@ -9,6 +9,7 @@ using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.Extensions.Options; using MsgPack; @@ -263,15 +264,20 @@ private static T ApplyHeaders(IDictionary source, T destinati public void WriteMessage(HubMessage message, IBufferWriter output) { - using (var stream = new LimitArrayPoolWriteStream()) + var writer = MemoryBufferWriter.Get(); + + try { // Write message to a buffer so we can get its length - WriteMessageCore(message, stream); - var buffer = stream.GetBuffer(); + WriteMessageCore(message, writer); // Write length then message to output - BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output); - output.Write(buffer); + BinaryMessageFormatter.WriteLengthPrefix(writer.Length, output); + writer.CopyTo(output); + } + finally + { + MemoryBufferWriter.Return(writer); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Microsoft.AspNetCore.SignalR.Protocols.MsgPack.csproj b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Microsoft.AspNetCore.SignalR.Protocols.MsgPack.csproj index 767a1169c9..1fb060bc20 100644 --- a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Microsoft.AspNetCore.SignalR.Protocols.MsgPack.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Microsoft.AspNetCore.SignalR.Protocols.MsgPack.csproj @@ -7,6 +7,10 @@ true + + + + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj index efdeb1d1fc..77128f3219 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj @@ -2,6 +2,7 @@ $(StandardTestTfms) + true diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index 72e2035e0f..456c8887c4 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -170,7 +170,7 @@ private byte[] FormatMessageToArray(byte[] message) { var output = new MemoryStream(); output.Write(message, 0, message.Length); - TextMessageFormatter.WriteRecordSeparator(output); + output.WriteByte(TextMessageFormatter.RecordSeparator); return output.ToArray(); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/TextMessageFormatterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/TextMessageFormatterTests.cs index ad4b1c0916..43bbfab16d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/TextMessageFormatterTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/TextMessageFormatterTests.cs @@ -18,7 +18,7 @@ public void WriteMessage() { var buffer = Encoding.UTF8.GetBytes("ABC"); ms.Write(buffer, 0, buffer.Length); - TextMessageFormatter.WriteRecordSeparator(ms); + ms.WriteByte(TextMessageFormatter.RecordSeparator); Assert.Equal("ABC\u001e", Encoding.UTF8.GetString(ms.ToArray())); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs index 44b1e9949a..b8e7018298 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs @@ -271,7 +271,7 @@ private static byte[] FormatMessageToArray(byte[] message) { var output = new MemoryStream(); output.Write(message, 0, message.Length); - TextMessageFormatter.WriteRecordSeparator(output); + output.WriteByte(TextMessageFormatter.RecordSeparator); return output.ToArray(); } From cf64af9bb857f99ae795fb8036775b9092483119 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Sun, 8 Apr 2018 03:35:06 -0700 Subject: [PATCH 2/5] Added MemoryBufferWriter tests - Fixed a bug with CopyTo and CopyToAsync --- src/Common/MemoryBufferWriter.cs | 24 +- ...crosoft.AspNetCore.Http.Connections.csproj | 1 - ...oft.AspNetCore.SignalR.Client.Tests.csproj | 1 - .../Protocol/MemoryBufferWriterTests.cs | 220 ++++++++++++++++++ 4 files changed, 235 insertions(+), 11 deletions(-) create mode 100644 test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs diff --git a/src/Common/MemoryBufferWriter.cs b/src/Common/MemoryBufferWriter.cs index 5a327429dc..024dc5ad37 100644 --- a/src/Common/MemoryBufferWriter.cs +++ b/src/Common/MemoryBufferWriter.cs @@ -5,6 +5,7 @@ using System.Buffers; using System.Collections.Generic; using System.IO; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -94,6 +95,8 @@ public void Reset() _position = 0; } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] public void Advance(int count) { _bytesWritten += count; @@ -108,7 +111,7 @@ public Memory GetMemory(int sizeHint = 0) _currentSegment = ArrayPool.Shared.Rent(_segmentSize); _position = 0; } - else if (_position == _segmentSize) + else if (_position == _currentSegment.Length) { if (_fullSegments == null) { @@ -132,9 +135,9 @@ public void CopyTo(IBufferWriter destination) if (_fullSegments != null) { // Copy full segments - for (var i = 0; i < _fullSegments.Count - 1; i++) + for (var i = 0; i < _fullSegments.Count; i++) { - destination.Write(_fullSegments[i].AsSpan(0, _segmentSize)); + destination.Write(_fullSegments[i]); } } @@ -157,9 +160,10 @@ private async Task CopyToSlowAsync(Stream destination) if (_fullSegments != null) { // Copy full segments - for (var i = 0; i < _fullSegments.Count - 1; i++) + for (var i = 0; i < _fullSegments.Count; i++) { - await destination.WriteAsync(_fullSegments[i], 0, _segmentSize); + var segment = _fullSegments[i]; + await destination.WriteAsync(segment, 0, segment.Length); } } @@ -184,7 +188,7 @@ public byte[] ToArray() { _fullSegments[i].CopyTo(result, totalWritten); - totalWritten += _segmentSize; + totalWritten += _fullSegments[i].Length; } } @@ -214,9 +218,9 @@ public override void SetLength(long value) throw new NotSupportedException(); } - public unsafe override void WriteByte(byte value) + public override void WriteByte(byte value) { - if (_currentSegment != null && _position < _segmentSize) + if (_currentSegment != null && _position < _currentSegment.Length) { _currentSegment[_position] = value; _bytesWritten += 1; @@ -224,7 +228,9 @@ public unsafe override void WriteByte(byte value) } else { - BuffersExtensions.Write(this, new ReadOnlySpan(&value, 1)); + var memory = GetMemory(); + memory.Span[0] = value; + Advance(1); } } diff --git a/src/Microsoft.AspNetCore.Http.Connections/Microsoft.AspNetCore.Http.Connections.csproj b/src/Microsoft.AspNetCore.Http.Connections/Microsoft.AspNetCore.Http.Connections.csproj index e7bc7ccdae..cd36dcd2e0 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Microsoft.AspNetCore.Http.Connections.csproj +++ b/src/Microsoft.AspNetCore.Http.Connections/Microsoft.AspNetCore.Http.Connections.csproj @@ -3,7 +3,6 @@ Components for providing real-time bi-directional communication across the Web. netstandard2.0;netcoreapp2.1 - true diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj index 77128f3219..efdeb1d1fc 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj @@ -2,7 +2,6 @@ $(StandardTestTfms) - true diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs new file mode 100644 index 0000000000..d46be27148 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs @@ -0,0 +1,220 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol +{ + public class MemoryBufferWriterTests + { + private static int MinimumSegmentSize; + + static MemoryBufferWriterTests() + { + var buffer = ArrayPool.Shared.Rent(1); + // Compute the minimum segment size of the array pool + MinimumSegmentSize = buffer.Length; + ArrayPool.Shared.Return(buffer); + } + + [Fact] + public void WritingNotingGivesEmptyData() + { + using (var bufferWriter = new MemoryBufferWriter()) + { + Assert.Equal(0, bufferWriter.Length); + var data = bufferWriter.ToArray(); + Assert.Empty(data); + } + } + + [Fact] + public void WriteByteWorksAsFirstCall() + { + using (var bufferWriter = new MemoryBufferWriter()) + { + bufferWriter.WriteByte(234); + var data = bufferWriter.ToArray(); + + Assert.Equal(1, bufferWriter.Length); + Assert.Single(data); + Assert.Equal(234, data[0]); + } + } + + [Fact] + public void WriteByteWorksIfFirstByteInNewSegment() + { + var inputSize = MinimumSegmentSize; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(16, bufferWriter.Length); + bufferWriter.WriteByte(16); + Assert.Equal(17, bufferWriter.Length); + + var data = bufferWriter.ToArray(); + Assert.Equal(input, data.Take(16)); + Assert.Equal(16, data[16]); + } + } + + [Fact] + public void WriteByteWorksIfSegmentHasSpace() + { + var input = new byte[] { 11, 12, 13 }; + + using (var bufferWriter = new MemoryBufferWriter()) + { + bufferWriter.Write(input, 0, input.Length); + bufferWriter.WriteByte(14); + + Assert.Equal(4, bufferWriter.Length); + + var data = bufferWriter.ToArray(); + Assert.Equal(4, data.Length); + Assert.Equal(11, data[0]); + Assert.Equal(12, data[1]); + Assert.Equal(13, data[2]); + Assert.Equal(14, data[3]); + } + } + + [Fact] + public void ToArrayWithExactlyFullSegmentsWorks() + { + var inputSize = MinimumSegmentSize * 2; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + + var data = bufferWriter.ToArray(); + Assert.Equal(input, data); + } + } + + [Fact] + public void ToArrayWithSomeFullSegmentsWorks() + { + var inputSize = (MinimumSegmentSize * 2) + 1; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + + var data = bufferWriter.ToArray(); + Assert.Equal(input, data); + } + } + + [Fact] + public async Task CopyToAsyncWithExactlyFullSegmentsWorks() + { + var inputSize = MinimumSegmentSize * 2; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + + var ms = new MemoryStream(); + await bufferWriter.CopyToAsync(ms); + var data = ms.ToArray(); + Assert.Equal(input, data); + } + } + + [Fact] + public async Task CopyToAsyncWithSomeFullSegmentsWorks() + { + // 2 segments + 1 extra byte + var inputSize = (MinimumSegmentSize * 2) + 1; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + + var ms = new MemoryStream(); + await bufferWriter.CopyToAsync(ms); + var data = ms.ToArray(); + Assert.Equal(input, data); + } + } + + [Fact] + public void CopyToWithExactlyFullSegmentsWorks() + { + var inputSize = MinimumSegmentSize * 2; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + + using (var destination = new MemoryBufferWriter()) + { + bufferWriter.CopyTo(destination); + var data = destination.ToArray(); + Assert.Equal(input, data); + } + } + } + + [Fact] + public void CopyToWithSomeFullSegmentsWorks() + { + var inputSize = (MinimumSegmentSize * 2) + 1; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + + using (var destination = new MemoryBufferWriter()) + { + bufferWriter.CopyTo(destination); + var data = destination.ToArray(); + Assert.Equal(input, data); + } + } + } + + [Fact] + public void ResetResetsTheMemoryBufferWriter() + { + var bufferWriter = new MemoryBufferWriter(); + bufferWriter.WriteByte(1); + Assert.Equal(1, bufferWriter.Length); + bufferWriter.Reset(); + Assert.Equal(0, bufferWriter.Length); + } + + [Fact] + public void DisposeResetsTheMemoryBufferWriter() + { + var bufferWriter = new MemoryBufferWriter(); + bufferWriter.WriteByte(1); + Assert.Equal(1, bufferWriter.Length); + bufferWriter.Dispose(); + Assert.Equal(0, bufferWriter.Length); + } + } +} \ No newline at end of file From 28e439ee93537615e65bf1bfe632c9d5c1d73146 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Sun, 8 Apr 2018 09:37:05 -0700 Subject: [PATCH 3/5] Made some performance improvements --- src/Common/MemoryBufferWriter.cs | 123 ++++++++++++++++++------------- 1 file changed, 72 insertions(+), 51 deletions(-) diff --git a/src/Common/MemoryBufferWriter.cs b/src/Common/MemoryBufferWriter.cs index 024dc5ad37..bd3a52deda 100644 --- a/src/Common/MemoryBufferWriter.cs +++ b/src/Common/MemoryBufferWriter.cs @@ -20,27 +20,27 @@ internal sealed class MemoryBufferWriter : Stream, IBufferWriter private bool _inUse; #endif - private readonly int _segmentSize; + private readonly int _minimumSegmentSize; private int _bytesWritten; private List _fullSegments; private byte[] _currentSegment; private int _position; - public MemoryBufferWriter(int segmentSize = 2048) + public MemoryBufferWriter(int minimumSegmentSize = 4096) { - _segmentSize = segmentSize; + _minimumSegmentSize = minimumSegmentSize; } public override long Length => _bytesWritten; - public override bool CanRead => false; - public override bool CanSeek => false; - public override bool CanWrite => true; - - public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } public static MemoryBufferWriter Get() { @@ -105,29 +105,16 @@ public void Advance(int count) public Memory GetMemory(int sizeHint = 0) { - // TODO: Use sizeHint - if (_currentSegment == null) - { - _currentSegment = ArrayPool.Shared.Rent(_segmentSize); - _position = 0; - } - else if (_position == _currentSegment.Length) - { - if (_fullSegments == null) - { - _fullSegments = new List(); - } - _fullSegments.Add(_currentSegment); - _currentSegment = ArrayPool.Shared.Rent(_segmentSize); - _position = 0; - } + EnsureCapacity(sizeHint); return _currentSegment.AsMemory(_position, _currentSegment.Length - _position); } public Span GetSpan(int sizeHint = 0) { - return GetMemory(sizeHint).Span; + EnsureCapacity(sizeHint); + + return _currentSegment.AsSpan(_position, _currentSegment.Length - _position); } public void CopyTo(IBufferWriter destination) @@ -155,6 +142,35 @@ public override Task CopyToAsync(Stream destination, int bufferSize, Cancellatio return CopyToSlowAsync(destination); } + private void EnsureCapacity(int sizeHint) + { + // TODO: Use sizeHint + if (_currentSegment != null && _position < _currentSegment.Length) + { + // We have capacity in the current segment + return; + } + + AddSegment(); + } + + private void AddSegment() + { + if (_currentSegment != null) + { + // We're adding a segment to the list + if (_fullSegments == null) + { + _fullSegments = new List(); + } + + _fullSegments.Add(_currentSegment); + } + + _currentSegment = ArrayPool.Shared.Rent(_minimumSegmentSize); + _position = 0; + } + private async Task CopyToSlowAsync(Stream destination) { if (_fullSegments != null) @@ -198,53 +214,58 @@ public byte[] ToArray() return result; } - public override void Flush() - { - - } - - public override int Read(byte[] buffer, int offset, int count) - { - throw new NotSupportedException(); - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } - - public override void SetLength(long value) - { - throw new NotSupportedException(); - } + public override void Flush() { } + public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); public override void WriteByte(byte value) { if (_currentSegment != null && _position < _currentSegment.Length) { _currentSegment[_position] = value; - _bytesWritten += 1; - _position += 1; } else { - var memory = GetMemory(); - memory.Span[0] = value; - Advance(1); + EnsureCapacity(1); + _currentSegment[_position] = value; } + + _bytesWritten++; + _position++; } public override void Write(byte[] buffer, int offset, int count) { - BuffersExtensions.Write(this, buffer.AsSpan(offset, count)); + if (_currentSegment != null && _position + count < _currentSegment.Length) + { + Buffer.BlockCopy(buffer, offset, _currentSegment, _position, count); + + _position += count; + _bytesWritten += count; + } + else + { + BuffersExtensions.Write(this, buffer.AsSpan(offset, count)); + } } #if NETCOREAPP2_1 public override void Write(ReadOnlySpan span) { - BuffersExtensions.Write(this, span); + if (_currentSegment != null && span.TryCopyTo(_currentSegment.AsSpan().Slice(_position))) + { + _position += span.Length; + _bytesWritten += span.Length; + } + else + { + BuffersExtensions.Write(this, span); + } } #endif + protected override void Dispose(bool disposing) { if (disposing) From 5c3f38730f258000caa3ef65210217615ffb8956 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Sun, 8 Apr 2018 09:53:56 -0700 Subject: [PATCH 4/5] Add more tests for MemoryBufferWriter --- src/Common/MemoryBufferWriter.cs | 2 -- .../Protocol/MemoryBufferWriterTests.cs | 21 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/Common/MemoryBufferWriter.cs b/src/Common/MemoryBufferWriter.cs index bd3a52deda..6ce8a55513 100644 --- a/src/Common/MemoryBufferWriter.cs +++ b/src/Common/MemoryBufferWriter.cs @@ -95,8 +95,6 @@ public void Reset() _position = 0; } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] public void Advance(int count) { _bytesWritten += count; diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs index d46be27148..55976b3095 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs @@ -197,6 +197,27 @@ public void CopyToWithSomeFullSegmentsWorks() } } +#if NETCOREAPP2_1 + [Fact] + public void WriteSpanWorksAtNonZeroOffset() + { + using (var bufferWriter = new MemoryBufferWriter()) + { + bufferWriter.WriteByte(1); + bufferWriter.Write(new byte[] { 2, 3, 4 }.AsSpan()); + + Assert.Equal(4, bufferWriter.Length); + + var data = bufferWriter.ToArray(); + Assert.Equal(4, data.Length); + Assert.Equal(1, data[0]); + Assert.Equal(2, data[1]); + Assert.Equal(3, data[2]); + Assert.Equal(4, data[3]); + } + } +#endif + [Fact] public void ResetResetsTheMemoryBufferWriter() { From 6176487d17e742f990a89b45aa7667a4f7b27358 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Sun, 8 Apr 2018 10:46:41 -0700 Subject: [PATCH 5/5] PR feedback --- src/Common/MemoryBufferWriter.cs | 40 ++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/src/Common/MemoryBufferWriter.cs b/src/Common/MemoryBufferWriter.cs index 6ce8a55513..d50b8966ac 100644 --- a/src/Common/MemoryBufferWriter.cs +++ b/src/Common/MemoryBufferWriter.cs @@ -49,9 +49,11 @@ public static MemoryBufferWriter Get() { writer = new MemoryBufferWriter(); } - - // Taken off the thread static - _cachedInstance = null; + else + { + // Taken off the thread static + _cachedInstance = null; + } #if DEBUG if (writer._inUse) { @@ -120,7 +122,8 @@ public void CopyTo(IBufferWriter destination) if (_fullSegments != null) { // Copy full segments - for (var i = 0; i < _fullSegments.Count; i++) + var count = _fullSegments.Count; + for (var i = 0; i < count; i++) { destination.Write(_fullSegments[i]); } @@ -173,8 +176,9 @@ private async Task CopyToSlowAsync(Stream destination) { if (_fullSegments != null) { - // Copy full segments - for (var i = 0; i < _fullSegments.Count; i++) + // Copy full segments + var count = _fullSegments.Count; + for (var i = 0; i < count; i++) { var segment = _fullSegments[i]; await destination.WriteAsync(segment, 0, segment.Length); @@ -198,11 +202,12 @@ public byte[] ToArray() if (_fullSegments != null) { // Copy full segments - for (var i = 0; i < _fullSegments.Count; i++) + var count = _fullSegments.Count; + for (var i = 0; i < count; i++) { - _fullSegments[i].CopyTo(result, totalWritten); - - totalWritten += _fullSegments[i].Length; + var segment = _fullSegments[i]; + segment.CopyTo(result, totalWritten); + totalWritten += segment.Length; } } @@ -220,27 +225,28 @@ public override void Flush() { } public override void WriteByte(byte value) { - if (_currentSegment != null && _position < _currentSegment.Length) + if (_currentSegment != null && (uint)_position < (uint)_currentSegment.Length) { _currentSegment[_position] = value; } else { - EnsureCapacity(1); - _currentSegment[_position] = value; + AddSegment(); + _currentSegment[0] = value; } - _bytesWritten++; _position++; + _bytesWritten++; } public override void Write(byte[] buffer, int offset, int count) { - if (_currentSegment != null && _position + count < _currentSegment.Length) + var position = _position; + if (_currentSegment != null && position < _currentSegment.Length - count) { - Buffer.BlockCopy(buffer, offset, _currentSegment, _position, count); + Buffer.BlockCopy(buffer, offset, _currentSegment, position, count); - _position += count; + _position = position + count; _bytesWritten += count; } else