From 96a84cc3ea500c1bb92ee68c3304583e55170d22 Mon Sep 17 00:00:00 2001 From: James Ko Date: Sun, 11 Dec 2016 14:15:09 -0500 Subject: [PATCH 1/3] Run the selector during Select.Count chains. --- src/System.Linq/src/System/Linq/Select.cs | 86 +++++++++++++++++++++-- 1 file changed, 81 insertions(+), 5 deletions(-) diff --git a/src/System.Linq/src/System/Linq/Select.cs b/src/System.Linq/src/System/Linq/Select.cs index 2ed2dd85f85a..796208116e8d 100644 --- a/src/System.Linq/src/System/Linq/Select.cs +++ b/src/System.Linq/src/System/Linq/Select.cs @@ -155,9 +155,38 @@ public TResult[] ToArray() return builder.ToArray(); } - public List ToList() => new List(this); + public List ToList() + { + var list = new List(); + + foreach (TSource item in _source) + { + list.Add(_selector(item)); + } + + return list; + } + + public int GetCount(bool onlyIfCheap) + { + if (onlyIfCheap) + { + return -1; + } - public int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : _source.Count(); + int count = 0; + + foreach (TSource item in _source) + { + _selector(item); + checked + { + count++; + } + } + + return count; + } } internal sealed class SelectArrayIterator : Iterator, IPartition @@ -226,6 +255,14 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + if (!onlyIfCheap) + { + foreach (TSource item in _source) + { + _selector(item); + } + } + return _source.Length; } @@ -351,7 +388,17 @@ public List ToList() public int GetCount(bool onlyIfCheap) { - return _source.Count; + int count = _source.Count; + + if (!onlyIfCheap) + { + for (int i = 0; i < count; i++) + { + _selector(_source[i]); + } + } + + return count; } public IPartition Skip(int count) @@ -491,7 +538,17 @@ public List ToList() public int GetCount(bool onlyIfCheap) { - return _source.Count; + int count = _source.Count; + + if (!onlyIfCheap) + { + for (int i = 0; i < count; i++) + { + _selector(_source[i]); + } + } + + return count; } public IPartition Skip(int count) @@ -703,6 +760,14 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + if (!onlyIfCheap) + { + foreach (TSource item in _source) + { + _selector(item); + } + } + return _source.GetCount(onlyIfCheap); } } @@ -852,7 +917,18 @@ public List ToList() public int GetCount(bool onlyIfCheap) { - return Count; + int count = Count; + + if (!onlyIfCheap) + { + int end = _minIndexInclusive + count; + for (int i = _minIndexInclusive; i != end; ++i) + { + _selector(_source[i]); + } + } + + return count; } } } From 6787c112ed65c97ea870c2ad8968535a66aae20f Mon Sep 17 00:00:00 2001 From: James Ko Date: Sun, 11 Dec 2016 15:35:49 -0500 Subject: [PATCH 2/3] Add tests for the new changes. --- src/System.Linq/tests/SelectTests.cs | 60 ++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/src/System.Linq/tests/SelectTests.cs b/src/System.Linq/tests/SelectTests.cs index 35aed30dcc04..2714ab863962 100644 --- a/src/System.Linq/tests/SelectTests.cs +++ b/src/System.Linq/tests/SelectTests.cs @@ -1168,5 +1168,65 @@ public static IEnumerable MoveNextAfterDisposeData() yield return new object[] { new int[1] }; yield return new object[] { Enumerable.Range(1, 30) }; } + + [Theory] + [MemberData(nameof(RunSelectorDuringCountData))] + public void RunSelectorDuringCount(IEnumerable source) + { + int timesRun = 0; + var selected = source.Select(i => timesRun++); + selected.Count(); + + Assert.Equal(source.Count(), timesRun); + } + + // [Theory] + [MemberData(nameof(RunSelectorDuringCountData))] + public void RunSelectorDuringPartitionCount(IEnumerable source) + { + int timesRun = 0; + + var selected = source.Select(i => timesRun++); + + if (source.Any()) + { + selected.Skip(1).Count(); + Assert.Equal(source.Count() - 1, timesRun); + + selected.Take(source.Count() - 1).Count(); + Assert.Equal(source.Count() * 2 - 2, timesRun); + } + } + + public static IEnumerable RunSelectorDuringCountData() + { + var transforms = new Func, IEnumerable>[] + { + e => e, + e => ForceNotCollection(e), + e => ForceNotCollection(e).Skip(1), + e => ForceNotCollection(e).Where(i => true), + e => e.ToArray().Where(i => true), + e => e.ToList().Where(i => true), + e => new LinkedList(e).Where(i => true), + e => e.Select(i => i), + e => e.Take(e.Count()), + e => e.ToArray(), + e => e.ToList(), + e => new LinkedList(e) // Implements IList. + }; + + var r = new Random(unchecked((int)0x984bf1a3)); + + for (int i = 0; i <= 5; i++) + { + var enumerable = Enumerable.Range(1, i).Select(_ => r.Next()); + + foreach (var transform in transforms) + { + yield return new object[] { transform(enumerable) }; + } + } + } } } From 5d629d10ab771824ec260f9ae69cab166fdf4cb1 Mon Sep 17 00:00:00 2001 From: James Ko Date: Sun, 11 Dec 2016 16:26:13 -0500 Subject: [PATCH 3/3] Add clarifying comments. --- src/System.Linq/src/System/Linq/Select.cs | 18 ++++++++++++++++++ src/System.Linq/src/System/Linq/Where.cs | 9 +++++++++ 2 files changed, 27 insertions(+) diff --git a/src/System.Linq/src/System/Linq/Select.cs b/src/System.Linq/src/System/Linq/Select.cs index 796208116e8d..006a0c76888b 100644 --- a/src/System.Linq/src/System/Linq/Select.cs +++ b/src/System.Linq/src/System/Linq/Select.cs @@ -169,6 +169,9 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (onlyIfCheap) { return -1; @@ -255,6 +258,9 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (!onlyIfCheap) { foreach (TSource item in _source) @@ -388,6 +394,9 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + int count = _source.Count; if (!onlyIfCheap) @@ -538,6 +547,9 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + int count = _source.Count; if (!onlyIfCheap) @@ -760,6 +772,9 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (!onlyIfCheap) { foreach (TSource item in _source) @@ -917,6 +932,9 @@ public List ToList() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + int count = Count; if (!onlyIfCheap) diff --git a/src/System.Linq/src/System/Linq/Where.cs b/src/System.Linq/src/System/Linq/Where.cs index efeb04e7e584..823ab0157ba0 100644 --- a/src/System.Linq/src/System/Linq/Where.cs +++ b/src/System.Linq/src/System/Linq/Where.cs @@ -434,6 +434,9 @@ public override Iterator Clone() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (onlyIfCheap) { return -1; @@ -536,6 +539,9 @@ public override Iterator Clone() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (onlyIfCheap) { return -1; @@ -658,6 +664,9 @@ public override void Dispose() public int GetCount(bool onlyIfCheap) { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + if (onlyIfCheap) { return -1;