From 53b0fe9867fe8a15501b5596bb177dad3c51a28a Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 16 May 2019 08:36:14 -0700 Subject: [PATCH] Use CancellationToken.UnsafeRegister in a few more places CancellationToken.Register captures the current ExecutionContext and uses it to invoke the callback if/when it's invoked. That's generally desirable and is the right default, but in cases where we know for certain the callback doesn't care about EC (e.g. we're not invoking any 3rd-party code), we can use UnsafeRegister instead (newly added in 3.0), which skips capturing the ExecutionContext, as if Capture returned null. This helps few a couple of small costs: - Avoids thread local lookups to capture the current EC. - Avoids additional delegate invocations and thread local gets/sets to invoke the callback with the captured EC. - Avoids holding on to the EC in case it's needed, which can potentially keep alive an unbounded amount of state due to AsyncLocals. --- .../System/Net/WebSockets/ManagedWebSocket.cs | 4 +- .../src/System/IO/FileSystemWatcher.Linux.cs | 2 +- .../src/System/IO/FileSystemWatcher.OSX.cs | 2 +- .../System/IO/Pipes/PipeCompletionSource.cs | 2 +- .../src/System/IO/Pipes/PipeStream.Unix.cs | 2 +- .../Http/SocketsHttpHandler/ConnectHelper.cs | 2 +- ...em.Net.WebSockets.WebSocketProtocol.csproj | 3 +- .../WebSockets/ManagedWebSocketExtensions.cs | 80 +---------------- .../ManagedWebSocketExtensions.netstandard.cs | 88 +++++++++++++++++++ .../src/System/Threading/Tasks/Parallel.cs | 12 +-- 10 files changed, 107 insertions(+), 90 deletions(-) create mode 100644 src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs diff --git a/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs index f422f5e33e91..73343bee07dc 100644 --- a/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -179,7 +179,9 @@ private ManagedWebSocket(Stream stream, bool isServer, string subprotocol, TimeS _receiveBuffer = new byte[ReceiveBufferMinLength]; // Set up the abort source so that if it's triggered, we transition the instance appropriately. - _abortSource.Token.Register(s => + // There's no need to store the resulting CancellationTokenRegistration, as this instance owns + // the CancellationTokenSource, and the lifetime of that CTS matches the lifetime of the registration. + _abortSource.Token.UnsafeRegister(s => { var thisRef = (ManagedWebSocket)s; diff --git a/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs b/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs index 448ee1658784..62cb2e18dde9 100644 --- a/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs +++ b/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.Linux.cs @@ -526,7 +526,7 @@ private void ProcessEvents() // When cancellation is requested, clear out all watches. This should force any active or future reads // on the inotify handle to return 0 bytes read immediately, allowing us to wake up from the blocking call // and exit the processing loop and clean up. - var ctr = _cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this); + var ctr = _cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this); try { // Previous event information diff --git a/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs b/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs index b608879181e7..d0be36b03c55 100644 --- a/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs +++ b/src/System.IO.FileSystem.Watcher/src/System/IO/FileSystemWatcher.OSX.cs @@ -169,7 +169,7 @@ internal RunningInstance( _includeChildren = includeChildren; _filterFlags = filter; _cancellationToken = cancelToken; - _cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this); + _cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this); _stopping = false; } diff --git a/src/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs b/src/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs index ab7fd35d1b91..31189655eaa0 100644 --- a/src/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs +++ b/src/System.IO.Pipes/src/System/IO/Pipes/PipeCompletionSource.cs @@ -70,7 +70,7 @@ internal void RegisterForCancellation(CancellationToken cancellationToken) if (state == NoResult) { // Register the cancellation - _cancellationRegistration = cancellationToken.Register(thisRef => ((PipeCompletionSource)thisRef).Cancel(), this); + _cancellationRegistration = cancellationToken.UnsafeRegister(thisRef => ((PipeCompletionSource)thisRef).Cancel(), this); // Grab the state for case if IO completed while we were setting the registration. state = Interlocked.Exchange(ref _state, NoResult); diff --git a/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs b/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs index e1d098ea2ead..d516af88eb55 100644 --- a/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs +++ b/src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs @@ -202,7 +202,7 @@ private async Task ReadAsyncCore(Memory destination, CancellationToke if (!t.IsCompletedSuccessfully) { var cancelTcs = new TaskCompletionSource(); - using (cancellationToken.Register(s => ((TaskCompletionSource)s).TrySetResult(true), cancelTcs)) + using (cancellationToken.UnsafeRegister(s => ((TaskCompletionSource)s).TrySetResult(true), cancelTcs)) { if (t == await Task.WhenAny(t, cancelTcs.Task).ConfigureAwait(false)) { diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs index b2e6f6be02c2..2f7d99e75cc7 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/ConnectHelper.cs @@ -60,7 +60,7 @@ public static async ValueTask ConnectAsync(string host, int port, Cancel if (Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, saea)) { // Connect completing asynchronously. Enable it to be canceled and wait for it. - using (cancellationToken.Register(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea)) + using (cancellationToken.UnsafeRegister(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea)) { await saea.Builder.Task.ConfigureAwait(false); } diff --git a/src/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj b/src/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj index fecb41572908..33d21d3ec7c9 100644 --- a/src/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj +++ b/src/System.Net.WebSockets.WebSocketProtocol/src/System.Net.WebSockets.WebSocketProtocol.csproj @@ -14,10 +14,11 @@ Common\System\Net\WebSockets\WebSocketValidate.cs + - + diff --git a/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs index 52b9e2fdc116..7eb3997a65b2 100644 --- a/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs +++ b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.cs @@ -2,87 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Buffers; -using System.Diagnostics; -using System.IO; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using System.Text; using System.Threading; -using System.Threading.Tasks; namespace System.Net.WebSockets { - internal static class ManagedWebSocketExtensions + internal static partial class ManagedWebSocketExtensions { - internal static unsafe string GetString(this UTF8Encoding encoding, Span bytes) - { - fixed (byte* b = &MemoryMarshal.GetReference(bytes)) - { - return encoding.GetString(b, bytes.Length); - } - } - - internal static ValueTask ReadAsync(this Stream stream, Memory destination, CancellationToken cancellationToken = default) - { - if (MemoryMarshal.TryGetArray(destination, out ArraySegment array)) - { - return new ValueTask(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken)); - } - else - { - byte[] buffer = ArrayPool.Shared.Rent(destination.Length); - return new ValueTask(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination)); - - async Task FinishReadAsync(Task readTask, byte[] localBuffer, Memory localDestination) - { - try - { - int result = await readTask.ConfigureAwait(false); - new Span(localBuffer, 0, result).CopyTo(localDestination.Span); - return result; - } - finally - { - ArrayPool.Shared.Return(localBuffer); - } - } - } - } - - internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory source, CancellationToken cancellationToken = default) - { - if (MemoryMarshal.TryGetArray(source, out ArraySegment array)) - { - return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken)); - } - else - { - byte[] buffer = ArrayPool.Shared.Rent(source.Length); - source.Span.CopyTo(buffer); - return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer)); - - async Task FinishWriteAsync(Task writeTask, byte[] localBuffer) - { - try - { - await writeTask.ConfigureAwait(false); - } - finally - { - ArrayPool.Shared.Return(localBuffer); - } - } - } - } - } - - internal static class BitConverter - { - internal static unsafe int ToInt32(ReadOnlySpan value) - { - Debug.Assert(value.Length >= sizeof(int)); - return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(value)); - } + internal static CancellationTokenRegistration UnsafeRegister(this CancellationToken cancellationToken, Action callback, object state) => + cancellationToken.Register(callback, state); } } diff --git a/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs new file mode 100644 index 000000000000..89e232651af9 --- /dev/null +++ b/src/System.Net.WebSockets.WebSocketProtocol/src/System/Net/WebSockets/ManagedWebSocketExtensions.netstandard.cs @@ -0,0 +1,88 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + internal static partial class ManagedWebSocketExtensions + { + internal static unsafe string GetString(this UTF8Encoding encoding, Span bytes) + { + fixed (byte* b = &MemoryMarshal.GetReference(bytes)) + { + return encoding.GetString(b, bytes.Length); + } + } + + internal static ValueTask ReadAsync(this Stream stream, Memory destination, CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(destination, out ArraySegment array)) + { + return new ValueTask(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken)); + } + else + { + byte[] buffer = ArrayPool.Shared.Rent(destination.Length); + return new ValueTask(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination)); + + async Task FinishReadAsync(Task readTask, byte[] localBuffer, Memory localDestination) + { + try + { + int result = await readTask.ConfigureAwait(false); + new Span(localBuffer, 0, result).CopyTo(localDestination.Span); + return result; + } + finally + { + ArrayPool.Shared.Return(localBuffer); + } + } + } + } + + internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + if (MemoryMarshal.TryGetArray(source, out ArraySegment array)) + { + return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken)); + } + else + { + byte[] buffer = ArrayPool.Shared.Rent(source.Length); + source.Span.CopyTo(buffer); + return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer)); + + async Task FinishWriteAsync(Task writeTask, byte[] localBuffer) + { + try + { + await writeTask.ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(localBuffer); + } + } + } + } + } + + internal static class BitConverter + { + internal static unsafe int ToInt32(ReadOnlySpan value) + { + Debug.Assert(value.Length >= sizeof(int)); + return Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(value)); + } + } +} diff --git a/src/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs b/src/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs index 3dc264a34680..2484140f5a65 100644 --- a/src/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs +++ b/src/System.Threading.Tasks.Parallel/src/System/Threading/Tasks/Parallel.cs @@ -1060,13 +1060,13 @@ private static ParallelLoopResult ForWorker( // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled) ? default(CancellationTokenRegistration) - : parallelOptions.CancellationToken.Register((o) => + : parallelOptions.CancellationToken.UnsafeRegister((o) => { // Record our cancellation before stopping processing oce = new OperationCanceledException(parallelOptions.CancellationToken); // Cause processing to stop sharedPStateFlags.Cancel(); - }, state: null, useSynchronizationContext: false); + }, state: null); // ETW event for Parallel For begin int forkJoinContextID = 0; @@ -1322,13 +1322,13 @@ private static ParallelLoopResult ForWorker64( // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled) ? default(CancellationTokenRegistration) - : parallelOptions.CancellationToken.Register((o) => + : parallelOptions.CancellationToken.UnsafeRegister((o) => { // Record our cancellation before stopping processing oce = new OperationCanceledException(parallelOptions.CancellationToken); // Cause processing to stop sharedPStateFlags.Cancel(); - }, state: null, useSynchronizationContext: false); + }, state: null); // ETW event for Parallel For begin int forkJoinContextID = 0; @@ -3121,13 +3121,13 @@ private static ParallelLoopResult PartitionerForEachWorker( // if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled) ? default(CancellationTokenRegistration) - : parallelOptions.CancellationToken.Register((o) => + : parallelOptions.CancellationToken.UnsafeRegister((o) => { // Record our cancellation before stopping processing oce = new OperationCanceledException(parallelOptions.CancellationToken); // Cause processing to stop sharedPStateFlags.Cancel(); - }, state: null, useSynchronizationContext: false); + }, state: null); // Get our dynamic partitioner -- depends on whether source is castable to OrderablePartitioner // Also, do some error checking.