diff --git a/src/System.Linq/src/System/Linq/Select.cs b/src/System.Linq/src/System/Linq/Select.cs index 2ed2dd85f85a..006a0c76888b 100644 --- a/src/System.Linq/src/System/Linq/Select.cs +++ b/src/System.Linq/src/System/Linq/Select.cs @@ -155,9 +155,41 @@ public TResult[] ToArray() return builder.ToArray(); } - public List ToList() => new List(this); + public List ToList() + { + var list = new List(); + + foreach (TSource item in _source) + { + list.Add(_selector(item)); + } + + return list; + } + + public int GetCount(bool onlyIfCheap) + { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + + if (onlyIfCheap) + { + return -1; + } - public int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : _source.Count(); + int count = 0; + + foreach (TSource item in _source) + { + _selector(item); + checked + { + count++; + } + } + + return count; + } } internal sealed class SelectArrayIterator : Iterator, IPartition @@ -226,6 +258,17 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + + if (!onlyIfCheap) + { + foreach (TSource item in _source) + { + _selector(item); + } + } + return _source.Length; } @@ -351,7 +394,20 @@ public List ToList() public int GetCount(bool onlyIfCheap) { - return _source.Count; + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + + int count = _source.Count; + + if (!onlyIfCheap) + { + for (int i = 0; i < count; i++) + { + _selector(_source[i]); + } + } + + return count; } public IPartition Skip(int count) @@ -491,7 +547,20 @@ public List ToList() public int GetCount(bool onlyIfCheap) { - return _source.Count; + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + + int count = _source.Count; + + if (!onlyIfCheap) + { + for (int i = 0; i < count; i++) + { + _selector(_source[i]); + } + } + + return count; } public IPartition Skip(int count) @@ -703,6 +772,17 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + + if (!onlyIfCheap) + { + foreach (TSource item in _source) + { + _selector(item); + } + } + return _source.GetCount(onlyIfCheap); } } @@ -852,7 +932,21 @@ public List ToList() public int GetCount(bool onlyIfCheap) { - return Count; + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + + int count = Count; + + if (!onlyIfCheap) + { + int end = _minIndexInclusive + count; + for (int i = _minIndexInclusive; i != end; ++i) + { + _selector(_source[i]); + } + } + + return count; } } } diff --git a/src/System.Linq/src/System/Linq/Where.cs b/src/System.Linq/src/System/Linq/Where.cs index efeb04e7e584..823ab0157ba0 100644 --- a/src/System.Linq/src/System/Linq/Where.cs +++ b/src/System.Linq/src/System/Linq/Where.cs @@ -434,6 +434,9 @@ public override Iterator Clone() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (onlyIfCheap) { return -1; @@ -536,6 +539,9 @@ public override Iterator Clone() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (onlyIfCheap) { return -1; @@ -658,6 +664,9 @@ public override void Dispose() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (onlyIfCheap) { return -1; diff --git a/src/System.Linq/tests/SelectTests.cs b/src/System.Linq/tests/SelectTests.cs index 35aed30dcc04..2714ab863962 100644 --- a/src/System.Linq/tests/SelectTests.cs +++ b/src/System.Linq/tests/SelectTests.cs @@ -1168,5 +1168,65 @@ public static IEnumerable MoveNextAfterDisposeData() yield return new object[] { new int[1] }; yield return new object[] { Enumerable.Range(1, 30) }; } + + [Theory] + [MemberData(nameof(RunSelectorDuringCountData))] + public void RunSelectorDuringCount(IEnumerable source) + { + int timesRun = 0; + var selected = source.Select(i => timesRun++); + selected.Count(); + + Assert.Equal(source.Count(), timesRun); + } + + // [Theory] + [MemberData(nameof(RunSelectorDuringCountData))] + public void RunSelectorDuringPartitionCount(IEnumerable source) + { + int timesRun = 0; + + var selected = source.Select(i => timesRun++); + + if (source.Any()) + { + selected.Skip(1).Count(); + Assert.Equal(source.Count() - 1, timesRun); + + selected.Take(source.Count() - 1).Count(); + Assert.Equal(source.Count() * 2 - 2, timesRun); + } + } + + public static IEnumerable RunSelectorDuringCountData() + { + var transforms = new Func, IEnumerable>[] + { + e => e, + e => ForceNotCollection(e), + e => ForceNotCollection(e).Skip(1), + e => ForceNotCollection(e).Where(i => true), + e => e.ToArray().Where(i => true), + e => e.ToList().Where(i => true), + e => new LinkedList(e).Where(i => true), + e => e.Select(i => i), + e => e.Take(e.Count()), + e => e.ToArray(), + e => e.ToList(), + e => new LinkedList(e) // Implements IList. + }; + + var r = new Random(unchecked((int)0x984bf1a3)); + + for (int i = 0; i <= 5; i++) + { + var enumerable = Enumerable.Range(1, i).Select(_ => r.Next()); + + foreach (var transform in transforms) + { + yield return new object[] { transform(enumerable) }; + } + } + } } }