From fffd49fe1c35544b388322608d1e16374f3832e6 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sat, 16 May 2015 06:05:45 -0400 Subject: [PATCH] Add ConcurrentDictionary GetOrAdd/AddOrUpdate overloads with generic arg Adds one overload to each of GetOrAdd and AddOrUpdate. These overloads accept a generic argument that is passed through to any invocations of the supplied delegates, enabling developers to avoid delegate/closure allocations when more input is needed than just the key or existing value. For AddOrUpdate, there are two existing overloads with delegates; this only provide a new overload for the one that accepts two delegates. --- .../Concurrent/ConcurrentDictionary.cs | 94 ++++++++++++++++++- .../tests/ConcurrentDictionaryTests.cs | 49 +++++++--- 2 files changed, 126 insertions(+), 17 deletions(-) diff --git a/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs b/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs index 212d5f44196c..3233dabcb52e 100644 --- a/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs +++ b/src/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs @@ -973,6 +973,37 @@ public TValue GetOrAdd(TKey key, Func valueFactory) return resultingValue; } + /// + /// Adds a key/value pair to the + /// if the key does not already exist. + /// + /// The key of the element to add. + /// The function used to generate a value for the key + /// An argument value to pass into . + /// is a null reference + /// (Nothing in Visual Basic). + /// is a null reference + /// (Nothing in Visual Basic). + /// The dictionary contains too many + /// elements. + /// The value for the key. This will be either the existing value for the key if the + /// key is already in the dictionary, or the new value for the key as returned by valueFactory + /// if the key was not in the dictionary. + public TValue GetOrAdd(TKey key, Func valueFactory, TArg factoryArgument) + { + if (key == null) throw new ArgumentNullException("key"); + if (valueFactory == null) throw new ArgumentNullException("valueFactory"); + + int hashcode = _comparer.GetHashCode(key); + + TValue resultingValue; + if (!TryGetValueInternal(key, hashcode, out resultingValue)) + { + TryAddInternal(key, hashcode, valueFactory(key, factoryArgument), false, true, out resultingValue); + } + return resultingValue; + } + /// /// Adds a key/value pair to the /// if the key does not already exist. @@ -999,6 +1030,59 @@ public TValue GetOrAdd(TKey key, TValue value) return resultingValue; } + /// + /// Adds a key/value pair to the if the key does not already + /// exist, or updates a key/value pair in the if the key + /// already exists. + /// + /// The key to be added or whose value should be updated + /// The function used to generate a value for an absent key + /// The function used to generate a new value for an existing key + /// based on the key's existing value + /// An argument to pass into and . + /// is a null reference + /// (Nothing in Visual Basic). + /// is a null reference + /// (Nothing in Visual Basic). + /// is a null reference + /// (Nothing in Visual Basic). + /// The dictionary contains too many + /// elements. + /// The new value for the key. This will be either be the result of addValueFactory (if the key was + /// absent) or the result of updateValueFactory (if the key was present). + public TValue AddOrUpdate( + TKey key, Func addValueFactory, Func updateValueFactory, TArg factoryArgument) + { + if (key == null) throw new ArgumentNullException("key"); + if (addValueFactory == null) throw new ArgumentNullException("addValueFactory"); + if (updateValueFactory == null) throw new ArgumentNullException("updateValueFactory"); + + int hashcode = _comparer.GetHashCode(key); + + while (true) + { + TValue oldValue; + if (TryGetValueInternal(key, hashcode, out oldValue)) + { + // key exists, try to update + TValue newValue = updateValueFactory(key, oldValue, factoryArgument); + if (TryUpdateInternal(key, hashcode, newValue, oldValue)) + { + return newValue; + } + } + else + { + // key doesn't exist, try to add + TValue resultingValue; + if (TryAddInternal(key, hashcode, addValueFactory(key, factoryArgument), false, true, out resultingValue)) + { + return resultingValue; + } + } + } + } + /// /// Adds a key/value pair to the if the key does not already /// exist, or updates a key/value pair in the if the key @@ -1030,16 +1114,17 @@ public TValue AddOrUpdate(TKey key, Func addValueFactory, Func { TValue oldValue; if (TryGetValueInternal(key, hashcode, out oldValue)) - //key exists, try to update { + // key exists, try to update TValue newValue = updateValueFactory(key, oldValue); if (TryUpdate(key, newValue, oldValue)) { return newValue; } } - else //try add + else { + // key doesn't exist, try to add TValue resultingValue; if (TryAddInternal(key, hashcode, addValue, false, true, out resultingValue)) { diff --git a/src/System.Collections.Concurrent/tests/ConcurrentDictionaryTests.cs b/src/System.Collections.Concurrent/tests/ConcurrentDictionaryTests.cs index 8dea0f1a6f7e..0764bb5bf5ac 100644 --- a/src/System.Collections.Concurrent/tests/ConcurrentDictionaryTests.cs +++ b/src/System.Collections.Concurrent/tests/ConcurrentDictionaryTests.cs @@ -447,25 +447,33 @@ private static void TestGetOrAddOrUpdate(int cLevel, int initSize, int threads, { if (isAdd) { - //call either of the two overloads of GetOrAdd - if (j + ii % 2 == 0) + //call one of the overloads of GetOrAdd + switch (j % 3) { - dict.GetOrAdd(j, -j); - } - else - { - dict.GetOrAdd(j, x => -x); + case 0: + dict.GetOrAdd(j, -j); + break; + case 1: + dict.GetOrAdd(j, x => -x); + break; + case 2: + dict.GetOrAdd(j, (x,m) => x * m, -1); + break; } } else { - if (j + ii % 2 == 0) - { - dict.AddOrUpdate(j, -j, (k, v) => -j); - } - else + switch (j % 3) { - dict.AddOrUpdate(j, (k) => -k, (k, v) => -k); + case 0: + dict.AddOrUpdate(j, -j, (k, v) => -j); + break; + case 1: + dict.AddOrUpdate(j, (k) => -k, (k, v) => -k); + break; + case 2: + dict.AddOrUpdate(j, (k,m) => k*m, (k, v, m) => k * m, -1); + break; } } } @@ -621,6 +629,12 @@ public static void TestExceptions() () => dictionary[null] = 1); // "TestExceptions: FAILED. this[] didn't throw ANE when null key is passed"); + Assert.Throws( + () => dictionary.GetOrAdd(null, (k,m) => 0, 0)); + // "TestExceptions: FAILED. GetOrAdd didn't throw ANE when null key is passed"); + Assert.Throws( + () => dictionary.GetOrAdd("1", null, 0)); + // "TestExceptions: FAILED. GetOrAdd didn't throw ANE when null valueFactory is passed"); Assert.Throws( () => dictionary.GetOrAdd(null, (k) => 0)); // "TestExceptions: FAILED. GetOrAdd didn't throw ANE when null key is passed"); @@ -631,6 +645,15 @@ public static void TestExceptions() () => dictionary.GetOrAdd(null, 0)); // "TestExceptions: FAILED. GetOrAdd didn't throw ANE when null key is passed"); + Assert.Throws( + () => dictionary.AddOrUpdate(null, (k, m) => 0, (k, v, m) => 0, 42)); + // "TestExceptions: FAILED. AddOrUpdate didn't throw ANE when null key is passed"); + Assert.Throws( + () => dictionary.AddOrUpdate("1", (k, m) => 0, null, 42)); + // "TestExceptions: FAILED. AddOrUpdate didn't throw ANE when null updateFactory is passed"); + Assert.Throws( + () => dictionary.AddOrUpdate("1", null, (k, v, m) => 0, 42)); + // "TestExceptions: FAILED. AddOrUpdate didn't throw ANE when null addFactory is passed"); Assert.Throws( () => dictionary.AddOrUpdate(null, (k) => 0, (k, v) => 0)); // "TestExceptions: FAILED. AddOrUpdate didn't throw ANE when null key is passed");