diff --git a/src/libraries/System.Threading.Channels/src/System.Threading.Channels.csproj b/src/libraries/System.Threading.Channels/src/System.Threading.Channels.csproj index 4523f5d8b55835..1d8a85a73e0259 100644 --- a/src/libraries/System.Threading.Channels/src/System.Threading.Channels.csproj +++ b/src/libraries/System.Threading.Channels/src/System.Threading.Channels.csproj @@ -25,6 +25,7 @@ System.Threading.Channel<T> + diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs index 6508b2d06505c0..ed2d2158bb7bee 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/BoundedChannel.cs @@ -388,13 +388,7 @@ public override bool TryWrite(T item) // There are no items in the channel, which means we may have blocked/waiting readers. // Try to get a blocked reader that we can transfer the item to. - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); // If we weren't able to get a reader, instead queue the item and get any waiters that need to be notified. if (blockedReader is null) @@ -551,13 +545,7 @@ public override ValueTask WriteAsync(T item, CancellationToken cancellationToken // There are no items in the channel, which means we may have blocked/waiting readers. // Try to get a blocked reader that we can transfer the item to. - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); // If we weren't able to get a reader, instead queue the item and get any waiters that need to be notified. if (blockedReader is null) diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.cs index 554a60d10724a3..1dbea485b3272b 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/Channel.cs @@ -15,6 +15,7 @@ public static Channel CreateUnbounded() => /// Specifies the type of data in the channel. /// Options that guide the behavior of the channel. /// The created channel. + /// is . public static Channel CreateUnbounded(UnboundedChannelOptions options) { ArgumentNullException.ThrowIfNull(options); @@ -35,35 +36,33 @@ public static Channel CreateUnbounded(UnboundedChannelOptions options) /// Channels created with this method apply the /// behavior and prohibit continuations from running synchronously. /// - public static Channel CreateBounded(int capacity) - { - if (capacity < 1) - { - throw new ArgumentOutOfRangeException(nameof(capacity)); - } - - return new BoundedChannel(capacity, BoundedChannelFullMode.Wait, runContinuationsAsynchronously: true, itemDropped: null); - } + /// is negative. + public static Channel CreateBounded(int capacity) => + capacity > 0 ? new BoundedChannel(capacity, BoundedChannelFullMode.Wait, runContinuationsAsynchronously: true, itemDropped: null) : + capacity == 0 ? new RendezvousChannel(BoundedChannelFullMode.Wait, runContinuationsAsynchronously: true, itemDropped: null) : + throw new ArgumentOutOfRangeException(nameof(capacity)); /// Creates a channel subject to the provided options. /// Specifies the type of data in the channel. /// Options that guide the behavior of the channel. /// The created channel. - public static Channel CreateBounded(BoundedChannelOptions options) - { - return CreateBounded(options, itemDropped: null); - } + /// is . + public static Channel CreateBounded(BoundedChannelOptions options) => + CreateBounded(options, itemDropped: null); /// Creates a channel subject to the provided options. /// Specifies the type of data in the channel. /// Options that guide the behavior of the channel. /// Delegate that will be called when item is being dropped from channel. See . /// The created channel. + /// is . public static Channel CreateBounded(BoundedChannelOptions options, Action? itemDropped) { ArgumentNullException.ThrowIfNull(options); - return new BoundedChannel(options.Capacity, options.FullMode, !options.AllowSynchronousContinuations, itemDropped); + return + options.Capacity > 0 ? new BoundedChannel(options.Capacity, options.FullMode, !options.AllowSynchronousContinuations, itemDropped) : + new RendezvousChannel(options.FullMode, !options.AllowSynchronousContinuations, itemDropped); } } } diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs index ab3a722bb909d2..85a2ff49e29983 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelOptions.cs @@ -49,9 +49,10 @@ public sealed class BoundedChannelOptions : ChannelOptions /// Initializes the options. /// The maximum number of items the bounded channel may store. + /// is negative. public BoundedChannelOptions(int capacity) { - if (capacity < 1) + if (capacity < 0) { throw new ArgumentOutOfRangeException(nameof(capacity)); } @@ -60,12 +61,13 @@ public BoundedChannelOptions(int capacity) } /// Gets or sets the maximum number of items the bounded channel may store. + /// is negative. public int Capacity { get => _capacity; set { - if (value < 1) + if (value < 0) { throw new ArgumentOutOfRangeException(nameof(value)); } @@ -74,6 +76,7 @@ public int Capacity } /// Gets or sets the behavior incurred by write operations when the channel is full. + /// is an invalid enum value. public BoundedChannelFullMode FullMode { get => _mode; diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs index 3436834636edb3..8bfd1889482622 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/ChannelUtilities.cs @@ -66,6 +66,23 @@ internal static ValueTask GetInvalidCompletionValueTask(Exception error) return new ValueTask(t); } + /// Dequeues from until an element is dequeued that can have completion reserved. + /// The head of the list, with items dequeued up through the returned element, or entirely if is returned. + /// The operation on which completion has been reserved, or null if none can be found. + internal static TAsyncOp? TryDequeueAndReserveCompletionIfCancelable(ref TAsyncOp? head) + where TAsyncOp : AsyncOperation + { + while (ChannelUtilities.TryDequeue(ref head, out var op)) + { + if (op.TryReserveCompletionIfCancelable()) + { + return op; + } + } + + return null; + } + /// Dequeues an operation from the circular doubly-linked list referenced by . /// The head of the list. /// The dequeued operation. @@ -317,6 +334,29 @@ internal static void AssertAll(TAsyncOp? head, Func co } } + /// Counts the number of operations in the list. + /// The head of the queue of operations to count. + internal static long CountOperations(TAsyncOp? head) + where TAsyncOp : AsyncOperation + { + TAsyncOp? current = head; + long count = 0; + + if (current is not null) + { + do + { + count++; + + Debug.Assert(current is not null); + current = current.Next; + } + while (current != head); + } + + return count; + } + /// Creates and returns an exception object to indicate that a channel has been closed. internal static Exception CreateInvalidCompletionException(Exception? inner = null) => inner is OperationCanceledException ? inner : diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs new file mode 100644 index 00000000000000..7861298fb7263a --- /dev/null +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/RendezvousChannel.cs @@ -0,0 +1,513 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + /// Provides an unbuffered channel where readers and writers must directly hand off to each other. + [DebuggerDisplay("{DebuggerDisplay,nq}")] + internal sealed class RendezvousChannel : Channel + { + /// Whether to suceed writes immediately even when there's no rendezvousing reader. + private readonly bool _dropWrites; + + /// The delegate that will be invoked when the channel hits its bound and an item is dropped from the channel. + private readonly Action? _itemDropped; + + /// Task signaled when the channel has completed. + private readonly TaskCompletionSource _completion; + + /// Head of linked list of blocked ReadAsync calls. + private BlockedReadAsyncOperation? _blockedReadersHead; + + /// Head of linked list of blocked WriteAsync calls. + private BlockedWriteAsyncOperation? _blockedWritersHead; + + /// Head of linked list of waiting WaitToReadAsync calls. + private WaitingReadAsyncOperation? _waitingReadersHead; + + /// Head of linked list of waiting WaitToWriteAsync calls. + private WaitingWriteAsyncOperation? _waitingWritersHead; + + /// Whether to force continuations to be executed asynchronously from producer writes. + private readonly bool _runContinuationsAsynchronously; + + /// Set to non-null once Complete has been called. + private Exception? _doneWriting; + + /// Initializes the . + /// The mode used when writing to a full channel. + /// Whether to force continuations to be executed asynchronously. + /// Delegate that will be invoked when an item is dropped from the channel. See . + internal RendezvousChannel(BoundedChannelFullMode mode, bool runContinuationsAsynchronously, Action? itemDropped) + { + _dropWrites = mode is not BoundedChannelFullMode.Wait; + + _runContinuationsAsynchronously = runContinuationsAsynchronously; + _itemDropped = itemDropped; + _completion = new TaskCompletionSource(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); + + Reader = new RendezvousChannelReader(this); + Writer = new RendezvousChannelWriter(this); + } + + [DebuggerDisplay("{DebuggerDisplay,nq}")] + private sealed class RendezvousChannelReader : ChannelReader + { + internal readonly RendezvousChannel _parent; + private readonly BlockedReadAsyncOperation _readerSingleton; + private readonly WaitingReadAsyncOperation _waiterSingleton; + + internal RendezvousChannelReader(RendezvousChannel parent) + { + _parent = parent; + _readerSingleton = new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, pooled: true); + _waiterSingleton = new WaitingReadAsyncOperation(parent._runContinuationsAsynchronously, pooled: true); + } + + public override Task Completion => _parent._completion.Task; + + public override bool CanCount => true; + + public override bool CanPeek => true; + + public override int Count => 0; + + public override bool TryRead([MaybeNullWhen(false)] out T item) + { + RendezvousChannel parent = _parent; + + // Reserve a blocked writer if one is available. + BlockedWriteAsyncOperation? blockedWriter = null; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + if (parent._doneWriting is null) + { + blockedWriter = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedWritersHead); + } + } + + // If we got one, transfer its item to the read and complete successfully. + if (blockedWriter is not null) + { + item = blockedWriter.Item!; + blockedWriter.DangerousSetResult(default); + return true; + } + + item = default; + return false; + } + + public override bool TryPeek([MaybeNullWhen(false)] out T item) + { + RendezvousChannel parent = _parent; + + // Peek at a blocked writer if one is available. + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + if (parent._doneWriting is null && + parent._blockedWritersHead is { } blockedWriter) + { + item = blockedWriter.Item!; + return true; + } + } + + item = default; + return false; + } + + public override ValueTask ReadAsync(CancellationToken cancellationToken) + { + RendezvousChannel parent = _parent; + + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + BlockedReadAsyncOperation? reader = null; + WaitingWriteAsyncOperation? waitingWriters = null; + BlockedWriteAsyncOperation? blockedWriter = null; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we're done writing so that there will never be more items, fail. + if (parent._doneWriting is not null) + { + return ChannelUtilities.GetInvalidCompletionValueTask(parent._doneWriting); + } + + // Reserve a blocked writer if one is available. + blockedWriter = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedWritersHead); + + // If we couldn't get one, create a waiting reader, and reserve any waiting writers to alert. + if (blockedWriter is null) + { + reader = + !cancellationToken.CanBeCanceled && _readerSingleton.TryOwnAndReset() ? _readerSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + ChannelUtilities.Enqueue(ref parent._blockedReadersHead, reader); + + waitingWriters = ChannelUtilities.TryReserveCompletionIfCancelable(ref parent._waitingWritersHead); + } + } + + // Either complete the reserved blocked writer, transferring its item to the read, + // or return the waiting reader task, also alerting any waiting writers. + ValueTask result; + if (blockedWriter is not null) + { + Debug.Assert(reader is null); + Debug.Assert(waitingWriters is null); + result = new(blockedWriter.Item!); + blockedWriter.DangerousSetResult(default); + } + else + { + Debug.Assert(reader is not null); + ChannelUtilities.DangerousSetOperations(waitingWriters, result: true); + result = reader.ValueTaskOfT; + } + + return result; + } + + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + RendezvousChannel parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we're done writing, a read will never be possible. + if (parent._doneWriting is not null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + new ValueTask(Task.FromException(parent._doneWriting)) : + default; + } + + // If there are any writers waiting, a read is possible. + if (parent._blockedWritersHead is not null) + { + return new ValueTask(true); + } + + // Register a waiting reader task. + WaitingReadAsyncOperation waiter = + !cancellationToken.CanBeCanceled && _waiterSingleton.TryOwnAndReset() ? _waiterSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + ChannelUtilities.Enqueue(ref parent._waitingReadersHead, waiter); + return waiter.ValueTaskOfT; + } + } + + internal string DebuggerDisplay + { + get + { + long blockedReaderCount, waitingReaderCount; + lock (_parent.SyncObj) + { + blockedReaderCount = ChannelUtilities.CountOperations(_parent._blockedReadersHead); + waitingReaderCount = ChannelUtilities.CountOperations(_parent._waitingReadersHead); + } + + return $"ReadAsync={blockedReaderCount}, WaitToReadAsync={waitingReaderCount}"; + } + } + } + + [DebuggerDisplay("{DebuggerDisplay,nq}")] + private sealed class RendezvousChannelWriter : ChannelWriter + { + internal readonly RendezvousChannel _parent; + private readonly BlockedWriteAsyncOperation _writerSingleton; + private readonly WaitingWriteAsyncOperation _waiterSingleton; + + internal RendezvousChannelWriter(RendezvousChannel parent) + { + _parent = parent; + _writerSingleton = new BlockedWriteAsyncOperation(runContinuationsAsynchronously: true, pooled: true); + _waiterSingleton = new WaitingWriteAsyncOperation(runContinuationsAsynchronously: true, pooled: true); + } + + public override bool TryComplete(Exception? error) + { + RendezvousChannel parent = _parent; + + BlockedReadAsyncOperation? blockedReadersHead; + BlockedWriteAsyncOperation? blockedWritersHead; + WaitingReadAsyncOperation? waitingReadersHead; + WaitingWriteAsyncOperation? waitingWritersHead; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we've already marked the channel as completed, bail. + if (parent._doneWriting is not null) + { + return false; + } + + // Mark that we're done writing. + parent._doneWriting = error ?? ChannelUtilities.s_doneWritingSentinel; + + // Snag the queues while holding the lock, so that we don't need to worry + // about concurrent mutation, such as from cancellation of pending operations. + blockedReadersHead = parent._blockedReadersHead; + blockedWritersHead = parent._blockedWritersHead; + waitingReadersHead = parent._waitingReadersHead; + waitingWritersHead = parent._waitingWritersHead; + parent._blockedReadersHead = null; + parent._blockedWritersHead = null; + parent._waitingReadersHead = null; + parent._waitingWritersHead = null; + } + + // Complete the channel's task, as no more data can possibly arrive at this point. We do this outside + // of the lock in case we'll be running synchronous completions, and we + // do it before completing blocked/waiting readers, so that when they + // wake up they'll see the task as being completed. + ChannelUtilities.Complete(parent._completion, error); + + // Complete all pending operations. We don't need to worry about concurrent mutation here: + // No other writers or readers will be able to register operations, and any cancellation callbacks + // will see the queues as being null and exit immediately. + ChannelUtilities.FailOperations(blockedReadersHead, ChannelUtilities.CreateInvalidCompletionException(error)); + ChannelUtilities.FailOperations(blockedWritersHead, ChannelUtilities.CreateInvalidCompletionException(error)); + ChannelUtilities.SetOrFailOperations(waitingReadersHead, result: false, error: error); + ChannelUtilities.SetOrFailOperations(waitingWritersHead, result: false, error: error); + + // Successfully transitioned to completed. + return true; + } + + public override bool TryWrite(T item) + { + RendezvousChannel parent = _parent; + + // Reserve a blocked reader if one is available. + BlockedReadAsyncOperation? blockedReader = null; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + if (parent._doneWriting is null) + { + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); + } + } + + // If we got one, transfer its item to the read and complete successfully. + if (blockedReader is not null) + { + blockedReader.DangerousSetResult(item); + return true; + } + + // There's no concurrent reader, but if we're configured to drop writes, we can succeed immediately. + if (parent._dropWrites) + { + parent._itemDropped?.Invoke(item); + return true; + } + + return false; + } + + public override ValueTask WaitToWriteAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + RendezvousChannel parent = _parent; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we're done writing, a read will never be possible. + if (parent._doneWriting is not null) + { + return parent._doneWriting != ChannelUtilities.s_doneWritingSentinel ? + new ValueTask(Task.FromException(parent._doneWriting)) : + default; + } + + // If there are any readers waiting, a write is possible. + if (parent._blockedReadersHead is not null || parent._dropWrites) + { + return new ValueTask(true); + } + + // There were no readers available, but there could be in the future, so ensure + // there's a waiting writer task and return it. + WaitingWriteAsyncOperation waiter = + !cancellationToken.CanBeCanceled && _waiterSingleton.TryOwnAndReset() ? _waiterSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + ChannelUtilities.Enqueue(ref parent._waitingWritersHead, waiter); + return waiter.ValueTaskOfT; + } + } + + public override ValueTask WriteAsync(T item, CancellationToken cancellationToken) + { + RendezvousChannel parent = _parent; + + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + BlockedWriteAsyncOperation? writer = null; + WaitingReadAsyncOperation? waitingReaders = null; + BlockedReadAsyncOperation? blockedReader = null; + lock (parent.SyncObj) + { + parent.AssertInvariants(); + + // If we've already been marked as done for writing, we shouldn't be writing. + if (parent._doneWriting is not null) + { + return new ValueTask(Task.FromException(ChannelUtilities.CreateInvalidCompletionException(parent._doneWriting))); + } + + // Reserve a blocked reader if one is available. + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); + + // If we couldn't get one, create a blocked writer, and reserve any waiting readers to alert. + if (blockedReader is null && !parent._dropWrites) + { + writer = + !cancellationToken.CanBeCanceled && _writerSingleton.TryOwnAndReset() ? _writerSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + writer.Item = item; + ChannelUtilities.Enqueue(ref parent._blockedWritersHead, writer); + + waitingReaders = ChannelUtilities.TryReserveCompletionIfCancelable(ref parent._waitingReadersHead); + } + } + + if (writer is not null) + { + Debug.Assert(blockedReader is null); + ChannelUtilities.DangerousSetOperations(waitingReaders, result: true); + return writer.ValueTask; + } + + if (blockedReader is not null) + { + blockedReader.DangerousSetResult(item); + } + else + { + Debug.Assert(parent._dropWrites); + parent._itemDropped?.Invoke(item); + } + + return default; + } + + internal string DebuggerDisplay + { + get + { + long blockedWriterCount, waitingWriterCount; + lock (_parent.SyncObj) + { + blockedWriterCount = ChannelUtilities.CountOperations(_parent._blockedWritersHead); + waitingWriterCount = ChannelUtilities.CountOperations(_parent._waitingWritersHead); + } + + return $"WriteAsync={blockedWriterCount}, WaitToWriteAsync={waitingWriterCount}"; + } + } + } + + /// Gets an object used to synchronize all state on the instance. + private object SyncObj => _completion; + + private Action CancellationCallbackDelegate => + field ??= (state, cancellationToken) => + { + AsyncOperation op = (AsyncOperation)state!; + if (op.TrySetCanceled(cancellationToken)) + { + ChannelUtilities.UnsafeQueueUserWorkItem(static state => // escape cancellation callback + { + lock (state.Key.SyncObj) + { + switch (state.Value) + { + case BlockedReadAsyncOperation blockedReader: + ChannelUtilities.Remove(ref state.Key._blockedReadersHead, blockedReader); + break; + + case BlockedWriteAsyncOperation blockedWriter: + ChannelUtilities.Remove(ref state.Key._blockedWritersHead, blockedWriter); + break; + + case WaitingReadAsyncOperation waitingReader: + ChannelUtilities.Remove(ref state.Key._waitingReadersHead, waitingReader); + break; + + case WaitingWriteAsyncOperation waitingWriter: + ChannelUtilities.Remove(ref state.Key._waitingWritersHead, waitingWriter); + break; + + default: + Debug.Fail($"Unexpected operation: {state.Value}"); + break; + } + } + }, new KeyValuePair, AsyncOperation>(this, op)); + } + }; + + private string DebuggerDisplay => + $"{((RendezvousChannelReader)Reader).DebuggerDisplay}, {((RendezvousChannelWriter)Writer).DebuggerDisplay}"; + + [Conditional("DEBUG")] + private void AssertInvariants() + { + Debug.Assert(SyncObj is not null, "The sync obj must not be null."); + Debug.Assert(Monitor.IsEntered(SyncObj), "Invariants can only be validated while holding the lock."); + + if (_blockedReadersHead is not null) + { + Debug.Assert(_blockedWritersHead is null, "There shouldn't be any blocked writer if there's a blocked reader."); + Debug.Assert(_waitingWritersHead is null, "There shouldn't be any waiting writers if there's a blocked reader."); + } + + if (_blockedWritersHead is not null) + { + Debug.Assert(_blockedReadersHead is null, "There shouldn't be any blocked readers if there's a blocked writer."); + Debug.Assert(_waitingReadersHead is null, "There shouldn't be any waiting readers if there's a blocked writer."); + } + + if (_completion.Task.IsCompleted) + { + Debug.Assert(_doneWriting is not null, "We can only complete if we're done writing."); + } + } + } +} diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs index c863f619c041fc..969249384ac1de 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/SingleConsumerUnboundedChannel.cs @@ -251,8 +251,6 @@ public override bool TryComplete(Exception? error) ChannelUtilities.Complete(parent._completion, error); } - Debug.Assert(blockedReader is null || waitingReader is null, "There should only ever be at most one reader."); - // Complete a blocked reader if necessary if (blockedReader is not null) { diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs index 6db2a3e5f8f1cc..f9bfb4127f4f93 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedChannel.cs @@ -97,19 +97,10 @@ public override ValueTask ReadAsync(CancellationToken cancellationToken) return ChannelUtilities.GetInvalidCompletionValueTask(parent._doneWriting); } - // If we're able to use the singleton reader, do so. - if (!cancellationToken.CanBeCanceled) - { - BlockedReadAsyncOperation singleton = _readerSingleton; - if (singleton.TryOwnAndReset()) - { - ChannelUtilities.Enqueue(ref parent._blockedReadersHead, singleton); - return singleton.ValueTaskOfT; - } - } - - // Otherwise, create and queue a reader. - var reader = new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + // If we're able to use the singleton reader, do so. Otherwise, create a new reader. + BlockedReadAsyncOperation reader = + !cancellationToken.CanBeCanceled && _readerSingleton.TryOwnAndReset() ? _readerSingleton : + new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); ChannelUtilities.Enqueue(ref parent._blockedReadersHead, reader); return reader.ValueTaskOfT; } @@ -174,19 +165,10 @@ public override ValueTask WaitToReadAsync(CancellationToken cancellationTo default; } - // If we're able to use the singleton waiter, do so. - if (!cancellationToken.CanBeCanceled) - { - WaitingReadAsyncOperation singleton = _waiterSingleton; - if (singleton.TryOwnAndReset()) - { - ChannelUtilities.Enqueue(ref parent._waitingReadersHead, singleton); - return singleton.ValueTaskOfT; - } - } - - // Otherwise, create and queue a waiter. - var waiter = new WaitingReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + // If we're able to use the singleton waiter, do so. Otherwise, create a new waiter. + WaitingReadAsyncOperation waiter = + !cancellationToken.CanBeCanceled && _waiterSingleton.TryOwnAndReset() ? _waiterSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); ChannelUtilities.Enqueue(ref parent._waitingReadersHead, waiter); return waiter.ValueTaskOfT; } @@ -269,13 +251,7 @@ public override bool TryWrite(T item) } // Try to get a blocked reader that we can transfer the item to. - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); // If we weren't able to get a reader, instead queue the item and get any waiters that need to be notified. if (blockedReader is null) diff --git a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs index ead76810b0b62b..738af8f8c596d5 100644 --- a/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs +++ b/src/libraries/System.Threading.Channels/src/System/Threading/Channels/UnboundedPriorityChannel.cs @@ -98,19 +98,10 @@ public override ValueTask ReadAsync(CancellationToken cancellationToken) return ChannelUtilities.GetInvalidCompletionValueTask(parent._doneWriting); } - // If we're able to use the singleton reader, do so. - if (!cancellationToken.CanBeCanceled) - { - BlockedReadAsyncOperation singleton = _readerSingleton; - if (singleton.TryOwnAndReset()) - { - ChannelUtilities.Enqueue(ref parent._blockedReadersHead, singleton); - return singleton.ValueTaskOfT; - } - } - - // Otherwise, create and queue a reader. - var reader = new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + // If we're able to use the singleton reader, do so. Otherwise, create a new reader. + BlockedReadAsyncOperation reader = + !cancellationToken.CanBeCanceled && _readerSingleton.TryOwnAndReset() ? _readerSingleton : + new BlockedReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); ChannelUtilities.Enqueue(ref parent._blockedReadersHead, reader); return reader.ValueTaskOfT; } @@ -179,19 +170,10 @@ public override ValueTask WaitToReadAsync(CancellationToken cancellationTo default; } - // If we're able to use the singleton waiter, do so. - if (!cancellationToken.CanBeCanceled) - { - WaitingReadAsyncOperation singleton = _waiterSingleton; - if (singleton.TryOwnAndReset()) - { - ChannelUtilities.Enqueue(ref parent._waitingReadersHead, singleton); - return singleton.ValueTaskOfT; - } - } - - // Otherwise, create and queue a waiter. - var waiter = new WaitingReadAsyncOperation(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); + // If we're able to use the singleton waiter, do so. Otherwise, create a new waiter. + WaitingReadAsyncOperation waiter = + !cancellationToken.CanBeCanceled && _waiterSingleton.TryOwnAndReset() ? _waiterSingleton : + new(parent._runContinuationsAsynchronously, cancellationToken, cancellationCallback: _parent.CancellationCallbackDelegate); ChannelUtilities.Enqueue(ref parent._waitingReadersHead, waiter); return waiter.ValueTaskOfT; } @@ -275,13 +257,7 @@ public override bool TryWrite(T item) } // Try to get a blocked reader that we can transfer the item to. - while (ChannelUtilities.TryDequeue(ref parent._blockedReadersHead, out blockedReader)) - { - if (blockedReader.TryReserveCompletionIfCancelable()) - { - break; - } - } + blockedReader = ChannelUtilities.TryDequeueAndReserveCompletionIfCancelable(ref parent._blockedReadersHead); // If we weren't able to get a reader, instead queue the item and get any waiters that need to be notified. if (blockedReader is null) diff --git a/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs b/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs index cefa3fe6f29901..2799d6509d1a70 100644 --- a/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs +++ b/src/libraries/System.Threading.Channels/tests/ChannelTestBase.cs @@ -22,23 +22,17 @@ public abstract partial class ChannelTestBase : TestBase protected virtual bool RequiresSingleReader => false; protected virtual bool RequiresSingleWriter => false; protected virtual bool BuffersItems => true; - - public static IEnumerable ThreeBools => - from b1 in new[] { false, true } - from b2 in new[] { false, true } - from b3 in new[] { false, true } - select new object[] { b1, b2, b3 }; + protected virtual bool HasDebuggerTypeProxy => true; [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsDebuggerTypeProxyAttributeSupported))] public void ValidateDebuggerAttributes() { Channel c = CreateChannel(); - for (int i = 1; i <= 10; i++) + DebuggerAttributes.ValidateDebuggerDisplayReferences(c); + if (HasDebuggerTypeProxy) { - c.Writer.WriteAsync(i); + DebuggerAttributes.InvokeDebuggerTypeProxyProperties(c); } - DebuggerAttributes.ValidateDebuggerDisplayReferences(c); - DebuggerAttributes.InvokeDebuggerTypeProxyProperties(c); } [Fact] @@ -112,7 +106,7 @@ public void TryComplete_Twice_ReturnsTrueThenFalse() } [Fact] - public async Task TryComplete_ErrorsPropage() + public async Task TryComplete_ErrorsPropagate() { Channel c; @@ -142,6 +136,10 @@ public void Count_ThrowsIfUnsupported() { Assert.Throws(() => c.Reader.Count); } + else + { + Assert.InRange(c.Reader.Count, 0, int.MaxValue); + } } [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] @@ -226,17 +224,13 @@ public void ManyProducerConsumer_ConcurrentReadWrite_Success(int numReaders, int { tasks[i] = Task.Run(async () => { - try + while (await c.Reader.WaitToReadAsync()) { - while (await c.Reader.WaitToReadAsync()) + while (c.Reader.TryRead(out int value)) { - if (c.Reader.TryRead(out int value)) - { - Interlocked.Add(ref readTotal, value); - } + Interlocked.Add(ref readTotal, value); } } - catch (ChannelClosedException) { } }); } @@ -352,13 +346,13 @@ public void WaitToWriteAsync_BlockedReader_ReturnsTrue() } [Fact] - public void TryRead_DataAvailable_Success() + public async Task TryRead_DataAvailable_Success() { Channel c = CreateChannel(); ValueTask write = c.Writer.WriteAsync(42); - Assert.True(write.IsCompletedSuccessfully); Assert.True(c.Reader.TryRead(out int result)); Assert.Equal(42, result); + await write; } [Fact] @@ -370,7 +364,7 @@ public void TryRead_AfterComplete_ReturnsFalse() } [Fact] - public void TryPeek_SucceedsWhenDataAvailable() + public async Task TryPeek_SucceedsWhenDataAvailable() { Channel c = CreateChannel(); @@ -379,7 +373,7 @@ public void TryPeek_SucceedsWhenDataAvailable() for (int i = 0; i < 3; i++) { // Write a value - Assert.True(c.Writer.WriteAsync(42).IsCompletedSuccessfully); + ValueTask write = c.Writer.WriteAsync(42); // Can peek at the written value Assert.True(c.Reader.TryPeek(out int peekedResult)); @@ -389,11 +383,18 @@ public void TryPeek_SucceedsWhenDataAvailable() Assert.True(c.Reader.TryRead(out int readResult)); Assert.Equal(42, readResult); + await write; + // Peeking no longer finds it Assert.False(c.Reader.TryPeek(out int noResult)); Assert.Equal(0, noResult); } + if (!BuffersItems) + { + return; + } + // Write another value Assert.True(c.Writer.WriteAsync(84).IsCompletedSuccessfully); @@ -544,7 +545,7 @@ public void Precancellation_Writing_ReturnsImmediately() Assert.True(writeTask.IsCanceled); ValueTask waitTask = c.Writer.WaitToWriteAsync(new CancellationToken(true)); - Assert.True(writeTask.IsCanceled); + Assert.True(waitTask.IsCanceled); } [Fact] @@ -558,16 +559,23 @@ public void Write_WaitToReadAsync_CompletesSynchronously() [Theory] [InlineData(false)] [InlineData(true)] - public void Precancellation_WaitToReadAsync_ReturnsImmediately(bool dataAvailable) + public async Task Precancellation_WaitToReadAsync_ReturnsImmediately(bool dataAvailable) { Channel c = CreateChannel(); + + ValueTask write = default; if (dataAvailable) { - Assert.True(c.Writer.TryWrite(42)); + write = c.Writer.WriteAsync(42); } ValueTask waitTask = c.Reader.WaitToReadAsync(new CancellationToken(true)); Assert.True(waitTask.IsCanceled); + + if (BuffersItems) + { + await write; + } } [Theory] @@ -634,7 +642,7 @@ public void Precancellation_ReadAsync_ReturnsImmediately(bool dataAvailable) Channel c = CreateChannel(); if (dataAvailable) { - Assert.True(c.Writer.TryWrite(42)); + c.Writer.WriteAsync(42); } ValueTask readTask = c.Reader.ReadAsync(new CancellationToken(true)); @@ -655,10 +663,9 @@ public async Task ReadAsync_Canceled_CanceledAsynchronously() await AssertExtensions.CanceledAsync(cts.Token, async () => await r); - if (c.Writer.TryWrite(42)) - { - Assert.Equal(42, await c.Reader.ReadAsync()); - } + ValueTask vt = c.Writer.WriteAsync(42); + Assert.Equal(42, await c.Reader.ReadAsync()); + await vt; } [Fact] @@ -823,7 +830,7 @@ public async Task ReadAllAsync_UseMoveNextAsyncAfterCompleted_ReturnsFalse(bool } [Fact] - public void ReadAllAsync_AvailableDataCompletesSynchronously() + public async Task ReadAllAsync_AvailableDataCompletesSynchronously() { Channel c = CreateChannel(); @@ -832,11 +839,12 @@ public void ReadAllAsync_AvailableDataCompletesSynchronously() { for (int i = 100; i < 110; i++) { - Assert.True(c.Writer.TryWrite(i)); + ValueTask write = c.Writer.WriteAsync(i); ValueTask vt = e.MoveNextAsync(); Assert.True(vt.IsCompletedSuccessfully); Assert.True(vt.Result); Assert.Equal(i, e.Current); + await write; } } finally @@ -859,7 +867,7 @@ public async Task ReadAllAsync_UnavailableDataCompletesAsynchronously() { ValueTask vt = e.MoveNextAsync(); Assert.False(vt.IsCompleted); - Task producer = Task.Run(() => c.Writer.TryWrite(i)); + Task producer = Task.Run(() => c.Writer.WriteAsync(i).AsTask()); Assert.True(await vt); await producer; Assert.Equal(i, e.Current); @@ -916,8 +924,7 @@ public async Task ReadAllAsync_MultipleEnumerationsToEnd() { Channel c = CreateChannel(); - Assert.True(c.Writer.TryWrite(42)); - c.Writer.Complete(); + ValueTask write = c.Writer.WriteAsync(42); IAsyncEnumerable enumerable = c.Reader.ReadAllAsync(); IAsyncEnumerator e = enumerable.GetAsyncEnumerator(); @@ -925,6 +932,9 @@ public async Task ReadAllAsync_MultipleEnumerationsToEnd() Assert.True(await e.MoveNextAsync()); Assert.Equal(42, e.Current); + await write; + c.Writer.Complete(); + Assert.False(await e.MoveNextAsync()); Assert.False(await e.MoveNextAsync()); @@ -942,16 +952,17 @@ public async Task ReadAllAsync_MultipleEnumerationsToEnd() [InlineData(false, true)] [InlineData(true, false)] [InlineData(true, true)] - public void ReadAllAsync_MultipleSingleElementEnumerations_AllItemsEnumerated(bool sameEnumerable, bool dispose) + public async Task ReadAllAsync_MultipleSingleElementEnumerations_AllItemsEnumerated(bool sameEnumerable, bool dispose) { Channel c = CreateChannel(); IAsyncEnumerable enumerable = c.Reader.ReadAllAsync(); for (int i = 0; i < 10; i++) { - Assert.True(c.Writer.TryWrite(i)); + ValueTask write = c.Writer.WriteAsync(i); IAsyncEnumerator e = (sameEnumerable ? enumerable : c.Reader.ReadAllAsync()).GetAsyncEnumerator(); ValueTask vt = e.MoveNextAsync(); + await write; Assert.True(vt.IsCompletedSuccessfully); Assert.True(vt.Result); Assert.Equal(i, e.Current); @@ -1015,9 +1026,10 @@ public async Task ReadAllAsync_DualConcurrentEnumeration_AllItemsEnumerated(bool public async Task ReadAllAsync_CanceledBeforeMoveNextAsync_Throws(bool dataAvailable) { Channel c = CreateChannel(); + if (dataAvailable) { - Assert.True(c.Writer.TryWrite(42)); + _ = c.Writer.WriteAsync(42); } var cts = new CancellationTokenSource(); @@ -1057,10 +1069,11 @@ public async Task WaitToReadAsync_ConsecutiveReadsSucceed() for (int i = 0; i < 5; i++) { ValueTask r = c.Reader.WaitToReadAsync(); - await c.Writer.WriteAsync(i); + ValueTask write = c.Writer.WriteAsync(i); Assert.True(await r); Assert.True(c.Reader.TryRead(out int item)); Assert.Equal(i, item); + await write; } } @@ -1155,7 +1168,8 @@ public async Task WaitToReadAsync_AwaitThenGetResult_Throws() Channel c = CreateChannel(); ValueTask read = c.Reader.WaitToReadAsync(); - Assert.True(c.Writer.TryWrite(42)); + + ValueTask write = c.Writer.WriteAsync(42); Assert.True(await read); Assert.Throws(() => read.GetAwaiter().IsCompleted); Assert.Throws(() => read.GetAwaiter().OnCompleted(() => { })); @@ -1169,8 +1183,9 @@ public async Task ReadAsync_AwaitThenGetResult_Throws() Channel c = CreateChannel(); ValueTask read = c.Reader.ReadAsync(); - Assert.True(c.Writer.TryWrite(42)); + ValueTask write = c.Writer.WriteAsync(42); Assert.Equal(42, await read); + await write; Assert.Throws(() => read.GetAwaiter().IsCompleted); Assert.Throws(() => read.GetAwaiter().OnCompleted(() => { })); Assert.Throws(() => read.GetAwaiter().GetResult()); @@ -1186,7 +1201,7 @@ public async Task WaitToWriteAsync_AwaitThenGetResult_Throws() } ValueTask write = c.Writer.WaitToWriteAsync(); - await c.Reader.ReadAsync(); + ValueTask read = c.Reader.ReadAsync(); Assert.True(await write); Assert.Throws(() => write.GetAwaiter().IsCompleted); Assert.Throws(() => write.GetAwaiter().OnCompleted(() => { })); @@ -1344,7 +1359,7 @@ await Task.Factory.StartNew(async () => { Assert.False(vt.IsCompleted); Assert.False(vt.IsCompletedSuccessfully); - c.Writer.TryWrite(true); + _ = c.Writer.WriteAsync(true); } SynchronizationContext.SetSynchronizationContext(new CustomSynchronizationContext()); @@ -1379,7 +1394,7 @@ await Task.Factory.StartNew(async () => { Assert.False(vt.IsCompleted); Assert.False(vt.IsCompletedSuccessfully); - c.Writer.TryWrite(true); + _ = c.Writer.WriteAsync(true); } await continuationRan.Task; @@ -1440,7 +1455,7 @@ await Task.Run(async () => { Assert.False(vt.IsCompleted); Assert.False(vt.IsCompletedSuccessfully); - c.Writer.TryWrite(true); + _ = c.Writer.WriteAsync(true); } await Task.Factory.StartNew(() => @@ -1482,7 +1497,7 @@ await Task.Factory.StartNew(() => { Assert.False(vt.IsCompleted); Assert.False(vt.IsCompletedSuccessfully); - c.Writer.TryWrite(true); + _ = c.Writer.WriteAsync(true); } await continuationRan.Task; diff --git a/src/libraries/System.Threading.Channels/tests/ChannelTests.cs b/src/libraries/System.Threading.Channels/tests/ChannelTests.cs index 2765e66d4f346a..f42096f7890dae 100644 --- a/src/libraries/System.Threading.Channels/tests/ChannelTests.cs +++ b/src/libraries/System.Threading.Channels/tests/ChannelTests.cs @@ -62,7 +62,7 @@ public void Create_NullOptions_ThrowsArgumentException() } [Theory] - [InlineData(0)] + [InlineData(-1)] [InlineData(-2)] public void CreateBounded_InvalidBufferSizes_ThrowArgumentExceptions(int capacity) { @@ -77,7 +77,7 @@ public void BoundedChannelOptions_InvalidModes_ThrowArgumentExceptions(BoundedCh AssertExtensions.Throws("value", () => new BoundedChannelOptions(1) { FullMode = mode }); [Theory] - [InlineData(0)] + [InlineData(-1)] [InlineData(-2)] public void BoundedChannelOptions_InvalidCapacity_ThrowArgumentExceptions(int capacity) => AssertExtensions.Throws("value", () => new BoundedChannelOptions(1) { Capacity = capacity }); diff --git a/src/libraries/System.Threading.Channels/tests/RendezvousChannelTests.cs b/src/libraries/System.Threading.Channels/tests/RendezvousChannelTests.cs new file mode 100644 index 00000000000000..5de9153ca120a6 --- /dev/null +++ b/src/libraries/System.Threading.Channels/tests/RendezvousChannelTests.cs @@ -0,0 +1,317 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; +using Microsoft.DotNet.XUnitExtensions; +using Xunit; + +namespace System.Threading.Channels.Tests +{ + public class RendezvousChannelTests : ChannelTestBase + { + protected override Channel CreateChannel() => Channel.CreateBounded(0); + + protected override Channel CreateFullChannel() => CreateChannel(); + + protected override bool BuffersItems => false; + + protected override bool HasDebuggerTypeProxy => false; + + [Fact] + public async Task Count_AlwaysZero() + { + Channel c = CreateChannel(); + + Assert.True(c.Reader.CanCount); + Assert.Equal(0, c.Reader.Count); + + var write1 = c.Writer.WriteAsync(1); + var write2 = c.Writer.WriteAsync(2); + + Assert.Equal(0, c.Reader.Count); + Assert.False(write1.IsCompleted); + Assert.False(write2.IsCompleted); + + Assert.Equal(1, await c.Reader.ReadAsync()); + + await write1; + Assert.Equal(0, c.Reader.Count); + Assert.False(write2.IsCompleted); + + Assert.Equal(2, await c.Reader.ReadAsync()); + + await write2; + Assert.Equal(0, c.Reader.Count); + } + + [Fact] + public void TryWrite_TryRead_NoPairing_ReturnsFalse() + { + Channel c = CreateChannel(); + + for (int i = 0; i < 3; i++) + { + Assert.False(c.Writer.TryWrite(42)); + Assert.False(c.Reader.TryRead(out int item)); + Assert.Equal(0, item); + } + } + + [Theory] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public async Task TryWrite_DropXx_DropsWrite(BoundedChannelFullMode mode) + { + int? dropped = null; + Channel c = Channel.CreateBounded(new BoundedChannelOptions(0) { FullMode = mode }, item => dropped = item); + + for (int i = 42; i < 52; i++) + { + var waiter = c.Writer.WaitToWriteAsync(); + AssertSynchronousSuccess(waiter); + Assert.True(await waiter); + + Assert.True(c.Writer.TryWrite(i)); + Assert.Equal(i, dropped); + + dropped = null; + AssertSynchronousSuccess(c.Writer.WriteAsync(i)); + Assert.Equal(i, dropped); + } + } + + [Theory] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public async Task TryWrite_DropXx_ReaderTakesPriority(BoundedChannelFullMode mode) + { + int? dropped = null; + Channel c = Channel.CreateBounded(new BoundedChannelOptions(0) { FullMode = mode }, item => dropped = item); + + for (int i = 42; i < 52; i++) + { + ValueTask reader; + + reader = c.Reader.ReadAsync(); + Assert.True(c.Writer.TryWrite(i)); + Assert.Null(dropped); + Assert.Equal(i, await reader); + + reader = c.Reader.ReadAsync(); + AssertSynchronousSuccess(c.Writer.WriteAsync(i)); + Assert.Null(dropped); + Assert.Equal(i, await reader); + } + } + + [Fact] + public async Task DroppedDelegateNotCalledOnWaitMode_SyncWrites() + { + bool dropDelegateCalled = false; + + Channel c = Channel.CreateBounded(new BoundedChannelOptions(0) { FullMode = BoundedChannelFullMode.Wait }, + item => + { + dropDelegateCalled = true; + }); + + ValueTask reader; + + reader = c.Reader.ReadAsync(); + Assert.True(c.Writer.TryWrite(42)); + Assert.Equal(42, await reader); + + reader = c.Reader.ReadAsync(); + AssertSynchronousSuccess(c.Writer.WriteAsync(43)); + Assert.Equal(43, await reader); + + _ = c.Writer.WriteAsync(44); + + Assert.False(dropDelegateCalled); + } + + [Theory] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public void DroppedDelegateIsNull_SyncAndAsyncWrites(BoundedChannelFullMode boundedChannelFullMode) + { + Channel c = Channel.CreateBounded(new BoundedChannelOptions(0) { FullMode = boundedChannelFullMode }, itemDropped: null); + + Assert.True(c.Writer.TryWrite(1)); + AssertSynchronousSuccess(c.Writer.WriteAsync(2)); + + Assert.False(c.Reader.TryRead(out int item)); + Assert.Equal(0, item); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public void DroppedDelegateCalledAfterLockReleased_SyncWrites(BoundedChannelFullMode boundedChannelFullMode) + { + Channel c = null; + bool dropDelegateCalled = false; + + c = Channel.CreateBounded(new BoundedChannelOptions(0) + { + FullMode = boundedChannelFullMode + }, (droppedItem) => + { + if (dropDelegateCalled) + { + return; + } + dropDelegateCalled = true; + + // Dropped delegate should not be called while holding the channel lock. + // Verify this by trying to write into the channel from different thread. + // If lock is held during callback, this should effectively cause deadlock. + ManualResetEventSlim mres = new(); + ThreadPool.QueueUserWorkItem(delegate + { + c.Writer.TryWrite(3); + mres.Set(); + }); + + mres.Wait(); + }); + + Assert.True(c.Writer.TryWrite(1)); + Assert.True(c.Writer.TryWrite(2)); + + Assert.True(dropDelegateCalled); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [InlineData(BoundedChannelFullMode.DropWrite)] + [InlineData(BoundedChannelFullMode.DropOldest)] + [InlineData(BoundedChannelFullMode.DropNewest)] + public async Task DroppedDelegateCalledAfterLockReleased_AsyncWrites(BoundedChannelFullMode boundedChannelFullMode) + { + Channel c = null; + bool dropDelegateCalled = false; + + c = Channel.CreateBounded(new BoundedChannelOptions(0) + { + FullMode = boundedChannelFullMode + }, (droppedItem) => + { + if (dropDelegateCalled) + { + return; + } + dropDelegateCalled = true; + + // Dropped delegate should not be called while holding the channel synchronisation lock. + // Verify this by trying to write into the channel from different thread. + // If lock is held during callback, this should effectively cause deadlock. + var mres = new ManualResetEventSlim(); + ThreadPool.QueueUserWorkItem(delegate + { + c.Writer.TryWrite(11); + mres.Set(); + }); + + mres.Wait(); + }); + + await c.Writer.WriteAsync(1); + await c.Writer.WriteAsync(2); + + Assert.True(dropDelegateCalled); + } + + [Fact] + public async Task CancelPendingWrite_Reading_DataTransferredFromCorrectWriter() + { + Channel c = CreateChannel(); + + CancellationTokenSource cts = new(); + + ValueTask write1 = c.Writer.WriteAsync(42); + ValueTask write2 = c.Writer.WriteAsync(43, cts.Token); + ValueTask write3 = c.Writer.WriteAsync(44); + + cts.Cancel(); + + Assert.Equal(42, await c.Reader.ReadAsync()); + Assert.Equal(44, await c.Reader.ReadAsync()); + + await write1; + await AssertExtensions.CanceledAsync(cts.Token, async () => await write2); + await write3; + } + + [Fact] + public async Task WaitToWriteAsync_AfterRead_ReturnsTrue() + { + Channel c = CreateChannel(); + + ValueTask write1 = c.Writer.WaitToWriteAsync(); + ValueTask write2 = c.Writer.WaitToWriteAsync(); + Assert.False(write1.IsCompleted); + Assert.False(write2.IsCompleted); + + _ = c.Reader.ReadAsync(); + + Assert.True(await write1); + Assert.True(await write2); + + ValueTask write3 = c.Writer.WaitToWriteAsync(); + AssertSynchronousSuccess(write3); + Assert.True(await write3); + } + + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + [MemberData(nameof(ThreeBools))] + public void AllowSynchronousContinuations_Reading_ContinuationsInvokedAccordingToSetting(bool allowSynchronousContinuations, bool cancelable, bool waitToReadAsync) + { + var c = Channel.CreateBounded(new BoundedChannelOptions(0) { AllowSynchronousContinuations = allowSynchronousContinuations }); + + CancellationToken ct = cancelable ? new CancellationTokenSource().Token : CancellationToken.None; + + int expectedId = Environment.CurrentManagedThreadId; + Task t = waitToReadAsync ? c.Reader.WaitToReadAsync(ct).AsTask() : c.Reader.ReadAsync(ct).AsTask(); + Task r = t.ContinueWith(_ => + { + Assert.Equal(allowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + + ValueTask write = c.Writer.WriteAsync(42); + if (!waitToReadAsync) + { + AssertSynchronousSuccess(write); + } + + ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation + r.GetAwaiter().GetResult(); + } + + [ConditionalTheory] + [InlineData(false)] + [InlineData(true)] + public void AllowSynchronousContinuations_CompletionTask_ContinuationsInvokedAccordingToSetting(bool allowSynchronousContinuations) + { + if (!allowSynchronousContinuations && !PlatformDetection.IsThreadingSupported) + { + throw new SkipTestException(nameof(PlatformDetection.IsThreadingSupported)); + } + + var c = Channel.CreateBounded(new BoundedChannelOptions(0) { AllowSynchronousContinuations = allowSynchronousContinuations }); + + int expectedId = Environment.CurrentManagedThreadId; + Task r = c.Reader.Completion.ContinueWith(_ => + { + Assert.Equal(allowSynchronousContinuations, expectedId == Environment.CurrentManagedThreadId); + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); + + Assert.True(c.Writer.TryComplete()); + ((IAsyncResult)r).AsyncWaitHandle.WaitOne(); // avoid inlining the continuation + r.GetAwaiter().GetResult(); + } + } +} diff --git a/src/libraries/System.Threading.Channels/tests/Stress.cs b/src/libraries/System.Threading.Channels/tests/Stress.cs index ffae1e8ce88fbc..a876b13a931435 100644 --- a/src/libraries/System.Threading.Channels/tests/Stress.cs +++ b/src/libraries/System.Threading.Channels/tests/Stress.cs @@ -12,6 +12,7 @@ public class StressTests { public static IEnumerable ReadWriteVariations_TestData() { + // Unbounded foreach (var readDelegate in new Func, Task>[] { ReadSynchronous, ReadAsynchronous, ReadSyncAndAsync} ) foreach (var writeDelegate in new Func, int, Task>[] { WriteSynchronous, WriteAsynchronous, WriteSyncAndAsync} ) foreach (bool singleReader in new [] {false, true}) @@ -19,15 +20,21 @@ public static IEnumerable ReadWriteVariations_TestData() foreach (bool allowSynchronousContinuations in new [] {false, true}) { Func> unbounded = o => Channel.CreateUnbounded((UnboundedChannelOptions)o); - yield return new object[] { unbounded, new UnboundedChannelOptions - { - SingleReader = singleReader, - SingleWriter = singleWriter, - AllowSynchronousContinuations = allowSynchronousContinuations - }, readDelegate, writeDelegate + yield return new object[] + { + unbounded, + new UnboundedChannelOptions + { + SingleReader = singleReader, + SingleWriter = singleWriter, + AllowSynchronousContinuations = allowSynchronousContinuations + }, + readDelegate, + writeDelegate }; } + // Bounded foreach (var readDelegate in new Func, Task>[] { ReadSynchronous, ReadAsynchronous, ReadSyncAndAsync} ) foreach (var writeDelegate in new Func, int, Task>[] { WriteSynchronous, WriteAsynchronous, WriteSyncAndAsync} ) foreach (BoundedChannelFullMode bco in Enum.GetValues(typeof(BoundedChannelFullMode))) @@ -37,13 +44,44 @@ public static IEnumerable ReadWriteVariations_TestData() foreach (bool allowSynchronousContinuations in new [] {false, true}) { Func> bounded = o => Channel.CreateBounded((BoundedChannelOptions)o); - yield return new object[] { bounded, new BoundedChannelOptions(capacity) - { - SingleReader = singleReader, - SingleWriter = singleWriter, - AllowSynchronousContinuations = allowSynchronousContinuations, - FullMode = bco - }, readDelegate, writeDelegate + yield return new object[] + { + bounded, + new BoundedChannelOptions(capacity) + { + SingleReader = singleReader, + SingleWriter = singleWriter, + AllowSynchronousContinuations = allowSynchronousContinuations, + FullMode = bco + }, + readDelegate, + writeDelegate + }; + } + + // Rendezvous + foreach (var readDelegate in new Func, Task>[] { ReadSynchronous, ReadAsynchronous, ReadSyncAndAsync, ReadAsynchronousNoWait } ) + foreach (var writeDelegate in new Func, int, Task>[] { WriteSynchronous, WriteAsynchronous, WriteSyncAndAsync, WriteAsynchronousNoWait } ) + foreach (var bcfm in new[] { BoundedChannelFullMode.Wait, BoundedChannelFullMode.DropWrite }) + foreach (bool allowSynchronousContinuations in new [] {false, true}) + { + if (readDelegate != ReadAsynchronousNoWait && + writeDelegate != WriteAsynchronousNoWait) + { + // At least one side must be persistent. + continue; + } + + yield return new object[] + { + (Func>)(o => Channel.CreateBounded((BoundedChannelOptions)o)), + new BoundedChannelOptions(0) + { + AllowSynchronousContinuations = allowSynchronousContinuations, + FullMode = bcfm + }, + readDelegate, + writeDelegate }; } } @@ -74,9 +112,21 @@ private static async Task ReadAsynchronous(ChannelReader reader) return false; } + private static async Task ReadAsynchronousNoWait(ChannelReader reader) + { + await reader.ReadAsync(); + return true; + } + + private static async Task WriteAsynchronousNoWait(ChannelWriter writer, int value) + { + await writer.WriteAsync(value); + return true; + } + private static async Task ReadSyncAndAsync(ChannelReader reader) { - if (!reader.TryRead(out int value)) + if (!reader.TryRead(out _)) { if (await reader.WaitToReadAsync()) { diff --git a/src/libraries/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj b/src/libraries/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj index 55ef0e463ca130..6792d7d47833a4 100644 --- a/src/libraries/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj +++ b/src/libraries/System.Threading.Channels/tests/System.Threading.Channels.Tests.csproj @@ -1,4 +1,4 @@ - + $(NetCoreAppCurrent);$(NetFrameworkCurrent) @@ -11,6 +11,7 @@ + diff --git a/src/libraries/System.Threading.Channels/tests/TestBase.cs b/src/libraries/System.Threading.Channels/tests/TestBase.cs index 17873445b2d7c2..b2058f644a770b 100644 --- a/src/libraries/System.Threading.Channels/tests/TestBase.cs +++ b/src/libraries/System.Threading.Channels/tests/TestBase.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using Xunit; @@ -10,6 +12,12 @@ namespace System.Threading.Channels.Tests { public abstract class TestBase { + public static IEnumerable ThreeBools => + from b1 in new[] { false, true } + from b2 in new[] { false, true } + from b3 in new[] { false, true } + select new object[] { b1, b2, b3 }; + protected void AssertSynchronouslyCanceled(Task task, CancellationToken token) { Assert.Equal(TaskStatus.Canceled, task.Status);