diff --git a/src/System.Linq/src/System/Linq/Range.SpeedOpt.cs b/src/System.Linq/src/System/Linq/Range.SpeedOpt.cs index b6437812c673..29a6f0c02507 100644 --- a/src/System.Linq/src/System/Linq/Range.SpeedOpt.cs +++ b/src/System.Linq/src/System/Linq/Range.SpeedOpt.cs @@ -12,7 +12,7 @@ private sealed partial class RangeIterator : IPartition { public override IEnumerable Select(Func selector) { - return new SelectIPartitionIterator(this, selector); + return new SelectRangeIterator(_start, _end, selector); } public int[] ToArray() diff --git a/src/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs b/src/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs index 1764351bc831..a55cfe209130 100644 --- a/src/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs +++ b/src/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; +using System.Diagnostics; namespace System.Linq { @@ -39,6 +40,8 @@ public List ToList() public IPartition Skip(int count) { + Debug.Assert(count > 0); + if (count >= _count) { return EmptyPartition.Instance; @@ -49,6 +52,8 @@ public IPartition Skip(int count) public IPartition Take(int count) { + Debug.Assert(count > 0); + if (count >= _count) { return this; diff --git a/src/System.Linq/src/System/Linq/Select.SpeedOpt.cs b/src/System.Linq/src/System/Linq/Select.SpeedOpt.cs index ff5a3a551556..b226a75abd79 100644 --- a/src/System.Linq/src/System/Linq/Select.SpeedOpt.cs +++ b/src/System.Linq/src/System/Linq/Select.SpeedOpt.cs @@ -125,8 +125,13 @@ public IPartition Skip(int count) return new SelectListPartitionIterator(_source, _selector, count, int.MaxValue); } - public IPartition Take(int count) => - count >= _source.Length ? (IPartition)this : new SelectListPartitionIterator(_source, _selector, 0, count - 1); + public IPartition Take(int count) + { + Debug.Assert(count > 0); + return count >= _source.Length ? + (IPartition)this : + new SelectListPartitionIterator(_source, _selector, 0, count - 1); + } public TResult TryGetElementAt(int index, out bool found) { @@ -157,6 +162,132 @@ public TResult TryGetLast(out bool found) } } + private sealed partial class SelectRangeIterator : Iterator, IPartition + { + private readonly int _start; + private readonly int _end; + private readonly Func _selector; + + public SelectRangeIterator(int start, int end, Func selector) + { + Debug.Assert(start < end); + Debug.Assert((end - start) <= int.MaxValue); + Debug.Assert(selector != null); + + _start = start; + _end = end; + _selector = selector; + } + + public override Iterator Clone() => + new SelectRangeIterator(_start, _end, _selector); + + public override bool MoveNext() + { + if (_state < 1 || _state == (_end - _start + 1)) + { + Dispose(); + return false; + } + + int index = _state++ - 1; + Debug.Assert(_start < _end - index); + _current = _selector(_start + index); + return true; + } + + public override IEnumerable Select(Func selector) => + new SelectRangeIterator(_start, _end, CombineSelectors(_selector, selector)); + + public TResult[] ToArray() + { + var results = new TResult[_end - _start]; + int srcIndex = _start; + for (int i = 0; i < results.Length; i++) + { + results[i] = _selector(srcIndex++); + } + + return results; + } + + public List ToList() + { + var results = new List(_end - _start); + for (int i = _start; i != _end; i++) + { + results.Add(_selector(i)); + } + + return results; + } + + public int GetCount(bool onlyIfCheap) + { + // In case someone uses Count() to force evaluation of the selector, + // run it provided `onlyIfCheap` is false. + if (!onlyIfCheap) + { + for (int i = _start; i != _end; i++) + { + _selector(i); + } + } + + return _end - _start; + } + + public IPartition Skip(int count) + { + Debug.Assert(count > 0); + + if (count >= (_end - _start)) + { + return EmptyPartition.Instance; + } + + return new SelectRangeIterator(_start + count, _end, _selector); + } + + public IPartition Take(int count) + { + Debug.Assert(count > 0); + + if (count >= (_end - _start)) + { + return this; + } + + return new SelectRangeIterator(_start, _start + count, _selector); + } + + public TResult TryGetElementAt(int index, out bool found) + { + if ((uint)index < (uint)(_end - _start)) + { + found = true; + return _selector(_start + index); + } + + found = false; + return default; + } + + public TResult TryGetFirst(out bool found) + { + Debug.Assert(_end > _start); + found = true; + return _selector(_start); + } + + public TResult TryGetLast(out bool found) + { + Debug.Assert(_end > _start); + found = true; + return _selector(_end - 1); + } + } + private sealed partial class SelectListIterator : IPartition { public TResult[] ToArray() @@ -212,8 +343,11 @@ public IPartition Skip(int count) return new SelectListPartitionIterator(_source, _selector, count, int.MaxValue); } - public IPartition Take(int count) => - new SelectListPartitionIterator(_source, _selector, 0, count - 1); + public IPartition Take(int count) + { + Debug.Assert(count > 0); + return new SelectListPartitionIterator(_source, _selector, 0, count - 1); + } public TResult TryGetElementAt(int index, out bool found) { @@ -308,8 +442,11 @@ public IPartition Skip(int count) return new SelectListPartitionIterator(_source, _selector, count, int.MaxValue); } - public IPartition Take(int count) => - new SelectListPartitionIterator(_source, _selector, 0, count - 1); + public IPartition Take(int count) + { + Debug.Assert(count > 0); + return new SelectListPartitionIterator(_source, _selector, 0, count - 1); + } public TResult TryGetElementAt(int index, out bool found) { @@ -413,8 +550,11 @@ public IPartition Skip(int count) return new SelectIPartitionIterator(_source.Skip(count), _selector); } - public IPartition Take(int count) => - new SelectIPartitionIterator(_source.Take(count), _selector); + public IPartition Take(int count) + { + Debug.Assert(count > 0); + return new SelectIPartitionIterator(_source.Take(count), _selector); + } public TResult TryGetElementAt(int index, out bool found) { @@ -579,6 +719,7 @@ public IPartition Skip(int count) public IPartition Take(int count) { + Debug.Assert(count > 0); int maxIndex = _minIndexInclusive + count - 1; return (uint)maxIndex >= (uint)_maxIndexInclusive ? this : new SelectListPartitionIterator(_source, _selector, _minIndexInclusive, maxIndex); }