From d7002764ff43148a2030d2fca3b7e8caba5fc40d Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sun, 1 Jun 2025 22:24:55 -0400 Subject: [PATCH 1/2] Support bounded channel with bound of 0 (rendezvous) This PR enables Channel.CreateBounded(0), whereas currently a bound of < 1 is exceptional. A bound is the number of items the channel can buffer, so a bound of 0 means it can't buffer anything, which makes it into a rendezvous, where the reader and writer must be at the channel at the same time in order to directly hand off from the writer to the reader. This is the same meaning as in other languages/libraries, e.g. if in go you don't specify a bound or you specify a bound of 0, you similarly get an unbuffered rendezvous channel. --- .../src/System.Threading.Channels.csproj | 1 + .../src/System/Threading/Channels/Channel.cs | 27 +- .../Threading/Channels/ChannelOptions.cs | 7 +- .../Threading/Channels/ChannelUtilities.cs | 23 + .../Threading/Channels/RendezvousChannel.cs | 537 ++++++++++++++++++ .../SingleConsumerUnboundedChannel.cs | 2 - .../Threading/Channels/UnboundedChannel.cs | 34 +- .../Channels/UnboundedPriorityChannel.cs | 34 +- .../tests/ChannelTestBase.cs | 107 ++-- .../tests/ChannelTests.cs | 4 +- .../tests/RendezvousChannelTests.cs | 317 +++++++++++ .../System.Threading.Channels/tests/Stress.cs | 78 ++- .../System.Threading.Channels.Tests.csproj | 3 +- .../tests/TestBase.cs | 8 + 14 files changed, 1049 insertions(+), 133 deletions(-) create mode 100644 src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs create mode 100644 src/libraries/System.Threading.Channels/tests/RendezvousChannelTests.cs diff --git a/src/libraries/System.Threading.Channels/src/System.Threading.Channels.csproj b/src/libraries/System.Threading.Channels/src/System.Threading.Channels.csproj index 4523f5d8b55835..1d8a85a73e0259 100644 --- a/src/libraries/System.Threading.Channels/src/System.Threading.Channels.csproj +++ b/src/libraries/System.Threading.Channels/src/System.Threading.Channels.csproj @@ -25,6 +25,7 @@ System.Threading.Channel<T> + diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.cs index 554a60d10724a3..1dbea485b3272b 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.cs @@ -15,6 +15,7 @@ public static Channel CreateUnbounded() => /// Specifies the type of data in the channel. /// Options that guide the behavior of the channel. /// The created channel. + /// is . public static Channel CreateUnbounded(UnboundedChannelOptions options) { ArgumentNullException.ThrowIfNull(options); @@ -35,35 +36,33 @@ public static Channel CreateUnbounded(UnboundedChannelOptions options) /// Channels created with this method apply the /// behavior and prohibit continuations from running synchronously. /// - public static Channel CreateBounded(int capacity) - { - if (capacity < 1) - { - throw new ArgumentOutOfRangeException(nameof(capacity)); - } - - return new BoundedChannel(capacity, BoundedChannelFullMode.Wait, runContinuationsAsynchronously: true, itemDropped: null); - } + /// is negative. + public static Channel CreateBounded(int capacity) => + capacity > 0 ? new BoundedChannel(capacity, BoundedChannelFullMode.Wait, runContinuationsAsynchronously: true, itemDropped: null) : + capacity == 0 ? new RendezvousChannel(BoundedChannelFullMode.Wait, runContinuationsAsynchronously: true, itemDropped: null) : + throw new ArgumentOutOfRangeException(nameof(capacity)); /// Creates a channel subject to the provided options. /// Specifies the type of data in the channel. /// Options that guide the behavior of the channel. /// The created channel. - public static Channel CreateBounded(BoundedChannelOptions options) - { - return CreateBounded(options, itemDropped: null); - } + /// is . + public static Channel CreateBounded(BoundedChannelOptions options) => + CreateBounded(options, itemDropped: null); /// Creates a channel subject to the provided options. /// Specifies the type of data in the channel. /// Options that guide the behavior of the channel. /// Delegate that will be called when item is being dropped from channel. See . /// The created channel. + /// is . public static Channel CreateBounded(BoundedChannelOptions options, Action? itemDropped) { ArgumentNullException.ThrowIfNull(options); - return new BoundedChannel(options.Capacity, options.FullMode, !options.AllowSynchronousContinuations, itemDropped); + return + options.Capacity > 0 ? new BoundedChannel(options.Capacity, options.FullMode, !options.AllowSynchronousContinuations, itemDropped) : + new RendezvousChannel(options.FullMode, !options.AllowSynchronousContinuations, itemDropped); } } } diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs index ab3a722bb909d2..85a2ff49e29983 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs @@ -49,9 +49,10 @@ public sealed class BoundedChannelOptions : ChannelOptions /// Initializes the options. /// The maximum number of items the bounded channel may store. + /// is negative. public BoundedChannelOptions(int capacity) { - if (capacity < 1) + if (capacity < 0) { throw new ArgumentOutOfRangeException(nameof(capacity)); } @@ -60,12 +61,13 @@ public BoundedChannelOptions(int capacity) } /// Gets or sets the maximum number of items the bounded channel may store. + /// is negative. public int Capacity { get => _capacity; set { - if (value < 1) + if (value < 0) { throw new ArgumentOutOfRangeException(nameof(value)); } @@ -74,6 +76,7 @@ public int Capacity } /// Gets or sets the behavior incurred by write operations when the channel is full. + /// is an invalid enum value. public BoundedChannelFullMode FullMode { get => _mode; diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs index 3436834636edb3..0e7f2742d1a4c9 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs @@ -317,6 +317,29 @@ internal static void AssertAll(TAsyncOp? head, Func co } } + /// Counts the number of operations in the list. + /// The head of the queue of operations to count. + internal static long CountOperations(TAsyncOp? head) + where TAsyncOp : AsyncOperation + { + TAsyncOp? current = head; + long count = 0; + + if (current is not null) + { + do + { + count++; + + Debug.Assert(current is not null); + current = current.Next; + } + while (current != head); + } + + return count; + } + /// Creates and returns an exception object to indicate that a channel has been closed. internal static Exception CreateInvalidCompletionException(Exception? inner = null) => inner is OperationCanceledException ? inner : diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs new file mode 100644 index 00000000000000..8dc641b8e4cb1b --- /dev/null +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs @@ -0,0 +1,537 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + /// Provides an unbuffered channel where readers and writers must directly hand off to each other. + [DebuggerDisplay("{DebuggerDisplay,nq}")] + internal sealed class RendezvousChannel : Channel + { + /// Whether to suceed writes immediately even when there's no rendezvousing reader. + private readonly bool _dropWrites; + + /// The delegate that will be invoked when the channel hits its bound and an item is dropped from the channel. + private readonly Action? _itemDropped; + + /// Task signaled when the channel has completed. + private readonly TaskCompletionSource _completion; + + /// Head of linked list of blocked ReadAsync calls. + private BlockedReadAsyncOperation? _blockedReadersHead; + + /// Head of linked list of blocked WriteAsync calls. + private BlockedWriteAsyncOperation? _blockedWritersHead; + + /// Head of linked list of waiting WaitToReadAsync calls. + private WaitingReadAsyncOperation? _waitingReadersHead; + + /// Head of linked list of waiting WaitToWriteAsync calls. + private WaitingWriteAsyncOperation? _waitingWritersHead; + + /// Whether to force continuations to be executed asynchronously from producer writes. + private readonly bool _runContinuationsAsynchronously; + + /// Set to non-null once Complete has been called. + private Exception? _doneWriting; + + /// Initializes the . + /// The mode used when writing to a full channel. + /// Whether to force continuations to be executed asynchronously. + /// Delegate that will be invoked when an item is dropped from the channel. See . + internal RendezvousChannel(BoundedChannelFullMode mode, bool runContinuationsAsynchronously, Action? itemDropped) + { + _dropWrites = mode is not BoundedChannelFullMode.Wait; + + _runContinuationsAsynchronously = runContinuationsAsynchronously; + _itemDropped = itemDropped; + _completion = new TaskCompletionSource(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); + + Reader = new RendezvousChannelReader(this); + Writer = new RendezvousChannelWriter(this); + } + + [DebuggerDisplay("{DebuggerDisplay,nq}")] + private sealed class RendezvousChannelReader : ChannelReader + { + internal readonly RendezvousChannel _parent; + private readonly BlockedReadAsyncOperation _readerSingleton; + private readonly WaitingReadAsyncOperation _waiterSingleton; + + internal RendezvousChannelReader(RendezvousChannel parent) + { + _parent = parent; + _readerSingleton = new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, pooled: true); + _waiterSingleton = new WaitingReadAsyncOperation(parent._runContinuationsAsynchronously, pooled: true); + } + + public override Task Completion => _parent._completion.Task; + + public override bool CanCount => true; + + public override bool CanPeek => true; + + public override int Count => 0; + + public override bool TryRead([MaybeNullWhen(false)] out T item) + { + RendezvousChannel parent = _parent; + + // Reserve a blocked writer if one is available. + BlockedWriteAsyncOperation? blockedWriter = null; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + if (parent._doneWriting is null) + { + while (ChannelUtilities.TryDequeue(ref parent._blockedWritersHead, out blockedWriter)) + { + if (blockedWriter.TryReserveCompletionIfCancelable()) + { + break; + } + } + } + } + + // If we got one, transfer its item to the read and complete successfully. + if (blockedWriter is not null) + { + item = blockedWriter.Item!; + blockedWriter.DangerousSetResult(default); + return true; + } + + item = default; + return false; + } + + public override bool TryPeek([MaybeNullWhen(false)] out T item) + { + RendezvousChannel parent = _parent; + + // Peek at a blocked writer if one is available. + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + if (parent._doneWriting is null && + parent._blockedWritersHead is { } blockedWriter) + { + item = blockedWriter.Item!; + return true; + } + } + + item = default; + return false; + } + + public override ValueTask ReadAsync(CancellationToken cancellationToken) + { + RendezvousChannel parent = _parent; + + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + BlockedReadAsyncOperation? reader = null; + WaitingWriteAsyncOperation? waitingWriters = null; + BlockedWriteAsyncOperation? blockedWriter = null; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we're done writing so that there will never be more items, fail. + if (parent._doneWriting is not null) + { + return ChannelUtilities.GetInvalidCompletionValueTask(parent._doneWriting); + } + + // Reserve a blocked writer if one is available. + while (ChannelUtilities.TryDequeue(ref parent._blockedWritersHead, out blockedWriter)) + { + if (blockedWriter.TryReserveCompletionIfCancelable()) + { + break; + } + } + + // If we couldn't get one, create a waiting reader, and reserve any waiting writers to alert. + if (blockedWriter is null) + { + reader = + !cancellationToken.CanBeCanceled && _readerSingleton.TryOwnAndReset() ? _readerSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + ChannelUtilities.Enqueue(ref parent._blockedReadersHead, reader); + + waitingWriters = ChannelUtilities.TryReserveCompletionIfCancelable(ref parent._waitingWritersHead); + } + } + + // Either complete the reserved blocked writer, transferring its item to the read, + // or return the waiting reader task, also alerting any waiting writers. + ValueTask result; + if (blockedWriter is not null) + { + Debug.Assert(reader is null); + Debug.Assert(waitingWriters is null); + result = new(blockedWriter.Item!); + blockedWriter.DangerousSetResult(default); + } + else + { + Debug.Assert(reader is not null); + ChannelUtilities.DangerousSetOperations(waitingWriters, result: true); + result = reader.ValueTaskOfT; + } + + return result; + } + + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + RendezvousChannel parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we're done writing, a read will never be possible. + if (parent._doneWriting is not null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + new ValueTask(Task.FromException(parent._doneWriting)) : + default; + } + + // If there are any writers waiting, a read is possible. + if (parent._blockedWritersHead is not null) + { + return new ValueTask(true); + } + + // Register a waiting reader task. + WaitingReadAsyncOperation waiter = + !cancellationToken.CanBeCanceled && _waiterSingleton.TryOwnAndReset() ? _waiterSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + ChannelUtilities.Enqueue(ref parent._waitingReadersHead, waiter); + return waiter.ValueTaskOfT; + } + } + + internal string DebuggerDisplay + { + get + { + long blockedReaderCount, waitingReaderCount; + lock (_parent.SyncObj) + { + blockedReaderCount = ChannelUtilities.CountOperations(_parent._blockedReadersHead); + waitingReaderCount = ChannelUtilities.CountOperations(_parent._waitingReadersHead); + } + + return $"ReadAsync={blockedReaderCount}, WaitToReadAsync={waitingReaderCount}"; + } + } + } + + [DebuggerDisplay("{DebuggerDisplay,nq}")] + private sealed class RendezvousChannelWriter : ChannelWriter + { + internal readonly RendezvousChannel _parent; + private readonly BlockedWriteAsyncOperation _writerSingleton; + private readonly WaitingWriteAsyncOperation _waiterSingleton; + + internal RendezvousChannelWriter(RendezvousChannel parent) + { + _parent = parent; + _writerSingleton = new BlockedWriteAsyncOperation(runContinuationsAsynchronously: true, pooled: true); + _waiterSingleton = new WaitingWriteAsyncOperation(runContinuationsAsynchronously: true, pooled: true); + } + + public override bool TryComplete(Exception? error) + { + RendezvousChannel parent = _parent; + + BlockedReadAsyncOperation? blockedReadersHead; + BlockedWriteAsyncOperation? blockedWritersHead; + WaitingReadAsyncOperation? waitingReadersHead; + WaitingWriteAsyncOperation? waitingWritersHead; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we've already marked the channel as completed, bail. + if (parent._doneWriting is not null) + { + return false; + } + + // Mark that we're done writing. + parent._doneWriting = error ?? ChannelUtilities.s_doneWritingSentinel; + + // Snag the queues while holding the lock, so that we don't need to worry + // about concurrent mutation, such as from cancellation of pending operations. + blockedReadersHead = parent._blockedReadersHead; + blockedWritersHead = parent._blockedWritersHead; + waitingReadersHead = parent._waitingReadersHead; + waitingWritersHead = parent._waitingWritersHead; + parent._blockedReadersHead = null; + parent._blockedWritersHead = null; + parent._waitingReadersHead = null; + parent._waitingWritersHead = null; + } + + // Complete the channel's task, as no more data can possibly arrive at this point. We do this outside + // of the lock in case we'll be running synchronous completions, and we + // do it before completing blocked/waiting readers, so that when they + // wake up they'll see the task as being completed. + ChannelUtilities.Complete(parent._completion, error); + + // Complete all pending operations. We don't need to worry about concurrent mutation here: + // No other writers or readers will be able to register operations, and any cancellation callbacks + // will see the queues as being null and exit immediately. + ChannelUtilities.FailOperations(blockedReadersHead, ChannelUtilities.CreateInvalidCompletionException(error)); + ChannelUtilities.FailOperations(blockedWritersHead, ChannelUtilities.CreateInvalidCompletionException(error)); + ChannelUtilities.SetOrFailOperations(waitingReadersHead, result: false, error: error); + ChannelUtilities.SetOrFailOperations(waitingWritersHead, result: false, error: error); + + // Successfully transitioned to completed. + return true; + } + + public override bool TryWrite(T item) + { + RendezvousChannel parent = _parent; + + // Reserve a blocked reader if one is available. + BlockedReadAsyncOperation? blockedReader = null; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + if (parent._doneWriting is null) + { + while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) + { + if (blockedReader.TryReserveCompletionIfCancelable()) + { + break; + } + } + } + } + + // If we got one, transfer its item to the read and complete successfully. + if (blockedReader is not null) + { + blockedReader.DangerousSetResult(item); + return true; + } + + // There's no concurrent reader, but if we're configured to drop writes, we can succeed immediately. + if (parent._dropWrites) + { + parent._itemDropped?.Invoke(item); + return true; + } + + return false; + } + + public override ValueTask WaitToWriteAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + RendezvousChannel parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we're done writing, a read will never be possible. + if (parent._doneWriting is not null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + new ValueTask(Task.FromException(parent._doneWriting)) : + default; + } + + // If there are any readers waiting, a write is possible. + if (parent._blockedReadersHead is not null || parent._dropWrites) + { + return new ValueTask(true); + } + + // There were no readers available, but there could be in the future, so ensure + // there's a waiting writer task and return it. + WaitingWriteAsyncOperation waiter = + !cancellationToken.CanBeCanceled && _waiterSingleton.TryOwnAndReset() ? _waiterSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + ChannelUtilities.Enqueue(ref parent._waitingWritersHead, waiter); + return waiter.ValueTaskOfT; + } + } + + public override ValueTask WriteAsync(T item, CancellationToken cancellationToken) + { + RendezvousChannel parent = _parent; + + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + BlockedWriteAsyncOperation? writer = null; + WaitingReadAsyncOperation? waitingReaders = null; + BlockedReadAsyncOperation? blockedReader = null; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we've already been marked as done for writing, we shouldn't be writing. + if (parent._doneWriting is not null) + { + return new ValueTask(Task.FromException(ChannelUtilities.CreateInvalidCompletionException(parent._doneWriting))); + } + + // Reserve a blocked reader if one is available. + while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) + { + if (blockedReader.TryReserveCompletionIfCancelable()) + { + break; + } + } + + // If we couldn't get one, create a blocked writer, and reserve any waiting readers to alert. + if (blockedReader is null && !parent._dropWrites) + { + writer = + !cancellationToken.CanBeCanceled && _writerSingleton.TryOwnAndReset() ? _writerSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + writer.Item = item; + ChannelUtilities.Enqueue(ref parent._blockedWritersHead, writer); + + waitingReaders = ChannelUtilities.TryReserveCompletionIfCancelable(ref parent._waitingReadersHead); + } + } + + if (writer is not null) + { + Debug.Assert(blockedReader is null); + ChannelUtilities.DangerousSetOperations(waitingReaders, result: true); + return writer.ValueTask; + } + + if (blockedReader is not null) + { + blockedReader.DangerousSetResult(item); + } + else + { + Debug.Assert(parent._dropWrites); + parent._itemDropped?.Invoke(item); + } + + return default; + } + + internal string DebuggerDisplay + { + get + { + long blockedWriterCount, waitingWriterCount; + lock (_parent.SyncObj) + { + blockedWriterCount = ChannelUtilities.CountOperations(_parent._blockedWritersHead); + waitingWriterCount = ChannelUtilities.CountOperations(_parent._waitingWritersHead); + } + + return $"WriteAsync={blockedWriterCount}, WaitToWriteAsync={waitingWriterCount}"; + } + } + } + + /// Gets an object used to synchronize all state on the instance. + private object SyncObj => _completion; + + private Action CancellationCallbackDelegate => + field ??= (state, cancellationToken) => + { + AsyncOperation op = (AsyncOperation)state!; + if (op.TrySetCanceled(cancellationToken)) + { + ChannelUtilities.UnsafeQueueUserWorkItem(static state => // escape cancellation callback + { + lock (state.Key.SyncObj) + { + switch (state.Value) + { + case BlockedReadAsyncOperation blockedReader: + ChannelUtilities.Remove(ref state.Key._blockedReadersHead, blockedReader); + break; + + case BlockedWriteAsyncOperation blockedWriter: + ChannelUtilities.Remove(ref state.Key._blockedWritersHead, blockedWriter); + break; + + case WaitingReadAsyncOperation waitingReader: + ChannelUtilities.Remove(ref state.Key._waitingReadersHead, waitingReader); + break; + + case WaitingWriteAsyncOperation waitingWriter: + ChannelUtilities.Remove(ref state.Key._waitingWritersHead, waitingWriter); + break; + + default: + Debug.Fail($"Unexpected operation: {state.Value}"); + break; + } + } + }, new KeyValuePair, AsyncOperation>(this, op)); + } + }; + + private string DebuggerDisplay => + $"{((RendezvousChannelReader)Reader).DebuggerDisplay}, {((RendezvousChannelWriter)Writer).DebuggerDisplay}"; + + [Conditional("DEBUG")] + private void AssertInvariants() + { + Debug.Assert(SyncObj is not null, "The sync obj must not be null."); + Debug.Assert(Monitor.IsEntered(SyncObj), "Invariants can only be validated while holding the lock."); + + if (_blockedReadersHead is not null) + { + Debug.Assert(_blockedWritersHead is null, "There shouldn't be any blocked writer if there's a blocked reader."); + Debug.Assert(_waitingWritersHead is null, "There shouldn't be any waiting writers if there's a blocked reader."); + } + + if (_blockedWritersHead is not null) + { + Debug.Assert(_blockedReadersHead is null, "There shouldn't be any blocked readers if there's a blocked writer."); + Debug.Assert(_waitingReadersHead is null, "There shouldn't be any waiting readers if there's a blocked writer."); + } + + if (_completion.Task.IsCompleted) + { + Debug.Assert(_doneWriting is not null, "We can only complete if we're done writing."); + } + } + } +} diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs index c863f619c041fc..969249384ac1de 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs @@ -251,8 +251,6 @@ public override bool TryComplete(Exception? error) ChannelUtilities.Complete(parent._completion, error); } - Debug.Assert(blockedReader is null || waitingReader is null, "There should only ever be at most one reader."); - // Complete a blocked reader if necessary if (blockedReader is not null) { diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs index 6db2a3e5f8f1cc..e294f18a60f1dd 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs @@ -97,19 +97,10 @@ public override ValueTask ReadAsync(CancellationToken cancellationToken) return ChannelUtilities.GetInvalidCompletionValueTask(parent._doneWriting); } - // If we're able to use the singleton reader, do so. - if (!cancellationToken.CanBeCanceled) - { - BlockedReadAsyncOperation singleton = _readerSingleton; - if (singleton.TryOwnAndReset()) - { - ChannelUtilities.Enqueue(ref parent._blockedReadersHead, singleton); - return singleton.ValueTaskOfT; - } - } - - // Otherwise, create and queue a reader. - var reader = new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + // If we're able to use the singleton reader, do so. Otherwise, create a new reader. + BlockedReadAsyncOperation reader = + !cancellationToken.CanBeCanceled && _readerSingleton.TryOwnAndReset() ? _readerSingleton : + new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); ChannelUtilities.Enqueue(ref parent._blockedReadersHead, reader); return reader.ValueTaskOfT; } @@ -174,19 +165,10 @@ public override ValueTask WaitToReadAsync(CancellationToken cancellationTo default; } - // If we're able to use the singleton waiter, do so. - if (!cancellationToken.CanBeCanceled) - { - WaitingReadAsyncOperation singleton = _waiterSingleton; - if (singleton.TryOwnAndReset()) - { - ChannelUtilities.Enqueue(ref parent._waitingReadersHead, singleton); - return singleton.ValueTaskOfT; - } - } - - // Otherwise, create and queue a waiter. - var waiter = new WaitingReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + // If we're able to use the singleton waiter, do so. Otherwise, create a new waiter. + WaitingReadAsyncOperation waiter = + !cancellationToken.CanBeCanceled && _waiterSingleton.TryOwnAndReset() ? _waiterSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); ChannelUtilities.Enqueue(ref parent._waitingReadersHead, waiter); return waiter.ValueTaskOfT; } diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs index ead76810b0b62b..68991ae6a620df 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs @@ -98,19 +98,10 @@ public override ValueTask ReadAsync(CancellationToken cancellationToken) return ChannelUtilities.GetInvalidCompletionValueTask(parent._doneWriting); } - // If we're able to use the singleton reader, do so. - if (!cancellationToken.CanBeCanceled) - { - BlockedReadAsyncOperation singleton = _readerSingleton; - if (singleton.TryOwnAndReset()) - { - ChannelUtilities.Enqueue(ref parent._blockedReadersHead, singleton); - return singleton.ValueTaskOfT; - } - } - - // Otherwise, create and queue a reader. - var reader = new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + // If we're able to use the singleton reader, do so. Otherwise, create a new reader. + BlockedReadAsyncOperation reader = + !cancellationToken.CanBeCanceled && _readerSingleton.TryOwnAndReset() ? _readerSingleton : + new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); ChannelUtilities.Enqueue(ref parent._blockedReadersHead, reader); return reader.ValueTaskOfT; } @@ -179,19 +170,10 @@ public override ValueTask WaitToReadAsync(CancellationToken cancellationTo default; } - // If we're able to use the singleton waiter, do so. - if (!cancellationToken.CanBeCanceled) - { - WaitingReadAsyncOperation singleton = _waiterSingleton; - if (singleton.TryOwnAndReset()) - { - ChannelUtilities.Enqueue(ref parent._waitingReadersHead, singleton); - return singleton.ValueTaskOfT; - } - } - - // Otherwise, create and queue a waiter. - var waiter = new WaitingReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + // If we're able to use the singleton waiter, do so. Otherwise, create a new waiter. + WaitingReadAsyncOperation waiter = + !cancellationToken.CanBeCanceled && _waiterSingleton.TryOwnAndReset() ? _waiterSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); ChannelUtilities.Enqueue(ref parent._waitingReadersHead, waiter); return waiter.ValueTaskOfT; } diff --git a/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs b/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs index cefa3fe6f29901..2799d6509d1a70 100644 --- a/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs +++ b/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs @@ -22,23 +22,17 @@ public abstract partial class ChannelTestBase : TestBase protected virtual bool RequiresSingleReader => false; protected virtual bool RequiresSingleWriter => false; protected virtual bool BuffersItems => true; - - public static IEnumerable ThreeBools => - from b1 in new[] { false, true } - from b2 in new[] { false, true } - from b3 in new[] { false, true } - select new object[] { b1, b2, b3 }; + protected virtual bool HasDebuggerTypeProxy => true; [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsDebuggerTypeProxyAttributeSupported))] public void ValidateDebuggerAttributes() { Channel c = CreateChannel(); - for (int i = 1; i <= 10; i++) + DebuggerAttributes.ValidateDebuggerDisplayReferences(c); + if (HasDebuggerTypeProxy) { - c.Writer.WriteAsync(i); + DebuggerAttributes.InvokeDebuggerTypeProxyProperties(c); } - DebuggerAttributes.ValidateDebuggerDisplayReferences(c); - DebuggerAttributes.InvokeDebuggerTypeProxyProperties(c); } [Fact] @@ -112,7 +106,7 @@ public void TryComplete_Twice_ReturnsTrueThenFalse() } [Fact] - public async Task TryComplete_ErrorsPropage() + public async Task TryComplete_ErrorsPropagate() { Channel c; @@ -142,6 +136,10 @@ public void Count_ThrowsIfUnsupported() { Assert.Throws(() => c.Reader.Count); } + else + { + Assert.InRange(c.Reader.Count, 0, int.MaxValue); + } } [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] @@ -226,17 +224,13 @@ public void ManyProducerConsumer_ConcurrentReadWrite_Success(int numReaders, int { tasks[i] = Task.Run(async () => { - try + while (await c.Reader.WaitToReadAsync()) { - while (await c.Reader.WaitToReadAsync()) + while (c.Reader.TryRead(out int value)) { - if (c.Reader.TryRead(out int value)) - { - Interlocked.Add(ref readTotal, value); - } + Interlocked.Add(ref readTotal, value); } } - catch (ChannelClosedException) { } }); } @@ -352,13 +346,13 @@ public void WaitToWriteAsync_BlockedReader_ReturnsTrue() } [Fact] - public void TryRead_DataAvailable_Success() + public async Task TryRead_DataAvailable_Success() { Channel c = CreateChannel(); ValueTask write = c.Writer.WriteAsync(42); - Assert.True(write.IsCompletedSuccessfully); Assert.True(c.Reader.TryRead(out int result)); Assert.Equal(42, result); + await write; } [Fact] @@ -370,7 +364,7 @@ public void TryRead_AfterComplete_ReturnsFalse() } [Fact] - public void TryPeek_SucceedsWhenDataAvailable() + public async Task TryPeek_SucceedsWhenDataAvailable() { Channel c = CreateChannel(); @@ -379,7 +373,7 @@ public void TryPeek_SucceedsWhenDataAvailable() for (int i = 0; i < 3; i++) { // Write a value - Assert.True(c.Writer.WriteAsync(42).IsCompletedSuccessfully); + ValueTask write = c.Writer.WriteAsync(42); // Can peek at the written value Assert.True(c.Reader.TryPeek(out int peekedResult)); @@ -389,11 +383,18 @@ public void TryPeek_SucceedsWhenDataAvailable() Assert.True(c.Reader.TryRead(out int readResult)); Assert.Equal(42, readResult); + await write; + // Peeking no longer finds it Assert.False(c.Reader.TryPeek(out int noResult)); Assert.Equal(0, noResult); } + if (!BuffersItems) + { + return; + } + // Write another value Assert.True(c.Writer.WriteAsync(84).IsCompletedSuccessfully); @@ -544,7 +545,7 @@ public void Precancellation_Writing_ReturnsImmediately() Assert.True(writeTask.IsCanceled); ValueTask waitTask = c.Writer.WaitToWriteAsync(new CancellationToken(true)); - Assert.True(writeTask.IsCanceled); + Assert.True(waitTask.IsCanceled); } [Fact] @@ -558,16 +559,23 @@ public void Write_WaitToReadAsync_CompletesSynchronously() [Theory] [InlineData(false)] [InlineData(true)] - public void Precancellation_WaitToReadAsync_ReturnsImmediately(bool dataAvailable) + public async Task Precancellation_WaitToReadAsync_ReturnsImmediately(bool dataAvailable) { Channel c = CreateChannel(); + + ValueTask write = default; if (dataAvailable) { - Assert.True(c.Writer.TryWrite(42)); + write = c.Writer.WriteAsync(42); } ValueTask waitTask = c.Reader.WaitToReadAsync(new CancellationToken(true)); Assert.True(waitTask.IsCanceled); + + if (BuffersItems) + { + await write; + } } [Theory] @@ -634,7 +642,7 @@ public void Precancellation_ReadAsync_ReturnsImmediately(bool dataAvailable) Channel c = CreateChannel(); if (dataAvailable) { - Assert.True(c.Writer.TryWrite(42)); + c.Writer.WriteAsync(42); } ValueTask readTask = c.Reader.ReadAsync(new CancellationToken(true)); @@ -655,10 +663,9 @@ public async Task ReadAsync_Canceled_CanceledAsynchronously() await AssertExtensions.CanceledAsync(cts.Token, async () => await r); - if (c.Writer.TryWrite(42)) - { - Assert.Equal(42, await c.Reader.ReadAsync()); - } + ValueTask vt = c.Writer.WriteAsync(42); + Assert.Equal(42, await c.Reader.ReadAsync()); + await vt; } [Fact] @@ -823,7 +830,7 @@ public async Task ReadAllAsync_UseMoveNextAsyncAfterCompleted_ReturnsFalse(bool } [Fact] - public void ReadAllAsync_AvailableDataCompletesSynchronously() + public async Task ReadAllAsync_AvailableDataCompletesSynchronously() { Channel c = CreateChannel(); @@ -832,11 +839,12 @@ public void ReadAllAsync_AvailableDataCompletesSynchronously() { for (int i = 100; i < 110; i++) { - Assert.True(c.Writer.TryWrite(i)); + ValueTask write = c.Writer.WriteAsync(i); ValueTask vt = e.MoveNextAsync(); Assert.True(vt.IsCompletedSuccessfully); Assert.True(vt.Result); Assert.Equal(i, e.Current); + await write; } } finally @@ -859,7 +867,7 @@ public async Task ReadAllAsync_UnavailableDataCompletesAsynchronously() { ValueTask vt = e.MoveNextAsync(); Assert.False(vt.IsCompleted); - Task producer = Task.Run(() => c.Writer.TryWrite(i)); + Task producer = Task.Run(() => c.Writer.WriteAsync(i).AsTask()); Assert.True(await vt); await producer; Assert.Equal(i, e.Current); @@ -916,8 +924,7 @@ public async Task ReadAllAsync_MultipleEnumerationsToEnd() { Channel c = CreateChannel(); - Assert.True(c.Writer.TryWrite(42)); - c.Writer.Complete(); + ValueTask write = c.Writer.WriteAsync(42); IAsyncEnumerable enumerable = c.Reader.ReadAllAsync(); IAsyncEnumerator e = enumerable.GetAsyncEnumerator(); @@ -925,6 +932,9 @@ public async Task ReadAllAsync_MultipleEnumerationsToEnd() Assert.True(await e.MoveNextAsync()); Assert.Equal(42, e.Current); + await write; + c.Writer.Complete(); + Assert.False(await e.MoveNextAsync()); Assert.False(await e.MoveNextAsync()); @@ -942,16 +952,17 @@ public async Task ReadAllAsync_MultipleEnumerationsToEnd() [InlineData(false, true)] [InlineData(true, false)] [InlineData(true, true)] - public void ReadAllAsync_MultipleSingleElementEnumerations_AllItemsEnumerated(bool sameEnumerable, bool dispose) + public async Task ReadAllAsync_MultipleSingleElementEnumerations_AllItemsEnumerated(bool sameEnumerable, bool dispose) { Channel c = CreateChannel(); IAsyncEnumerable enumerable = c.Reader.ReadAllAsync(); for (int i = 0; i < 10; i++) { - Assert.True(c.Writer.TryWrite(i)); + ValueTask write = c.Writer.WriteAsync(i); IAsyncEnumerator e = (sameEnumerable ? enumerable : c.Reader.ReadAllAsync()).GetAsyncEnumerator(); ValueTask vt = e.MoveNextAsync(); + await write; Assert.True(vt.IsCompletedSuccessfully); Assert.True(vt.Result); Assert.Equal(i, e.Current); @@ -1015,9 +1026,10 @@ public async Task ReadAllAsync_DualConcurrentEnumeration_AllItemsEnumerated(bool public async Task ReadAllAsync_CanceledBeforeMoveNextAsync_Throws(bool dataAvailable) { Channel c = CreateChannel(); + if (dataAvailable) { - Assert.True(c.Writer.TryWrite(42)); + _ = c.Writer.WriteAsync(42); } var cts = new CancellationTokenSource(); @@ -1057,10 +1069,11 @@ public async Task WaitToReadAsync_ConsecutiveReadsSucceed() for (int i = 0; i < 5; i++) { ValueTask r = c.Reader.WaitToReadAsync(); - await c.Writer.WriteAsync(i); + ValueTask write = c.Writer.WriteAsync(i); Assert.True(await r); Assert.True(c.Reader.TryRead(out int item)); Assert.Equal(i, item); + await write; } } @@ -1155,7 +1168,8 @@ public async Task WaitToReadAsync_AwaitThenGetResult_Throws() Channel c = CreateChannel(); ValueTask read = c.Reader.WaitToReadAsync(); - Assert.True(c.Writer.TryWrite(42)); + + ValueTask write = c.Writer.WriteAsync(42); Assert.True(await read); Assert.Throws(() => read.GetAwaiter().IsCompleted); Assert.Throws(() => read.GetAwaiter().OnCompleted(() => { })); @@ -1169,8 +1183,9 @@ public async Task ReadAsync_AwaitThenGetResult_Throws() Channel c = CreateChannel(); ValueTask read = c.Reader.ReadAsync(); - Assert.True(c.Writer.TryWrite(42)); + ValueTask write = c.Writer.WriteAsync(42); Assert.Equal(42, await read); + await write; Assert.Throws(() => read.GetAwaiter().IsCompleted); Assert.Throws(() => read.GetAwaiter().OnCompleted(() => { })); Assert.Throws(() => read.GetAwaiter().GetResult()); @@ -1186,7 +1201,7 @@ public async Task WaitToWriteAsync_AwaitThenGetResult_Throws() } ValueTask write = c.Writer.WaitToWriteAsync(); - await c.Reader.ReadAsync(); + ValueTask read = c.Reader.ReadAsync(); Assert.True(await write); Assert.Throws(() => write.GetAwaiter().IsCompleted); Assert.Throws(() => write.GetAwaiter().OnCompleted(() => { })); @@ -1344,7 +1359,7 @@ await Task.Factory.StartNew(async () => { Assert.False(vt.IsCompleted); Assert.False(vt.IsCompletedSuccessfully); - c.Writer.TryWrite(true); + _ = c.Writer.WriteAsync(true); } SynchronizationContext.SetSynchronizationContext(new CustomSynchronizationContext()); @@ -1379,7 +1394,7 @@ await Task.Factory.StartNew(async () => { Assert.False(vt.IsCompleted); Assert.False(vt.IsCompletedSuccessfully); - c.Writer.TryWrite(true); + _ = c.Writer.WriteAsync(true); } await continuationRan.Task; @@ -1440,7 +1455,7 @@ await Task.Run(async () => { Assert.False(vt.IsCompleted); Assert.False(vt.IsCompletedSuccessfully); - c.Writer.TryWrite(true); + _ = c.Writer.WriteAsync(true); } await Task.Factory.StartNew(() => @@ -1482,7 +1497,7 @@ await Task.Factory.StartNew(() => { Assert.False(vt.IsCompleted); Assert.False(vt.IsCompletedSuccessfully); - c.Writer.TryWrite(true); + _ = c.Writer.WriteAsync(true); } await continuationRan.Task; diff --git a/src/libraries/System.Threading.Channels/tests/ChannelTests.cs b/src/libraries/System.Threading.Channels/tests/ChannelTests.cs index 2765e66d4f346a..f42096f7890dae 100644 --- a/src/libraries/System.Threading.Channels/tests/ChannelTests.cs +++ b/src/libraries/System.Threading.Channels/tests/ChannelTests.cs @@ -62,7 +62,7 @@ public void Create_NullOptions_ThrowsArgumentException() } [Theory] - [InlineData(0)] + [InlineData(-1)] [InlineData(-2)] public void CreateBounded_InvalidBufferSizes_ThrowArgumentExceptions(int capacity) { @@ -77,7 +77,7 @@ public void BoundedChannelOptions_InvalidModes_ThrowArgumentExceptions(BoundedCh AssertExtensions.Throws("value", () => new BoundedChannelOptions(1) { FullMode = mode }); [Theory] - [InlineData(0)] + [InlineData(-1)] [InlineData(-2)] public void BoundedChannelOptions_InvalidCapacity_ThrowArgumentExceptions(int capacity) => AssertExtensions.Throws("value", () => new BoundedChannelOptions(1) { Capacity = capacity }); diff --git a/src/libraries/System.Threading.Channels/tests/RendezvousChannelTests.cs b/src/libraries/System.Threading.Channels/tests/RendezvousChannelTests.cs new file mode 100644 index 00000000000000..5de9153ca120a6 --- /dev/null +++ b/src/libraries/System.Threading.Channels/tests/RendezvousChannelTests.cs @@ -0,0 +1,317 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; +using Microsoft.DotNet.XUnitExtensions; +using Xunit; + +namespace System.Threading.Channels.Tests +{ + public class RendezvousChannelTests : ChannelTestBase + { + protected override Channel CreateChannel() => Channel.CreateBounded(0); + + protected override Channel CreateFullChannel() => CreateChannel(); + + protected override bool BuffersItems => false; + + protected override bool HasDebuggerTypeProxy => false; + + [Fact] + public async Task Count_AlwaysZero() + { + Channel c = CreateChannel(); + + Assert.True(c.Reader.CanCount); + Assert.Equal(0, c.Reader.Count); + + var write1 = c.Writer.WriteAsync(1); + var write2 = c.Writer.WriteAsync(2); + + Assert.Equal(0, c.Reader.Count); + Assert.False(write1.IsCompleted); + Assert.False(write2.IsCompleted); + + Assert.Equal(1, await c.Reader.ReadAsync()); + + await write1; + Assert.Equal(0, c.Reader.Count); + Assert.False(write2.IsCompleted); + + Assert.Equal(2, await c.Reader.ReadAsync()); + + await write2; + Assert.Equal(0, c.Reader.Count); + } + + [Fact] + public void TryWrite_TryRead_NoPairing_ReturnsFalse() + { + Channel c = CreateChannel(); + + for (int i = 0; i < 3; i++) + { + Assert.False(c.Writer.TryWrite(42)); + Assert.False(c.Reader.TryRead(out int item)); + Assert.Equal(0, item); + } + } + + [Theory] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public async Task TryWrite_DropXx_DropsWrite(BoundedChannelFullMode mode) + { + int? dropped = null; + Channel c = Channel.CreateBounded(new BoundedChannelOptions(0) { FullMode = mode }, item => dropped = item); + + for (int i = 42; i < 52; i++) + { + var waiter = c.Writer.WaitToWriteAsync(); + AssertSynchronousSuccess(waiter); + Assert.True(await waiter); + + Assert.True(c.Writer.TryWrite(i)); + Assert.Equal(i, dropped); + + dropped = null; + AssertSynchronousSuccess(c.Writer.WriteAsync(i)); + Assert.Equal(i, dropped); + } + } + + [Theory] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public async Task TryWrite_DropXx_ReaderTakesPriority(BoundedChannelFullMode mode) + { + int? dropped = null; + Channel c = Channel.CreateBounded(new BoundedChannelOptions(0) { FullMode = mode }, item => dropped = item); + + for (int i = 42; i < 52; i++) + { + ValueTask reader; + + reader = c.Reader.ReadAsync(); + Assert.True(c.Writer.TryWrite(i)); + Assert.Null(dropped); + Assert.Equal(i, await reader); + + reader = c.Reader.ReadAsync(); + AssertSynchronousSuccess(c.Writer.WriteAsync(i)); + Assert.Null(dropped); + Assert.Equal(i, await reader); + } + } + + [Fact] + public async Task DroppedDelegateNotCalledOnWaitMode_SyncWrites() + { + bool dropDelegateCalled = false; + + Channel c = Channel.CreateBounded(new BoundedChannelOptions(0) { FullMode = BoundedChannelFullMode.Wait }, + item => + { + dropDelegateCalled = true; + }); + + ValueTask reader; + + reader = c.Reader.ReadAsync(); + Assert.True(c.Writer.TryWrite(42)); + Assert.Equal(42, await reader); + + reader = c.Reader.ReadAsync(); + AssertSynchronousSuccess(c.Writer.WriteAsync(43)); + Assert.Equal(43, await reader); + + _ = c.Writer.WriteAsync(44); + + Assert.False(dropDelegateCalled); + } + + [Theory] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public void DroppedDelegateIsNull_SyncAndAsyncWrites(BoundedChannelFullMode boundedChannelFullMode) + { + Channel c = Channel.CreateBounded(new BoundedChannelOptions(0) { FullMode = boundedChannelFullMode }, itemDropped: null); + + Assert.True(c.Writer.TryWrite(1)); + AssertSynchronousSuccess(c.Writer.WriteAsync(2)); + + Assert.False(c.Reader.TryRead(out int item)); + Assert.Equal(0, item); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public void DroppedDelegateCalledAfterLockReleased_SyncWrites(BoundedChannelFullMode boundedChannelFullMode) + { + Channel c = null; + bool dropDelegateCalled = false; + + c = Channel.CreateBounded(new BoundedChannelOptions(0) + { + FullMode = boundedChannelFullMode + }, (droppedItem) => + { + if (dropDelegateCalled) + { + return; + } + dropDelegateCalled = true; + + // Dropped delegate should not be called while holding the channel lock. + // Verify this by trying to write into the channel from different thread. + // If lock is held during callback, this should effectively cause deadlock. + ManualResetEventSlim mres = new(); + ThreadPool.QueueUserWorkItem(delegate + { + c.Writer.TryWrite(3); + mres.Set(); + }); + + mres.Wait(); + }); + + Assert.True(c.Writer.TryWrite(1)); + Assert.True(c.Writer.TryWrite(2)); + + Assert.True(dropDelegateCalled); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public async Task DroppedDelegateCalledAfterLockReleased_AsyncWrites(BoundedChannelFullMode boundedChannelFullMode) + { + Channel c = null; + bool dropDelegateCalled = false; + + c = Channel.CreateBounded(new BoundedChannelOptions(0) + { + FullMode = boundedChannelFullMode + }, (droppedItem) => + { + if (dropDelegateCalled) + { + return; + } + dropDelegateCalled = true; + + // Dropped delegate should not be called while holding the channel synchronisation lock. + // Verify this by trying to write into the channel from different thread. + // If lock is held during callback, this should effectively cause deadlock. + var mres = new ManualResetEventSlim(); + ThreadPool.QueueUserWorkItem(delegate + { + c.Writer.TryWrite(11); + mres.Set(); + }); + + mres.Wait(); + }); + + await c.Writer.WriteAsync(1); + await c.Writer.WriteAsync(2); + + Assert.True(dropDelegateCalled); + } + + [Fact] + public async Task CancelPendingWrite_Reading_DataTransferredFromCorrectWriter() + { + Channel c = CreateChannel(); + + CancellationTokenSource cts = new(); + + ValueTask write1 = c.Writer.WriteAsync(42); + ValueTask write2 = c.Writer.WriteAsync(43, cts.Token); + ValueTask write3 = c.Writer.WriteAsync(44); + + cts.Cancel(); + + Assert.Equal(42, await c.Reader.ReadAsync()); + Assert.Equal(44, await c.Reader.ReadAsync()); + + await write1; + await AssertExtensions.CanceledAsync(cts.Token, async () => await write2); + await write3; + } + + [Fact] + public async Task WaitToWriteAsync_AfterRead_ReturnsTrue() + { + Channel c = CreateChannel(); + + ValueTask write1 = c.Writer.WaitToWriteAsync(); + ValueTask write2 = c.Writer.WaitToWriteAsync(); + Assert.False(write1.IsCompleted); + Assert.False(write2.IsCompleted); + + _ = c.Reader.ReadAsync(); + + Assert.True(await write1); + Assert.True(await write2); + + ValueTask write3 = c.Writer.WaitToWriteAsync(); + AssertSynchronousSuccess(write3); + Assert.True(await write3); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [MemberData(nameof(ThreeBools))] + public void AllowSynchronousContinuations_Reading_ContinuationsInvokedAccordingToSetting(bool allowSynchronousContinuations, bool cancelable, bool waitToReadAsync) + { + var c = Channel.CreateBounded(new BoundedChannelOptions(0) { AllowSynchronousContinuations = allowSynchronousContinuations }); + + CancellationToken ct = cancelable ? new CancellationTokenSource().Token : CancellationToken.None; + + int expectedId = Environment.CurrentManagedThreadId; + Task t = waitToReadAsync ? c.Reader.WaitToReadAsync(ct).AsTask() : c.Reader.ReadAsync(ct).AsTask(); + Task r = t.ContinueWith(_ => + { + Assert.Equal(allowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + + ValueTask write = c.Writer.WriteAsync(42); + if (!waitToReadAsync) + { + AssertSynchronousSuccess(write); + } + + ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation + r.GetAwaiter().GetResult(); + } + + [ConditionalTheory] + [InlineData(false)] + [InlineData(true)] + public void AllowSynchronousContinuations_CompletionTask_ContinuationsInvokedAccordingToSetting(bool allowSynchronousContinuations) + { + if (!allowSynchronousContinuations && !PlatformDetection.IsThreadingSupported) + { + throw new SkipTestException(nameof(PlatformDetection.IsThreadingSupported)); + } + + var c = Channel.CreateBounded(new BoundedChannelOptions(0) { AllowSynchronousContinuations = allowSynchronousContinuations }); + + int expectedId = Environment.CurrentManagedThreadId; + Task r = c.Reader.Completion.ContinueWith(_ => + { + Assert.Equal(allowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + + Assert.True(c.Writer.TryComplete()); + ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation + r.GetAwaiter().GetResult(); + } + } +} diff --git a/src/libraries/System.Threading.Channels/tests/Stress.cs b/src/libraries/System.Threading.Channels/tests/Stress.cs index ffae1e8ce88fbc..a876b13a931435 100644 --- a/src/libraries/System.Threading.Channels/tests/Stress.cs +++ b/src/libraries/System.Threading.Channels/tests/Stress.cs @@ -12,6 +12,7 @@ public class StressTests { public static IEnumerable ReadWriteVariations_TestData() { + // Unbounded foreach (var readDelegate in new Func, Task>[] { ReadSynchronous, ReadAsynchronous, ReadSyncAndAsync} ) foreach (var writeDelegate in new Func, int, Task>[] { WriteSynchronous, WriteAsynchronous, WriteSyncAndAsync} ) foreach (bool singleReader in new [] {false, true}) @@ -19,15 +20,21 @@ public static IEnumerable ReadWriteVariations_TestData() foreach (bool allowSynchronousContinuations in new [] {false, true}) { Func> unbounded = o => Channel.CreateUnbounded((UnboundedChannelOptions)o); - yield return new object[] { unbounded, new UnboundedChannelOptions - { - SingleReader = singleReader, - SingleWriter = singleWriter, - AllowSynchronousContinuations = allowSynchronousContinuations - }, readDelegate, writeDelegate + yield return new object[] + { + unbounded, + new UnboundedChannelOptions + { + SingleReader = singleReader, + SingleWriter = singleWriter, + AllowSynchronousContinuations = allowSynchronousContinuations + }, + readDelegate, + writeDelegate }; } + // Bounded foreach (var readDelegate in new Func, Task>[] { ReadSynchronous, ReadAsynchronous, ReadSyncAndAsync} ) foreach (var writeDelegate in new Func, int, Task>[] { WriteSynchronous, WriteAsynchronous, WriteSyncAndAsync} ) foreach (BoundedChannelFullMode bco in Enum.GetValues(typeof(BoundedChannelFullMode))) @@ -37,13 +44,44 @@ public static IEnumerable ReadWriteVariations_TestData() foreach (bool allowSynchronousContinuations in new [] {false, true}) { Func> bounded = o => Channel.CreateBounded((BoundedChannelOptions)o); - yield return new object[] { bounded, new BoundedChannelOptions(capacity) - { - SingleReader = singleReader, - SingleWriter = singleWriter, - AllowSynchronousContinuations = allowSynchronousContinuations, - FullMode = bco - }, readDelegate, writeDelegate + yield return new object[] + { + bounded, + new BoundedChannelOptions(capacity) + { + SingleReader = singleReader, + SingleWriter = singleWriter, + AllowSynchronousContinuations = allowSynchronousContinuations, + FullMode = bco + }, + readDelegate, + writeDelegate + }; + } + + // Rendezvous + foreach (var readDelegate in new Func, Task>[] { ReadSynchronous, ReadAsynchronous, ReadSyncAndAsync, ReadAsynchronousNoWait } ) + foreach (var writeDelegate in new Func, int, Task>[] { WriteSynchronous, WriteAsynchronous, WriteSyncAndAsync, WriteAsynchronousNoWait } ) + foreach (var bcfm in new[] { BoundedChannelFullMode.Wait, BoundedChannelFullMode.DropWrite }) + foreach (bool allowSynchronousContinuations in new [] {false, true}) + { + if (readDelegate != ReadAsynchronousNoWait && + writeDelegate != WriteAsynchronousNoWait) + { + // At least one side must be persistent. + continue; + } + + yield return new object[] + { + (Func>)(o => Channel.CreateBounded((BoundedChannelOptions)o)), + new BoundedChannelOptions(0) + { + AllowSynchronousContinuations = allowSynchronousContinuations, + FullMode = bcfm + }, + readDelegate, + writeDelegate }; } } @@ -74,9 +112,21 @@ private static async Task ReadAsynchronous(ChannelReader reader) return false; } + private static async Task ReadAsynchronousNoWait(ChannelReader reader) + { + await reader.ReadAsync(); + return true; + } + + private static async Task WriteAsynchronousNoWait(ChannelWriter writer, int value) + { + await writer.WriteAsync(value); + return true; + } + private static async Task ReadSyncAndAsync(ChannelReader reader) { - if (!reader.TryRead(out int value)) + if (!reader.TryRead(out _)) { if (await reader.WaitToReadAsync()) { diff --git a/src/libraries/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj b/src/libraries/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj index 55ef0e463ca130..6792d7d47833a4 100644 --- a/src/libraries/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj +++ b/src/libraries/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj @@ -1,4 +1,4 @@ - + $(NetCoreAppCurrent);$(NetFrameworkCurrent) @@ -11,6 +11,7 @@ + diff --git a/src/libraries/System.Threading.Channels/tests/TestBase.cs b/src/libraries/System.Threading.Channels/tests/TestBase.cs index 17873445b2d7c2..b2058f644a770b 100644 --- a/src/libraries/System.Threading.Channels/tests/TestBase.cs +++ b/src/libraries/System.Threading.Channels/tests/TestBase.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using Xunit; @@ -10,6 +12,12 @@ namespace System.Threading.Channels.Tests { public abstract class TestBase { + public static IEnumerable ThreeBools => + from b1 in new[] { false, true } + from b2 in new[] { false, true } + from b3 in new[] { false, true } + select new object[] { b1, b2, b3 }; + protected void AssertSynchronouslyCanceled(Task task, CancellationToken token) { Assert.Equal(TaskStatus.Canceled, task.Status); From 62741bfca0320ac7d219197a78ddf9244d17ed5d Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 2 Jun 2025 18:10:06 -0400 Subject: [PATCH 2/2] Address PR feedback --- .../Threading/Channels/BoundedChannel.cs | 16 ++-------- .../Threading/Channels/ChannelUtilities.cs | 17 ++++++++++ .../Threading/Channels/RendezvousChannel.cs | 32 +++---------------- .../Threading/Channels/UnboundedChannel.cs | 8 +---- .../Channels/UnboundedPriorityChannel.cs | 8 +---- 5 files changed, 25 insertions(+), 56 deletions(-) diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs index 6508b2d06505c0..ed2d2158bb7bee 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs @@ -388,13 +388,7 @@ public override bool TryWrite(T item) // There are no items in the channel, which means we may have blocked/waiting readers. // Try to get a blocked reader that we can transfer the item to. - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); // If we weren't able to get a reader, instead queue the item and get any waiters that need to be notified. if (blockedReader is null) @@ -551,13 +545,7 @@ public override ValueTask WriteAsync(T item, CancellationToken cancellationToken // There are no items in the channel, which means we may have blocked/waiting readers. // Try to get a blocked reader that we can transfer the item to. - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); // If we weren't able to get a reader, instead queue the item and get any waiters that need to be notified. if (blockedReader is null) diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs index 0e7f2742d1a4c9..8bfd1889482622 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs @@ -66,6 +66,23 @@ internal static ValueTask GetInvalidCompletionValueTask(Exception error) return new ValueTask(t); } + /// Dequeues from until an element is dequeued that can have completion reserved. + /// The head of the list, with items dequeued up through the returned element, or entirely if is returned. + /// The operation on which completion has been reserved, or null if none can be found. + internal static TAsyncOp? TryDequeueAndReserveCompletionIfCancelable(ref TAsyncOp? head) + where TAsyncOp : AsyncOperation + { + while (ChannelUtilities.TryDequeue(ref head, out var op)) + { + if (op.TryReserveCompletionIfCancelable()) + { + return op; + } + } + + return null; + } + /// Dequeues an operation from the circular doubly-linked list referenced by . /// The head of the list. /// The dequeued operation. diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs index 8dc641b8e4cb1b..7861298fb7263a 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs @@ -90,13 +90,7 @@ public override bool TryRead([MaybeNullWhen(false)] out T item) if (parent._doneWriting is null) { - while (ChannelUtilities.TryDequeue(ref parent._blockedWritersHead, out blockedWriter)) - { - if (blockedWriter.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedWriter = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedWritersHead); } } @@ -156,13 +150,7 @@ public override ValueTask ReadAsync(CancellationToken cancellationToken) } // Reserve a blocked writer if one is available. - while (ChannelUtilities.TryDequeue(ref parent._blockedWritersHead, out blockedWriter)) - { - if (blockedWriter.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedWriter = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedWritersHead); // If we couldn't get one, create a waiting reader, and reserve any waiting writers to alert. if (blockedWriter is null) @@ -324,13 +312,7 @@ public override bool TryWrite(T item) if (parent._doneWriting is null) { - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); } } @@ -410,13 +392,7 @@ public override ValueTask WriteAsync(T item, CancellationToken cancellationToken } // Reserve a blocked reader if one is available. - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); // If we couldn't get one, create a blocked writer, and reserve any waiting readers to alert. if (blockedReader is null && !parent._dropWrites) diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs index e294f18a60f1dd..f9bfb4127f4f93 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs @@ -251,13 +251,7 @@ public override bool TryWrite(T item) } // Try to get a blocked reader that we can transfer the item to. - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); // If we weren't able to get a reader, instead queue the item and get any waiters that need to be notified. if (blockedReader is null) diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs index 68991ae6a620df..738af8f8c596d5 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs @@ -257,13 +257,7 @@ public override bool TryWrite(T item) } // Try to get a blocked reader that we can transfer the item to. - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); // If we weren't able to get a reader, instead queue the item and get any waiters that need to be notified. if (blockedReader is null)