diff --git a/src/libraries/System.Security.Cryptography.Primitives/ref/System.Security.Cryptography.Primitives.cs b/src/libraries/System.Security.Cryptography.Primitives/ref/System.Security.Cryptography.Primitives.cs index dbe106511bcdbf..5e9515946979a8 100644 --- a/src/libraries/System.Security.Cryptography.Primitives/ref/System.Security.Cryptography.Primitives.cs +++ b/src/libraries/System.Security.Cryptography.Primitives/ref/System.Security.Cryptography.Primitives.cs @@ -75,6 +75,8 @@ public CryptoStream(System.IO.Stream stream, System.Security.Cryptography.ICrypt public override System.IAsyncResult BeginRead(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } public override System.IAsyncResult BeginWrite(byte[] buffer, int offset, int count, System.AsyncCallback? callback, object? state) { throw null; } public void Clear() { } + public override void CopyTo(System.IO.Stream destination, int bufferSize) { throw null; } + public override System.Threading.Tasks.Task CopyToAsync(System.IO.Stream destination, int bufferSize, System.Threading.CancellationToken cancellationToken) { throw null; } protected override void Dispose(bool disposing) { } public override System.Threading.Tasks.ValueTask DisposeAsync() { throw null; } public override int EndRead(System.IAsyncResult asyncResult) { throw null; } @@ -85,11 +87,13 @@ public void FlushFinalBlock() { } public System.Threading.Tasks.ValueTask FlushFinalBlockAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override int Read(byte[] buffer, int offset, int count) { throw null; } public override System.Threading.Tasks.Task ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } + public override System.Threading.Tasks.ValueTask ReadAsync(System.Memory buffer, System.Threading.CancellationToken cancellationToken = default) { throw null; } public override int ReadByte() { throw null; } public override long Seek(long offset, System.IO.SeekOrigin origin) { throw null; } public override void SetLength(long value) { } public override void Write(byte[] buffer, int offset, int count) { } public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) { throw null; } + public override System.Threading.Tasks.ValueTask WriteAsync(System.ReadOnlyMemory buffer, System.Threading.CancellationToken cancellationToken = default) { throw null; } public override void WriteByte(byte value) { } } public enum CryptoStreamMode diff --git a/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs b/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs index ff5003d7bc78a4..7b2796584b3bc3 100644 --- a/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs +++ b/src/libraries/System.Security.Cryptography.Primitives/src/System/Security/Cryptography/CryptoStream.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; @@ -202,16 +203,19 @@ public override void SetLength(long value) public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { CheckReadArguments(buffer, offset, count); - return ReadAsyncInternal(buffer, offset, count, cancellationToken); + return ReadAsyncInternal(buffer.AsMemory(offset, count), cancellationToken).AsTask(); } - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => - TaskToApm.Begin(ReadAsync(buffer, offset, count, CancellationToken.None), callback, state); + /// + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (!CanRead) + return ValueTask.FromException(new NotSupportedException(SR.NotSupported_UnreadableStream)); - public override int EndRead(IAsyncResult asyncResult) => - TaskToApm.End(asyncResult); + return ReadAsyncInternal(buffer, cancellationToken); + } - private async Task ReadAsyncInternal(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + private async ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken = default) { // To avoid a race with a stream's position pointer & generating race // conditions with internal buffer indexes in our own streams that @@ -222,7 +226,7 @@ private async Task ReadAsyncInternal(byte[] buffer, int offset, int count, await AsyncActiveSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); try { - return await ReadAsyncCore(buffer, offset, count, cancellationToken, useAsync: true).ConfigureAwait(false); + return await ReadAsyncCore(buffer, cancellationToken, useAsync: true).ConfigureAwait(false); } finally { @@ -230,6 +234,12 @@ private async Task ReadAsyncInternal(byte[] buffer, int offset, int count, } } + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => + TaskToApm.Begin(ReadAsync(buffer, offset, count, CancellationToken.None), callback, state); + + public override int EndRead(IAsyncResult asyncResult) => + TaskToApm.End(asyncResult); + public override int ReadByte() { // If we have enough bytes in the buffer such that reading 1 will still leave bytes @@ -268,7 +278,10 @@ public override void WriteByte(byte value) public override int Read(byte[] buffer, int offset, int count) { CheckReadArguments(buffer, offset, count); - return ReadAsyncCore(buffer, offset, count, default(CancellationToken), useAsync: false).GetAwaiter().GetResult(); + ValueTask completedValueTask = ReadAsyncCore(buffer.AsMemory(offset, count), default(CancellationToken), useAsync: false); + Debug.Assert(completedValueTask.IsCompleted); + + return completedValueTask.GetAwaiter().GetResult(); } private void CheckReadArguments(byte[] buffer, int offset, int count) @@ -278,22 +291,22 @@ private void CheckReadArguments(byte[] buffer, int offset, int count) throw new NotSupportedException(SR.NotSupported_UnreadableStream); } - private async Task ReadAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken, bool useAsync) + private async ValueTask ReadAsyncCore(Memory buffer, CancellationToken cancellationToken, bool useAsync) { // read <= count bytes from the input stream, transforming as we go. // Basic idea: first we deliver any bytes we already have in the // _OutputBuffer, because we know they're good. Then, if asked to deliver // more bytes, we read & transform a block at a time until either there are // no bytes ready or we've delivered enough. - int bytesToDeliver = count; - int currentOutputIndex = offset; + int bytesToDeliver = buffer.Length; + int currentOutputIndex = 0; Debug.Assert(_outputBuffer != null); if (_outputBufferIndex != 0) { // we have some already-transformed bytes in the output buffer - if (_outputBufferIndex <= count) + if (_outputBufferIndex <= buffer.Length) { - Buffer.BlockCopy(_outputBuffer, 0, buffer, offset, _outputBufferIndex); + _outputBuffer.AsSpan(0, _outputBufferIndex).CopyTo(buffer.Span); bytesToDeliver -= _outputBufferIndex; currentOutputIndex += _outputBufferIndex; int toClear = _outputBuffer.Length - _outputBufferIndex; @@ -302,14 +315,14 @@ private async Task ReadAsyncCore(byte[] buffer, int offset, int count, Canc } else { - Buffer.BlockCopy(_outputBuffer, 0, buffer, offset, count); - Buffer.BlockCopy(_outputBuffer, count, _outputBuffer, 0, _outputBufferIndex - count); - _outputBufferIndex -= count; + _outputBuffer.AsSpan(0, buffer.Length).CopyTo(buffer.Span); + Buffer.BlockCopy(_outputBuffer, buffer.Length, _outputBuffer, 0, _outputBufferIndex - buffer.Length); + _outputBufferIndex -= buffer.Length; int toClear = _outputBuffer.Length - _outputBufferIndex; CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear)); - return (count); + return buffer.Length; } } // _finalBlockTransformed == true implies we're at the end of the input stream @@ -319,7 +332,7 @@ private async Task ReadAsyncCore(byte[] buffer, int offset, int count, Canc // eventually, we'll just always return 0 here because there's no more to read if (_finalBlockTransformed) { - return (count - bytesToDeliver); + return buffer.Length - bytesToDeliver; } // ok, now loop until we've delivered enough or there's nothing available int amountRead = 0; @@ -373,7 +386,7 @@ await _stream.ReadAsync(new Memory(tempInputBuffer, _inputBufferIndex, num // Use ArrayPool.Shared instead of CryptoPool because the array is passed out. tempOutputBuffer = ArrayPool.Shared.Rent(numWholeReadBlocks * _outputBlockSize); numOutputBytes = _transform.TransformBlock(tempInputBuffer, 0, numWholeReadBlocksInBytes, tempOutputBuffer, 0); - Buffer.BlockCopy(tempOutputBuffer, 0, buffer, currentOutputIndex, numOutputBytes); + tempOutputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span.Slice(currentOutputIndex)); // Clear what was written while we know how much that was CryptographicOperations.ZeroMemory(new Span(tempOutputBuffer, 0, numOutputBytes)); @@ -429,22 +442,22 @@ await _stream.ReadAsync(new Memory(_inputBuffer, _inputBufferIndex, _input if (bytesToDeliver >= numOutputBytes) { - Buffer.BlockCopy(_outputBuffer, 0, buffer, currentOutputIndex, numOutputBytes); + _outputBuffer.AsSpan(0, numOutputBytes).CopyTo(buffer.Span.Slice(currentOutputIndex)); CryptographicOperations.ZeroMemory(new Span(_outputBuffer, 0, numOutputBytes)); currentOutputIndex += numOutputBytes; bytesToDeliver -= numOutputBytes; } else { - Buffer.BlockCopy(_outputBuffer, 0, buffer, currentOutputIndex, bytesToDeliver); + _outputBuffer.AsSpan(0, bytesToDeliver).CopyTo(buffer.Span.Slice(currentOutputIndex)); _outputBufferIndex = numOutputBytes - bytesToDeliver; Buffer.BlockCopy(_outputBuffer, bytesToDeliver, _outputBuffer, 0, _outputBufferIndex); int toClear = _outputBuffer.Length - _outputBufferIndex; CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear)); - return count; + return buffer.Length; } } - return count; + return buffer.Length; ProcessFinalBlock: // if so, then call TransformFinalBlock to get whatever is left @@ -458,36 +471,39 @@ await _stream.ReadAsync(new Memory(_inputBuffer, _inputBufferIndex, _input // now, return either everything we just got or just what's asked for, whichever is smaller if (bytesToDeliver < _outputBufferIndex) { - Buffer.BlockCopy(_outputBuffer, 0, buffer, currentOutputIndex, bytesToDeliver); + _outputBuffer.AsSpan(0, bytesToDeliver).CopyTo(buffer.Span.Slice(currentOutputIndex)); _outputBufferIndex -= bytesToDeliver; Buffer.BlockCopy(_outputBuffer, bytesToDeliver, _outputBuffer, 0, _outputBufferIndex); int toClear = _outputBuffer.Length - _outputBufferIndex; CryptographicOperations.ZeroMemory(new Span(_outputBuffer, _outputBufferIndex, toClear)); - return (count); + return buffer.Length; } else { - Buffer.BlockCopy(_outputBuffer, 0, buffer, currentOutputIndex, _outputBufferIndex); + _outputBuffer.AsSpan(0, _outputBufferIndex).CopyTo(buffer.Span.Slice(currentOutputIndex)); bytesToDeliver -= _outputBufferIndex; _outputBufferIndex = 0; CryptographicOperations.ZeroMemory(_outputBuffer); - return (count - bytesToDeliver); + return buffer.Length - bytesToDeliver; } } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { CheckWriteArguments(buffer, offset, count); - return WriteAsyncInternal(buffer, offset, count, cancellationToken); + return WriteAsyncInternal(buffer.AsMemory(offset, count), cancellationToken).AsTask(); } - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => - TaskToApm.Begin(WriteAsync(buffer, offset, count, CancellationToken.None), callback, state); + /// + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (!CanWrite) + return ValueTask.FromException(new NotSupportedException(SR.NotSupported_UnwritableStream)); - public override void EndWrite(IAsyncResult asyncResult) => - TaskToApm.End(asyncResult); + return WriteAsyncInternal(buffer, cancellationToken); + } - private async Task WriteAsyncInternal(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + private async ValueTask WriteAsyncInternal(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { // To avoid a race with a stream's position pointer & generating race // conditions with internal buffer indexes in our own streams that @@ -498,7 +514,7 @@ private async Task WriteAsyncInternal(byte[] buffer, int offset, int count, Canc await AsyncActiveSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); try { - await WriteAsyncCore(buffer, offset, count, cancellationToken, useAsync: true).ConfigureAwait(false); + await WriteAsyncCore(buffer, cancellationToken, useAsync: true).ConfigureAwait(false); } finally { @@ -506,10 +522,16 @@ private async Task WriteAsyncInternal(byte[] buffer, int offset, int count, Canc } } + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) => + TaskToApm.Begin(WriteAsync(buffer, offset, count, CancellationToken.None), callback, state); + + public override void EndWrite(IAsyncResult asyncResult) => + TaskToApm.End(asyncResult); + public override void Write(byte[] buffer, int offset, int count) { CheckWriteArguments(buffer, offset, count); - WriteAsyncCore(buffer, offset, count, default(CancellationToken), useAsync: false).AsTask().GetAwaiter().GetResult(); + WriteAsyncCore(buffer.AsMemory(offset, count), default, useAsync: false).AsTask().GetAwaiter().GetResult(); } private void CheckWriteArguments(byte[] buffer, int offset, int count) @@ -519,22 +541,22 @@ private void CheckWriteArguments(byte[] buffer, int offset, int count) throw new NotSupportedException(SR.NotSupported_UnwritableStream); } - private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken, bool useAsync) + private async ValueTask WriteAsyncCore(ReadOnlyMemory buffer, CancellationToken cancellationToken, bool useAsync) { // write <= count bytes to the output stream, transforming as we go. // Basic idea: using bytes in the _InputBuffer first, make whole blocks, // transform them, and write them out. Cache any remaining bytes in the _InputBuffer. - int bytesToWrite = count; - int currentInputIndex = offset; + int bytesToWrite = buffer.Length; + int currentInputIndex = 0; // if we have some bytes in the _InputBuffer, we have to deal with those first, // so let's try to make an entire block out of it if (_inputBufferIndex > 0) { Debug.Assert(_inputBuffer != null); - if (count >= _inputBlockSize - _inputBufferIndex) + if (buffer.Length >= _inputBlockSize - _inputBufferIndex) { // we have enough to transform at least a block, so fill the input block - Buffer.BlockCopy(buffer, offset, _inputBuffer, _inputBufferIndex, _inputBlockSize - _inputBufferIndex); + buffer.Slice(0, _inputBlockSize - _inputBufferIndex).CopyTo(_inputBuffer.AsMemory(_inputBufferIndex)); currentInputIndex += (_inputBlockSize - _inputBufferIndex); bytesToWrite -= (_inputBlockSize - _inputBufferIndex); _inputBufferIndex = _inputBlockSize; @@ -544,8 +566,8 @@ private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, Can { // not enough to transform a block, so just copy the bytes into the _InputBuffer // and return - Buffer.BlockCopy(buffer, offset, _inputBuffer, _inputBufferIndex, count); - _inputBufferIndex += count; + buffer.CopyTo(_inputBuffer.AsMemory(_inputBufferIndex)); + _inputBufferIndex += buffer.Length; return; } } @@ -585,8 +607,7 @@ private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, Can try { - numOutputBytes = - _transform.TransformBlock(buffer, currentInputIndex, numWholeBlocksInBytes, tempOutputBuffer, 0); + numOutputBytes = TransformBlock(_transform, buffer.Slice(currentInputIndex, numWholeBlocksInBytes), tempOutputBuffer, 0); if (useAsync) { @@ -614,7 +635,7 @@ private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, Can { Debug.Assert(_outputBuffer != null); // do it the slow way - numOutputBytes = _transform.TransformBlock(buffer, currentInputIndex, _inputBlockSize, _outputBuffer, 0); + numOutputBytes = TransformBlock(_transform, buffer.Slice(currentInputIndex, _inputBlockSize), _outputBuffer, 0); if (useAsync) await _stream.WriteAsync(new ReadOnlyMemory(_outputBuffer, 0, numOutputBytes), cancellationToken).ConfigureAwait(false); @@ -630,12 +651,126 @@ private async ValueTask WriteAsyncCore(byte[] buffer, int offset, int count, Can Debug.Assert(_inputBuffer != null); // In this case, we don't have an entire block's worth left, so store it up in the // input buffer, which by now must be empty. - Buffer.BlockCopy(buffer, currentInputIndex, _inputBuffer, 0, bytesToWrite); + buffer.Slice(currentInputIndex, bytesToWrite).CopyTo(_inputBuffer); _inputBufferIndex += bytesToWrite; return; } } return; + + unsafe static int TransformBlock(ICryptoTransform transform, ReadOnlyMemory inputBuffer, byte[] outputBuffer, int outputOffset) + { + if (MemoryMarshal.TryGetArray(inputBuffer, out ArraySegment segment)) + { + // Skip the copy if readonlymemory is actually an array. + Debug.Assert(segment.Array is not null); + return transform.TransformBlock(segment.Array, segment.Offset, inputBuffer.Length, outputBuffer, outputOffset); + } + else + { + // Use ArrayPool.Shared instead of CryptoPool because the array is passed out. + byte[] rentedBuffer = ArrayPool.Shared.Rent(inputBuffer.Length); + int result = default; + + // Pin the rented buffer for security. + fixed (byte* _ = &rentedBuffer[0]) + { + try + { + inputBuffer.CopyTo(rentedBuffer); + result = transform.TransformBlock(rentedBuffer, 0, inputBuffer.Length, outputBuffer, outputOffset); + } + finally + { + CryptographicOperations.ZeroMemory(rentedBuffer.AsSpan(0, inputBuffer.Length)); + } + } + + ArrayPool.Shared.Return(rentedBuffer); + rentedBuffer = null!; + return result; + } + } + } + + /// + public unsafe override void CopyTo(Stream destination, int bufferSize) + { + CheckCopyToArguments(destination, bufferSize); + + // Use ArrayPool.Shared instead of CryptoPool because the array is passed out. + byte[] rentedBuffer = ArrayPool.Shared.Rent(bufferSize); + // Pin the array for security. + fixed (byte* _ = &rentedBuffer[0]) + { + try + { + int bytesRead; + do + { + bytesRead = Read(rentedBuffer, 0, bufferSize); + destination.Write(rentedBuffer, 0, bytesRead); + } while (bytesRead > 0); + } + finally + { + CryptographicOperations.ZeroMemory(rentedBuffer.AsSpan(0, bufferSize)); + } + } + ArrayPool.Shared.Return(rentedBuffer); + rentedBuffer = null!; + } + + /// + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + CheckCopyToArguments(destination, bufferSize); + return CopyToAsyncInternal(destination, bufferSize, cancellationToken); + } + + private async Task CopyToAsyncInternal(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + // Use ArrayPool.Shared instead of CryptoPool because the array is passed out. + byte[] rentedBuffer = ArrayPool.Shared.Rent(bufferSize); + // Pin the array for security. + GCHandle pinHandle = GCHandle.Alloc(rentedBuffer, GCHandleType.Pinned); + try + { + int bytesRead; + do + { + bytesRead = await ReadAsync(rentedBuffer.AsMemory(0, bufferSize), cancellationToken).ConfigureAwait(false); + await destination.WriteAsync(rentedBuffer.AsMemory(0, bytesRead), cancellationToken).ConfigureAwait(false); + } while (bytesRead > 0); + } + finally + { + CryptographicOperations.ZeroMemory(rentedBuffer.AsSpan(0, bufferSize)); + pinHandle.Free(); + } + ArrayPool.Shared.Return(rentedBuffer); + rentedBuffer = null!; + } + + private void CheckCopyToArguments(Stream destination, int bufferSize) + { + if (destination is null) + throw new ArgumentNullException(nameof(destination)); + + EnsureNotDisposed(destination, nameof(destination)); + + if (!destination.CanWrite) + throw new NotSupportedException(SR.NotSupported_UnwritableStream); + if (bufferSize <= 0) + throw new ArgumentOutOfRangeException(nameof(bufferSize), SR.ArgumentOutOfRange_NeedPosNum); + if (!CanRead) + throw new NotSupportedException(SR.NotSupported_UnreadableStream); + } + + private static void EnsureNotDisposed(Stream stream, string objectName) + { + if (!stream.CanRead && !stream.CanWrite) + throw new ObjectDisposedException(objectName); } public void Clear()