From 66a2d2634b8db5e0c9ef8872c6ba92f55640fba3 Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Fri, 15 Jan 2016 11:59:22 +0000 Subject: [PATCH 1/9] Combine IArrayProvider and IListProvider Anything that can be one can be the other, so merge the two interfaces. Add ToList support to OrderedPartition. --- src/System.Linq/src/System/Linq/Enumerable.cs | 66 ++++++++++++------- src/System.Linq/tests/OrderedSubsetting.cs | 24 +++++++ 2 files changed, 65 insertions(+), 25 deletions(-) diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index 8da90277886f..492569abe486 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -530,7 +530,7 @@ public override IEnumerable Select(Func s } - internal sealed class SelectArrayIterator : Iterator, IArrayProvider, IListProvider + internal sealed class SelectArrayIterator : Iterator, IIListProvider { private readonly TSource[] _source; private readonly Func _selector; @@ -587,7 +587,7 @@ public List ToList() } } - internal sealed class SelectListIterator : Iterator, IArrayProvider, IListProvider + internal sealed class SelectListIterator : Iterator, IIListProvider { private readonly List _source; private readonly Func _selector; @@ -653,7 +653,7 @@ public List ToList() } } - internal sealed class SelectIListIterator : Iterator, IArrayProvider, IListProvider + internal sealed class SelectIListIterator : Iterator, IIListProvider { private readonly IList _source; private readonly Func _selector; @@ -1270,14 +1270,14 @@ public static IEnumerable AsEnumerable(this IEnumerable(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); - IArrayProvider arrayProvider = source as IArrayProvider; + IIListProvider arrayProvider = source as IIListProvider; return arrayProvider != null ? arrayProvider.ToArray() : new Buffer(source).ToArray(); } public static List ToList(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); - IListProvider listProvider = source as IListProvider; + IIListProvider listProvider = source as IIListProvider; return listProvider != null ? listProvider.ToList() : new List(source); } @@ -1787,7 +1787,7 @@ public static IEnumerable Range(int start, int count) return new RangeIterator(start, count); } - private sealed class RangeIterator : Iterator, IArrayProvider, IListProvider, IPartition + private sealed class RangeIterator : Iterator, IPartition { private readonly int _start; private readonly int _end; @@ -1901,7 +1901,7 @@ public static IEnumerable Repeat(TResult element, int count) return new RepeatIterator(element, count); } - private sealed class RepeatIterator : Iterator, IArrayProvider, IListProvider, IPartition + private sealed class RepeatIterator : Iterator, IPartition { private readonly int _count; private int _sent; @@ -3310,22 +3310,16 @@ public static decimal Average(this IEnumerable source, Func - /// An iterator that can produce an array through an optimized path. + /// An iterator that can produce an array or an through an optimized path. /// - internal interface IArrayProvider + internal interface IIListProvider { /// /// Produce an array of the sequence through an optimized path. /// /// The array. TElement[] ToArray(); - } - /// - /// An iterator that can produce a through an optimized path. - /// - internal interface IListProvider - { /// /// Produce a of the sequence through an optimized path. /// @@ -3358,7 +3352,7 @@ public interface ILookup : IEnumerable bool Contains(TKey key); } - public class Lookup : IEnumerable>, ILookup, IArrayProvider>, IListProvider> + public class Lookup : IEnumerable>, ILookup, IIListProvider> { private IEqualityComparer _comparer; private Grouping[] _groupings; @@ -3429,7 +3423,7 @@ public IEnumerator> GetEnumerator() } } - IGrouping[] IArrayProvider>.ToArray() + IGrouping[] IIListProvider>.ToArray() { IGrouping[] array = new IGrouping[_count]; int index = 0; @@ -3446,7 +3440,7 @@ IGrouping[] IArrayProvider>.ToArray() return array; } - List> IListProvider>.ToList() + List> IIListProvider>.ToList() { List> list = new List>(_count); Grouping g = _lastGrouping; @@ -3786,7 +3780,7 @@ IEnumerator IEnumerable.GetEnumerator() } } - internal class GroupedEnumerable : IEnumerable>, IArrayProvider>, IListProvider> + internal class GroupedEnumerable : IEnumerable>, IIListProvider> { private IEnumerable _source; private Func _keySelector; @@ -3816,18 +3810,18 @@ IEnumerator IEnumerable.GetEnumerator() public IGrouping[] ToArray() { - IArrayProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); + IIListProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); return lookup.ToArray(); } public List> ToList() { - IListProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); + IIListProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); return lookup.ToList(); } } - internal interface IPartition : IEnumerable, IArrayProvider + internal interface IPartition : IEnumerable, IIListProvider { IPartition Skip(int count); @@ -3846,7 +3840,7 @@ internal interface IPartition : IEnumerable, IArrayProvider< TElement LastOrDefault(); } - internal sealed class EmptyPartition : IPartition, IListProvider, IEnumerator + internal sealed class EmptyPartition : IPartition, IEnumerator { public EmptyPartition() { @@ -4015,9 +4009,14 @@ public TElement[] ToArray() { return _source.ToArray(_minIndex, _maxIndex); } + + public List ToList() + { + return _source.ToList(_minIndex, _maxIndex); + } } - internal abstract class OrderedEnumerable : IOrderedEnumerable, IArrayProvider, IListProvider, IPartition + internal abstract class OrderedEnumerable : IOrderedEnumerable, IPartition { internal IEnumerable source; @@ -4108,6 +4107,23 @@ internal TElement[] ToArray(int minIdx, int maxIdx) return array; } + internal List ToList(int minIdx, int maxIdx) + { + Buffer buffer = new Buffer(source); + int count = buffer.count; + if (count <= minIdx) return new List(0); + if (count <= maxIdx) maxIdx = count - 1; + if (minIdx == maxIdx) return new List(1) { GetEnumerableSorter().ElementAt(buffer.items, count, minIdx) }; + int[] map = SortedMap(buffer, minIdx, maxIdx); + List list = new List(maxIdx - minIdx + 1); + while (minIdx <= maxIdx) + { + list.Add(buffer.items[map[minIdx]]); + ++minIdx; + } + return list; + } + private EnumerableSorter GetEnumerableSorter() { return GetEnumerableSorter(null); @@ -4656,7 +4672,7 @@ internal struct Buffer internal Buffer(IEnumerable source) { - IArrayProvider iterator = source as IArrayProvider; + IIListProvider iterator = source as IIListProvider; if (iterator != null) { TElement[] array = iterator.ToArray(); diff --git a/src/System.Linq/tests/OrderedSubsetting.cs b/src/System.Linq/tests/OrderedSubsetting.cs index cf25c4b6203e..7d56573a7c8f 100644 --- a/src/System.Linq/tests/OrderedSubsetting.cs +++ b/src/System.Linq/tests/OrderedSubsetting.cs @@ -302,24 +302,48 @@ public void ToArray() Assert.Equal(Enumerable.Range(10, 20), Enumerable.Range(0, 100).Shuffle().OrderBy(i => i).Skip(10).Take(20).ToArray()); } + [Fact] + public void ToList() + { + Assert.Equal(Enumerable.Range(10, 20), Enumerable.Range(0, 100).Shuffle().OrderBy(i => i).Skip(10).Take(20).ToList()); + } + [Fact] public void EmptyToArray() { Assert.Empty(Enumerable.Range(0, 100).Shuffle().OrderBy(i => i).Skip(100).ToArray()); } + [Fact] + public void EmptyToList() + { + Assert.Empty(Enumerable.Range(0, 100).Shuffle().OrderBy(i => i).Skip(100).ToList()); + } + [Fact] public void AttemptedMoreArray() { Assert.Equal(Enumerable.Range(0, 20), Enumerable.Range(0, 20).Shuffle().OrderBy(i => i).Take(30).ToArray()); } + [Fact] + public void AttemptedMoreToList() + { + Assert.Equal(Enumerable.Range(0, 20), Enumerable.Range(0, 20).Shuffle().OrderBy(i => i).Take(30).ToList()); + } + [Fact] public void SingleElementToArray() { Assert.Equal(Enumerable.Repeat(10, 1), Enumerable.Range(0, 20).Shuffle().OrderBy(i => i).Skip(10).Take(1).ToArray()); } + [Fact] + public void SingleElementToList() + { + Assert.Equal(Enumerable.Repeat(10, 1), Enumerable.Range(0, 20).Shuffle().OrderBy(i => i).Skip(10).Take(1).ToList()); + } + [Fact] public void EnumeratorDoesntContinue() { From a180ec72d19b57a338a7265140b3dc9f1b304e27 Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Fri, 15 Jan 2016 15:54:42 +0000 Subject: [PATCH 2/9] Have IList optimised result of Skip() partitionable. Optimisation of Skip() for IList sources from #4551 fits with optimisations of Skip() and Take() for other sources from #2401. Combine the approaches, extending how the result of Skip() on a list handles subsequent operations. --- src/System.Linq/src/System/Linq/Enumerable.cs | 123 ++++++- src/System.Linq/tests/SkipTests.cs | 220 ++++++++++++ src/System.Linq/tests/TakeTests.cs | 322 +++++++++++++++++- 3 files changed, 659 insertions(+), 6 deletions(-) diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index 492569abe486..3df4114623c0 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -836,6 +836,8 @@ public static IEnumerable Take(this IEnumerable sourc if (count <= 0) return new EmptyPartition(); IPartition partition = source as IPartition; if (partition != null) return partition.Take(count); + IList sourceList = source as IList; + if (sourceList != null) return new SkipListIterator(sourceList, 0, count - 1); return TakeIterator(source, count); } @@ -889,14 +891,127 @@ public static IEnumerable Skip(this IEnumerable sourc IPartition partition = source as IPartition; if (partition != null) return partition.Skip(count); IList sourceList = source as IList; - return sourceList != null ? SkipList(sourceList, count) : SkipIterator(source, count); + return sourceList != null ? new SkipListIterator(sourceList, count, int.MaxValue) : SkipIterator(source, count); } - private static IEnumerable SkipList(IList source, int count) + private sealed class SkipListIterator : Iterator, IPartition { - while (count < source.Count) + private readonly IList _source; + private readonly int _minIndex; + private readonly int _maxIndex; + private int _index; + + public SkipListIterator(IList source, int minIndex, int maxIndex) + { + Debug.Assert(source != null); + Debug.Assert(minIndex >= 0); + Debug.Assert(minIndex <= maxIndex); + _source = source; + _minIndex = minIndex; + _maxIndex = maxIndex; + } + + public override Iterator Clone() + { + return new SkipListIterator(_source, _minIndex, _maxIndex); + } + + public override bool MoveNext() { - yield return source[count++]; + switch(state) + { + case 1: + _index = _minIndex; + state = 2; + goto case 2; + case 2: + if (_index <= _maxIndex && _index < _source.Count) + { + current = _source[_index]; + ++_index; + return true; + } + break; + } + Dispose(); + return false; + } + + public IPartition Skip(int count) + { + int minIndex = _minIndex + count; + return minIndex >= _maxIndex + ? (IPartition)new EmptyPartition() + : new SkipListIterator(_source, minIndex, _maxIndex); + } + + public IPartition Take(int count) + { + int maxIndex = _minIndex + count - 1; + if (maxIndex >= _maxIndex) maxIndex = _maxIndex; + return new SkipListIterator(_source, _minIndex, maxIndex); + } + + public TSource ElementAt(int index) + { + if ((uint)index > (uint)_maxIndex - _minIndex || index >= _source.Count - _minIndex) throw Error.ArgumentOutOfRange("index"); + return _source[_minIndex + index]; + } + + public TSource ElementAtOrDefault(int index) + { + return (uint)index > (uint)_maxIndex - _minIndex || index >= _source.Count - _minIndex ? default(TSource) : _source[_minIndex + index]; + } + + public TSource First() + { + if (_source.Count <= _minIndex) throw Error.NoElements(); + return _source[_minIndex]; + } + + public TSource FirstOrDefault() + { + return _source.Count <= _minIndex ? default(TSource) : _source[_minIndex]; + } + + public TSource Last() + { + int lastIndex = _source.Count - 1; + if (lastIndex < _minIndex) throw Error.NoElements(); + return _source[lastIndex > _maxIndex ? _maxIndex : lastIndex]; + } + + public TSource LastOrDefault() + { + int lastIndex = _source.Count - 1; + if (lastIndex < _minIndex) return default(TSource); + return _source[lastIndex > _maxIndex ? _maxIndex : lastIndex]; + } + + public TSource[] ToArray() + { + int lastIndex = _source.Count - 1; + if (lastIndex < _minIndex) return new TSource[0]; + if (lastIndex > _maxIndex) lastIndex = _maxIndex; + TSource[] array = new TSource[lastIndex - _minIndex + 1]; + int curIdx = _minIndex; + for (int i = 0; i != array.Length; ++i) + { + array[i] = _source[curIdx]; + ++curIdx; + } + return array; + } + + public List ToList() + { + int lastIndex = _source.Count - 1; + if (lastIndex < _minIndex) return new List(0); + if (lastIndex > _maxIndex) lastIndex = _maxIndex; + List list = new List(lastIndex - _minIndex + 1); + for (int i = _minIndex; i <= lastIndex; ++i) + list.Add(_source[i]); + return list; } } diff --git a/src/System.Linq/tests/SkipTests.cs b/src/System.Linq/tests/SkipTests.cs index a49570be9015..e8fa2c44b3c9 100644 --- a/src/System.Linq/tests/SkipTests.cs +++ b/src/System.Linq/tests/SkipTests.cs @@ -213,5 +213,225 @@ public void ForcedToEnumeratorDoesntEnumerateIList() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void FollowWithTake() + { + var source = new[] { 5, 6, 7, 8 }; + var expected = new[] { 6, 7 }; + Assert.Equal(expected, source.Skip(1).Take(2)); + } + + [Fact] + public void FollowWithTakeNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(5, 4); + var expected = new[] { 6, 7 }; + Assert.Equal(expected, source.Skip(1).Take(2)); + } + + [Fact] + public void FollowWithSkip() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var expected = new[] { 4, 5, 6 }; + Assert.Equal(expected, source.Skip(1).Skip(2).Skip(-4)); + } + + [Fact] + public void FollowWithSkipNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(1, 6); + var expected = new[] { 4, 5, 6 }; + Assert.Equal(expected, source.Skip(1).Skip(2).Skip(-4)); + } + + [Fact] + public void ElementAt() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var remaining = source.Skip(2); + Assert.Equal(3, remaining.ElementAt(0)); + Assert.Equal(4, remaining.ElementAt(1)); + Assert.Equal(6, remaining.ElementAt(3)); + Assert.Throws("index", () => remaining.ElementAt(-1)); + Assert.Throws("index", () => remaining.ElementAt(4)); + } + + [Fact] + public void ElementAtNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5, 6 }); + var remaining = source.Skip(2); + Assert.Equal(3, remaining.ElementAt(0)); + Assert.Equal(4, remaining.ElementAt(1)); + Assert.Equal(6, remaining.ElementAt(3)); + Assert.Throws("index", () => remaining.ElementAt(-1)); + Assert.Throws("index", () => remaining.ElementAt(4)); + } + + [Fact] + public void ElementAtOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var remaining = source.Skip(2); + Assert.Equal(3, remaining.ElementAtOrDefault(0)); + Assert.Equal(4, remaining.ElementAtOrDefault(1)); + Assert.Equal(6, remaining.ElementAtOrDefault(3)); + Assert.Equal(0, remaining.ElementAtOrDefault(-1)); + Assert.Equal(0, remaining.ElementAtOrDefault(4)); + } + + [Fact] + public void ElementAtOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5, 6 }); + var remaining = source.Skip(2); + Assert.Equal(3, remaining.ElementAtOrDefault(0)); + Assert.Equal(4, remaining.ElementAtOrDefault(1)); + Assert.Equal(6, remaining.ElementAtOrDefault(3)); + Assert.Equal(0, remaining.ElementAtOrDefault(-1)); + Assert.Equal(0, remaining.ElementAtOrDefault(4)); + } + + [Fact] + public void First() + { + var source = new []{ 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Skip(0).First()); + Assert.Equal(3, source.Skip(2).First()); + Assert.Equal(5, source.Skip(4).First()); + Assert.Throws(() => source.Skip(5).First()); + } + + [Fact] + public void FirstNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Skip(0).First()); + Assert.Equal(3, source.Skip(2).First()); + Assert.Equal(5, source.Skip(4).First()); + Assert.Throws(() => source.Skip(5).First()); + } + + [Fact] + public void FirstOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Skip(0).FirstOrDefault()); + Assert.Equal(3, source.Skip(2).FirstOrDefault()); + Assert.Equal(5, source.Skip(4).FirstOrDefault()); + Assert.Equal(0, source.Skip(5).FirstOrDefault()); + } + + [Fact] + public void FirstOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Skip(0).FirstOrDefault()); + Assert.Equal(3, source.Skip(2).FirstOrDefault()); + Assert.Equal(5, source.Skip(4).FirstOrDefault()); + Assert.Equal(0, source.Skip(5).FirstOrDefault()); + } + + [Fact] + public void Last() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(5, source.Skip(0).Last()); + Assert.Equal(5, source.Skip(1).Last()); + Assert.Equal(5, source.Skip(4).Last()); + Assert.Throws(() => source.Skip(5).Last()); + } + + [Fact] + public void LastNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(5, source.Skip(0).Last()); + Assert.Equal(5, source.Skip(1).Last()); + Assert.Equal(5, source.Skip(4).Last()); + Assert.Throws(() => source.Skip(5).Last()); + } + + [Fact] + public void LastOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(5, source.Skip(0).LastOrDefault()); + Assert.Equal(5, source.Skip(1).LastOrDefault()); + Assert.Equal(5, source.Skip(4).LastOrDefault()); + Assert.Equal(0, source.Skip(5).LastOrDefault()); + } + + [Fact] + public void LastOrDefaultNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(5, source.Skip(0).LastOrDefault()); + Assert.Equal(5, source.Skip(1).LastOrDefault()); + Assert.Equal(5, source.Skip(4).LastOrDefault()); + Assert.Equal(0, source.Skip(5).LastOrDefault()); + } + + [Fact] + public void ToArray() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Skip(0).ToArray()); + Assert.Equal(new[] { 2, 3, 4, 5 }, source.Skip(1).ToArray()); + Assert.Equal(5, source.Skip(4).ToArray().Single()); + Assert.Empty(source.Skip(5).ToArray()); + Assert.Empty(source.Skip(40).ToArray()); + } + + [Fact] + public void ToArrayNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Skip(0).ToArray()); + Assert.Equal(new[] { 2, 3, 4, 5 }, source.Skip(1).ToArray()); + Assert.Equal(5, source.Skip(4).ToArray().Single()); + Assert.Empty(source.Skip(5).ToArray()); + Assert.Empty(source.Skip(40).ToArray()); + } + + [Fact] + public void ToList() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Skip(0).ToList()); + Assert.Equal(new[] { 2, 3, 4, 5 }, source.Skip(1).ToList()); + Assert.Equal(5, source.Skip(4).ToList().Single()); + Assert.Empty(source.Skip(5).ToList()); + Assert.Empty(source.Skip(40).ToList()); + } + + [Fact] + public void ToListNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Skip(0).ToList()); + Assert.Equal(new[] { 2, 3, 4, 5 }, source.Skip(1).ToList()); + Assert.Equal(5, source.Skip(4).ToList().Single()); + Assert.Empty(source.Skip(5).ToList()); + Assert.Empty(source.Skip(40).ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + var source = new[] { 1, 2, 3, 4, 5 }; + var remaining = source.Skip(1); + Assert.Equal(remaining, remaining); + } + + [Fact] + public void RepeatEnumeratingNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + var remaining = source.Skip(1); + Assert.Equal(remaining, remaining); + } } } diff --git a/src/System.Linq/tests/TakeTests.cs b/src/System.Linq/tests/TakeTests.cs index e767a74dadbd..aeebf7400009 100644 --- a/src/System.Linq/tests/TakeTests.cs +++ b/src/System.Linq/tests/TakeTests.cs @@ -9,6 +9,12 @@ namespace System.Linq.Tests { public class TakeTests : EnumerableTests { + private static IEnumerable GuaranteeNotIList(IEnumerable source) + { + foreach (T element in source) + yield return element; + } + [Fact] public void SameResultsRepeatCallsIntQuery() { @@ -19,6 +25,16 @@ where x > Int32.MinValue Assert.Equal(q.Take(9), q.Take(9)); } + [Fact] + public void SameResultsRepeatCallsIntQueryIList() + { + var q = (from x in new[] { 9999, 0, 888, -1, 66, -777, 1, 2, -12345 } + where x > Int32.MinValue + select x).ToList(); + + Assert.Equal(q.Take(9), q.Take(9)); + } + [Fact] public void SameResultsRepeatCallsStringQuery() { @@ -29,6 +45,16 @@ public void SameResultsRepeatCallsStringQuery() Assert.Equal(q.Take(7), q.Take(7)); } + [Fact] + public void SameResultsRepeatCallsStringQueryIList() + { + var q = (from x in new[] { "!@#$%^", "C", "AAA", "", "Calling Twice", "SoS", String.Empty } + where !String.IsNullOrEmpty(x) + select x).ToList(); + + Assert.Equal(q.Take(7), q.Take(7)); + } + [Fact] public void SourceEmptyCountPositive() { @@ -36,6 +62,13 @@ public void SourceEmptyCountPositive() Assert.Empty(source.Take(5)); } + [Fact] + public void SourceEmptyCountPositiveNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(0, 0); + Assert.Empty(source.Take(5)); + } + [Fact] public void SourceNonEmptyCountNegative() { @@ -43,6 +76,13 @@ public void SourceNonEmptyCountNegative() Assert.Empty(source.Take(-5)); } + [Fact] + public void SourceNonEmptyCountNegativeNotIList() + { + var source = GuaranteeNotIList(new[] { 2, 5, 9, 1 }); + Assert.Empty(source.Take(-5)); + } + [Fact] public void SourceNonEmptyCountZero() { @@ -50,6 +90,13 @@ public void SourceNonEmptyCountZero() Assert.Empty(source.Take(0)); } + [Fact] + public void SourceNonEmptyCountZeroNotIList() + { + var source = GuaranteeNotIList(new []{ 2, 5, 9, 1 }); + Assert.Empty(source.Take(0)); + } + [Fact] public void SourceNonEmptyCountOne() { @@ -59,14 +106,31 @@ public void SourceNonEmptyCountOne() Assert.Equal(expected, source.Take(1)); } + [Fact] + public void SourceNonEmptyCountOneNotIList() + { + var source = GuaranteeNotIList(new[] { 2, 5, 9, 1 }); + int[] expected = { 2 }; + + Assert.Equal(expected, source.Take(1)); + } + [Fact] public void SourceNonEmptyTakeAllExactly() { int[] source = { 2, 5, 9, 1 }; - + Assert.Equal(source, source.Take(source.Length)); } + [Fact] + public void SourceNonEmptyTakeAllExactlyNotIList() + { + var source = GuaranteeNotIList(new[] { 2, 5, 9, 1 }); + + Assert.Equal(source, source.Take(source.Count())); + } + [Fact] public void SourceNonEmptyTakeAllButOne() { @@ -76,6 +140,15 @@ public void SourceNonEmptyTakeAllButOne() Assert.Equal(expected, source.Take(3)); } + [Fact] + public void SourceNonEmptyTakeAllButOneNotIList() + { + var source = GuaranteeNotIList(new[] { 2, 5, 9, 1 }); + int[] expected = { 2, 5, 9 }; + + Assert.Equal(expected, source.Take(3)); + } + [Fact] public void SourceNonEmptyTakeExcessive() { @@ -83,7 +156,15 @@ public void SourceNonEmptyTakeExcessive() Assert.Equal(source, source.Take(source.Length + 1)); } - + + [Fact] + public void SourceNonEmptyTakeExcessiveNotIList() + { + var source = GuaranteeNotIList(new int?[] { 2, 5, null, 9, 1 }); + + Assert.Equal(source, source.Take(source.Count() + 1)); + } + [Fact] public void ThrowsOnNullSource() { @@ -99,5 +180,242 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void ForcedToEnumeratorDoesntEnumerateIList() + { + var iterator = NumberRangeGuaranteedNotCollectionType(0, 3).ToList().Take(2); + // Don't insist on this behaviour, but check its correct if it happens + var en = iterator as IEnumerator; + Assert.False(en != null && en.MoveNext()); + } + + [Fact] + public void FollowWithTake() + { + var source = new[] { 5, 6, 7, 8 }; + var expected = new[] { 5, 6 }; + Assert.Equal(expected, source.Take(5).Take(3).Take(2).Take(40)); + } + + [Fact] + public void FollowWithTakeNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(5, 4); + var expected = new[] { 5, 6 }; + Assert.Equal(expected, source.Take(5).Take(3).Take(2)); + } + + [Fact] + public void FollowWithSkip() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var expected = new[] { 3, 4, 5 }; + Assert.Equal(expected, source.Take(5).Skip(2).Skip(-4)); + } + + [Fact] + public void FollowWithSkipNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(1, 6); + var expected = new[] { 3, 4, 5 }; + Assert.Equal(expected, source.Take(5).Skip(2).Skip(-4)); + } + + [Fact] + public void ElementAt() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var taken = source.Take(3); + Assert.Equal(1, taken.ElementAt(0)); + Assert.Equal(3, taken.ElementAt(2)); + Assert.Throws("index", () => taken.ElementAt(-1)); + Assert.Throws("index", () => taken.ElementAt(3)); + } + + [Fact] + public void ElementAtNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5, 6 }); + var taken = source.Take(3); + Assert.Equal(1, taken.ElementAt(0)); + Assert.Equal(3, taken.ElementAt(2)); + Assert.Throws("index", () => taken.ElementAt(-1)); + Assert.Throws("index", () => taken.ElementAt(3)); + } + + [Fact] + public void ElementAtOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var taken = source.Take(3); + Assert.Equal(1, taken.ElementAtOrDefault(0)); + Assert.Equal(3, taken.ElementAtOrDefault(2)); + Assert.Equal(0, taken.ElementAtOrDefault(-1)); + Assert.Equal(0, taken.ElementAtOrDefault(3)); + } + + [Fact] + public void ElementAtOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5, 6 }); + var taken = source.Take(3); + Assert.Equal(1, taken.ElementAtOrDefault(0)); + Assert.Equal(3, taken.ElementAtOrDefault(2)); + Assert.Equal(0, taken.ElementAtOrDefault(-1)); + Assert.Equal(0, taken.ElementAtOrDefault(3)); + } + + [Fact] + public void First() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Take(1).First()); + Assert.Equal(1, source.Take(4).First()); + Assert.Equal(1, source.Take(40).First()); + Assert.Throws(() => source.Take(0).First()); + Assert.Throws(() => source.Skip(5).Take(10).First()); + } + + [Fact] + public void FirstNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Take(1).First()); + Assert.Equal(1, source.Take(4).First()); + Assert.Equal(1, source.Take(40).First()); + Assert.Throws(() => source.Take(0).First()); + Assert.Throws(() => source.Skip(5).Take(10).First()); + } + + [Fact] + public void FirstOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Take(1).FirstOrDefault()); + Assert.Equal(1, source.Take(4).FirstOrDefault()); + Assert.Equal(1, source.Take(40).FirstOrDefault()); + Assert.Equal(0, source.Take(0).FirstOrDefault()); + Assert.Equal(0, source.Skip(5).Take(10).FirstOrDefault()); + } + + [Fact] + public void FirstOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Take(1).FirstOrDefault()); + Assert.Equal(1, source.Take(4).FirstOrDefault()); + Assert.Equal(1, source.Take(40).FirstOrDefault()); + Assert.Equal(0, source.Take(0).FirstOrDefault()); + Assert.Equal(0, source.Skip(5).Take(10).FirstOrDefault()); + } + + [Fact] + public void Last() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Take(1).Last()); + Assert.Equal(5, source.Take(5).Last()); + Assert.Equal(5, source.Take(40).Last()); + Assert.Throws(() => source.Take(0).Last()); + Assert.Throws(() => Array.Empty().Take(40).Last()); + } + + [Fact] + public void LastNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Take(1).Last()); + Assert.Equal(5, source.Take(5).Last()); + Assert.Equal(5, source.Take(40).Last()); + Assert.Throws(() => source.Take(0).Last()); + Assert.Throws(() => GuaranteeNotIList(Array.Empty()).Take(40).Last()); + } + + [Fact] + public void LastOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Take(1).LastOrDefault()); + Assert.Equal(5, source.Take(5).LastOrDefault()); + Assert.Equal(5, source.Take(40).LastOrDefault()); + Assert.Equal(0, source.Take(0).LastOrDefault()); + Assert.Equal(0, Array.Empty().Take(40).LastOrDefault()); + } + + [Fact] + public void LastOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Take(1).LastOrDefault()); + Assert.Equal(5, source.Take(5).LastOrDefault()); + Assert.Equal(5, source.Take(40).LastOrDefault()); + Assert.Equal(0, source.Take(0).LastOrDefault()); + Assert.Equal(0, GuaranteeNotIList(Array.Empty()).Take(40).LastOrDefault()); + } + + [Fact] + public void ToArray() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(5).ToArray()); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(40).ToArray()); + Assert.Equal(new[] { 1, 2, 3, 4 }, source.Take(4).ToArray()); + Assert.Equal(1, source.Take(1).ToArray().Single()); + Assert.Empty(source.Take(0).ToArray()); + Assert.Empty(source.Take(-10).ToArray()); + } + + [Fact] + public void ToArrayNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(5).ToArray()); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(40).ToArray()); + Assert.Equal(new[] { 1, 2, 3, 4 }, source.Take(4).ToArray()); + Assert.Equal(1, source.Take(1).ToArray().Single()); + Assert.Empty(source.Take(0).ToArray()); + Assert.Empty(source.Take(-10).ToArray()); + } + + [Fact] + public void ToList() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(5).ToList()); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(40).ToList()); + Assert.Equal(new[] { 1, 2, 3, 4 }, source.Take(4).ToList()); + Assert.Equal(1, source.Take(1).ToList().Single()); + Assert.Empty(source.Take(0).ToList()); + Assert.Empty(source.Take(-10).ToList()); + } + + [Fact] + public void ToListNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(5).ToList()); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(40).ToList()); + Assert.Equal(new[] { 1, 2, 3, 4 }, source.Take(4).ToList()); + Assert.Equal(1, source.Take(1).ToList().Single()); + Assert.Empty(source.Take(0).ToList()); + Assert.Empty(source.Take(-10).ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + var source = new[] { 1, 2, 3, 4, 5 }; + var taken = source.Take(3); + Assert.Equal(taken, taken); + } + + [Fact] + public void RepeatEnumeratingNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + var taken = source.Take(3); + Assert.Equal(taken, taken); + } } } From 7ac195cd9eb0ad1647ec1e791882dc9a779652f8 Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Fri, 15 Jan 2016 20:47:30 +0000 Subject: [PATCH 3/9] Use partitioning on list-based select iterators. The creation of SelectListIterator allows for partitioning to be used with the list-based select iterators, improving some subsequent operations on them. --- src/System.Linq/src/System/Linq/Enumerable.cs | 173 ++++++++++++++++-- .../tests/ElementAtOrDefaultTests.cs | 31 ++++ src/System.Linq/tests/ElementAtTests.cs | 31 ++++ src/System.Linq/tests/FirstOrDefaultTests.cs | 22 +++ src/System.Linq/tests/FirstTests.cs | 25 +++ src/System.Linq/tests/LastOrDefaultTests.cs | 22 +++ src/System.Linq/tests/LastTests.cs | 25 +++ src/System.Linq/tests/SkipTests.cs | 34 ++++ src/System.Linq/tests/TakeTests.cs | 34 ++++ 9 files changed, 380 insertions(+), 17 deletions(-) diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index 3df4114623c0..f72f062b573c 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -530,7 +530,7 @@ public override IEnumerable Select(Func s } - internal sealed class SelectArrayIterator : Iterator, IIListProvider + internal sealed class SelectArrayIterator : Iterator, IPartition { private readonly TSource[] _source; private readonly Func _selector; @@ -585,9 +585,56 @@ public List ToList() } return results; } + + public IEnumerable Skip(int count) + { + return count == 0 + ? (IEnumerable)new SelectArrayIterator(_source, _selector) + : new SelectEnumerableIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + } + + public IEnumerable Take(int count) + { + return count >= _source.Length + ? (IEnumerable)new SelectArrayIterator(_source, _selector) + : new SelectEnumerableIterator(new SkipListIterator(_source, 0, count - 1), _selector); + } + + public TResult ElementAt(int index) + { + if ((uint)index >= (uint)_source.Length) throw Error.ArgumentOutOfRange("index"); + return _selector(_source[index]); + } + + public TResult ElementAtOrDefault(int index) + { + return (uint)index >= (uint)_source.Length ? default(TResult) : _selector(_source[index]); + } + + public TResult First() + { + if (_source.Length == 0) throw Error.NoElements(); + return _selector(_source[0]); + } + + public TResult FirstOrDefault() + { + return _source.Length == 0 ? default(TResult) : _selector(_source[0]); + } + + public TResult Last() + { + if (_source.Length == 0) throw Error.NoElements(); + return _selector(_source[_source.Length - 1]); + } + + public TResult LastOrDefault() + { + return _source.Length == 0 ? default(TResult) : _selector(_source[_source.Length - 1]); + } } - internal sealed class SelectListIterator : Iterator, IIListProvider + internal sealed class SelectListIterator : Iterator, IPartition { private readonly List _source; private readonly Func _selector; @@ -651,9 +698,54 @@ public List ToList() } return results; } + + public IEnumerable Skip(int count) + { + return count == 0 + ? (IEnumerable)new SelectListIterator(_source, _selector) + : new SelectEnumerableIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + } + + public IEnumerable Take(int count) + { + return new SelectEnumerableIterator(new SkipListIterator(_source, 0, count - 1), _selector); + } + + public TResult ElementAt(int index) + { + // out of range throws correct exception with correct parameter name + return _selector(_source[index]); + } + + public TResult ElementAtOrDefault(int index) + { + return (uint)index >= (uint)_source.Count ? default(TResult) : _selector(_source[index]); + } + + public TResult First() + { + if (_source.Count == 0) throw Error.NoElements(); + return _selector(_source[0]); + } + + public TResult FirstOrDefault() + { + return _source.Count == 0 ? default(TResult) : _selector(_source[0]); + } + + public TResult Last() + { + if (_source.Count == 0) throw Error.NoElements(); + return _selector(_source[_source.Count - 1]); + } + + public TResult LastOrDefault() + { + return _source.Count == 0 ? default(TResult) : _selector(_source[_source.Count - 1]); + } } - internal sealed class SelectIListIterator : Iterator, IIListProvider + internal sealed class SelectIListIterator : Iterator, IPartition { private readonly IList _source; private readonly Func _selector; @@ -727,6 +819,53 @@ public List ToList() } return results; } + + public IEnumerable Skip(int count) + { + return count == 0 + ? (IEnumerable)new SelectIListIterator(_source, _selector) + : new SelectEnumerableIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + } + + public IEnumerable Take(int count) + { + return new SelectEnumerableIterator(new SkipListIterator(_source, 0, count - 1), _selector); + } + + public TResult ElementAt(int index) + { + // IList implementation should throw correct argument with correct parameter name + // but lean on the side of caution and assume some do not. + if ((uint)index >= (uint)_source.Count) throw Error.ArgumentOutOfRange("index"); + return _selector(_source[index]); + } + + public TResult ElementAtOrDefault(int index) + { + return (uint)index >= (uint)_source.Count ? default(TResult) : _selector(_source[index]); + } + + public TResult First() + { + if (_source.Count == 0) throw Error.NoElements(); + return _selector(_source[0]); + } + + public TResult FirstOrDefault() + { + return _source.Count == 0 ? default(TResult) : _selector(_source[0]); + } + + public TResult Last() + { + if (_source.Count == 0) throw Error.NoElements(); + return _selector(_source[_source.Count - 1]); + } + + public TResult LastOrDefault() + { + return _source.Count == 0 ? default(TResult) : _selector(_source[_source.Count - 1]); + } } //public static IEnumerable Where(this IEnumerable source, Func predicate) { @@ -937,7 +1076,7 @@ public override bool MoveNext() return false; } - public IPartition Skip(int count) + public IEnumerable Skip(int count) { int minIndex = _minIndex + count; return minIndex >= _maxIndex @@ -945,7 +1084,7 @@ public IPartition Skip(int count) : new SkipListIterator(_source, minIndex, _maxIndex); } - public IPartition Take(int count) + public IEnumerable Take(int count) { int maxIndex = _minIndex + count - 1; if (maxIndex >= _maxIndex) maxIndex = _maxIndex; @@ -1964,13 +2103,13 @@ public List ToList() return list; } - public IPartition Skip(int count) + public IEnumerable Skip(int count) { if (count >= _end - _start) return new EmptyPartition(); return new RangeIterator(_start + count, _end - _start - count); } - public IPartition Take(int count) + public IEnumerable Take(int count) { int curCount = _end - _start; if (count > curCount) count = curCount; @@ -2068,13 +2207,13 @@ public List ToList() return list; } - public IPartition Skip(int count) + public IEnumerable Skip(int count) { if (count >= _count) return new EmptyPartition(); return new RepeatIterator(current, _count - count); } - public IPartition Take(int count) + public IEnumerable Take(int count) { if (count > _count) count = _count; return new RepeatIterator(current, count); @@ -3938,9 +4077,9 @@ public List> ToList() internal interface IPartition : IEnumerable, IIListProvider { - IPartition Skip(int count); + IEnumerable Skip(int count); - IPartition Take(int count); + IEnumerable Take(int count); TElement ElementAt(int index); @@ -3998,12 +4137,12 @@ void IDisposable.Dispose() // Do nothing. } - public IPartition Skip(int count) + public IEnumerable Skip(int count) { return new EmptyPartition(); } - public IPartition Take(int count) + public IEnumerable Take(int count) { return new EmptyPartition(); } @@ -4072,7 +4211,7 @@ IEnumerator IEnumerable.GetEnumerator() return GetEnumerator(); } - public IPartition Skip(int count) + public IEnumerable Skip(int count) { int minIndex = _minIndex + count; return minIndex >= _maxIndex @@ -4080,7 +4219,7 @@ public IPartition Skip(int count) : new OrderedPartition(_source, minIndex, _maxIndex); } - public IPartition Take(int count) + public IEnumerable Take(int count) { int maxIndex = _minIndex + count - 1; if (maxIndex >= _maxIndex) maxIndex = _maxIndex; @@ -4265,12 +4404,12 @@ IOrderedEnumerable IOrderedEnumerable.CreateOrderedEnumerabl return result; } - public IPartition Skip(int count) + public IEnumerable Skip(int count) { return new OrderedPartition(this, count, int.MaxValue); } - public IPartition Take(int count) + public IEnumerable Take(int count) { return new OrderedPartition(this, 0, count - 1); } diff --git a/src/System.Linq/tests/ElementAtOrDefaultTests.cs b/src/System.Linq/tests/ElementAtOrDefaultTests.cs index a0fc8902b362..52c14c6c3a4e 100644 --- a/src/System.Linq/tests/ElementAtOrDefaultTests.cs +++ b/src/System.Linq/tests/ElementAtOrDefaultTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -131,5 +132,35 @@ public void NullSource() { Assert.Throws("source", () => ((IEnumerable)null).ElementAtOrDefault(2)); } + + [Fact] + public void ArraySelectSource() + { + var source = new[] { 1, 2, 3, 4 }.Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAtOrDefault(i)); + Assert.Equal(0, source.ElementAtOrDefault(-1)); + Assert.Equal(0, source.ElementAtOrDefault(4)); + } + + [Fact] + public void ListSelectSource() + { + var source = new[] { 1, 2, 3, 4 }.ToList().Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Equal(0, source.ElementAtOrDefault(-1)); + Assert.Equal(0, source.ElementAtOrDefault(4)); + } + + [Fact] + public void IListSelectSource() + { + var source = new ReadOnlyCollection(new[] { 1, 2, 3, 4 }).Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Equal(0, source.ElementAtOrDefault(-1)); + Assert.Equal(0, source.ElementAtOrDefault(4)); + } } } diff --git a/src/System.Linq/tests/ElementAtTests.cs b/src/System.Linq/tests/ElementAtTests.cs index b7c3ef9258f5..dc2a9212e884 100644 --- a/src/System.Linq/tests/ElementAtTests.cs +++ b/src/System.Linq/tests/ElementAtTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -131,5 +132,35 @@ public void NullSource() { Assert.Throws("source", () => ((IEnumerable)null).ElementAt(2)); } + + [Fact] + public void ArraySelectSource() + { + var source = new[] { 1, 2, 3, 4 }.Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Throws("index", () => source.ElementAt(-1)); + Assert.Throws("index", () => source.ElementAt(4)); + } + + [Fact] + public void ListSelectSource() + { + var source = new[] { 1, 2, 3, 4 }.ToList().Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Throws("index", () => source.ElementAt(-1)); + Assert.Throws("index", () => source.ElementAt(4)); + } + + [Fact] + public void IListSelectSource() + { + var source = new ReadOnlyCollection(new[] { 1, 2, 3, 4 }).Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Throws("index", () => source.ElementAt(-1)); + Assert.Throws("index", () => source.ElementAt(4)); + } } } diff --git a/src/System.Linq/tests/FirstOrDefaultTests.cs b/src/System.Linq/tests/FirstOrDefaultTests.cs index 4a6d2369257d..748b9c41c988 100644 --- a/src/System.Linq/tests/FirstOrDefaultTests.cs +++ b/src/System.Linq/tests/FirstOrDefaultTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -194,5 +195,26 @@ public void NullPredicate() Func predicate = null; Assert.Throws("predicate", () => Enumerable.Range(0, 3).FirstOrDefault(predicate)); } + + [Fact] + public void ArraySelectSource() + { + Assert.Equal(11, new[] { 5, 6, 7, 8 }.Select(i => i * 2 + 1).FirstOrDefault()); + Assert.Equal(0, new int[0].Select(i => i * 2 + 1).FirstOrDefault()); + } + + [Fact] + public void ListSelectSource() + { + Assert.Equal(11, new[] { 5, 6, 7, 8 }.ToList().Select(i => i * 2 + 1).FirstOrDefault()); + Assert.Equal(0, new List(0).Select(i => i * 2 + 1).FirstOrDefault()); + } + + [Fact] + public void IListSelectSource() + { + Assert.Equal(11, new ReadOnlyCollection(new[] { 5, 6, 7, 8 }).Select(i => i * 2 + 1).FirstOrDefault()); + Assert.Equal(0, new ReadOnlyCollection(new int[0]).Select(i => i * 2 + 1).FirstOrDefault()); + } } } diff --git a/src/System.Linq/tests/FirstTests.cs b/src/System.Linq/tests/FirstTests.cs index abcda8b354a6..a8a9f5e7fb54 100644 --- a/src/System.Linq/tests/FirstTests.cs +++ b/src/System.Linq/tests/FirstTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -191,5 +192,29 @@ public void NullPredicate() Func predicate = null; Assert.Throws("predicate", () => Enumerable.Range(0, 3).First(predicate)); } + + [Fact] + public void ArraySelectSource() + { + Assert.Equal(11, new[] { 5, 6, 7, 8 }.Select(i => i * 2 + 1).First()); + var emptySource = new int[0].Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.First()); + } + + [Fact] + public void ListSelectSource() + { + Assert.Equal(11, new[] { 5, 6, 7, 8 }.ToList().Select(i => i * 2 + 1).First()); + var emptySource = new List(0).Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.First()); + } + + [Fact] + public void IListSelectSource() + { + Assert.Equal(11, new ReadOnlyCollection(new[] { 5, 6, 7, 8 }).Select(i => i * 2 + 1).First()); + var emptySource = new ReadOnlyCollection(new int[0]).Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.First()); + } } } diff --git a/src/System.Linq/tests/LastOrDefaultTests.cs b/src/System.Linq/tests/LastOrDefaultTests.cs index 1b2413b3266e..43f5fc67df85 100644 --- a/src/System.Linq/tests/LastOrDefaultTests.cs +++ b/src/System.Linq/tests/LastOrDefaultTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests.LegacyTests @@ -244,5 +245,26 @@ public void NullPredicate() Func predicate = null; Assert.Throws("predicate", () => Enumerable.Range(0, 3).LastOrDefault(predicate)); } + + [Fact] + public void ArraySelectSource() + { + Assert.Equal(17, new[] { 5, 6, 7, 8 }.Select(i => i * 2 + 1).LastOrDefault()); + Assert.Equal(0, new int[0].Select(i => i * 2 + 1).LastOrDefault()); + } + + [Fact] + public void ListSelectSource() + { + Assert.Equal(17, new[] { 5, 6, 7, 8 }.ToList().Select(i => i * 2 + 1).LastOrDefault()); + Assert.Equal(0, new List(0).Select(i => i * 2 + 1).LastOrDefault()); + } + + [Fact] + public void IListSelectSource() + { + Assert.Equal(17, new ReadOnlyCollection(new[] { 5, 6, 7, 8 }).Select(i => i * 2 + 1).LastOrDefault()); + Assert.Equal(0, new ReadOnlyCollection(new int[0]).Select(i => i * 2 + 1).LastOrDefault()); + } } } diff --git a/src/System.Linq/tests/LastTests.cs b/src/System.Linq/tests/LastTests.cs index 7512bfbe60bb..9ed77153fc2d 100644 --- a/src/System.Linq/tests/LastTests.cs +++ b/src/System.Linq/tests/LastTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -239,5 +240,29 @@ public void NullPredicate() Func predicate = null; Assert.Throws("predicate", () => Enumerable.Range(0, 3).Last(predicate)); } + + [Fact] + public void ArraySelectSource() + { + Assert.Equal(17, new[] { 5, 6, 7, 8 }.Select(i => i * 2 + 1).Last()); + var emptySource = new int[0].Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.Last()); + } + + [Fact] + public void ListSelectSource() + { + Assert.Equal(17, new[] { 5, 6, 7, 8 }.ToList().Select(i => i * 2 + 1).Last()); + var emptySource = new List(0).Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.Last()); + } + + [Fact] + public void IListSelectSource() + { + Assert.Equal(17, new ReadOnlyCollection(new[] { 5, 6, 7, 8 }).Select(i => i * 2 + 1).Last()); + var emptySource = new ReadOnlyCollection(new int[0]).Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.Last()); + } } } diff --git a/src/System.Linq/tests/SkipTests.cs b/src/System.Linq/tests/SkipTests.cs index e8fa2c44b3c9..5bb6889bafb9 100644 --- a/src/System.Linq/tests/SkipTests.cs +++ b/src/System.Linq/tests/SkipTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; using Xunit.Abstractions; @@ -433,5 +434,38 @@ public void RepeatEnumeratingNotList() var remaining = source.Skip(1); Assert.Equal(remaining, remaining); } + + [Fact] + public void ArraySelectSource() + { + var source = new[] { 1, 2, 3, 4 }.Select(i => i * 2); + Assert.Equal(new[] { 6, 8 }, source.Skip(2)); + Assert.Empty(source.Skip(4)); + Assert.Empty(source.Skip(20)); + Assert.Equal(source, source.Skip(0)); + Assert.Equal(source, source.Skip(-1)); + } + + [Fact] + public void ListSelectSource() + { + var source = new[] { 1, 2, 3, 4 }.ToList().Select(i => i * 2); + Assert.Equal(new[] { 6, 8 }, source.Skip(2)); + Assert.Empty(source.Skip(4)); + Assert.Empty(source.Skip(20)); + Assert.Equal(source, source.Skip(0)); + Assert.Equal(source, source.Skip(-1)); + } + + [Fact] + public void IListSelectSource() + { + var source = new ReadOnlyCollection(new[] { 1, 2, 3, 4 }).Select(i => i * 2); + Assert.Equal(new[] { 6, 8 }, source.Skip(2)); + Assert.Empty(source.Skip(4)); + Assert.Empty(source.Skip(20)); + Assert.Equal(source, source.Skip(0)); + Assert.Equal(source, source.Skip(-1)); + } } } diff --git a/src/System.Linq/tests/TakeTests.cs b/src/System.Linq/tests/TakeTests.cs index aeebf7400009..36d65813a4f7 100644 --- a/src/System.Linq/tests/TakeTests.cs +++ b/src/System.Linq/tests/TakeTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -417,5 +418,38 @@ public void RepeatEnumeratingNotList() var taken = source.Take(3); Assert.Equal(taken, taken); } + + [Fact] + public void ArraySelectSource() + { + var source = new[] { 1, 2, 3, 4 }.Select(i => i * 2); + Assert.Equal(new[] { 2, 4 }, source.Take(2)); + Assert.Empty(source.Take(0)); + Assert.Empty(source.Take(-20)); + Assert.Equal(source, source.Take(4)); + Assert.Equal(source, source.Take(40)); + } + + [Fact] + public void ListSelectSource() + { + var source = new[] { 1, 2, 3, 4 }.ToList().Select(i => i * 2); + Assert.Equal(new[] { 2, 4 }, source.Take(2)); + Assert.Empty(source.Take(0)); + Assert.Empty(source.Take(-20)); + Assert.Equal(source, source.Take(4)); + Assert.Equal(source, source.Take(40)); + } + + [Fact] + public void IListSelectSource() + { + var source = new ReadOnlyCollection(new[] { 1, 2, 3, 4 }).Select(i => i * 2); + Assert.Equal(new[] { 2, 4 }, source.Take(2)); + Assert.Empty(source.Take(0)); + Assert.Empty(source.Take(-20)); + Assert.Equal(source, source.Take(4)); + Assert.Equal(source, source.Take(40)); + } } } From aaa362686e90fcc8d853d92bc47f0ae545b946a7 Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Sat, 16 Jan 2016 14:55:57 +0000 Subject: [PATCH 4/9] Remove unused paths in Set. Set never has Add called after Remove. Remove paths and fields only necessary in that case. Add debug-only check on this assumption, so any new uses that break this assumption are flagged to the developer. --- src/System.Linq/src/System/Linq/Enumerable.cs | 60 ++++++++----------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index f72f062b573c..86d4277e0c74 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -3895,8 +3895,10 @@ internal class Set private int[] _buckets; private Slot[] _slots; private int _count; - private int _freeList; - private IEqualityComparer _comparer; + private readonly IEqualityComparer _comparer; +#if DEBUG + private bool _haveRemoved; +#endif public Set(IEqualityComparer comparer) { @@ -3904,18 +3906,36 @@ public Set(IEqualityComparer comparer) _comparer = comparer; _buckets = new int[7]; _slots = new Slot[7]; - _freeList = -1; } // If value is not in set, add it and return true; otherwise return false public bool Add(TElement value) { - return !Find(value, true); +#if DEBUG + Debug.Assert(!_haveRemoved, "This set assumes no adds after a removal. If your use requires adds after removal undo that optimization."); +#endif + int hashCode = InternalGetHashCode(value); + for (int i = _buckets[hashCode % _buckets.Length] - 1; i >= 0; i = _slots[i].next) + { + if (_slots[i].hashCode == hashCode && _comparer.Equals(_slots[i].value, value)) return false; + } + if (_count == _slots.Length) Resize(); + int index = _count; + _count++; + int bucket = hashCode % _buckets.Length; + _slots[index].hashCode = hashCode; + _slots[index].value = value; + _slots[index].next = _buckets[bucket] - 1; + _buckets[bucket] = index + 1; + return true; } // If value is in set, remove it and return true; otherwise return false public bool Remove(TElement value) { +#if DEBUG + _haveRemoved = true; +#endif int hashCode = InternalGetHashCode(value); int bucket = hashCode % _buckets.Length; int last = -1; @@ -3933,44 +3953,12 @@ public bool Remove(TElement value) } _slots[i].hashCode = -1; _slots[i].value = default(TElement); - _slots[i].next = _freeList; - _freeList = i; return true; } } return false; } - private bool Find(TElement value, bool add) - { - int hashCode = InternalGetHashCode(value); - for (int i = _buckets[hashCode % _buckets.Length] - 1; i >= 0; i = _slots[i].next) - { - if (_slots[i].hashCode == hashCode && _comparer.Equals(_slots[i].value, value)) return true; - } - if (add) - { - int index; - if (_freeList >= 0) - { - index = _freeList; - _freeList = _slots[index].next; - } - else - { - if (_count == _slots.Length) Resize(); - index = _count; - _count++; - } - int bucket = hashCode % _buckets.Length; - _slots[index].hashCode = hashCode; - _slots[index].value = value; - _slots[index].next = _buckets[bucket] - 1; - _buckets[bucket] = index + 1; - } - return false; - } - private void Resize() { int newSize = checked(_count * 2 + 1); From 02f0e8c375cb8df094c0bce975835c36f90964aa Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Sat, 16 Jan 2016 16:48:20 +0000 Subject: [PATCH 5/9] Implement IIListProvider on Distinct() and Union(). With the changes to Set it's now easy to give these methods and optimised ToArray and ToList(). --- src/System.Linq/src/System/Linq/Enumerable.cs | 197 ++++++++++++++++-- src/System.Linq/tests/DistinctTests.cs | 28 +++ src/System.Linq/tests/UnionTests.cs | 31 +++ 3 files changed, 242 insertions(+), 14 deletions(-) diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index 86d4277e0c74..6d430a98d7f5 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -1392,43 +1392,189 @@ private static IEnumerable ZipIterator(IEnume public static IEnumerable Distinct(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); - return DistinctIterator(source, null); + return new DistinctIterator(source, null); } public static IEnumerable Distinct(this IEnumerable source, IEqualityComparer comparer) { if (source == null) throw Error.ArgumentNull("source"); - return DistinctIterator(source, comparer); + return new DistinctIterator(source, comparer); } - private static IEnumerable DistinctIterator(IEnumerable source, IEqualityComparer comparer) + private sealed class DistinctIterator : Iterator, IIListProvider { - Set set = new Set(comparer); - foreach (TSource element in source) - if (set.Add(element)) yield return element; + private readonly IEnumerable _source; + private readonly IEqualityComparer _comparer; + private Set _set; + private IEnumerator _enumerator; + + public DistinctIterator(IEnumerable source, IEqualityComparer comparer) + { + _source = source; + _comparer = comparer; + } + + public override Iterator Clone() + { + return new DistinctIterator(_source, _comparer); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _set = new Set(_comparer); + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + while (_enumerator.MoveNext()) + { + TSource element = _enumerator.Current; + if (_set.Add(element)) + { + current = element; + return true; + } + } + Dispose(); + break; + } + return false; + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); + } + + private Set FillList() + { + Set set = new Set(_comparer); + foreach (TSource element in _source) + set.Add(element); + return set; + } + + public TSource[] ToArray() + { + return FillList().ToArray(); + } + + public List ToList() + { + return FillList().ToList(); + } } public static IEnumerable Union(this IEnumerable first, IEnumerable second) { if (first == null) throw Error.ArgumentNull("first"); if (second == null) throw Error.ArgumentNull("second"); - return UnionIterator(first, second, null); + return new UnionIterator(first, second, null); } public static IEnumerable Union(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { if (first == null) throw Error.ArgumentNull("first"); if (second == null) throw Error.ArgumentNull("second"); - return UnionIterator(first, second, comparer); + return new UnionIterator(first, second, comparer); } - private static IEnumerable UnionIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) + private sealed class UnionIterator : Iterator, IIListProvider { - Set set = new Set(comparer); - foreach (TSource element in first) - if (set.Add(element)) yield return element; - foreach (TSource element in second) - if (set.Add(element)) yield return element; + private readonly IEnumerable _first; + private readonly IEnumerable _second; + private readonly IEqualityComparer _comparer; + private Set _set; + private IEnumerator _enumerator; + + public UnionIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + _first = first; + _second = second; + _comparer = comparer; + } + + public override Iterator Clone() + { + return new UnionIterator(_first, _second, _comparer); + } + + private bool GetNext() + { + while (_enumerator.MoveNext()) + { + TSource element = _enumerator.Current; + if (_set.Add(element)) + { + current = element; + return true; + } + } + return false; + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _set = new Set(_comparer); + _enumerator = _first.GetEnumerator(); + state = 2; + goto case 2; + case 2: + if (GetNext()) + return true; + _enumerator.Dispose(); + _enumerator = _second.GetEnumerator(); + state = 3; + goto case 3; + case 3: + if (GetNext()) + return true; + Dispose(); + break; + } + return false; + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); + } + + private Set FillList() + { + Set set = new Set(_comparer); + foreach (TSource element in _first) + set.Add(element); + foreach (TSource element in _second) + set.Add(element); + return set; + } + + public TSource[] ToArray() + { + return FillList().ToArray(); + } + + public List ToList() + { + return FillList().ToList(); + } } public static IEnumerable Intersect(this IEnumerable first, IEnumerable second) @@ -3975,6 +4121,29 @@ private void Resize() _slots = newSlots; } + internal TElement[] ToArray() + { +#if DEBUG + Debug.Assert(!_haveRemoved, "Optimised ToArray cannot be called if Remove has been called."); +#endif + TElement[] array = new TElement[_count]; + for (int i = 0; i != array.Length; ++i) + array[i] = _slots[i].value; + return array; + } + + internal List ToList() + { +#if DEBUG + Debug.Assert(!_haveRemoved, "Optimised ToList cannot be called if Remove has been called."); +#endif + int count = _count; + List list = new List(count); + for (int i = 0; i != count; ++i) + list.Add(_slots[i].value); + return list; + } + internal int InternalGetHashCode(TElement value) { // Handle comparer implementations that throw when passed null diff --git a/src/System.Linq/tests/DistinctTests.cs b/src/System.Linq/tests/DistinctTests.cs index 4f954a69b6ee..32bcf9e5928c 100644 --- a/src/System.Linq/tests/DistinctTests.cs +++ b/src/System.Linq/tests/DistinctTests.cs @@ -198,5 +198,33 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void ToArray() + { + int?[] source = { 1, 1, 1, 2, 2, 2, null, null }; + int?[] expected = { 1, 2, null }; + + Assert.Equal(expected, source.Distinct().ToArray()); + } + + [Fact] + public void ToList() + { + int?[] source = { 1, 1, 1, 2, 2, 2, null, null }; + int?[] expected = { 1, 2, null }; + + Assert.Equal(expected, source.Distinct().ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + int?[] source = { 1, 1, 1, 2, 2, 2, null, null }; + + var result = source.Distinct(); + + Assert.Equal(result, result); + } } } diff --git a/src/System.Linq/tests/UnionTests.cs b/src/System.Linq/tests/UnionTests.cs index 6decc34defba..b16143d90647 100644 --- a/src/System.Linq/tests/UnionTests.cs +++ b/src/System.Linq/tests/UnionTests.cs @@ -205,5 +205,36 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void ToArray() + { + string[] first = { "Bob", "Robert", "Tim", "Matt", "miT" }; + string[] second = { "ttaM", "Charlie", "Bbo" }; + string[] expected = { "Bob", "Robert", "Tim", "Matt", "miT", "ttaM", "Charlie", "Bbo" }; + + Assert.Equal(expected, first.Union(second).ToArray()); + } + + [Fact] + public void ToList() + { + string[] first = { "Bob", "Robert", "Tim", "Matt", "miT" }; + string[] second = { "ttaM", "Charlie", "Bbo" }; + string[] expected = { "Bob", "Robert", "Tim", "Matt", "miT", "ttaM", "Charlie", "Bbo" }; + + Assert.Equal(expected, first.Union(second).ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + string[] first = { "Bob", "Robert", "Tim", "Matt", "miT" }; + string[] second = { "ttaM", "Charlie", "Bbo" }; + + var result = first.Union(second); + + Assert.Equal(result, result); + } } } From 68d50364fa4b72059af9cfc2fd10b13c0c9700a5 Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Sat, 16 Jan 2016 17:30:50 +0000 Subject: [PATCH 6/9] Tests of IIListProvider on empty GroupBy results. --- src/System.Linq/tests/GroupByTests.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/System.Linq/tests/GroupByTests.cs b/src/System.Linq/tests/GroupByTests.cs index 7c5b8f719d4c..791a1e6c0e6c 100644 --- a/src/System.Linq/tests/GroupByTests.cs +++ b/src/System.Linq/tests/GroupByTests.cs @@ -630,5 +630,17 @@ public void GroupingWithResultsToList() Assert.Equal(4, groupedList.Count); Assert.Equal(source.GroupBy(r => r.Name, (r, e) => e), groupedList); } + + [Fact] + public void EmptyGroupingToArray() + { + Assert.Empty(Enumerable.Empty().GroupBy(i => i).ToArray()); + } + + [Fact] + public void EmptyGroupingToList() + { + Assert.Empty(Enumerable.Empty().GroupBy(i => i).ToList()); + } } } \ No newline at end of file From e52b017902305e5687ba0e68241890a14fca3751 Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Sun, 17 Jan 2016 02:18:58 +0000 Subject: [PATCH 7/9] Add IIListProvider support to ReverseIterator Improve performance of Reverse().ToArray() and Reverse().ToList(). --- src/System.Linq/src/System/Linq/Enumerable.cs | 64 +++++++++++++++++-- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index 6d430a98d7f5..9626803cd823 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -1624,13 +1624,69 @@ private static IEnumerable ExceptIterator(IEnumerable public static IEnumerable Reverse(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); - return ReverseIterator(source); + return new ReverseIterator(source); } - private static IEnumerable ReverseIterator(IEnumerable source) + private sealed class ReverseIterator : Iterator, IIListProvider { - Buffer buffer = new Buffer(source); - for (int i = buffer.count - 1; i >= 0; i--) yield return buffer.items[i]; + private readonly IEnumerable _source; + private Buffer _buffer; + private int _index; + + public ReverseIterator(IEnumerable source) + { + Debug.Assert(source != null); + _source = source; + } + + public override Iterator Clone() + { + return new ReverseIterator(_source); + } + + public override bool MoveNext() + { + switch(state) + { + case 1: + _buffer = new Buffer(_source); + _index = _buffer.count - 1; + state = 2; + goto case 2; + case 2: + if (_index >= 0) + { + current = _buffer.items[_index]; + --_index; + return true; + } + break; + } + Dispose(); + return false; + } + + public TSource[] ToArray() + { + Buffer buffer = new Buffer(_source); + int count = buffer.count; + TSource[] sourceArray = buffer.items; + TSource[] array = new TSource[count]; + for (int i = 0, sourceIdx = count - 1; i != array.Length; ++i, --sourceIdx) + array[i] = sourceArray[sourceIdx]; + return array; + } + + public List ToList() + { + Buffer buffer = new Buffer(_source); + int count = buffer.count; + TSource[] sourceArray = buffer.items; + List list = new List(count); + for (int i = count - 1; i >= 0; --i) + list.Add(sourceArray[i]); + return list; + } } public static bool SequenceEqual(this IEnumerable first, IEnumerable second) From 77b27c6ff4828835066ccedfd6dfb187b22268b7 Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Sun, 17 Jan 2016 18:20:21 +0000 Subject: [PATCH 8/9] Have Select() on an IPartition implement IPartition MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With an IPartition-implementing result for Select, the advantages of IPartition results can carry through to results. In particular the sequence .OrderBy(…).Select(…).Skip(pageOffset).Take(pageSize) can pass through the partial-sorting capability. --- src/System.Linq/src/System/Linq/Enumerable.cs | 699 ++++++++++-------- src/System.Linq/tests/OrderedSubsetting.cs | 30 + src/System.Linq/tests/ReverseTests.cs | 26 + 3 files changed, 427 insertions(+), 328 deletions(-) diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index 9626803cd823..be639ebe992c 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Threading; namespace System.Linq { @@ -48,6 +47,8 @@ public static IEnumerable Select(this IEnumerable iterator = source as Iterator; if (iterator != null) return iterator.Select(selector); + IPartition partition = source as IPartition; + if (partition != null) return new SelectIPartitionIterator(partition, selector); IList ilist = source as IList; if (ilist != null) { @@ -530,7 +531,7 @@ public override IEnumerable Select(Func s } - internal sealed class SelectArrayIterator : Iterator, IPartition + internal sealed class SelectArrayIterator : Iterator, IPartition, IIListProvider { private readonly TSource[] _source; private readonly Func _selector; @@ -586,55 +587,58 @@ public List ToList() return results; } - public IEnumerable Skip(int count) + public IPartition Skip(int count) { return count == 0 - ? (IEnumerable)new SelectArrayIterator(_source, _selector) - : new SelectEnumerableIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + ? (IPartition)new SelectArrayIterator(_source, _selector) + : new SelectIPartitionIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); } - public IEnumerable Take(int count) + public IPartition Take(int count) { return count >= _source.Length - ? (IEnumerable)new SelectArrayIterator(_source, _selector) - : new SelectEnumerableIterator(new SkipListIterator(_source, 0, count - 1), _selector); + ? (IPartition)new SelectArrayIterator(_source, _selector) + : new SelectIPartitionIterator(new SkipListIterator(_source, 0, count - 1), _selector); } - public TResult ElementAt(int index) + public bool TryGetElementAt(int index, out TResult element) { - if ((uint)index >= (uint)_source.Length) throw Error.ArgumentOutOfRange("index"); - return _selector(_source[index]); - } + if ((uint)index >= (uint)_source.Length) + { + element = default(TResult); + return false; + } - public TResult ElementAtOrDefault(int index) - { - return (uint)index >= (uint)_source.Length ? default(TResult) : _selector(_source[index]); + element = _selector(_source[index]); + return true; } - public TResult First() + public bool TryGetFirst(out TResult first) { - if (_source.Length == 0) throw Error.NoElements(); - return _selector(_source[0]); - } + if (_source.Length == 0) + { + first = default(TResult); + return false; + } - public TResult FirstOrDefault() - { - return _source.Length == 0 ? default(TResult) : _selector(_source[0]); + first = _selector(_source[0]); + return true; } - public TResult Last() + public bool TryGetLast(out TResult last) { - if (_source.Length == 0) throw Error.NoElements(); - return _selector(_source[_source.Length - 1]); - } + if (_source.Length == 0) + { + last = default(TResult); + return false; + } - public TResult LastOrDefault() - { - return _source.Length == 0 ? default(TResult) : _selector(_source[_source.Length - 1]); + last = _selector(_source[_source.Length - 1]); + return true; } } - internal sealed class SelectListIterator : Iterator, IPartition + internal sealed class SelectListIterator : Iterator, IPartition, IIListProvider { private readonly List _source; private readonly Func _selector; @@ -699,53 +703,56 @@ public List ToList() return results; } - public IEnumerable Skip(int count) + public IPartition Skip(int count) { return count == 0 - ? (IEnumerable)new SelectListIterator(_source, _selector) - : new SelectEnumerableIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + ? (IPartition)new SelectListIterator(_source, _selector) + : new SelectIPartitionIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); } - public IEnumerable Take(int count) + public IPartition Take(int count) { - return new SelectEnumerableIterator(new SkipListIterator(_source, 0, count - 1), _selector); + return new SelectIPartitionIterator(new SkipListIterator(_source, 0, count - 1), _selector); } - public TResult ElementAt(int index) + public bool TryGetElementAt(int index, out TResult element) { - // out of range throws correct exception with correct parameter name - return _selector(_source[index]); - } + if ((uint)index >= (uint)_source.Count) + { + element = default(TResult); + return false; + } - public TResult ElementAtOrDefault(int index) - { - return (uint)index >= (uint)_source.Count ? default(TResult) : _selector(_source[index]); + element = _selector(_source[index]); + return true; } - public TResult First() + public bool TryGetFirst(out TResult first) { - if (_source.Count == 0) throw Error.NoElements(); - return _selector(_source[0]); - } + if (_source.Count == 0) + { + first = default(TResult); + return false; + } - public TResult FirstOrDefault() - { - return _source.Count == 0 ? default(TResult) : _selector(_source[0]); + first = _selector(_source[0]); + return true; } - public TResult Last() + public bool TryGetLast(out TResult result) { - if (_source.Count == 0) throw Error.NoElements(); - return _selector(_source[_source.Count - 1]); - } + if (_source.Count == 0) + { + result = default(TResult); + return false; + } - public TResult LastOrDefault() - { - return _source.Count == 0 ? default(TResult) : _selector(_source[_source.Count - 1]); + result = _selector(_source[_source.Count - 1]); + return true; } } - internal sealed class SelectIListIterator : Iterator, IPartition + internal sealed class SelectIListIterator : Iterator, IPartition, IIListProvider { private readonly IList _source; private readonly Func _selector; @@ -820,77 +827,158 @@ public List ToList() return results; } - public IEnumerable Skip(int count) + public IPartition Skip(int count) { return count == 0 - ? (IEnumerable)new SelectIListIterator(_source, _selector) - : new SelectEnumerableIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + ? (IPartition)new SelectIListIterator(_source, _selector) + : new SelectIPartitionIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + } + + public IPartition Take(int count) + { + return new SelectIPartitionIterator(new SkipListIterator(_source, 0, count - 1), _selector); + } + + public bool TryGetElementAt(int index, out TResult element) + { + if ((uint)index >= (uint)_source.Count) + { + element = default(TResult); + return false; + } + + element = _selector(_source[index]); + return true; + } + + public bool TryGetFirst(out TResult first) + { + if (_source.Count == 0) + { + first = default(TResult); + return false; + } + + first = _selector(_source[0]); + return true; + } + + public bool TryGetLast(out TResult last) + { + if (_source.Count == 0) + { + last = default(TResult); + return false; + } + + last = _selector(_source[_source.Count - 1]); + return true; + } + } + + internal sealed class SelectIPartitionIterator : Iterator, IPartition + { + private readonly IPartition _source; + private readonly Func _selector; + private IEnumerator _enumerator; + + public SelectIPartitionIterator(IPartition source, Func selector) + { + Debug.Assert(source != null); + Debug.Assert(selector != null); + _source = source; + _selector = selector; } - public IEnumerable Take(int count) + public override Iterator Clone() { - return new SelectEnumerableIterator(new SkipListIterator(_source, 0, count - 1), _selector); + return new SelectIPartitionIterator(_source, _selector); } - public TResult ElementAt(int index) + public override bool MoveNext() { - // IList implementation should throw correct argument with correct parameter name - // but lean on the side of caution and assume some do not. - if ((uint)index >= (uint)_source.Count) throw Error.ArgumentOutOfRange("index"); - return _selector(_source[index]); + switch (state) + { + case 1: + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + if (_enumerator.MoveNext()) + { + current = _selector(_enumerator.Current); + return true; + } + Dispose(); + break; + } + return false; } - public TResult ElementAtOrDefault(int index) + public override void Dispose() { - return (uint)index >= (uint)_source.Count ? default(TResult) : _selector(_source[index]); + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); } - public TResult First() + public override IEnumerable Select(Func selector) { - if (_source.Count == 0) throw Error.NoElements(); - return _selector(_source[0]); + return new SelectIPartitionIterator(_source, CombineSelectors(_selector, selector)); } - public TResult FirstOrDefault() + public IPartition Skip(int count) { - return _source.Count == 0 ? default(TResult) : _selector(_source[0]); + return new SelectIPartitionIterator(_source.Skip(count), _selector); } - public TResult Last() + public IPartition Take(int count) { - if (_source.Count == 0) throw Error.NoElements(); - return _selector(_source[_source.Count - 1]); + return new SelectIPartitionIterator(_source.Take(count), _selector); } - public TResult LastOrDefault() + public bool TryGetElementAt(int index, out TResult element) { - return _source.Count == 0 ? default(TResult) : _selector(_source[_source.Count - 1]); + TSource input; + if (_source.TryGetElementAt(index, out input)) + { + element = _selector(input); + return true; + } + + element = default(TResult); + return false; } - } - //public static IEnumerable Where(this IEnumerable source, Func predicate) { - // if (source == null) throw Error.ArgumentNull("source"); - // if (predicate == null) throw Error.ArgumentNull("predicate"); - // return WhereIterator(source, predicate); - //} + public bool TryGetFirst(out TResult first) + { + TSource input; + if (_source.TryGetFirst(out input)) + { + first = _selector(input); + return true; + } - //static IEnumerable WhereIterator(IEnumerable source, Func predicate) { - // foreach (TSource element in source) { - // if (predicate(element)) yield return element; - // } - //} + first = default(TResult); + return false; + } - //public static IEnumerable Select(this IEnumerable source, Func selector) { - // if (source == null) throw Error.ArgumentNull("source"); - // if (selector == null) throw Error.ArgumentNull("selector"); - // return SelectIterator(source, selector); - //} + public bool TryGetLast(out TResult last) + { + TSource input; + if (_source.TryGetLast(out input)) + { + last = _selector(input); + return true; + } - //static IEnumerable SelectIterator(IEnumerable source, Func selector) { - // foreach (TSource element in source) { - // yield return selector(element); - // } - //} + last = default(TResult); + return false; + } + } public static IEnumerable SelectMany(this IEnumerable source, Func> selector) { @@ -1033,7 +1121,7 @@ public static IEnumerable Skip(this IEnumerable sourc return sourceList != null ? new SkipListIterator(sourceList, count, int.MaxValue) : SkipIterator(source, count); } - private sealed class SkipListIterator : Iterator, IPartition + private sealed class SkipListIterator : Iterator, IPartition, IIListProvider { private readonly IList _source; private readonly int _minIndex; @@ -1076,7 +1164,7 @@ public override bool MoveNext() return false; } - public IEnumerable Skip(int count) + public IPartition Skip(int count) { int minIndex = _minIndex + count; return minIndex >= _maxIndex @@ -1084,47 +1172,48 @@ public IEnumerable Skip(int count) : new SkipListIterator(_source, minIndex, _maxIndex); } - public IEnumerable Take(int count) + public IPartition Take(int count) { int maxIndex = _minIndex + count - 1; if (maxIndex >= _maxIndex) maxIndex = _maxIndex; return new SkipListIterator(_source, _minIndex, maxIndex); } - public TSource ElementAt(int index) + public bool TryGetElementAt(int index, out TSource element) { - if ((uint)index > (uint)_maxIndex - _minIndex || index >= _source.Count - _minIndex) throw Error.ArgumentOutOfRange("index"); - return _source[_minIndex + index]; - } + if ((uint)index > (uint)_maxIndex - _minIndex || index >= _source.Count - _minIndex) + { + element = default(TSource); + return false; + } - public TSource ElementAtOrDefault(int index) - { - return (uint)index > (uint)_maxIndex - _minIndex || index >= _source.Count - _minIndex ? default(TSource) : _source[_minIndex + index]; + element = _source[_minIndex + index]; + return true; } - public TSource First() + public bool TryGetFirst(out TSource first) { - if (_source.Count <= _minIndex) throw Error.NoElements(); - return _source[_minIndex]; - } + if (_source.Count <= _minIndex) + { + first = default(TSource); + return false; + } - public TSource FirstOrDefault() - { - return _source.Count <= _minIndex ? default(TSource) : _source[_minIndex]; + first = _source[_minIndex]; + return true; } - public TSource Last() + public bool TryGetLast(out TSource last) { int lastIndex = _source.Count - 1; - if (lastIndex < _minIndex) throw Error.NoElements(); - return _source[lastIndex > _maxIndex ? _maxIndex : lastIndex]; - } + if (lastIndex < _minIndex) + { + last = default(TSource); + return false; + } - public TSource LastOrDefault() - { - int lastIndex = _source.Count - 1; - if (lastIndex < _minIndex) return default(TSource); - return _source[lastIndex > _maxIndex ? _maxIndex : lastIndex]; + last = _source[lastIndex > _maxIndex ? _maxIndex : lastIndex]; + return true; } public TSource[] ToArray() @@ -1908,17 +1997,25 @@ public static TSource First(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.First(); - IList list = source as IList; - if (list != null) + if (partition != null) { - if (list.Count > 0) return list[0]; + TSource result; + if (partition.TryGetFirst(out result)) + return result; } else { - using (IEnumerator e = source.GetEnumerator()) + IList list = source as IList; + if (list != null) { - if (e.MoveNext()) return e.Current; + if (list.Count > 0) return list[0]; + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (e.MoveNext()) return e.Current; + } } } throw Error.NoElements(); @@ -1941,7 +2038,12 @@ public static TSource FirstOrDefault(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.FirstOrDefault(); + if (partition != null) + { + TSource result; + partition.TryGetFirst(out result); + return result; + } IList list = source as IList; if (list != null) { @@ -1974,25 +2076,33 @@ public static TSource Last(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.Last(); - IList list = source as IList; - if (list != null) + if (partition != null) { - int count = list.Count; - if (count > 0) return list[count - 1]; + TSource result; + if (partition.TryGetLast(out result)) + return result; } else { - using (IEnumerator e = source.GetEnumerator()) + IList list = source as IList; + if (list != null) { - if (e.MoveNext()) + int count = list.Count; + if (count > 0) return list[count - 1]; + } + else + { + using (IEnumerator e = source.GetEnumerator()) { - TSource result; - do + if (e.MoveNext()) { - result = e.Current; - } while (e.MoveNext()); - return result; + TSource result; + do + { + result = e.Current; + } while (e.MoveNext()); + return result; + } } } } @@ -2040,7 +2150,12 @@ public static TSource LastOrDefault(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.LastOrDefault(); + if (partition != null) + { + TSource result; + partition.TryGetLast(out result); + return result; + } IList list = source as IList; if (list != null) { @@ -2191,17 +2306,25 @@ public static TSource ElementAt(this IEnumerable source, int i { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.ElementAt(index); - IList list = source as IList; - if (list != null) return list[index]; - if (index >= 0) + if (partition != null) { - using (IEnumerator e = source.GetEnumerator()) + TSource result; + if (partition.TryGetElementAt(index, out result)) + return result; + } + else + { + IList list = source as IList; + if (list != null) return list[index]; + if (index >= 0) { - while (e.MoveNext()) + using (IEnumerator e = source.GetEnumerator()) { - if (index == 0) return e.Current; - index--; + while (e.MoveNext()) + { + if (index == 0) return e.Current; + index--; + } } } } @@ -2212,7 +2335,12 @@ public static TSource ElementAtOrDefault(this IEnumerable sour { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.ElementAtOrDefault(index); + if (partition != null) + { + TSource result; + partition.TryGetElementAt(index, out result); + return result; + } if (index >= 0) { IList list = source as IList; @@ -2243,7 +2371,7 @@ public static IEnumerable Range(int start, int count) return new RangeIterator(start, count); } - private sealed class RangeIterator : Iterator, IPartition + private sealed class RangeIterator : Iterator, IPartition, IIListProvider { private readonly int _start; private readonly int _end; @@ -2305,48 +2433,41 @@ public List ToList() return list; } - public IEnumerable Skip(int count) + public IPartition Skip(int count) { if (count >= _end - _start) return new EmptyPartition(); return new RangeIterator(_start + count, _end - _start - count); } - public IEnumerable Take(int count) + public IPartition Take(int count) { int curCount = _end - _start; if (count > curCount) count = curCount; return new RangeIterator(_start, count); } - public int ElementAt(int index) + public bool TryGetElementAt(int index, out int element) { - if ((uint)index >= (uint)(_end - _start)) throw Error.ArgumentOutOfRange("index"); - return _start + index; - } - - public int ElementAtOrDefault(int index) - { - return (uint)index >= (uint)(_end - _start) ? 0 : _start + index; - } - - public int First() - { - return _start; - } + if ((uint)index >= (uint)(_end - _start)) + { + element = 0; + return false; + } - public int FirstOrDefault() - { - return _start; + element = _start + index; + return true; } - public int Last() + public bool TryGetFirst(out int first) { - return _end - 1; + first = _start; + return true; } - public int LastOrDefault() + public bool TryGetLast(out int last) { - return _end - 1; + last = _end - 1; + return true; } } @@ -2357,7 +2478,7 @@ public static IEnumerable Repeat(TResult element, int count) return new RepeatIterator(element, count); } - private sealed class RepeatIterator : Iterator, IPartition + private sealed class RepeatIterator : Iterator, IPartition, IIListProvider { private readonly int _count; private int _sent; @@ -2409,47 +2530,40 @@ public List ToList() return list; } - public IEnumerable Skip(int count) + public IPartition Skip(int count) { if (count >= _count) return new EmptyPartition(); return new RepeatIterator(current, _count - count); } - public IEnumerable Take(int count) + public IPartition Take(int count) { if (count > _count) count = _count; return new RepeatIterator(current, count); } - public TResult ElementAt(int index) - { - if ((uint)index >= (uint)_count) throw Error.ArgumentOutOfRange("index"); - return current; - } - - public TResult ElementAtOrDefault(int index) - { - return (uint)index >= (uint)_count ? default(TResult) : current; - } - - public TResult First() + public bool TryGetElementAt(int index, out TResult element) { - return current; - } + if ((uint)index >= (uint)_count) + { + element = default(TResult); + return false; + } - public TResult FirstOrDefault() - { - return current; + element = current; + return true; } - public TResult Last() + public bool TryGetFirst(out TResult first) { - return current; + first = current; + return true; } - public TResult LastOrDefault() + public bool TryGetLast(out TResult last) { - return current; + last = current; + return true; } } @@ -4288,26 +4402,20 @@ public List> ToList() } } - internal interface IPartition : IEnumerable, IIListProvider + internal interface IPartition : IEnumerable { - IEnumerable Skip(int count); - - IEnumerable Take(int count); + IPartition Skip(int count); - TElement ElementAt(int index); + IPartition Take(int count); - TElement ElementAtOrDefault(int index); + bool TryGetElementAt(int index, out TElement element); - TElement First(); + bool TryGetFirst(out TElement first); - TElement FirstOrDefault(); - - TElement Last(); - - TElement LastOrDefault(); + bool TryGetLast(out TElement last); } - internal sealed class EmptyPartition : IPartition, IEnumerator + internal sealed class EmptyPartition : IPartition, IEnumerator, IIListProvider { public EmptyPartition() { @@ -4350,44 +4458,32 @@ void IDisposable.Dispose() // Do nothing. } - public IEnumerable Skip(int count) + public IPartition Skip(int count) { return new EmptyPartition(); } - public IEnumerable Take(int count) + public IPartition Take(int count) { return new EmptyPartition(); } - public TElement ElementAt(int index) - { - throw Error.ArgumentOutOfRange("index"); - } - - public TElement ElementAtOrDefault(int index) + public bool TryGetElementAt(int index, out TElement element) { - return default(TElement); - } - - public TElement First() - { - throw Error.NoElements(); - } - - public TElement FirstOrDefault() - { - return default(TElement); + element = default(TElement); + return false; } - public TElement Last() + public bool TryGetFirst(out TElement first) { - throw Error.NoElements(); + first = default(TElement); + return false; } - public TElement LastOrDefault() + public bool TryGetLast(out TElement last) { - return default(TElement); + last = default(TElement); + return false; } public TElement[] ToArray() @@ -4401,7 +4497,7 @@ public List ToList() } } - internal sealed class OrderedPartition : IPartition + internal sealed class OrderedPartition : IPartition, IIListProvider { private readonly OrderedEnumerable _source; private readonly int _minIndex; @@ -4424,7 +4520,7 @@ IEnumerator IEnumerable.GetEnumerator() return GetEnumerator(); } - public IEnumerable Skip(int count) + public IPartition Skip(int count) { int minIndex = _minIndex + count; return minIndex >= _maxIndex @@ -4432,44 +4528,32 @@ public IEnumerable Skip(int count) : new OrderedPartition(_source, minIndex, _maxIndex); } - public IEnumerable Take(int count) + public IPartition Take(int count) { int maxIndex = _minIndex + count - 1; if (maxIndex >= _maxIndex) maxIndex = _maxIndex; return new OrderedPartition(_source, _minIndex, maxIndex); } - public TElement ElementAt(int index) + public bool TryGetElementAt(int index, out TElement element) { - if ((uint)index > (uint)_maxIndex - _minIndex) throw Error.ArgumentOutOfRange("index"); - return _source.ElementAt(index + _minIndex); - } - - public TElement ElementAtOrDefault(int index) - { - return (uint)index <= (uint)_maxIndex - _minIndex ? _source.ElementAtOrDefault(index + _minIndex) : default(TElement); - } - - public TElement First() - { - TElement result; - if (!_source.TryGetElementAt(_minIndex, out result)) throw Error.NoElements(); - return result; - } + if ((uint)index > (uint)_maxIndex - _minIndex) + { + element = default(TElement); + return false; + } - public TElement FirstOrDefault() - { - return _source.ElementAtOrDefault(_minIndex); + return _source.TryGetElementAt(index + _minIndex, out element); } - public TElement Last() + public bool TryGetFirst(out TElement first) { - return _source.Last(_minIndex, _maxIndex); + return _source.TryGetElementAt(_minIndex, out first); } - public TElement LastOrDefault() + public bool TryGetLast(out TElement element) { - return _source.LastOrDefault(_minIndex, _maxIndex); + return _source.TryGetLast(_minIndex, _maxIndex, out element); } public TElement[] ToArray() @@ -4483,7 +4567,7 @@ public List ToList() } } - internal abstract class OrderedEnumerable : IOrderedEnumerable, IPartition + internal abstract class OrderedEnumerable : IOrderedEnumerable, IPartition, IIListProvider { internal IEnumerable source; @@ -4617,12 +4701,12 @@ IOrderedEnumerable IOrderedEnumerable.CreateOrderedEnumerabl return result; } - public IEnumerable Skip(int count) + public IPartition Skip(int count) { return new OrderedPartition(this, count, int.MaxValue); } - public IEnumerable Take(int count) + public IPartition Take(int count) { return new OrderedPartition(this, 0, count - 1); } @@ -4644,21 +4728,7 @@ public bool TryGetElementAt(int index, out TElement result) return false; } - public TElement ElementAt(int index) - { - TElement result; - if (!TryGetElementAt(index, out result)) throw Error.ArgumentOutOfRange("index"); - return result; - } - - public TElement ElementAtOrDefault(int index) - { - TElement result; - TryGetElementAt(index, out result); - return result; - } - - private bool TryGetFirst(out TElement result) + public bool TryGetFirst(out TElement result) { CachingComparer comparer = GetComparer(); using (IEnumerator e = source.GetEnumerator()) @@ -4680,20 +4750,6 @@ private bool TryGetFirst(out TElement result) } } - public TElement FirstOrDefault() - { - TElement result; - TryGetFirst(out result); - return result; - } - - public TElement First() - { - TElement result; - if (!TryGetFirst(out result)) throw Error.NoElements(); - return result; - } - public TElement First(Func predicate) { CachingComparer comparer = GetComparer(); @@ -4736,29 +4792,16 @@ public TElement FirstOrDefault(Func predicate) } } - public TElement Last() + public bool TryGetLast(out TElement element) { CachingComparer comparer = GetComparer(); using (IEnumerator e = source.GetEnumerator()) { - if (!e.MoveNext()) throw Error.NoElements(); - TElement value = e.Current; - comparer.SetElement(value); - while (e.MoveNext()) + if (!e.MoveNext()) { - TElement x = e.Current; - if (comparer.Compare(x, false) >= 0) value = x; + element = default(TElement); + return false; } - return value; - } - } - - public TElement LastOrDefault() - { - CachingComparer comparer = GetComparer(); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) return default(TElement); TElement value = e.Current; comparer.SetElement(value); while (e.MoveNext()) @@ -4766,28 +4809,28 @@ public TElement LastOrDefault() TElement x = e.Current; if (comparer.Compare(x, false) >= 0) value = x; } - return value; + element = value; + return true; } + } - public TElement Last(int minIdx, int maxIdx) + public bool TryGetLast(int minIdx, int maxIdx, out TElement last) { Buffer buffer = new Buffer(source); int count = buffer.count; - if (minIdx >= count) throw Error.NoElements(); - if (maxIdx < count - 1) return GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); + if (minIdx >= count) + { + last = default(TElement); + return false; + } + if (maxIdx < count - 1) + last = GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); // If we're here, we want the same results we would have got from // Last(), but we've already buffered our source. - return Last(buffer); - } - - public TElement LastOrDefault(int minIdx, int maxIdx) - { - Buffer buffer = new Buffer(source); - int count = buffer.count; - if (minIdx >= count) return default(TElement); - if (maxIdx < count - 1) return GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); - return Last(buffer); + else + last = Last(buffer); + return true; } private TElement Last(Buffer buffer) diff --git a/src/System.Linq/tests/OrderedSubsetting.cs b/src/System.Linq/tests/OrderedSubsetting.cs index 7d56573a7c8f..75e4695d2afd 100644 --- a/src/System.Linq/tests/OrderedSubsetting.cs +++ b/src/System.Linq/tests/OrderedSubsetting.cs @@ -351,5 +351,35 @@ public void EnumeratorDoesntContinue() while (enumerator.MoveNext()) { } Assert.False(enumerator.MoveNext()); } + + [Fact] + public void SubsetAfterSelect() + { + var page = Enumerable.Range(0, 50).Shuffle().OrderBy(i => i).Select(i => i * 2).Skip(20).Take(5); + Assert.Equal(new[] { 40, 42, 44, 46, 48 }, page); + Assert.Equal(new[] { 41, 43, 45, 47, 49 }, page.Select(i => i + 1)); + Assert.Equal(40, page.First()); + Assert.Equal(48, page.Last()); + Assert.Equal(42, page.ElementAt(1)); + Assert.Throws("index", () => page.ElementAt(20)); + page = Enumerable.Range(0, 50).Shuffle().OrderBy(i => i).Select(i => i * 2).Skip(100).Take(5); + Assert.Throws(() => page.First()); + Assert.Throws(() => page.Last()); + } + + [Fact] + public void SubsetAfterSelectEnumeratorDoesntContinue() + { + var enumerator = NumberRangeGuaranteedNotCollectionType(0, 3).Shuffle().OrderBy(i => i).Select(i => i * 2).Take(1).GetEnumerator(); + while (enumerator.MoveNext()) { } + Assert.False(enumerator.MoveNext()); + } + + [Fact] + public void SubsetAfterSelectRepeatEnumerating() + { + var page = NumberRangeGuaranteedNotCollectionType(0, 3).Shuffle().OrderBy(i => i).Select(i => i * 2).Take(1); + Assert.Equal(page, page); + } } } \ No newline at end of file diff --git a/src/System.Linq/tests/ReverseTests.cs b/src/System.Linq/tests/ReverseTests.cs index d61bd7557b8f..8aceb80fe497 100644 --- a/src/System.Linq/tests/ReverseTests.cs +++ b/src/System.Linq/tests/ReverseTests.cs @@ -77,5 +77,31 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void ToArray() + { + int?[] source = new int?[] { -10, 0, 5, null, 0, 9, 100, null, 9 }; + int?[] expected = new int?[] { 9, null, 100, 9, 0, null, 5, 0, -10 }; + + Assert.Equal(expected, source.Reverse().ToArray()); + } + + [Fact] + public void ToList() + { + int?[] source = new int?[] { -10, 0, 5, null, 0, 9, 100, null, 9 }; + int?[] expected = new int?[] { 9, null, 100, 9, 0, null, 5, 0, -10 }; + + Assert.Equal(expected, source.Reverse().ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + var reversed = new int?[] { -10, 0, 5, null, 0, 9, 100, null, 9 }.Reverse(); + + Assert.Equal(reversed, reversed); + } } } From 209edd14f90b94d667928d5f5be0185a097addee Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Mon, 18 Jan 2016 15:31:51 +0000 Subject: [PATCH 9/9] Don't use capacity when constructing empty lists. As per https://github.com/JonHanna/corefx/commit/66a2d2634b8db5e0c9ef8872c6ba92f55640fba3#commitcomment-15508545 --- src/System.Linq/src/System/Linq/Enumerable.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index be639ebe992c..be10b6c4ed8a 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -1234,7 +1234,7 @@ public TSource[] ToArray() public List ToList() { int lastIndex = _source.Count - 1; - if (lastIndex < _minIndex) return new List(0); + if (lastIndex < _minIndex) return new List(); if (lastIndex > _maxIndex) lastIndex = _maxIndex; List list = new List(lastIndex - _minIndex + 1); for (int i = _minIndex; i <= lastIndex; ++i) @@ -4493,7 +4493,7 @@ public TElement[] ToArray() public List ToList() { - return new List(0); + return new List(); } } @@ -4662,7 +4662,7 @@ internal List ToList(int minIdx, int maxIdx) { Buffer buffer = new Buffer(source); int count = buffer.count; - if (count <= minIdx) return new List(0); + if (count <= minIdx) return new List(); if (count <= maxIdx) maxIdx = count - 1; if (minIdx == maxIdx) return new List(1) { GetEnumerableSorter().ElementAt(buffer.items, count, minIdx) }; int[] map = SortedMap(buffer, minIdx, maxIdx);