diff --git a/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs index f422f5e33e91..73343bee07dc 100644 --- a/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -179,7 +179,9 @@ private ManagedWebSocket(Stream stream, bool isServer, string subprotocol, TimeS _receiveBuffer = new byte[ReceiveBufferMinLength]; // Set up the abort source so that if it's triggered, we transition the instance appropriately. - _abortSource.Token.Register(s => + // There's no need to store the resulting CancellationTokenRegistration, as this instance owns + // the CancellationTokenSource, and the lifetime of that CTS matches the lifetime of the registration. + _abortSource.Token.UnsafeRegister(s => { var thisRef = (ManagedWebSocket)s; diff --git a/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs b/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs index 448ee1658784..62cb2e18dde9 100644 --- a/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs +++ b/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs @@ -526,7 +526,7 @@ private void ProcessEvents() // When cancellation is requested, clear out all watches. This should force any active or future reads // on the inotify handle to return 0 bytes read immediately, allowing us to wake up from the blocking call // and exit the processing loop and clean up. - var ctr = _cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this); + var ctr = _cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this); try { // Previous event information diff --git a/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs b/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs index b608879181e7..d0be36b03c55 100644 --- a/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs +++ b/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs @@ -169,7 +169,7 @@ internal RunningInstance( _includeChildren = includeChildren; _filterFlags = filter; _cancellationToken = cancelToken; - _cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this); + _cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this); _stopping = false; } diff --git a/src/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs b/src/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs index ab7fd35d1b91..31189655eaa0 100644 --- a/src/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs +++ b/src/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs @@ -70,7 +70,7 @@ internal void RegisterForCancellation(CancellationToken cancellationToken) if (state == NoResult) { // Register the cancellation - _cancellationRegistration = cancellationToken.Register(thisRef => ((PipeCompletionSource)thisRef).Cancel(), this); + _cancellationRegistration = cancellationToken.UnsafeRegister(thisRef => ((PipeCompletionSource)thisRef).Cancel(), this); // Grab the state for case if IO completed while we were setting the registration. state = Interlocked.Exchange(ref _state, NoResult); diff --git a/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs b/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs index e1d098ea2ead..d516af88eb55 100644 --- a/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs +++ b/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs @@ -202,7 +202,7 @@ private async Task ReadAsyncCore(Memory destination, CancellationToke if (!t.IsCompletedSuccessfully) { var cancelTcs = new TaskCompletionSource(); - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetResult(true), cancelTcs)) + using (cancellationToken.UnsafeRegister(s => ((TaskCompletionSource)s).TrySetResult(true), cancelTcs)) { if (t == await Task.WhenAny(t, cancelTcs.Task).ConfigureAwait(false)) { diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs index b2e6f6be02c2..2f7d99e75cc7 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs @@ -60,7 +60,7 @@ public static async ValueTask ConnectAsync(string host, int port, Cancel if (Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, saea)) { // Connect completing asynchronously. Enable it to be canceled and wait for it. - using (cancellationToken.Register(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea)) + using (cancellationToken.UnsafeRegister(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea)) { await saea.Builder.Task.ConfigureAwait(false); } diff --git a/src/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj b/src/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj index fecb41572908..33d21d3ec7c9 100644 --- a/src/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj +++ b/src/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj @@ -14,10 +14,11 @@ Common\System\Net\WebSockets\WebSocketValidate.cs + - + diff --git a/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs index 52b9e2fdc116..7eb3997a65b2 100644 --- a/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs +++ b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs @@ -2,87 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Buffers; -using System.Diagnostics; -using System.IO; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using System.Text; using System.Threading; -using System.Threading.Tasks; namespace System.Net.WebSockets { - internal static class ManagedWebSocketExtensions + internal static partial class ManagedWebSocketExtensions { - internal static unsafe string GetString(this UTF8Encoding encoding, Span bytes) - { - fixed (byte* b = &MemoryMarshal.GetReference(bytes)) - { - return encoding.GetString(b, bytes.Length); - } - } - - internal static ValueTask ReadAsync(this Stream stream, Memory destination, CancellationToken cancellationToken = default) - { - if (MemoryMarshal.TryGetArray(destination, out ArraySegment array)) - { - return new ValueTask(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken)); - } - else - { - byte[] buffer = ArrayPool.Shared.Rent(destination.Length); - return new ValueTask(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination)); - - async Task FinishReadAsync(Task readTask, byte[] localBuffer, Memory localDestination) - { - try - { - int result = await readTask.ConfigureAwait(false); - new Span(localBuffer, 0, result).CopyTo(localDestination.Span); - return result; - } - finally - { - ArrayPool.Shared.Return(localBuffer); - } - } - } - } - - internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory source, CancellationToken cancellationToken = default) - { - if (MemoryMarshal.TryGetArray(source, out ArraySegment array)) - { - return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken)); - } - else - { - byte[] buffer = ArrayPool.Shared.Rent(source.Length); - source.Span.CopyTo(buffer); - return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer)); - - async Task FinishWriteAsync(Task writeTask, byte[] localBuffer) - { - try - { - await writeTask.ConfigureAwait(false); - } - finally - { - ArrayPool.Shared.Return(localBuffer); - } - } - } - } - } - - internal static class BitConverter - { - internal static unsafe int ToInt32(ReadOnlySpan value) - { - Debug.Assert(value.Length >= sizeof(int)); - return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(value)); - } + internal static CancellationTokenRegistration UnsafeRegister(this CancellationToken cancellationToken, Action callback, object state) => + cancellationToken.Register(callback, state); } } diff --git a/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs new file mode 100644 index 000000000000..89e232651af9 --- /dev/null +++ b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs @@ -0,0 +1,88 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + internal static partial class ManagedWebSocketExtensions + { + internal static unsafe string GetString(this UTF8Encoding encoding, Span bytes) + { + fixed (byte* b = &MemoryMarshal.GetReference(bytes)) + { + return encoding.GetString(b, bytes.Length); + } + } + + internal static ValueTask ReadAsync(this Stream stream, Memory destination, CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(destination, out ArraySegment array)) + { + return new ValueTask(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken)); + } + else + { + byte[] buffer = ArrayPool.Shared.Rent(destination.Length); + return new ValueTask(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination)); + + async Task FinishReadAsync(Task readTask, byte[] localBuffer, Memory localDestination) + { + try + { + int result = await readTask.ConfigureAwait(false); + new Span(localBuffer, 0, result).CopyTo(localDestination.Span); + return result; + } + finally + { + ArrayPool.Shared.Return(localBuffer); + } + } + } + } + + internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(source, out ArraySegment array)) + { + return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken)); + } + else + { + byte[] buffer = ArrayPool.Shared.Rent(source.Length); + source.Span.CopyTo(buffer); + return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer)); + + async Task FinishWriteAsync(Task writeTask, byte[] localBuffer) + { + try + { + await writeTask.ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(localBuffer); + } + } + } + } + } + + internal static class BitConverter + { + internal static unsafe int ToInt32(ReadOnlySpan value) + { + Debug.Assert(value.Length >= sizeof(int)); + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(value)); + } + } +} diff --git a/src/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs b/src/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs index 3dc264a34680..2484140f5a65 100644 --- a/src/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs +++ b/src/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs @@ -1060,13 +1060,13 @@ private static ParallelLoopResult ForWorker( // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled) ? default(CancellationTokenRegistration) - : parallelOptions.CancellationToken.Register((o) => + : parallelOptions.CancellationToken.UnsafeRegister((o) => { // Record our cancellation before stopping processing oce = new OperationCanceledException(parallelOptions.CancellationToken); // Cause processing to stop sharedPStateFlags.Cancel(); - }, state: null, useSynchronizationContext: false); + }, state: null); // ETW event for Parallel For begin int forkJoinContextID = 0; @@ -1322,13 +1322,13 @@ private static ParallelLoopResult ForWorker64( // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled) ? default(CancellationTokenRegistration) - : parallelOptions.CancellationToken.Register((o) => + : parallelOptions.CancellationToken.UnsafeRegister((o) => { // Record our cancellation before stopping processing oce = new OperationCanceledException(parallelOptions.CancellationToken); // Cause processing to stop sharedPStateFlags.Cancel(); - }, state: null, useSynchronizationContext: false); + }, state: null); // ETW event for Parallel For begin int forkJoinContextID = 0; @@ -3121,13 +3121,13 @@ private static ParallelLoopResult PartitionerForEachWorker( // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled) ? default(CancellationTokenRegistration) - : parallelOptions.CancellationToken.Register((o) => + : parallelOptions.CancellationToken.UnsafeRegister((o) => { // Record our cancellation before stopping processing oce = new OperationCanceledException(parallelOptions.CancellationToken); // Cause processing to stop sharedPStateFlags.Cancel(); - }, state: null, useSynchronizationContext: false); + }, state: null); // Get our dynamic partitioner -- depends on whether source is castable to OrderablePartitioner // Also, do some error checking.