Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ private ManagedWebSocket(Stream stream, bool isServer, string subprotocol, TimeS
_receiveBuffer = new byte[ReceiveBufferMinLength];

// Set up the abort source so that if it's triggered, we transition the instance appropriately.
_abortSource.Token.Register(s =>
// There's no need to store the resulting CancellationTokenRegistration, as this instance owns
// the CancellationTokenSource, and the lifetime of that CTS matches the lifetime of the registration.
_abortSource.Token.UnsafeRegister(s =>
{
var thisRef = (ManagedWebSocket)s;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ private void ProcessEvents()
// When cancellation is requested, clear out all watches. This should force any active or future reads
// on the inotify handle to return 0 bytes read immediately, allowing us to wake up from the blocking call
// and exit the processing loop and clean up.
var ctr = _cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this);
var ctr = _cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this);
try
{
// Previous event information
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ internal RunningInstance(
_includeChildren = includeChildren;
_filterFlags = filter;
_cancellationToken = cancelToken;
_cancellationToken.Register(obj => ((RunningInstance)obj).CancellationCallback(), this);
_cancellationToken.UnsafeRegister(obj => ((RunningInstance)obj).CancellationCallback(), this);
_stopping = false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ internal void RegisterForCancellation(CancellationToken cancellationToken)
if (state == NoResult)
{
// Register the cancellation
_cancellationRegistration = cancellationToken.Register(thisRef => ((PipeCompletionSource<TResult>)thisRef).Cancel(), this);
_cancellationRegistration = cancellationToken.UnsafeRegister(thisRef => ((PipeCompletionSource<TResult>)thisRef).Cancel(), this);

// Grab the state for case if IO completed while we were setting the registration.
state = Interlocked.Exchange(ref _state, NoResult);
Expand Down
2 changes: 1 addition & 1 deletion src/System.IO.Pipes/src/System/IO/Pipes/PipeStream.Unix.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ private async Task<int> ReadAsyncCore(Memory<byte> destination, CancellationToke
if (!t.IsCompletedSuccessfully)
{
var cancelTcs = new TaskCompletionSource<bool>();
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), cancelTcs))
using (cancellationToken.UnsafeRegister(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), cancelTcs))
{
if (t == await Task.WhenAny(t, cancelTcs.Task).ConfigureAwait(false))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public static async ValueTask<Stream> ConnectAsync(string host, int port, Cancel
if (Socket.ConnectAsync(SocketType.Stream, ProtocolType.Tcp, saea))
{
// Connect completing asynchronously. Enable it to be canceled and wait for it.
using (cancellationToken.Register(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea))
using (cancellationToken.UnsafeRegister(s => Socket.CancelConnectAsync((SocketAsyncEventArgs)s), saea))
{
await saea.Builder.Task.ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
<Link>Common\System\Net\WebSockets\WebSocketValidate.cs</Link>
</Compile>
<Compile Include="System\Net\WebSockets\ManagedWebSocket.netstandard.cs" />
<Compile Include="System\Net\WebSockets\ManagedWebSocketExtensions.cs" />
<Compile Include="System\Net\WebSockets\WebSocketProtocol.cs" />
</ItemGroup>
<ItemGroup Condition="'$(TargetGroup)'=='netstandard'">
<Compile Include="System\Net\WebSockets\ManagedWebSocketExtensions.cs" />
<Compile Include="System\Net\WebSockets\ManagedWebSocketExtensions.netstandard.cs" />
<Reference Include="System.Runtime.CompilerServices.Unsafe" />
</ItemGroup>
<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace System.Net.WebSockets
{
internal static class ManagedWebSocketExtensions
internal static partial class ManagedWebSocketExtensions
{
internal static unsafe string GetString(this UTF8Encoding encoding, Span<byte> bytes)
{
fixed (byte* b = &MemoryMarshal.GetReference(bytes))
{
return encoding.GetString(b, bytes.Length);
}
}

internal static ValueTask<int> ReadAsync(this Stream stream, Memory<byte> destination, CancellationToken cancellationToken = default)
{
if (MemoryMarshal.TryGetArray(destination, out ArraySegment<byte> array))
{
return new ValueTask<int>(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken));
}
else
{
byte[] buffer = ArrayPool<byte>.Shared.Rent(destination.Length);
return new ValueTask<int>(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination));

async Task<int> FinishReadAsync(Task<int> readTask, byte[] localBuffer, Memory<byte> localDestination)
{
try
{
int result = await readTask.ConfigureAwait(false);
new Span<byte>(localBuffer, 0, result).CopyTo(localDestination.Span);
return result;
}
finally
{
ArrayPool<byte>.Shared.Return(localBuffer);
}
}
}
}

internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
{
if (MemoryMarshal.TryGetArray(source, out ArraySegment<byte> array))
{
return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken));
}
else
{
byte[] buffer = ArrayPool<byte>.Shared.Rent(source.Length);
source.Span.CopyTo(buffer);
return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer));

async Task FinishWriteAsync(Task writeTask, byte[] localBuffer)
{
try
{
await writeTask.ConfigureAwait(false);
}
finally
{
ArrayPool<byte>.Shared.Return(localBuffer);
}
}
}
}
}

internal static class BitConverter
{
internal static unsafe int ToInt32(ReadOnlySpan<byte> value)
{
Debug.Assert(value.Length >= sizeof(int));
return Unsafe.ReadUnaligned<int>(ref MemoryMarshal.GetReference(value));
}
internal static CancellationTokenRegistration UnsafeRegister(this CancellationToken cancellationToken, Action<object> callback, object state) =>
cancellationToken.Register(callback, state);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace System.Net.WebSockets
{
internal static partial class ManagedWebSocketExtensions
{
internal static unsafe string GetString(this UTF8Encoding encoding, Span<byte> bytes)
{
fixed (byte* b = &MemoryMarshal.GetReference(bytes))
{
return encoding.GetString(b, bytes.Length);
}
}

internal static ValueTask<int> ReadAsync(this Stream stream, Memory<byte> destination, CancellationToken cancellationToken = default)
{
if (MemoryMarshal.TryGetArray(destination, out ArraySegment<byte> array))
{
return new ValueTask<int>(stream.ReadAsync(array.Array, array.Offset, array.Count, cancellationToken));
}
else
{
byte[] buffer = ArrayPool<byte>.Shared.Rent(destination.Length);
return new ValueTask<int>(FinishReadAsync(stream.ReadAsync(buffer, 0, destination.Length, cancellationToken), buffer, destination));

async Task<int> FinishReadAsync(Task<int> readTask, byte[] localBuffer, Memory<byte> localDestination)
{
try
{
int result = await readTask.ConfigureAwait(false);
new Span<byte>(localBuffer, 0, result).CopyTo(localDestination.Span);
return result;
}
finally
{
ArrayPool<byte>.Shared.Return(localBuffer);
}
}
}
}

internal static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
{
if (MemoryMarshal.TryGetArray(source, out ArraySegment<byte> array))
{
return new ValueTask(stream.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken));
}
else
{
byte[] buffer = ArrayPool<byte>.Shared.Rent(source.Length);
source.Span.CopyTo(buffer);
return new ValueTask(FinishWriteAsync(stream.WriteAsync(buffer, 0, source.Length, cancellationToken), buffer));

async Task FinishWriteAsync(Task writeTask, byte[] localBuffer)
{
try
{
await writeTask.ConfigureAwait(false);
}
finally
{
ArrayPool<byte>.Shared.Return(localBuffer);
}
}
}
}
}

internal static class BitConverter
{
internal static unsafe int ToInt32(ReadOnlySpan<byte> value)
{
Debug.Assert(value.Length >= sizeof(int));
return Unsafe.ReadUnaligned<int>(ref MemoryMarshal.GetReference(value));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1060,13 +1060,13 @@ private static ParallelLoopResult ForWorker<TLocal>(
// if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled
CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled)
? default(CancellationTokenRegistration)
: parallelOptions.CancellationToken.Register((o) =>
: parallelOptions.CancellationToken.UnsafeRegister((o) =>
{
// Record our cancellation before stopping processing
oce = new OperationCanceledException(parallelOptions.CancellationToken);
// Cause processing to stop
sharedPStateFlags.Cancel();
}, state: null, useSynchronizationContext: false);
}, state: null);

// ETW event for Parallel For begin
int forkJoinContextID = 0;
Expand Down Expand Up @@ -1322,13 +1322,13 @@ private static ParallelLoopResult ForWorker64<TLocal>(
// if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled
CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled)
? default(CancellationTokenRegistration)
: parallelOptions.CancellationToken.Register((o) =>
: parallelOptions.CancellationToken.UnsafeRegister((o) =>
{
// Record our cancellation before stopping processing
oce = new OperationCanceledException(parallelOptions.CancellationToken);
// Cause processing to stop
sharedPStateFlags.Cancel();
}, state: null, useSynchronizationContext: false);
}, state: null);

// ETW event for Parallel For begin
int forkJoinContextID = 0;
Expand Down Expand Up @@ -3121,13 +3121,13 @@ private static ParallelLoopResult PartitionerForEachWorker<TSource, TLocal>(
// if cancellation is enabled, we need to register a callback to stop the loop when it gets signaled
CancellationTokenRegistration ctr = (!parallelOptions.CancellationToken.CanBeCanceled)
? default(CancellationTokenRegistration)
: parallelOptions.CancellationToken.Register((o) =>
: parallelOptions.CancellationToken.UnsafeRegister((o) =>
{
// Record our cancellation before stopping processing
oce = new OperationCanceledException(parallelOptions.CancellationToken);
// Cause processing to stop
sharedPStateFlags.Cancel();
}, state: null, useSynchronizationContext: false);
}, state: null);

// Get our dynamic partitioner -- depends on whether source is castable to OrderablePartitioner
// Also, do some error checking.
Expand Down