From 16cf9914d778fe28bd110c9902596e3349df517b Mon Sep 17 00:00:00 2001 From: danmosemsft Date: Wed, 18 Jul 2018 13:16:12 -0700 Subject: [PATCH] Port of 'Ensure ConcurrentBag's TryTake is linearizable' again --- .../Collections/Concurrent/ConcurrentBag.cs | 177 ++++++++++++------ .../tests/ConcurrentBagTests.cs | 56 ++++++ 2 files changed, 176 insertions(+), 57 deletions(-) diff --git a/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentBag.cs b/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentBag.cs index 4cc955c76cc0..6554f6133141 100644 --- a/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentBag.cs +++ b/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentBag.cs @@ -32,9 +32,11 @@ namespace System.Collections.Concurrent public class ConcurrentBag : IProducerConsumerCollection, IReadOnlyCollection { /// The per-bag, per-thread work-stealing queues. - private ThreadLocal _locals; + private readonly ThreadLocal _locals; /// The head work stealing queue in a linked list of queues. private volatile WorkStealingQueue _workStealingQueues; + /// Number of times any list transitions from empty to non-empty. + private long _emptyToNonEmptyListTransitionCount; /// Initializes a new instance of the class. public ConcurrentBag() @@ -62,7 +64,7 @@ public ConcurrentBag(IEnumerable collection) WorkStealingQueue queue = GetCurrentThreadWorkStealingQueue(forceCreate: true); foreach (T item in collection) { - queue.LocalPush(item); + queue.LocalPush(item, ref _emptyToNonEmptyListTransitionCount); } } @@ -72,7 +74,9 @@ public ConcurrentBag(IEnumerable collection) /// The object to be added to the /// . The value can be a null reference /// (Nothing in Visual Basic) for reference types. - public void Add(T item) => GetCurrentThreadWorkStealingQueue(forceCreate: true).LocalPush(item); + public void Add(T item) => + GetCurrentThreadWorkStealingQueue(forceCreate: true) + .LocalPush(item, ref _emptyToNonEmptyListTransitionCount); /// /// Attempts to add an object to the . @@ -176,22 +180,55 @@ private bool TrySteal(out T result, bool take) CDSCollectionETWBCLProvider.Log.ConcurrentBag_TryPeekSteals(); } - // If there's no local queue for this thread, just start from the head queue - // and try to steal from each queue until we get a result. - WorkStealingQueue localQueue = GetCurrentThreadWorkStealingQueue(forceCreate: false); - if (localQueue == null) + while (true) { - return TryStealFromTo(_workStealingQueues, null, out result, take); - } + // We need to track whether any lists transition from empty to non-empty both before + // and after we attempt the steal in case we don't get an item: + // + // If we don't get an item, we need to handle the possibility of a race condition that led to + // an item being added to a list after we already looked at it in a way that breaks + // linearizability. For example, say there are three threads 0, 1, and 2, each with their own + // list that's currently empty. We could then have the following series of operations: + // - Thread 2 adds an item, such that there's now 1 item in the bag. + // - Thread 1 sees that the count is 1 and does a Take. Its local list is empty, so it tries to + // steal from list 0, but it's empty. Before it can steal from Thread 2, it's pre-empted. + // - Thread 0 adds an item. The count is now 2. + // - Thread 2 takes an item, which comes from its local queue. The count is now 1. + // - Thread 1 continues to try to steal from 2, finds it's empty, and fails its take, even though + // at any given time during its take the count was >= 1. Oops. + // This is particularly problematic for wrapper types that track count using their own synchronization, + // e.g. BlockingCollection, and thus expect that a take will always be successful if the number of items + // is known to be > 0. + // + // We work around this by looking at the number of times any list transitions from == 0 to > 0, + // checking that before and after the steal attempts. We don't care about > 0 to > 0 transitions, + // because a steal from a list with > 0 elements would have been successful. + long initialEmptyToNonEmptyCounts = Interlocked.Read(ref _emptyToNonEmptyListTransitionCount); + + // If there's no local queue for this thread, just start from the head queue + // and try to steal from each queue until we get a result. If there is a local queue from this thread, + // then start from the next queue after it, and then iterate around back from the head to this queue, + // not including it. + WorkStealingQueue localQueue = GetCurrentThreadWorkStealingQueue(forceCreate: false); + bool gotItem = localQueue == null ? + TryStealFromTo(_workStealingQueues, null, out result, take) : + (TryStealFromTo(localQueue._nextQueue, null, out result, take) || TryStealFromTo(_workStealingQueues, localQueue, out result, take)); + if (gotItem) + { + return true; + } - // If there is a local queue from this thread, then start from the next queue - // after it, and then iterate around back from the head to this queue, not including it. - return - TryStealFromTo(localQueue._nextQueue, null, out result, take) || - TryStealFromTo(_workStealingQueues, localQueue, out result, take); + if (Interlocked.Read(ref _emptyToNonEmptyListTransitionCount) == initialEmptyToNonEmptyCounts) + { + // The version number matched, so we didn't get an item and we're confident enough + // in our steal attempt to say so. + return false; + } - // TODO: Investigate storing the queues in an array instead of a linked list, and then - // randomly choosing a starting location from which to start iterating. + // Some list transitioned from empty to non-empty between just before the steal and now. + // Since we don't know if it caused a race condition like the above description, we + // have little choice but to try to steal again. + } } /// @@ -684,7 +721,7 @@ internal bool IsEmpty /// Add new item to the tail of the queue. /// /// The item to add. - internal void LocalPush(T item) + internal void LocalPush(T item, ref long emptyToNonEmptyListTransitionCount) { Debug.Assert(Environment.CurrentManagedThreadId == _ownerThreadId); bool lockTaken = false; @@ -701,7 +738,7 @@ internal void LocalPush(T item) _currentOp = (int)Operation.None; // set back to None temporarily to avoid a deadlock lock (this) { - Debug.Assert(_tailIndex == int.MaxValue, "No other thread should be changing _tailIndex"); + Debug.Assert(_tailIndex == tail, "No other thread should be changing _tailIndex"); // Rather than resetting to zero, we'll just mask off the bits we don't care about. // This way we don't need to rearrange the items already in the queue; they'll be found @@ -711,22 +748,31 @@ internal void LocalPush(T item) // bits are set, so all of the bits we're keeping will also be set. Thus it's impossible // for the head to end up > than the tail, since you can't set any more bits than all of them. _headIndex = _headIndex & _mask; - _tailIndex = tail = _tailIndex & _mask; + _tailIndex = tail = tail & _mask; Debug.Assert(_headIndex <= _tailIndex); - _currentOp = (int)Operation.Add; + Interlocked.Exchange(ref _currentOp, (int)Operation.Add); // ensure subsequent reads aren't reordered before this } } - // We'd like to take the fast path that doesn't require locking, if possible. It's not possible if another - // thread is currently requesting that the whole bag synchronize, e.g. a ToArray operation. It's also - // not possible if there are fewer than two spaces available. One space is necessary for obvious reasons: - // to store the element we're trying to push. The other is necessary due to synchronization with steals. - // A stealing thread first increments _headIndex to reserve the slot at its old value, and then tries to - // read from that slot. We could potentially have a race condition whereby _headIndex is incremented just - // before this check, in which case we could overwrite the element being stolen as that slot would appear - // to be empty. Thus, we only allow the fast path if there are two empty slots. - if (!_frozen && tail < (_headIndex + _mask)) + // We'd like to take the fast path that doesn't require locking, if possible. It's not possible if: + // - another thread is currently requesting that the whole bag synchronize, e.g. a ToArray operation + // - if there are fewer than two spaces available. One space is necessary for obvious reasons: + // to store the element we're trying to push. The other is necessary due to synchronization with steals. + // A stealing thread first increments _headIndex to reserve the slot at its old value, and then tries to + // read from that slot. We could potentially have a race condition whereby _headIndex is incremented just + // before this check, in which case we could overwrite the element being stolen as that slot would appear + // to be empty. Thus, we only allow the fast path if there are two empty slots. + // - if there <= 1 elements in the list. We need to be able to successfully track transitions from + // empty to non-empty in a way that other threads can check, and we can only do that tracking + // correctly if we synchronize with steals when it's possible a steal could take the last item + // in the list just as we're adding. It's possible at this point that there's currently an active steal + // operation happening but that it hasn't yet incremented the head index, such that we could read a smaller + // than accurate by 1 value for the head. However, only one steal could possibly be doing so, as steals + // take the lock, and another steal couldn't then increment the header further because it'll see that + // there's currently an add operation in progress and wait until the add completes. + int head = _headIndex; // read after _currentOp set to Add + if (!_frozen && head < tail - 1 & tail < (head + _mask)) { _array[tail & _mask] = item; _tailIndex = tail + 1; @@ -737,8 +783,8 @@ internal void LocalPush(T item) _currentOp = (int)Operation.None; // set back to None to avoid a deadlock Monitor.Enter(this, ref lockTaken); - int head = _headIndex; - int count = _tailIndex - _headIndex; + head = _headIndex; + int count = tail - head; // this count is stable, as we're holding the lock // If we're full, expand the array. if (count >= _mask) @@ -767,6 +813,14 @@ internal void LocalPush(T item) _array[tail & _mask] = item; _tailIndex = tail + 1; + // Now that the item has been added, if we were at 0 (now at 1) item, + // increase the empty to non-empty transition count. + if (count == 0) + { + // We just transitioned from empty to non-empty, so increment the transition count. + Interlocked.Increment(ref emptyToNonEmptyListTransitionCount); + } + // Update the count to avoid overflow. We can trust _stealCount here, // as we're inside the lock and it's only manipulated there. _addTakeCount -= _stealCount; @@ -908,41 +962,50 @@ internal bool TryLocalPeek(out T result) /// true to take the item; false to simply peek at it internal bool TrySteal(out T result, bool take) { - // Fast-path check to see if the queue is empty. - if (_headIndex < _tailIndex) + lock (this) { - // Anything other than empty requires synchronization. - lock (this) + int head = _headIndex; // _headIndex is only manipulated under the lock + if (take) { - int head = _headIndex; - if (take) + // If there are <= 2 items in the list, we need to ensure no add operation + // is in progress, as add operations need to accurately count transitions + // from empty to non-empty, and they can only do that if there are no concurrent + // steal operations happening at the time. + if (head < _tailIndex - 1 && _currentOp != (int)Operation.Add) { - // Increment head to tentatively take an element: a full fence is used to ensure the read - // of _tailIndex doesn't move earlier, as otherwise we could potentially end up stealing - // the same element that's being popped locally. - Interlocked.Exchange(ref _headIndex, unchecked(head + 1)); - - // If there's an element to steal, do it. - if (head < _tailIndex) + var spinner = new SpinWait(); + do { - int idx = head & _mask; - result = _array[idx]; - _array[idx] = default(T); - _stealCount++; - return true; - } - else - { - // We contended with the local thread and lost the race, so restore the head. - _headIndex = head; + spinner.SpinOnce(); } + while (_currentOp == (int)Operation.Add); } - else if (head < _tailIndex) + + // Increment head to tentatively take an element: a full fence is used to ensure the read + // of _tailIndex doesn't move earlier, as otherwise we could potentially end up stealing + // the same element that's being popped locally. + Interlocked.Exchange(ref _headIndex, unchecked(head + 1)); + + // If there's an element to steal, do it. + if (head < _tailIndex) { - // Peek, if there's an element available - result = _array[head & _mask]; + int idx = head & _mask; + result = _array[idx]; + _array[idx] = default(T); + _stealCount++; return true; } + else + { + // We contended with the local thread and lost the race, so restore the head. + _headIndex = head; + } + } + else if (head < _tailIndex) + { + // Peek, if there's an element available + result = _array[head & _mask]; + return true; } } diff --git a/src/System.Collections.Concurrent/tests/ConcurrentBagTests.cs b/src/System.Collections.Concurrent/tests/ConcurrentBagTests.cs index c18ecc524a47..3bce75bf997c 100644 --- a/src/System.Collections.Concurrent/tests/ConcurrentBagTests.cs +++ b/src/System.Collections.Concurrent/tests/ConcurrentBagTests.cs @@ -92,6 +92,62 @@ public static void AddManyItems_ThenTakeOnDifferentThread_ItemsOutputInExpectedO }, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default).GetAwaiter().GetResult(); } + [Fact] + public static void SingleProducerAdding_MultiConsumerTaking_SemaphoreThrottling_AllTakesSucceed() + { + var bag = new ConcurrentBag(); + var s = new SemaphoreSlim(0); + CountdownEvent ce = null; + const int ItemCount = 200_000; + + int producerNextValue = 0; + Action producer = null; + producer = delegate + { + ThreadPool.QueueUserWorkItem(delegate + { + bag.Add(producerNextValue++); + s.Release(); + if (producerNextValue < ItemCount) + { + producer(); + } + else + { + ce.Signal(); + } + }); + }; + + int consumed = 0; + Action consumer = null; + consumer = delegate + { + ThreadPool.QueueUserWorkItem(delegate + { + if (s.Wait(0)) + { + Assert.True(bag.TryTake(out _), "There's an item available, but we couldn't take it."); + Interlocked.Increment(ref consumed); + } + else if (Volatile.Read(ref consumed) >= ItemCount) + { + ce.Signal(); + return; + } + + consumer(); + }); + }; + + // one producer, two consumers + ce = new CountdownEvent(3); + producer(); + consumer(); + consumer(); + ce.Wait(); + } + [Theory] [InlineData(0)] [InlineData(1)]