diff --git a/src/CommunityToolkit.HighPerformance/Extensions/ReadOnlySequenceExtensions.cs b/src/CommunityToolkit.HighPerformance/Extensions/ReadOnlySequenceExtensions.cs new file mode 100644 index 000000000..e4ca8f766 --- /dev/null +++ b/src/CommunityToolkit.HighPerformance/Extensions/ReadOnlySequenceExtensions.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using CommunityToolkit.HighPerformance.Streams; +using System; +using System.Buffers; +using System.IO; +using System.Runtime.CompilerServices; + +namespace CommunityToolkit.HighPerformance; + +/// +/// Helpers for working with the type. +/// +public static class ReadOnlySequenceExtensions +{ + /// + /// Returns a wrapping the contents of the given of instance. + /// + /// The input of instance. + /// A wrapping the data within . + /// + /// Since this method only receives a instance, which does not track + /// the lifetime of its underlying buffer, it is responsibility of the caller to manage that. + /// In particular, the caller must ensure that the target buffer is not disposed as long + /// as the returned is in use, to avoid unexpected issues. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Stream AsStream(this ReadOnlySequence sequence) + { + return ReadOnlySequenceStream.Create(sequence); + } +} diff --git a/src/CommunityToolkit.HighPerformance/Streams/MemoryStream.Validate.cs b/src/CommunityToolkit.HighPerformance/Streams/MemoryStream.Validate.cs index e7dae1359..126d90a44 100644 --- a/src/CommunityToolkit.HighPerformance/Streams/MemoryStream.Validate.cs +++ b/src/CommunityToolkit.HighPerformance/Streams/MemoryStream.Validate.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Runtime.CompilerServices; @@ -24,6 +25,20 @@ public static void ValidatePosition(long position, int length) } } + /// + /// Validates the argument (it needs to be in the [0, length]) range. + /// + /// The new value being set. + /// The maximum length of the target . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void ValidatePosition(long position, long length) + { + if ((ulong)position > (ulong)length) + { + ThrowArgumentOutOfRangeExceptionForPosition(); + } + } + /// /// Validates the or arguments. /// @@ -31,7 +46,7 @@ public static void ValidatePosition(long position, int length) /// The offset within . /// The number of elements to process within . [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static void ValidateBuffer(byte[]? buffer, int offset, int count) + public static void ValidateBuffer([NotNull] byte[]? buffer, int offset, int count) { if (buffer is null) { diff --git a/src/CommunityToolkit.HighPerformance/Streams/ReadOnlySequenceStream.cs b/src/CommunityToolkit.HighPerformance/Streams/ReadOnlySequenceStream.cs new file mode 100644 index 000000000..2e76f6a3f --- /dev/null +++ b/src/CommunityToolkit.HighPerformance/Streams/ReadOnlySequenceStream.cs @@ -0,0 +1,304 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.IO; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace CommunityToolkit.HighPerformance.Streams; + +/// +/// A implementation wrapping a of instance. +/// +internal sealed partial class ReadOnlySequenceStream : Stream +{ + /// + /// The instance currently in use. + /// + private readonly ReadOnlySequence source; + + /// + /// The current position within . + /// + private long position; + + /// + /// Indicates whether or not the current instance has been disposed + /// + private bool disposed; + + /// + /// Initializes a new instance of the class with the specified source. + /// + /// The source. + public ReadOnlySequenceStream(ReadOnlySequence source) + { + this.source = source; + } + + /// + public sealed override bool CanRead + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => !this.disposed; + } + + /// + public sealed override bool CanSeek + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => !this.disposed; + } + + /// + public sealed override bool CanWrite + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => false; + } + + /// + public sealed override long Length + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + MemoryStream.ValidateDisposed(this.disposed); + + return this.source.Length; + } + } + + /// + public sealed override long Position + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + MemoryStream.ValidateDisposed(this.disposed); + + return this.position; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + set + { + MemoryStream.ValidateDisposed(this.disposed); + MemoryStream.ValidatePosition(value, this.source.Length); + + this.position = value; + } + } + + /// + /// Creates a new from the input of instance. + /// + /// The input instance. + /// A wrapping the underlying data for . + public static Stream Create(ReadOnlySequence sequence) + { + return new ReadOnlySequenceStream(sequence); + } + + /// + public sealed override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + try + { + MemoryStream.ValidateDisposed(this.disposed); + + if (this.position >= this.source.Length) + { + return Task.CompletedTask; + } + + if (this.source.IsSingleSegment) + { + ReadOnlyMemory buffer = this.source.First.Slice(unchecked((int)this.position)); + + this.position = this.source.Length; + + return destination.WriteAsync(buffer, cancellationToken).AsTask(); + } + + async Task CoreCopyToAsync(Stream destination, CancellationToken cancellationToken) + { + ReadOnlySequence sequence = this.source.Slice(this.position); + + this.position = this.source.Length; + + foreach (ReadOnlyMemory segment in sequence) + { + await destination.WriteAsync(segment, cancellationToken).ConfigureAwait(false); + } + } + + return CoreCopyToAsync(destination, cancellationToken); + } + catch (OperationCanceledException e) + { + return Task.FromCanceled(e.CancellationToken); + } + catch (Exception e) + { + return Task.FromException(e); + } + } + + /// + public sealed override void Flush() + { + } + + /// + public sealed override Task FlushAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + return Task.CompletedTask; + } + + /// + public sealed override Task ReadAsync(byte[]? buffer, int offset, int count, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + + try + { + int result = Read(buffer, offset, count); + + return Task.FromResult(result); + } + catch (OperationCanceledException e) + { + return Task.FromCanceled(e.CancellationToken); + } + catch (Exception e) + { + return Task.FromException(e); + } + } + + public sealed override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + throw MemoryStream.GetNotSupportedException(); + } + + /// + public sealed override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + throw MemoryStream.GetNotSupportedException(); + } + + /// + public sealed override long Seek(long offset, SeekOrigin origin) + { + MemoryStream.ValidateDisposed(this.disposed); + + long index = origin switch + { + SeekOrigin.Begin => offset, + SeekOrigin.Current => this.position + offset, + SeekOrigin.End => this.source.Length + offset, + _ => MemoryStream.ThrowArgumentExceptionForSeekOrigin() + }; + + MemoryStream.ValidatePosition(index, this.source.Length); + + this.position = index; + + return index; + } + + /// + public sealed override void SetLength(long value) + { + throw MemoryStream.GetNotSupportedException(); + } + + /// + public sealed override int Read(byte[]? buffer, int offset, int count) + { + MemoryStream.ValidateDisposed(this.disposed); + MemoryStream.ValidateBuffer(buffer, offset, count); + + if (this.position >= this.source.Length) + { + return 0; + } + + ReadOnlySequence sequence = this.source.Slice(this.position); + Span destination = buffer.AsSpan(offset, count); + int bytesCopied = 0; + + foreach (ReadOnlyMemory segment in sequence) + { + int bytesToCopy = Math.Min(segment.Length, destination.Length); + + segment.Span.Slice(0, bytesToCopy).CopyTo(destination); + + destination = destination.Slice(bytesToCopy); + + bytesCopied += bytesToCopy; + + this.position += bytesToCopy; + + if (destination.Length == 0) + { + break; + } + } + + return bytesCopied; + } + + /// + public sealed override int ReadByte() + { + MemoryStream.ValidateDisposed(this.disposed); + + if (this.position == this.source.Length) + { + return -1; + } + + ReadOnlySequence sequence = this.source.Slice(this.position); + + this.position++; + + return sequence.First.Span[0]; + } + + /// + public sealed override void Write(byte[]? buffer, int offset, int count) + { + throw MemoryStream.GetNotSupportedException(); + } + + /// + public sealed override void WriteByte(byte value) + { + throw MemoryStream.GetNotSupportedException(); + } + + /// + protected override void Dispose(bool disposing) + { + this.disposed = true; + } +} diff --git a/tests/CommunityToolkit.HighPerformance.UnitTests/Streams/Test_ReadOnlySequenceStream.cs b/tests/CommunityToolkit.HighPerformance.UnitTests/Streams/Test_ReadOnlySequenceStream.cs new file mode 100644 index 000000000..0cb053f2a --- /dev/null +++ b/tests/CommunityToolkit.HighPerformance.UnitTests/Streams/Test_ReadOnlySequenceStream.cs @@ -0,0 +1,407 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace CommunityToolkit.HighPerformance.UnitTests.Streams; + +[TestClass] +public partial class Test_ReadOnlySequenceStream +{ + [TestMethod] + public void Test_ReadOnlySequenceStream_Lifecycle() + { + ReadOnlySequence sequence = CreateReadOnlySequence(new byte[100]); + + Stream stream = sequence.AsStream(); + + Assert.IsTrue(stream.CanRead); + Assert.IsTrue(stream.CanSeek); + Assert.IsFalse(stream.CanWrite); + Assert.AreEqual(stream.Length, sequence.Length); + Assert.AreEqual(stream.Position, 0); + + stream.Dispose(); + + Assert.IsFalse(stream.CanRead); + Assert.IsFalse(stream.CanSeek); + Assert.IsFalse(stream.CanWrite); + + _ = Assert.ThrowsException(() => stream.Length); + _ = Assert.ThrowsException(() => stream.Position); + } + + [TestMethod] + public void Test_ReadOnlySequenceStream_Seek() + { + Stream stream = CreateReadOnlySequence(new byte[50], new byte[50]).AsStream(); + + Assert.AreEqual(stream.Position, 0); + + stream.Position = 42; + + Assert.AreEqual(stream.Position, 42); + + _ = Assert.ThrowsException(() => stream.Position = -1); + _ = Assert.ThrowsException(() => stream.Position = 120); + + _ = stream.Seek(0, SeekOrigin.Begin); + + _ = Assert.ThrowsException(() => stream.Seek(-1, SeekOrigin.Begin)); + _ = Assert.ThrowsException(() => stream.Seek(120, SeekOrigin.Begin)); + + Assert.AreEqual(stream.Position, 0); + + _ = stream.Seek(-1, SeekOrigin.End); + + _ = Assert.ThrowsException(() => stream.Seek(20, SeekOrigin.End)); + _ = Assert.ThrowsException(() => stream.Seek(-120, SeekOrigin.End)); + + Assert.AreEqual(stream.Position, stream.Length - 1); + + _ = stream.Seek(42, SeekOrigin.Begin); + _ = stream.Seek(20, SeekOrigin.Current); + _ = stream.Seek(-30, SeekOrigin.Current); + + _ = Assert.ThrowsException(() => stream.Seek(-64, SeekOrigin.Current)); + _ = Assert.ThrowsException(() => stream.Seek(80, SeekOrigin.Current)); + + Assert.AreEqual(stream.Position, 32); + } + + [TestMethod] + public void Test_ReadOnlySequenceStream_Read_Array() + { + Memory data = CreateRandomData(64); + + Stream stream = CreateReadOnlySequence(data).AsStream(); + + stream.Position = 0; + + byte[] result = new byte[data.Length]; + + int bytesRead = stream.Read(result, 0, result.Length); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(stream.Position, data.Length); + Assert.IsTrue(data.Span.SequenceEqual(result)); + + stream.Dispose(); + + _ = Assert.ThrowsException(() => stream.Read(result, 0, result.Length)); + } + + [TestMethod] + public void Test_ReadOnlySequenceStream_NotFromStart_Read_Array() + { + const int offset = 8; + + Memory data = CreateRandomData(64); + + Stream stream = CreateReadOnlySequence(data).AsStream(); + + stream.Position = offset; + + byte[] result = new byte[data.Length - offset]; + + int bytesRead = stream.Read(result, 0, result.Length); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(stream.Position, result.Length + offset); + Assert.IsTrue(data.Span.Slice(offset).SequenceEqual(result)); + + stream.Dispose(); + + _ = Assert.ThrowsException(() => stream.Read(result, 0, result.Length)); + } + + [TestMethod] + public void Test_ReadOnlySequenceStream_ReadByte() + { + Memory data = new byte[] { 1, 128, 255, 32 }; + + Stream stream = CreateReadOnlySequence(data.Slice(0,2), data.Slice(2, 2)).AsStream(); + + Span result = stackalloc byte[4]; + + foreach (ref byte value in result) + { + value = checked((byte)stream.ReadByte()); + } + + Assert.AreEqual(stream.Position, data.Length); + Assert.IsTrue(data.Span.SequenceEqual(result)); + + int exitCode = stream.ReadByte(); + + Assert.AreEqual(exitCode, -1); + } + + [TestMethod] + public void Test_ReadOnlySequenceStream_Read_Span() + { + Memory data = CreateRandomData(64); + + Stream stream = CreateReadOnlySequence(data).AsStream(); + + stream.Position = 0; + + Span result = new byte[data.Length]; + + int bytesRead = stream.Read(result); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(stream.Position, data.Length); + Assert.IsTrue(data.Span.SequenceEqual(result)); + } + + [TestMethod] + public async Task Test_ReadOnlySequenceStream_ReadAsync_Memory() + { + Memory data = CreateRandomData(64); + + Stream stream = CreateReadOnlySequence(data).AsStream(); + + Memory result = new byte[data.Length]; + + int bytesRead = await stream.ReadAsync(result); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(stream.Position, data.Length); + Assert.IsTrue(data.Span.SequenceEqual(result.Span)); + } + + [TestMethod] + public void Test_ReadOnlySequenceStream_SigleSegment_CopyTo() + { + Memory data = CreateRandomData(64); + + Stream source = CreateReadOnlySequence(data).AsStream(); + + Stream destination = new byte[100].AsMemory().AsStream(); + + source.CopyTo(destination); + + Assert.AreEqual(source.Position, destination.Position); + + destination.Position = 0; + + Memory result = new byte[data.Length]; + + int bytesRead = destination.Read(result.Span); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(destination.Position, data.Length); + Assert.IsTrue(data.Span.SequenceEqual(result.Span)); + } + + [TestMethod] + public void Test_ReadOnlySequenceStream_CopyTo() + { + Memory data = CreateRandomData(64); + + Stream source = CreateReadOnlySequence(data.Slice(0, 32), data.Slice(32)).AsStream(); + + Stream destination = new byte[100].AsMemory().AsStream(); + + source.CopyTo(destination); + + Assert.AreEqual(source.Position, destination.Position); + + destination.Position = 0; + + Memory result = new byte[data.Length]; + + int bytesRead = destination.Read(result.Span); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(destination.Position, data.Length); + Assert.IsTrue(data.Span.SequenceEqual(result.Span)); + } + + [TestMethod] + public async Task Test_ReadOnlySequenceStream_SigleSegment_CopyToAsync() + { + Memory data = CreateRandomData(64); + + Stream source = CreateReadOnlySequence(data).AsStream(); + + Stream destination = new byte[100].AsMemory().AsStream(); + + await source.CopyToAsync(destination); + + Assert.AreEqual(source.Position, destination.Position); + + destination.Position = 0; + + Memory result = new byte[data.Length]; + + int bytesRead = await destination.ReadAsync(result); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(destination.Position, data.Length); + Assert.IsTrue(data.Span.SequenceEqual(result.Span)); + } + + [TestMethod] + public async Task Test_ReadOnlySequenceStream_SigleSegment_NotFromStart_CopyToAsync() + { + const int offset = 8; + + Memory data = CreateRandomData(64); + + Stream source = CreateReadOnlySequence(data).AsStream(); + + source.Position = offset; + + Stream destination = new byte[100].AsMemory().AsStream(); + + await source.CopyToAsync(destination); + + Assert.AreEqual(source.Position, destination.Position + offset); + + destination.Position = 0; + + Memory result = new byte[data.Length - offset]; + + int bytesRead = await destination.ReadAsync(result); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(destination.Position, data.Length - offset); + Assert.IsTrue(data.Span.Slice(offset).SequenceEqual(result.Span)); + } + + [TestMethod] + public async Task Test_ReadOnlySequenceStream_MultipleSegments_CopyToAsync() + { + Memory data = CreateRandomData(64); + + Stream source = CreateReadOnlySequence(data.Slice(0, 16), data.Slice(16, 16), data.Slice(32, 16), data.Slice(48, 16)).AsStream(); + + Stream destination = new byte[100].AsMemory().AsStream(); + + await source.CopyToAsync(destination); + + Assert.AreEqual(source.Position, destination.Position); + + destination.Position = 0; + + Memory result = new byte[data.Length]; + + int bytesRead = await destination.ReadAsync(result); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(destination.Position, data.Length); + Assert.IsTrue(data.Span.SequenceEqual(result.Span)); + } + + [TestMethod] + public async Task Test_ReadOnlySequenceStream_MultipleSegments_NotFromStart_CopyToAsync() + { + const int offset = 8; + + Memory data = CreateRandomData(64); + + Stream source = CreateReadOnlySequence(data.Slice(0, 16), data.Slice(16, 16), data.Slice(32, 16), data.Slice(48, 16)).AsStream(); + + source.Position = offset; + + Stream destination = new byte[100].AsMemory().AsStream(); + + await source.CopyToAsync(destination); + + Assert.AreEqual(source.Position, destination.Position + offset); + + destination.Position = 0; + + Memory result = new byte[data.Length - offset]; + + int bytesRead = await destination.ReadAsync(result); + + Assert.AreEqual(bytesRead, result.Length); + Assert.AreEqual(destination.Position, data.Length - offset); + Assert.IsTrue(data.Span.Slice(offset).SequenceEqual(result.Span)); + } + + /// + /// Creates a random array filled with random data. + /// + /// The number of array items to create. + /// The returned random array. + private static byte[] CreateRandomData(int count) + { + Random? random = new(DateTime.Now.Ticks.GetHashCode()); + + byte[] data = new byte[count]; + + foreach (ref byte n in MemoryMarshal.AsBytes(data.AsSpan())) + { + n = (byte)random.Next(0, byte.MaxValue); + } + + return data; + } + + /// + /// Creates a value from the input segments. + /// + /// The input segments. + /// The resulting value. + private static ReadOnlySequence CreateReadOnlySequence(params ReadOnlyMemory[] segments) + { + if (segments is not { Length: > 0 }) + { + return ReadOnlySequence.Empty; + } + + if (segments.Length == 1) + { + return new(segments[0]); + } + + ReadOnlySequenceSegmentOfByte first = new(segments[0]); + ReadOnlySequenceSegmentOfByte last = first; + long length = first.Memory.Length; + + for (int i = 1; i < segments.Length; i++) + { + ReadOnlyMemory segment = segments[i]; + + length += segment.Length; + + last = last.Append(segment); + } + + return new(first, 0, last, (int)(length - last.RunningIndex)); + } + + /// + /// A custom that supports appending new segments. + /// + private sealed class ReadOnlySequenceSegmentOfByte : ReadOnlySequenceSegment + { + public ReadOnlySequenceSegmentOfByte(ReadOnlyMemory memory) + { + Memory = memory; + } + + public ReadOnlySequenceSegmentOfByte Append(ReadOnlyMemory memory) + { + ReadOnlySequenceSegmentOfByte nextSegment = new(memory) + { + RunningIndex = RunningIndex + Memory.Length + }; + + Next = nextSegment; + + return nextSegment; + } + } +}