diff --git a/src/System.Linq/src/System/Linq/Enumerable.cs b/src/System.Linq/src/System/Linq/Enumerable.cs index 8da90277886f..be10b6c4ed8a 100644 --- a/src/System.Linq/src/System/Linq/Enumerable.cs +++ b/src/System.Linq/src/System/Linq/Enumerable.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Threading; namespace System.Linq { @@ -48,6 +47,8 @@ public static IEnumerable Select(this IEnumerable iterator = source as Iterator; if (iterator != null) return iterator.Select(selector); + IPartition partition = source as IPartition; + if (partition != null) return new SelectIPartitionIterator(partition, selector); IList ilist = source as IList; if (ilist != null) { @@ -530,7 +531,7 @@ public override IEnumerable Select(Func s } - internal sealed class SelectArrayIterator : Iterator, IArrayProvider, IListProvider + internal sealed class SelectArrayIterator : Iterator, IPartition, IIListProvider { private readonly TSource[] _source; private readonly Func _selector; @@ -585,9 +586,59 @@ public List ToList() } return results; } + + public IPartition Skip(int count) + { + return count == 0 + ? (IPartition)new SelectArrayIterator(_source, _selector) + : new SelectIPartitionIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + } + + public IPartition Take(int count) + { + return count >= _source.Length + ? (IPartition)new SelectArrayIterator(_source, _selector) + : new SelectIPartitionIterator(new SkipListIterator(_source, 0, count - 1), _selector); + } + + public bool TryGetElementAt(int index, out TResult element) + { + if ((uint)index >= (uint)_source.Length) + { + element = default(TResult); + return false; + } + + element = _selector(_source[index]); + return true; + } + + public bool TryGetFirst(out TResult first) + { + if (_source.Length == 0) + { + first = default(TResult); + return false; + } + + first = _selector(_source[0]); + return true; + } + + public bool TryGetLast(out TResult last) + { + if (_source.Length == 0) + { + last = default(TResult); + return false; + } + + last = _selector(_source[_source.Length - 1]); + return true; + } } - internal sealed class SelectListIterator : Iterator, IArrayProvider, IListProvider + internal sealed class SelectListIterator : Iterator, IPartition, IIListProvider { private readonly List _source; private readonly Func _selector; @@ -651,9 +702,57 @@ public List ToList() } return results; } + + public IPartition Skip(int count) + { + return count == 0 + ? (IPartition)new SelectListIterator(_source, _selector) + : new SelectIPartitionIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + } + + public IPartition Take(int count) + { + return new SelectIPartitionIterator(new SkipListIterator(_source, 0, count - 1), _selector); + } + + public bool TryGetElementAt(int index, out TResult element) + { + if ((uint)index >= (uint)_source.Count) + { + element = default(TResult); + return false; + } + + element = _selector(_source[index]); + return true; + } + + public bool TryGetFirst(out TResult first) + { + if (_source.Count == 0) + { + first = default(TResult); + return false; + } + + first = _selector(_source[0]); + return true; + } + + public bool TryGetLast(out TResult result) + { + if (_source.Count == 0) + { + result = default(TResult); + return false; + } + + result = _selector(_source[_source.Count - 1]); + return true; + } } - internal sealed class SelectIListIterator : Iterator, IArrayProvider, IListProvider + internal sealed class SelectIListIterator : Iterator, IPartition, IIListProvider { private readonly IList _source; private readonly Func _selector; @@ -727,31 +826,159 @@ public List ToList() } return results; } + + public IPartition Skip(int count) + { + return count == 0 + ? (IPartition)new SelectIListIterator(_source, _selector) + : new SelectIPartitionIterator(new SkipListIterator(_source, count, int.MaxValue), _selector); + } + + public IPartition Take(int count) + { + return new SelectIPartitionIterator(new SkipListIterator(_source, 0, count - 1), _selector); + } + + public bool TryGetElementAt(int index, out TResult element) + { + if ((uint)index >= (uint)_source.Count) + { + element = default(TResult); + return false; + } + + element = _selector(_source[index]); + return true; + } + + public bool TryGetFirst(out TResult first) + { + if (_source.Count == 0) + { + first = default(TResult); + return false; + } + + first = _selector(_source[0]); + return true; + } + + public bool TryGetLast(out TResult last) + { + if (_source.Count == 0) + { + last = default(TResult); + return false; + } + + last = _selector(_source[_source.Count - 1]); + return true; + } } - //public static IEnumerable Where(this IEnumerable source, Func predicate) { - // if (source == null) throw Error.ArgumentNull("source"); - // if (predicate == null) throw Error.ArgumentNull("predicate"); - // return WhereIterator(source, predicate); - //} + internal sealed class SelectIPartitionIterator : Iterator, IPartition + { + private readonly IPartition _source; + private readonly Func _selector; + private IEnumerator _enumerator; - //static IEnumerable WhereIterator(IEnumerable source, Func predicate) { - // foreach (TSource element in source) { - // if (predicate(element)) yield return element; - // } - //} + public SelectIPartitionIterator(IPartition source, Func selector) + { + Debug.Assert(source != null); + Debug.Assert(selector != null); + _source = source; + _selector = selector; + } - //public static IEnumerable Select(this IEnumerable source, Func selector) { - // if (source == null) throw Error.ArgumentNull("source"); - // if (selector == null) throw Error.ArgumentNull("selector"); - // return SelectIterator(source, selector); - //} + public override Iterator Clone() + { + return new SelectIPartitionIterator(_source, _selector); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + if (_enumerator.MoveNext()) + { + current = _selector(_enumerator.Current); + return true; + } + Dispose(); + break; + } + return false; + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); + } + + public override IEnumerable Select(Func selector) + { + return new SelectIPartitionIterator(_source, CombineSelectors(_selector, selector)); + } + + public IPartition Skip(int count) + { + return new SelectIPartitionIterator(_source.Skip(count), _selector); + } + + public IPartition Take(int count) + { + return new SelectIPartitionIterator(_source.Take(count), _selector); + } + + public bool TryGetElementAt(int index, out TResult element) + { + TSource input; + if (_source.TryGetElementAt(index, out input)) + { + element = _selector(input); + return true; + } + + element = default(TResult); + return false; + } + + public bool TryGetFirst(out TResult first) + { + TSource input; + if (_source.TryGetFirst(out input)) + { + first = _selector(input); + return true; + } + + first = default(TResult); + return false; + } + + public bool TryGetLast(out TResult last) + { + TSource input; + if (_source.TryGetLast(out input)) + { + last = _selector(input); + return true; + } - //static IEnumerable SelectIterator(IEnumerable source, Func selector) { - // foreach (TSource element in source) { - // yield return selector(element); - // } - //} + last = default(TResult); + return false; + } + } public static IEnumerable SelectMany(this IEnumerable source, Func> selector) { @@ -836,6 +1063,8 @@ public static IEnumerable Take(this IEnumerable sourc if (count <= 0) return new EmptyPartition(); IPartition partition = source as IPartition; if (partition != null) return partition.Take(count); + IList sourceList = source as IList; + if (sourceList != null) return new SkipListIterator(sourceList, 0, count - 1); return TakeIterator(source, count); } @@ -889,14 +1118,128 @@ public static IEnumerable Skip(this IEnumerable sourc IPartition partition = source as IPartition; if (partition != null) return partition.Skip(count); IList sourceList = source as IList; - return sourceList != null ? SkipList(sourceList, count) : SkipIterator(source, count); + return sourceList != null ? new SkipListIterator(sourceList, count, int.MaxValue) : SkipIterator(source, count); } - private static IEnumerable SkipList(IList source, int count) + private sealed class SkipListIterator : Iterator, IPartition, IIListProvider { - while (count < source.Count) + private readonly IList _source; + private readonly int _minIndex; + private readonly int _maxIndex; + private int _index; + + public SkipListIterator(IList source, int minIndex, int maxIndex) + { + Debug.Assert(source != null); + Debug.Assert(minIndex >= 0); + Debug.Assert(minIndex <= maxIndex); + _source = source; + _minIndex = minIndex; + _maxIndex = maxIndex; + } + + public override Iterator Clone() + { + return new SkipListIterator(_source, _minIndex, _maxIndex); + } + + public override bool MoveNext() + { + switch(state) + { + case 1: + _index = _minIndex; + state = 2; + goto case 2; + case 2: + if (_index <= _maxIndex && _index < _source.Count) + { + current = _source[_index]; + ++_index; + return true; + } + break; + } + Dispose(); + return false; + } + + public IPartition Skip(int count) + { + int minIndex = _minIndex + count; + return minIndex >= _maxIndex + ? (IPartition)new EmptyPartition() + : new SkipListIterator(_source, minIndex, _maxIndex); + } + + public IPartition Take(int count) + { + int maxIndex = _minIndex + count - 1; + if (maxIndex >= _maxIndex) maxIndex = _maxIndex; + return new SkipListIterator(_source, _minIndex, maxIndex); + } + + public bool TryGetElementAt(int index, out TSource element) + { + if ((uint)index > (uint)_maxIndex - _minIndex || index >= _source.Count - _minIndex) + { + element = default(TSource); + return false; + } + + element = _source[_minIndex + index]; + return true; + } + + public bool TryGetFirst(out TSource first) + { + if (_source.Count <= _minIndex) + { + first = default(TSource); + return false; + } + + first = _source[_minIndex]; + return true; + } + + public bool TryGetLast(out TSource last) { - yield return source[count++]; + int lastIndex = _source.Count - 1; + if (lastIndex < _minIndex) + { + last = default(TSource); + return false; + } + + last = _source[lastIndex > _maxIndex ? _maxIndex : lastIndex]; + return true; + } + + public TSource[] ToArray() + { + int lastIndex = _source.Count - 1; + if (lastIndex < _minIndex) return new TSource[0]; + if (lastIndex > _maxIndex) lastIndex = _maxIndex; + TSource[] array = new TSource[lastIndex - _minIndex + 1]; + int curIdx = _minIndex; + for (int i = 0; i != array.Length; ++i) + { + array[i] = _source[curIdx]; + ++curIdx; + } + return array; + } + + public List ToList() + { + int lastIndex = _source.Count - 1; + if (lastIndex < _minIndex) return new List(); + if (lastIndex > _maxIndex) lastIndex = _maxIndex; + List list = new List(lastIndex - _minIndex + 1); + for (int i = _minIndex; i <= lastIndex; ++i) + list.Add(_source[i]); + return list; } } @@ -1138,43 +1481,189 @@ private static IEnumerable ZipIterator(IEnume public static IEnumerable Distinct(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); - return DistinctIterator(source, null); + return new DistinctIterator(source, null); } public static IEnumerable Distinct(this IEnumerable source, IEqualityComparer comparer) { if (source == null) throw Error.ArgumentNull("source"); - return DistinctIterator(source, comparer); + return new DistinctIterator(source, comparer); } - private static IEnumerable DistinctIterator(IEnumerable source, IEqualityComparer comparer) + private sealed class DistinctIterator : Iterator, IIListProvider { - Set set = new Set(comparer); - foreach (TSource element in source) - if (set.Add(element)) yield return element; + private readonly IEnumerable _source; + private readonly IEqualityComparer _comparer; + private Set _set; + private IEnumerator _enumerator; + + public DistinctIterator(IEnumerable source, IEqualityComparer comparer) + { + _source = source; + _comparer = comparer; + } + + public override Iterator Clone() + { + return new DistinctIterator(_source, _comparer); + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _set = new Set(_comparer); + _enumerator = _source.GetEnumerator(); + state = 2; + goto case 2; + case 2: + while (_enumerator.MoveNext()) + { + TSource element = _enumerator.Current; + if (_set.Add(element)) + { + current = element; + return true; + } + } + Dispose(); + break; + } + return false; + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); + } + + private Set FillList() + { + Set set = new Set(_comparer); + foreach (TSource element in _source) + set.Add(element); + return set; + } + + public TSource[] ToArray() + { + return FillList().ToArray(); + } + + public List ToList() + { + return FillList().ToList(); + } } public static IEnumerable Union(this IEnumerable first, IEnumerable second) { if (first == null) throw Error.ArgumentNull("first"); if (second == null) throw Error.ArgumentNull("second"); - return UnionIterator(first, second, null); + return new UnionIterator(first, second, null); } public static IEnumerable Union(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { if (first == null) throw Error.ArgumentNull("first"); if (second == null) throw Error.ArgumentNull("second"); - return UnionIterator(first, second, comparer); + return new UnionIterator(first, second, comparer); } - private static IEnumerable UnionIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) - { - Set set = new Set(comparer); - foreach (TSource element in first) - if (set.Add(element)) yield return element; - foreach (TSource element in second) - if (set.Add(element)) yield return element; + private sealed class UnionIterator : Iterator, IIListProvider + { + private readonly IEnumerable _first; + private readonly IEnumerable _second; + private readonly IEqualityComparer _comparer; + private Set _set; + private IEnumerator _enumerator; + + public UnionIterator(IEnumerable first, IEnumerable second, IEqualityComparer comparer) + { + _first = first; + _second = second; + _comparer = comparer; + } + + public override Iterator Clone() + { + return new UnionIterator(_first, _second, _comparer); + } + + private bool GetNext() + { + while (_enumerator.MoveNext()) + { + TSource element = _enumerator.Current; + if (_set.Add(element)) + { + current = element; + return true; + } + } + return false; + } + + public override bool MoveNext() + { + switch (state) + { + case 1: + _set = new Set(_comparer); + _enumerator = _first.GetEnumerator(); + state = 2; + goto case 2; + case 2: + if (GetNext()) + return true; + _enumerator.Dispose(); + _enumerator = _second.GetEnumerator(); + state = 3; + goto case 3; + case 3: + if (GetNext()) + return true; + Dispose(); + break; + } + return false; + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + base.Dispose(); + } + + private Set FillList() + { + Set set = new Set(_comparer); + foreach (TSource element in _first) + set.Add(element); + foreach (TSource element in _second) + set.Add(element); + return set; + } + + public TSource[] ToArray() + { + return FillList().ToArray(); + } + + public List ToList() + { + return FillList().ToList(); + } } public static IEnumerable Intersect(this IEnumerable first, IEnumerable second) @@ -1224,13 +1713,69 @@ private static IEnumerable ExceptIterator(IEnumerable public static IEnumerable Reverse(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); - return ReverseIterator(source); + return new ReverseIterator(source); } - private static IEnumerable ReverseIterator(IEnumerable source) + private sealed class ReverseIterator : Iterator, IIListProvider { - Buffer buffer = new Buffer(source); - for (int i = buffer.count - 1; i >= 0; i--) yield return buffer.items[i]; + private readonly IEnumerable _source; + private Buffer _buffer; + private int _index; + + public ReverseIterator(IEnumerable source) + { + Debug.Assert(source != null); + _source = source; + } + + public override Iterator Clone() + { + return new ReverseIterator(_source); + } + + public override bool MoveNext() + { + switch(state) + { + case 1: + _buffer = new Buffer(_source); + _index = _buffer.count - 1; + state = 2; + goto case 2; + case 2: + if (_index >= 0) + { + current = _buffer.items[_index]; + --_index; + return true; + } + break; + } + Dispose(); + return false; + } + + public TSource[] ToArray() + { + Buffer buffer = new Buffer(_source); + int count = buffer.count; + TSource[] sourceArray = buffer.items; + TSource[] array = new TSource[count]; + for (int i = 0, sourceIdx = count - 1; i != array.Length; ++i, --sourceIdx) + array[i] = sourceArray[sourceIdx]; + return array; + } + + public List ToList() + { + Buffer buffer = new Buffer(_source); + int count = buffer.count; + TSource[] sourceArray = buffer.items; + List list = new List(count); + for (int i = count - 1; i >= 0; --i) + list.Add(sourceArray[i]); + return list; + } } public static bool SequenceEqual(this IEnumerable first, IEnumerable second) @@ -1270,14 +1815,14 @@ public static IEnumerable AsEnumerable(this IEnumerable(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); - IArrayProvider arrayProvider = source as IArrayProvider; + IIListProvider arrayProvider = source as IIListProvider; return arrayProvider != null ? arrayProvider.ToArray() : new Buffer(source).ToArray(); } public static List ToList(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); - IListProvider listProvider = source as IListProvider; + IIListProvider listProvider = source as IIListProvider; return listProvider != null ? listProvider.ToList() : new List(source); } @@ -1452,17 +1997,25 @@ public static TSource First(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.First(); - IList list = source as IList; - if (list != null) + if (partition != null) { - if (list.Count > 0) return list[0]; + TSource result; + if (partition.TryGetFirst(out result)) + return result; } else { - using (IEnumerator e = source.GetEnumerator()) + IList list = source as IList; + if (list != null) { - if (e.MoveNext()) return e.Current; + if (list.Count > 0) return list[0]; + } + else + { + using (IEnumerator e = source.GetEnumerator()) + { + if (e.MoveNext()) return e.Current; + } } } throw Error.NoElements(); @@ -1485,7 +2038,12 @@ public static TSource FirstOrDefault(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.FirstOrDefault(); + if (partition != null) + { + TSource result; + partition.TryGetFirst(out result); + return result; + } IList list = source as IList; if (list != null) { @@ -1518,25 +2076,33 @@ public static TSource Last(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.Last(); - IList list = source as IList; - if (list != null) + if (partition != null) { - int count = list.Count; - if (count > 0) return list[count - 1]; + TSource result; + if (partition.TryGetLast(out result)) + return result; } else { - using (IEnumerator e = source.GetEnumerator()) + IList list = source as IList; + if (list != null) { - if (e.MoveNext()) + int count = list.Count; + if (count > 0) return list[count - 1]; + } + else + { + using (IEnumerator e = source.GetEnumerator()) { - TSource result; - do + if (e.MoveNext()) { - result = e.Current; - } while (e.MoveNext()); - return result; + TSource result; + do + { + result = e.Current; + } while (e.MoveNext()); + return result; + } } } } @@ -1584,7 +2150,12 @@ public static TSource LastOrDefault(this IEnumerable source) { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.LastOrDefault(); + if (partition != null) + { + TSource result; + partition.TryGetLast(out result); + return result; + } IList list = source as IList; if (list != null) { @@ -1735,17 +2306,25 @@ public static TSource ElementAt(this IEnumerable source, int i { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.ElementAt(index); - IList list = source as IList; - if (list != null) return list[index]; - if (index >= 0) + if (partition != null) { - using (IEnumerator e = source.GetEnumerator()) + TSource result; + if (partition.TryGetElementAt(index, out result)) + return result; + } + else + { + IList list = source as IList; + if (list != null) return list[index]; + if (index >= 0) { - while (e.MoveNext()) + using (IEnumerator e = source.GetEnumerator()) { - if (index == 0) return e.Current; - index--; + while (e.MoveNext()) + { + if (index == 0) return e.Current; + index--; + } } } } @@ -1756,7 +2335,12 @@ public static TSource ElementAtOrDefault(this IEnumerable sour { if (source == null) throw Error.ArgumentNull("source"); IPartition partition = source as IPartition; - if (partition != null) return partition.ElementAtOrDefault(index); + if (partition != null) + { + TSource result; + partition.TryGetElementAt(index, out result); + return result; + } if (index >= 0) { IList list = source as IList; @@ -1787,7 +2371,7 @@ public static IEnumerable Range(int start, int count) return new RangeIterator(start, count); } - private sealed class RangeIterator : Iterator, IArrayProvider, IListProvider, IPartition + private sealed class RangeIterator : Iterator, IPartition, IIListProvider { private readonly int _start; private readonly int _end; @@ -1862,35 +2446,28 @@ public IPartition Take(int count) return new RangeIterator(_start, count); } - public int ElementAt(int index) + public bool TryGetElementAt(int index, out int element) { - if ((uint)index >= (uint)(_end - _start)) throw Error.ArgumentOutOfRange("index"); - return _start + index; - } - - public int ElementAtOrDefault(int index) - { - return (uint)index >= (uint)(_end - _start) ? 0 : _start + index; - } - - public int First() - { - return _start; - } + if ((uint)index >= (uint)(_end - _start)) + { + element = 0; + return false; + } - public int FirstOrDefault() - { - return _start; + element = _start + index; + return true; } - public int Last() + public bool TryGetFirst(out int first) { - return _end - 1; + first = _start; + return true; } - public int LastOrDefault() + public bool TryGetLast(out int last) { - return _end - 1; + last = _end - 1; + return true; } } @@ -1901,7 +2478,7 @@ public static IEnumerable Repeat(TResult element, int count) return new RepeatIterator(element, count); } - private sealed class RepeatIterator : Iterator, IArrayProvider, IListProvider, IPartition + private sealed class RepeatIterator : Iterator, IPartition, IIListProvider { private readonly int _count; private int _sent; @@ -1965,35 +2542,28 @@ public IPartition Take(int count) return new RepeatIterator(current, count); } - public TResult ElementAt(int index) - { - if ((uint)index >= (uint)_count) throw Error.ArgumentOutOfRange("index"); - return current; - } - - public TResult ElementAtOrDefault(int index) - { - return (uint)index >= (uint)_count ? default(TResult) : current; - } - - public TResult First() + public bool TryGetElementAt(int index, out TResult element) { - return current; - } + if ((uint)index >= (uint)_count) + { + element = default(TResult); + return false; + } - public TResult FirstOrDefault() - { - return current; + element = current; + return true; } - public TResult Last() + public bool TryGetFirst(out TResult first) { - return current; + first = current; + return true; } - public TResult LastOrDefault() + public bool TryGetLast(out TResult last) { - return current; + last = current; + return true; } } @@ -3310,22 +3880,16 @@ public static decimal Average(this IEnumerable source, Func - /// An iterator that can produce an array through an optimized path. + /// An iterator that can produce an array or an through an optimized path. /// - internal interface IArrayProvider + internal interface IIListProvider { /// /// Produce an array of the sequence through an optimized path. /// /// The array. TElement[] ToArray(); - } - /// - /// An iterator that can produce a through an optimized path. - /// - internal interface IListProvider - { /// /// Produce a of the sequence through an optimized path. /// @@ -3358,7 +3922,7 @@ public interface ILookup : IEnumerable bool Contains(TKey key); } - public class Lookup : IEnumerable>, ILookup, IArrayProvider>, IListProvider> + public class Lookup : IEnumerable>, ILookup, IIListProvider> { private IEqualityComparer _comparer; private Grouping[] _groupings; @@ -3429,7 +3993,7 @@ public IEnumerator> GetEnumerator() } } - IGrouping[] IArrayProvider>.ToArray() + IGrouping[] IIListProvider>.ToArray() { IGrouping[] array = new IGrouping[_count]; int index = 0; @@ -3446,7 +4010,7 @@ IGrouping[] IArrayProvider>.ToArray() return array; } - List> IListProvider>.ToList() + List> IIListProvider>.ToList() { List> list = new List>(_count); Grouping g = _lastGrouping; @@ -3647,8 +4211,10 @@ internal class Set private int[] _buckets; private Slot[] _slots; private int _count; - private int _freeList; - private IEqualityComparer _comparer; + private readonly IEqualityComparer _comparer; +#if DEBUG + private bool _haveRemoved; +#endif public Set(IEqualityComparer comparer) { @@ -3656,18 +4222,36 @@ public Set(IEqualityComparer comparer) _comparer = comparer; _buckets = new int[7]; _slots = new Slot[7]; - _freeList = -1; } // If value is not in set, add it and return true; otherwise return false public bool Add(TElement value) { - return !Find(value, true); +#if DEBUG + Debug.Assert(!_haveRemoved, "This set assumes no adds after a removal. If your use requires adds after removal undo that optimization."); +#endif + int hashCode = InternalGetHashCode(value); + for (int i = _buckets[hashCode % _buckets.Length] - 1; i >= 0; i = _slots[i].next) + { + if (_slots[i].hashCode == hashCode && _comparer.Equals(_slots[i].value, value)) return false; + } + if (_count == _slots.Length) Resize(); + int index = _count; + _count++; + int bucket = hashCode % _buckets.Length; + _slots[index].hashCode = hashCode; + _slots[index].value = value; + _slots[index].next = _buckets[bucket] - 1; + _buckets[bucket] = index + 1; + return true; } // If value is in set, remove it and return true; otherwise return false public bool Remove(TElement value) { +#if DEBUG + _haveRemoved = true; +#endif int hashCode = InternalGetHashCode(value); int bucket = hashCode % _buckets.Length; int last = -1; @@ -3685,44 +4269,12 @@ public bool Remove(TElement value) } _slots[i].hashCode = -1; _slots[i].value = default(TElement); - _slots[i].next = _freeList; - _freeList = i; return true; } } return false; } - private bool Find(TElement value, bool add) - { - int hashCode = InternalGetHashCode(value); - for (int i = _buckets[hashCode % _buckets.Length] - 1; i >= 0; i = _slots[i].next) - { - if (_slots[i].hashCode == hashCode && _comparer.Equals(_slots[i].value, value)) return true; - } - if (add) - { - int index; - if (_freeList >= 0) - { - index = _freeList; - _freeList = _slots[index].next; - } - else - { - if (_count == _slots.Length) Resize(); - index = _count; - _count++; - } - int bucket = hashCode % _buckets.Length; - _slots[index].hashCode = hashCode; - _slots[index].value = value; - _slots[index].next = _buckets[bucket] - 1; - _buckets[bucket] = index + 1; - } - return false; - } - private void Resize() { int newSize = checked(_count * 2 + 1); @@ -3739,6 +4291,29 @@ private void Resize() _slots = newSlots; } + internal TElement[] ToArray() + { +#if DEBUG + Debug.Assert(!_haveRemoved, "Optimised ToArray cannot be called if Remove has been called."); +#endif + TElement[] array = new TElement[_count]; + for (int i = 0; i != array.Length; ++i) + array[i] = _slots[i].value; + return array; + } + + internal List ToList() + { +#if DEBUG + Debug.Assert(!_haveRemoved, "Optimised ToList cannot be called if Remove has been called."); +#endif + int count = _count; + List list = new List(count); + for (int i = 0; i != count; ++i) + list.Add(_slots[i].value); + return list; + } + internal int InternalGetHashCode(TElement value) { // Handle comparer implementations that throw when passed null @@ -3786,7 +4361,7 @@ IEnumerator IEnumerable.GetEnumerator() } } - internal class GroupedEnumerable : IEnumerable>, IArrayProvider>, IListProvider> + internal class GroupedEnumerable : IEnumerable>, IIListProvider> { private IEnumerable _source; private Func _keySelector; @@ -3816,37 +4391,31 @@ IEnumerator IEnumerable.GetEnumerator() public IGrouping[] ToArray() { - IArrayProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); + IIListProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); return lookup.ToArray(); } public List> ToList() { - IListProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); + IIListProvider> lookup = Lookup.Create(_source, _keySelector, _elementSelector, _comparer); return lookup.ToList(); } } - internal interface IPartition : IEnumerable, IArrayProvider + internal interface IPartition : IEnumerable { IPartition Skip(int count); IPartition Take(int count); - TElement ElementAt(int index); - - TElement ElementAtOrDefault(int index); + bool TryGetElementAt(int index, out TElement element); - TElement First(); + bool TryGetFirst(out TElement first); - TElement FirstOrDefault(); - - TElement Last(); - - TElement LastOrDefault(); + bool TryGetLast(out TElement last); } - internal sealed class EmptyPartition : IPartition, IListProvider, IEnumerator + internal sealed class EmptyPartition : IPartition, IEnumerator, IIListProvider { public EmptyPartition() { @@ -3899,34 +4468,22 @@ public IPartition Take(int count) return new EmptyPartition(); } - public TElement ElementAt(int index) - { - throw Error.ArgumentOutOfRange("index"); - } - - public TElement ElementAtOrDefault(int index) - { - return default(TElement); - } - - public TElement First() - { - throw Error.NoElements(); - } - - public TElement FirstOrDefault() + public bool TryGetElementAt(int index, out TElement element) { - return default(TElement); + element = default(TElement); + return false; } - public TElement Last() + public bool TryGetFirst(out TElement first) { - throw Error.NoElements(); + first = default(TElement); + return false; } - public TElement LastOrDefault() + public bool TryGetLast(out TElement last) { - return default(TElement); + last = default(TElement); + return false; } public TElement[] ToArray() @@ -3936,11 +4493,11 @@ public TElement[] ToArray() public List ToList() { - return new List(0); + return new List(); } } - internal sealed class OrderedPartition : IPartition + internal sealed class OrderedPartition : IPartition, IIListProvider { private readonly OrderedEnumerable _source; private readonly int _minIndex; @@ -3978,46 +4535,39 @@ public IPartition Take(int count) return new OrderedPartition(_source, _minIndex, maxIndex); } - public TElement ElementAt(int index) - { - if ((uint)index > (uint)_maxIndex - _minIndex) throw Error.ArgumentOutOfRange("index"); - return _source.ElementAt(index + _minIndex); - } - - public TElement ElementAtOrDefault(int index) + public bool TryGetElementAt(int index, out TElement element) { - return (uint)index <= (uint)_maxIndex - _minIndex ? _source.ElementAtOrDefault(index + _minIndex) : default(TElement); - } + if ((uint)index > (uint)_maxIndex - _minIndex) + { + element = default(TElement); + return false; + } - public TElement First() - { - TElement result; - if (!_source.TryGetElementAt(_minIndex, out result)) throw Error.NoElements(); - return result; + return _source.TryGetElementAt(index + _minIndex, out element); } - public TElement FirstOrDefault() + public bool TryGetFirst(out TElement first) { - return _source.ElementAtOrDefault(_minIndex); + return _source.TryGetElementAt(_minIndex, out first); } - public TElement Last() + public bool TryGetLast(out TElement element) { - return _source.Last(_minIndex, _maxIndex); + return _source.TryGetLast(_minIndex, _maxIndex, out element); } - public TElement LastOrDefault() + public TElement[] ToArray() { - return _source.LastOrDefault(_minIndex, _maxIndex); + return _source.ToArray(_minIndex, _maxIndex); } - public TElement[] ToArray() + public List ToList() { - return _source.ToArray(_minIndex, _maxIndex); + return _source.ToList(_minIndex, _maxIndex); } } - internal abstract class OrderedEnumerable : IOrderedEnumerable, IArrayProvider, IListProvider, IPartition + internal abstract class OrderedEnumerable : IOrderedEnumerable, IPartition, IIListProvider { internal IEnumerable source; @@ -4108,6 +4658,23 @@ internal TElement[] ToArray(int minIdx, int maxIdx) return array; } + internal List ToList(int minIdx, int maxIdx) + { + Buffer buffer = new Buffer(source); + int count = buffer.count; + if (count <= minIdx) return new List(); + if (count <= maxIdx) maxIdx = count - 1; + if (minIdx == maxIdx) return new List(1) { GetEnumerableSorter().ElementAt(buffer.items, count, minIdx) }; + int[] map = SortedMap(buffer, minIdx, maxIdx); + List list = new List(maxIdx - minIdx + 1); + while (minIdx <= maxIdx) + { + list.Add(buffer.items[map[minIdx]]); + ++minIdx; + } + return list; + } + private EnumerableSorter GetEnumerableSorter() { return GetEnumerableSorter(null); @@ -4161,21 +4728,7 @@ public bool TryGetElementAt(int index, out TElement result) return false; } - public TElement ElementAt(int index) - { - TElement result; - if (!TryGetElementAt(index, out result)) throw Error.ArgumentOutOfRange("index"); - return result; - } - - public TElement ElementAtOrDefault(int index) - { - TElement result; - TryGetElementAt(index, out result); - return result; - } - - private bool TryGetFirst(out TElement result) + public bool TryGetFirst(out TElement result) { CachingComparer comparer = GetComparer(); using (IEnumerator e = source.GetEnumerator()) @@ -4197,20 +4750,6 @@ private bool TryGetFirst(out TElement result) } } - public TElement FirstOrDefault() - { - TElement result; - TryGetFirst(out result); - return result; - } - - public TElement First() - { - TElement result; - if (!TryGetFirst(out result)) throw Error.NoElements(); - return result; - } - public TElement First(Func predicate) { CachingComparer comparer = GetComparer(); @@ -4253,29 +4792,16 @@ public TElement FirstOrDefault(Func predicate) } } - public TElement Last() + public bool TryGetLast(out TElement element) { CachingComparer comparer = GetComparer(); using (IEnumerator e = source.GetEnumerator()) { - if (!e.MoveNext()) throw Error.NoElements(); - TElement value = e.Current; - comparer.SetElement(value); - while (e.MoveNext()) + if (!e.MoveNext()) { - TElement x = e.Current; - if (comparer.Compare(x, false) >= 0) value = x; + element = default(TElement); + return false; } - return value; - } - } - - public TElement LastOrDefault() - { - CachingComparer comparer = GetComparer(); - using (IEnumerator e = source.GetEnumerator()) - { - if (!e.MoveNext()) return default(TElement); TElement value = e.Current; comparer.SetElement(value); while (e.MoveNext()) @@ -4283,28 +4809,28 @@ public TElement LastOrDefault() TElement x = e.Current; if (comparer.Compare(x, false) >= 0) value = x; } - return value; + element = value; + return true; } + } - public TElement Last(int minIdx, int maxIdx) + public bool TryGetLast(int minIdx, int maxIdx, out TElement last) { Buffer buffer = new Buffer(source); int count = buffer.count; - if (minIdx >= count) throw Error.NoElements(); - if (maxIdx < count - 1) return GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); + if (minIdx >= count) + { + last = default(TElement); + return false; + } + if (maxIdx < count - 1) + last = GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); // If we're here, we want the same results we would have got from // Last(), but we've already buffered our source. - return Last(buffer); - } - - public TElement LastOrDefault(int minIdx, int maxIdx) - { - Buffer buffer = new Buffer(source); - int count = buffer.count; - if (minIdx >= count) return default(TElement); - if (maxIdx < count - 1) return GetEnumerableSorter().ElementAt(buffer.items, count, maxIdx); - return Last(buffer); + else + last = Last(buffer); + return true; } private TElement Last(Buffer buffer) @@ -4656,7 +5182,7 @@ internal struct Buffer internal Buffer(IEnumerable source) { - IArrayProvider iterator = source as IArrayProvider; + IIListProvider iterator = source as IIListProvider; if (iterator != null) { TElement[] array = iterator.ToArray(); diff --git a/src/System.Linq/tests/DistinctTests.cs b/src/System.Linq/tests/DistinctTests.cs index 4f954a69b6ee..32bcf9e5928c 100644 --- a/src/System.Linq/tests/DistinctTests.cs +++ b/src/System.Linq/tests/DistinctTests.cs @@ -198,5 +198,33 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void ToArray() + { + int?[] source = { 1, 1, 1, 2, 2, 2, null, null }; + int?[] expected = { 1, 2, null }; + + Assert.Equal(expected, source.Distinct().ToArray()); + } + + [Fact] + public void ToList() + { + int?[] source = { 1, 1, 1, 2, 2, 2, null, null }; + int?[] expected = { 1, 2, null }; + + Assert.Equal(expected, source.Distinct().ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + int?[] source = { 1, 1, 1, 2, 2, 2, null, null }; + + var result = source.Distinct(); + + Assert.Equal(result, result); + } } } diff --git a/src/System.Linq/tests/ElementAtOrDefaultTests.cs b/src/System.Linq/tests/ElementAtOrDefaultTests.cs index a0fc8902b362..52c14c6c3a4e 100644 --- a/src/System.Linq/tests/ElementAtOrDefaultTests.cs +++ b/src/System.Linq/tests/ElementAtOrDefaultTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -131,5 +132,35 @@ public void NullSource() { Assert.Throws("source", () => ((IEnumerable)null).ElementAtOrDefault(2)); } + + [Fact] + public void ArraySelectSource() + { + var source = new[] { 1, 2, 3, 4 }.Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAtOrDefault(i)); + Assert.Equal(0, source.ElementAtOrDefault(-1)); + Assert.Equal(0, source.ElementAtOrDefault(4)); + } + + [Fact] + public void ListSelectSource() + { + var source = new[] { 1, 2, 3, 4 }.ToList().Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Equal(0, source.ElementAtOrDefault(-1)); + Assert.Equal(0, source.ElementAtOrDefault(4)); + } + + [Fact] + public void IListSelectSource() + { + var source = new ReadOnlyCollection(new[] { 1, 2, 3, 4 }).Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Equal(0, source.ElementAtOrDefault(-1)); + Assert.Equal(0, source.ElementAtOrDefault(4)); + } } } diff --git a/src/System.Linq/tests/ElementAtTests.cs b/src/System.Linq/tests/ElementAtTests.cs index b7c3ef9258f5..dc2a9212e884 100644 --- a/src/System.Linq/tests/ElementAtTests.cs +++ b/src/System.Linq/tests/ElementAtTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -131,5 +132,35 @@ public void NullSource() { Assert.Throws("source", () => ((IEnumerable)null).ElementAt(2)); } + + [Fact] + public void ArraySelectSource() + { + var source = new[] { 1, 2, 3, 4 }.Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Throws("index", () => source.ElementAt(-1)); + Assert.Throws("index", () => source.ElementAt(4)); + } + + [Fact] + public void ListSelectSource() + { + var source = new[] { 1, 2, 3, 4 }.ToList().Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Throws("index", () => source.ElementAt(-1)); + Assert.Throws("index", () => source.ElementAt(4)); + } + + [Fact] + public void IListSelectSource() + { + var source = new ReadOnlyCollection(new[] { 1, 2, 3, 4 }).Select(i => i * 2); + for (int i = 0; i != 4; ++i) + Assert.Equal((i + 1) * 2, source.ElementAt(i)); + Assert.Throws("index", () => source.ElementAt(-1)); + Assert.Throws("index", () => source.ElementAt(4)); + } } } diff --git a/src/System.Linq/tests/FirstOrDefaultTests.cs b/src/System.Linq/tests/FirstOrDefaultTests.cs index 4a6d2369257d..748b9c41c988 100644 --- a/src/System.Linq/tests/FirstOrDefaultTests.cs +++ b/src/System.Linq/tests/FirstOrDefaultTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -194,5 +195,26 @@ public void NullPredicate() Func predicate = null; Assert.Throws("predicate", () => Enumerable.Range(0, 3).FirstOrDefault(predicate)); } + + [Fact] + public void ArraySelectSource() + { + Assert.Equal(11, new[] { 5, 6, 7, 8 }.Select(i => i * 2 + 1).FirstOrDefault()); + Assert.Equal(0, new int[0].Select(i => i * 2 + 1).FirstOrDefault()); + } + + [Fact] + public void ListSelectSource() + { + Assert.Equal(11, new[] { 5, 6, 7, 8 }.ToList().Select(i => i * 2 + 1).FirstOrDefault()); + Assert.Equal(0, new List(0).Select(i => i * 2 + 1).FirstOrDefault()); + } + + [Fact] + public void IListSelectSource() + { + Assert.Equal(11, new ReadOnlyCollection(new[] { 5, 6, 7, 8 }).Select(i => i * 2 + 1).FirstOrDefault()); + Assert.Equal(0, new ReadOnlyCollection(new int[0]).Select(i => i * 2 + 1).FirstOrDefault()); + } } } diff --git a/src/System.Linq/tests/FirstTests.cs b/src/System.Linq/tests/FirstTests.cs index abcda8b354a6..a8a9f5e7fb54 100644 --- a/src/System.Linq/tests/FirstTests.cs +++ b/src/System.Linq/tests/FirstTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -191,5 +192,29 @@ public void NullPredicate() Func predicate = null; Assert.Throws("predicate", () => Enumerable.Range(0, 3).First(predicate)); } + + [Fact] + public void ArraySelectSource() + { + Assert.Equal(11, new[] { 5, 6, 7, 8 }.Select(i => i * 2 + 1).First()); + var emptySource = new int[0].Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.First()); + } + + [Fact] + public void ListSelectSource() + { + Assert.Equal(11, new[] { 5, 6, 7, 8 }.ToList().Select(i => i * 2 + 1).First()); + var emptySource = new List(0).Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.First()); + } + + [Fact] + public void IListSelectSource() + { + Assert.Equal(11, new ReadOnlyCollection(new[] { 5, 6, 7, 8 }).Select(i => i * 2 + 1).First()); + var emptySource = new ReadOnlyCollection(new int[0]).Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.First()); + } } } diff --git a/src/System.Linq/tests/GroupByTests.cs b/src/System.Linq/tests/GroupByTests.cs index 7c5b8f719d4c..791a1e6c0e6c 100644 --- a/src/System.Linq/tests/GroupByTests.cs +++ b/src/System.Linq/tests/GroupByTests.cs @@ -630,5 +630,17 @@ public void GroupingWithResultsToList() Assert.Equal(4, groupedList.Count); Assert.Equal(source.GroupBy(r => r.Name, (r, e) => e), groupedList); } + + [Fact] + public void EmptyGroupingToArray() + { + Assert.Empty(Enumerable.Empty().GroupBy(i => i).ToArray()); + } + + [Fact] + public void EmptyGroupingToList() + { + Assert.Empty(Enumerable.Empty().GroupBy(i => i).ToList()); + } } } \ No newline at end of file diff --git a/src/System.Linq/tests/LastOrDefaultTests.cs b/src/System.Linq/tests/LastOrDefaultTests.cs index 1b2413b3266e..43f5fc67df85 100644 --- a/src/System.Linq/tests/LastOrDefaultTests.cs +++ b/src/System.Linq/tests/LastOrDefaultTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests.LegacyTests @@ -244,5 +245,26 @@ public void NullPredicate() Func predicate = null; Assert.Throws("predicate", () => Enumerable.Range(0, 3).LastOrDefault(predicate)); } + + [Fact] + public void ArraySelectSource() + { + Assert.Equal(17, new[] { 5, 6, 7, 8 }.Select(i => i * 2 + 1).LastOrDefault()); + Assert.Equal(0, new int[0].Select(i => i * 2 + 1).LastOrDefault()); + } + + [Fact] + public void ListSelectSource() + { + Assert.Equal(17, new[] { 5, 6, 7, 8 }.ToList().Select(i => i * 2 + 1).LastOrDefault()); + Assert.Equal(0, new List(0).Select(i => i * 2 + 1).LastOrDefault()); + } + + [Fact] + public void IListSelectSource() + { + Assert.Equal(17, new ReadOnlyCollection(new[] { 5, 6, 7, 8 }).Select(i => i * 2 + 1).LastOrDefault()); + Assert.Equal(0, new ReadOnlyCollection(new int[0]).Select(i => i * 2 + 1).LastOrDefault()); + } } } diff --git a/src/System.Linq/tests/LastTests.cs b/src/System.Linq/tests/LastTests.cs index 7512bfbe60bb..9ed77153fc2d 100644 --- a/src/System.Linq/tests/LastTests.cs +++ b/src/System.Linq/tests/LastTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests @@ -239,5 +240,29 @@ public void NullPredicate() Func predicate = null; Assert.Throws("predicate", () => Enumerable.Range(0, 3).Last(predicate)); } + + [Fact] + public void ArraySelectSource() + { + Assert.Equal(17, new[] { 5, 6, 7, 8 }.Select(i => i * 2 + 1).Last()); + var emptySource = new int[0].Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.Last()); + } + + [Fact] + public void ListSelectSource() + { + Assert.Equal(17, new[] { 5, 6, 7, 8 }.ToList().Select(i => i * 2 + 1).Last()); + var emptySource = new List(0).Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.Last()); + } + + [Fact] + public void IListSelectSource() + { + Assert.Equal(17, new ReadOnlyCollection(new[] { 5, 6, 7, 8 }).Select(i => i * 2 + 1).Last()); + var emptySource = new ReadOnlyCollection(new int[0]).Select(i => i * 2 + 1); + Assert.Throws(() => emptySource.Last()); + } } } diff --git a/src/System.Linq/tests/OrderedSubsetting.cs b/src/System.Linq/tests/OrderedSubsetting.cs index cf25c4b6203e..75e4695d2afd 100644 --- a/src/System.Linq/tests/OrderedSubsetting.cs +++ b/src/System.Linq/tests/OrderedSubsetting.cs @@ -302,24 +302,48 @@ public void ToArray() Assert.Equal(Enumerable.Range(10, 20), Enumerable.Range(0, 100).Shuffle().OrderBy(i => i).Skip(10).Take(20).ToArray()); } + [Fact] + public void ToList() + { + Assert.Equal(Enumerable.Range(10, 20), Enumerable.Range(0, 100).Shuffle().OrderBy(i => i).Skip(10).Take(20).ToList()); + } + [Fact] public void EmptyToArray() { Assert.Empty(Enumerable.Range(0, 100).Shuffle().OrderBy(i => i).Skip(100).ToArray()); } + [Fact] + public void EmptyToList() + { + Assert.Empty(Enumerable.Range(0, 100).Shuffle().OrderBy(i => i).Skip(100).ToList()); + } + [Fact] public void AttemptedMoreArray() { Assert.Equal(Enumerable.Range(0, 20), Enumerable.Range(0, 20).Shuffle().OrderBy(i => i).Take(30).ToArray()); } + [Fact] + public void AttemptedMoreToList() + { + Assert.Equal(Enumerable.Range(0, 20), Enumerable.Range(0, 20).Shuffle().OrderBy(i => i).Take(30).ToList()); + } + [Fact] public void SingleElementToArray() { Assert.Equal(Enumerable.Repeat(10, 1), Enumerable.Range(0, 20).Shuffle().OrderBy(i => i).Skip(10).Take(1).ToArray()); } + [Fact] + public void SingleElementToList() + { + Assert.Equal(Enumerable.Repeat(10, 1), Enumerable.Range(0, 20).Shuffle().OrderBy(i => i).Skip(10).Take(1).ToList()); + } + [Fact] public void EnumeratorDoesntContinue() { @@ -327,5 +351,35 @@ public void EnumeratorDoesntContinue() while (enumerator.MoveNext()) { } Assert.False(enumerator.MoveNext()); } + + [Fact] + public void SubsetAfterSelect() + { + var page = Enumerable.Range(0, 50).Shuffle().OrderBy(i => i).Select(i => i * 2).Skip(20).Take(5); + Assert.Equal(new[] { 40, 42, 44, 46, 48 }, page); + Assert.Equal(new[] { 41, 43, 45, 47, 49 }, page.Select(i => i + 1)); + Assert.Equal(40, page.First()); + Assert.Equal(48, page.Last()); + Assert.Equal(42, page.ElementAt(1)); + Assert.Throws("index", () => page.ElementAt(20)); + page = Enumerable.Range(0, 50).Shuffle().OrderBy(i => i).Select(i => i * 2).Skip(100).Take(5); + Assert.Throws(() => page.First()); + Assert.Throws(() => page.Last()); + } + + [Fact] + public void SubsetAfterSelectEnumeratorDoesntContinue() + { + var enumerator = NumberRangeGuaranteedNotCollectionType(0, 3).Shuffle().OrderBy(i => i).Select(i => i * 2).Take(1).GetEnumerator(); + while (enumerator.MoveNext()) { } + Assert.False(enumerator.MoveNext()); + } + + [Fact] + public void SubsetAfterSelectRepeatEnumerating() + { + var page = NumberRangeGuaranteedNotCollectionType(0, 3).Shuffle().OrderBy(i => i).Select(i => i * 2).Take(1); + Assert.Equal(page, page); + } } } \ No newline at end of file diff --git a/src/System.Linq/tests/ReverseTests.cs b/src/System.Linq/tests/ReverseTests.cs index d61bd7557b8f..8aceb80fe497 100644 --- a/src/System.Linq/tests/ReverseTests.cs +++ b/src/System.Linq/tests/ReverseTests.cs @@ -77,5 +77,31 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void ToArray() + { + int?[] source = new int?[] { -10, 0, 5, null, 0, 9, 100, null, 9 }; + int?[] expected = new int?[] { 9, null, 100, 9, 0, null, 5, 0, -10 }; + + Assert.Equal(expected, source.Reverse().ToArray()); + } + + [Fact] + public void ToList() + { + int?[] source = new int?[] { -10, 0, 5, null, 0, 9, 100, null, 9 }; + int?[] expected = new int?[] { 9, null, 100, 9, 0, null, 5, 0, -10 }; + + Assert.Equal(expected, source.Reverse().ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + var reversed = new int?[] { -10, 0, 5, null, 0, 9, 100, null, 9 }.Reverse(); + + Assert.Equal(reversed, reversed); + } } } diff --git a/src/System.Linq/tests/SkipTests.cs b/src/System.Linq/tests/SkipTests.cs index a49570be9015..5bb6889bafb9 100644 --- a/src/System.Linq/tests/SkipTests.cs +++ b/src/System.Linq/tests/SkipTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; using Xunit.Abstractions; @@ -213,5 +214,258 @@ public void ForcedToEnumeratorDoesntEnumerateIList() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void FollowWithTake() + { + var source = new[] { 5, 6, 7, 8 }; + var expected = new[] { 6, 7 }; + Assert.Equal(expected, source.Skip(1).Take(2)); + } + + [Fact] + public void FollowWithTakeNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(5, 4); + var expected = new[] { 6, 7 }; + Assert.Equal(expected, source.Skip(1).Take(2)); + } + + [Fact] + public void FollowWithSkip() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var expected = new[] { 4, 5, 6 }; + Assert.Equal(expected, source.Skip(1).Skip(2).Skip(-4)); + } + + [Fact] + public void FollowWithSkipNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(1, 6); + var expected = new[] { 4, 5, 6 }; + Assert.Equal(expected, source.Skip(1).Skip(2).Skip(-4)); + } + + [Fact] + public void ElementAt() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var remaining = source.Skip(2); + Assert.Equal(3, remaining.ElementAt(0)); + Assert.Equal(4, remaining.ElementAt(1)); + Assert.Equal(6, remaining.ElementAt(3)); + Assert.Throws("index", () => remaining.ElementAt(-1)); + Assert.Throws("index", () => remaining.ElementAt(4)); + } + + [Fact] + public void ElementAtNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5, 6 }); + var remaining = source.Skip(2); + Assert.Equal(3, remaining.ElementAt(0)); + Assert.Equal(4, remaining.ElementAt(1)); + Assert.Equal(6, remaining.ElementAt(3)); + Assert.Throws("index", () => remaining.ElementAt(-1)); + Assert.Throws("index", () => remaining.ElementAt(4)); + } + + [Fact] + public void ElementAtOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var remaining = source.Skip(2); + Assert.Equal(3, remaining.ElementAtOrDefault(0)); + Assert.Equal(4, remaining.ElementAtOrDefault(1)); + Assert.Equal(6, remaining.ElementAtOrDefault(3)); + Assert.Equal(0, remaining.ElementAtOrDefault(-1)); + Assert.Equal(0, remaining.ElementAtOrDefault(4)); + } + + [Fact] + public void ElementAtOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5, 6 }); + var remaining = source.Skip(2); + Assert.Equal(3, remaining.ElementAtOrDefault(0)); + Assert.Equal(4, remaining.ElementAtOrDefault(1)); + Assert.Equal(6, remaining.ElementAtOrDefault(3)); + Assert.Equal(0, remaining.ElementAtOrDefault(-1)); + Assert.Equal(0, remaining.ElementAtOrDefault(4)); + } + + [Fact] + public void First() + { + var source = new []{ 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Skip(0).First()); + Assert.Equal(3, source.Skip(2).First()); + Assert.Equal(5, source.Skip(4).First()); + Assert.Throws(() => source.Skip(5).First()); + } + + [Fact] + public void FirstNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Skip(0).First()); + Assert.Equal(3, source.Skip(2).First()); + Assert.Equal(5, source.Skip(4).First()); + Assert.Throws(() => source.Skip(5).First()); + } + + [Fact] + public void FirstOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Skip(0).FirstOrDefault()); + Assert.Equal(3, source.Skip(2).FirstOrDefault()); + Assert.Equal(5, source.Skip(4).FirstOrDefault()); + Assert.Equal(0, source.Skip(5).FirstOrDefault()); + } + + [Fact] + public void FirstOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Skip(0).FirstOrDefault()); + Assert.Equal(3, source.Skip(2).FirstOrDefault()); + Assert.Equal(5, source.Skip(4).FirstOrDefault()); + Assert.Equal(0, source.Skip(5).FirstOrDefault()); + } + + [Fact] + public void Last() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(5, source.Skip(0).Last()); + Assert.Equal(5, source.Skip(1).Last()); + Assert.Equal(5, source.Skip(4).Last()); + Assert.Throws(() => source.Skip(5).Last()); + } + + [Fact] + public void LastNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(5, source.Skip(0).Last()); + Assert.Equal(5, source.Skip(1).Last()); + Assert.Equal(5, source.Skip(4).Last()); + Assert.Throws(() => source.Skip(5).Last()); + } + + [Fact] + public void LastOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(5, source.Skip(0).LastOrDefault()); + Assert.Equal(5, source.Skip(1).LastOrDefault()); + Assert.Equal(5, source.Skip(4).LastOrDefault()); + Assert.Equal(0, source.Skip(5).LastOrDefault()); + } + + [Fact] + public void LastOrDefaultNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(5, source.Skip(0).LastOrDefault()); + Assert.Equal(5, source.Skip(1).LastOrDefault()); + Assert.Equal(5, source.Skip(4).LastOrDefault()); + Assert.Equal(0, source.Skip(5).LastOrDefault()); + } + + [Fact] + public void ToArray() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Skip(0).ToArray()); + Assert.Equal(new[] { 2, 3, 4, 5 }, source.Skip(1).ToArray()); + Assert.Equal(5, source.Skip(4).ToArray().Single()); + Assert.Empty(source.Skip(5).ToArray()); + Assert.Empty(source.Skip(40).ToArray()); + } + + [Fact] + public void ToArrayNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Skip(0).ToArray()); + Assert.Equal(new[] { 2, 3, 4, 5 }, source.Skip(1).ToArray()); + Assert.Equal(5, source.Skip(4).ToArray().Single()); + Assert.Empty(source.Skip(5).ToArray()); + Assert.Empty(source.Skip(40).ToArray()); + } + + [Fact] + public void ToList() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Skip(0).ToList()); + Assert.Equal(new[] { 2, 3, 4, 5 }, source.Skip(1).ToList()); + Assert.Equal(5, source.Skip(4).ToList().Single()); + Assert.Empty(source.Skip(5).ToList()); + Assert.Empty(source.Skip(40).ToList()); + } + + [Fact] + public void ToListNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Skip(0).ToList()); + Assert.Equal(new[] { 2, 3, 4, 5 }, source.Skip(1).ToList()); + Assert.Equal(5, source.Skip(4).ToList().Single()); + Assert.Empty(source.Skip(5).ToList()); + Assert.Empty(source.Skip(40).ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + var source = new[] { 1, 2, 3, 4, 5 }; + var remaining = source.Skip(1); + Assert.Equal(remaining, remaining); + } + + [Fact] + public void RepeatEnumeratingNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + var remaining = source.Skip(1); + Assert.Equal(remaining, remaining); + } + + [Fact] + public void ArraySelectSource() + { + var source = new[] { 1, 2, 3, 4 }.Select(i => i * 2); + Assert.Equal(new[] { 6, 8 }, source.Skip(2)); + Assert.Empty(source.Skip(4)); + Assert.Empty(source.Skip(20)); + Assert.Equal(source, source.Skip(0)); + Assert.Equal(source, source.Skip(-1)); + } + + [Fact] + public void ListSelectSource() + { + var source = new[] { 1, 2, 3, 4 }.ToList().Select(i => i * 2); + Assert.Equal(new[] { 6, 8 }, source.Skip(2)); + Assert.Empty(source.Skip(4)); + Assert.Empty(source.Skip(20)); + Assert.Equal(source, source.Skip(0)); + Assert.Equal(source, source.Skip(-1)); + } + + [Fact] + public void IListSelectSource() + { + var source = new ReadOnlyCollection(new[] { 1, 2, 3, 4 }).Select(i => i * 2); + Assert.Equal(new[] { 6, 8 }, source.Skip(2)); + Assert.Empty(source.Skip(4)); + Assert.Empty(source.Skip(20)); + Assert.Equal(source, source.Skip(0)); + Assert.Equal(source, source.Skip(-1)); + } } } diff --git a/src/System.Linq/tests/TakeTests.cs b/src/System.Linq/tests/TakeTests.cs index e767a74dadbd..36d65813a4f7 100644 --- a/src/System.Linq/tests/TakeTests.cs +++ b/src/System.Linq/tests/TakeTests.cs @@ -3,12 +3,19 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using Xunit; namespace System.Linq.Tests { public class TakeTests : EnumerableTests { + private static IEnumerable GuaranteeNotIList(IEnumerable source) + { + foreach (T element in source) + yield return element; + } + [Fact] public void SameResultsRepeatCallsIntQuery() { @@ -19,6 +26,16 @@ where x > Int32.MinValue Assert.Equal(q.Take(9), q.Take(9)); } + [Fact] + public void SameResultsRepeatCallsIntQueryIList() + { + var q = (from x in new[] { 9999, 0, 888, -1, 66, -777, 1, 2, -12345 } + where x > Int32.MinValue + select x).ToList(); + + Assert.Equal(q.Take(9), q.Take(9)); + } + [Fact] public void SameResultsRepeatCallsStringQuery() { @@ -29,6 +46,16 @@ public void SameResultsRepeatCallsStringQuery() Assert.Equal(q.Take(7), q.Take(7)); } + [Fact] + public void SameResultsRepeatCallsStringQueryIList() + { + var q = (from x in new[] { "!@#$%^", "C", "AAA", "", "Calling Twice", "SoS", String.Empty } + where !String.IsNullOrEmpty(x) + select x).ToList(); + + Assert.Equal(q.Take(7), q.Take(7)); + } + [Fact] public void SourceEmptyCountPositive() { @@ -36,6 +63,13 @@ public void SourceEmptyCountPositive() Assert.Empty(source.Take(5)); } + [Fact] + public void SourceEmptyCountPositiveNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(0, 0); + Assert.Empty(source.Take(5)); + } + [Fact] public void SourceNonEmptyCountNegative() { @@ -43,6 +77,13 @@ public void SourceNonEmptyCountNegative() Assert.Empty(source.Take(-5)); } + [Fact] + public void SourceNonEmptyCountNegativeNotIList() + { + var source = GuaranteeNotIList(new[] { 2, 5, 9, 1 }); + Assert.Empty(source.Take(-5)); + } + [Fact] public void SourceNonEmptyCountZero() { @@ -50,6 +91,13 @@ public void SourceNonEmptyCountZero() Assert.Empty(source.Take(0)); } + [Fact] + public void SourceNonEmptyCountZeroNotIList() + { + var source = GuaranteeNotIList(new []{ 2, 5, 9, 1 }); + Assert.Empty(source.Take(0)); + } + [Fact] public void SourceNonEmptyCountOne() { @@ -59,14 +107,31 @@ public void SourceNonEmptyCountOne() Assert.Equal(expected, source.Take(1)); } + [Fact] + public void SourceNonEmptyCountOneNotIList() + { + var source = GuaranteeNotIList(new[] { 2, 5, 9, 1 }); + int[] expected = { 2 }; + + Assert.Equal(expected, source.Take(1)); + } + [Fact] public void SourceNonEmptyTakeAllExactly() { int[] source = { 2, 5, 9, 1 }; - + Assert.Equal(source, source.Take(source.Length)); } + [Fact] + public void SourceNonEmptyTakeAllExactlyNotIList() + { + var source = GuaranteeNotIList(new[] { 2, 5, 9, 1 }); + + Assert.Equal(source, source.Take(source.Count())); + } + [Fact] public void SourceNonEmptyTakeAllButOne() { @@ -76,6 +141,15 @@ public void SourceNonEmptyTakeAllButOne() Assert.Equal(expected, source.Take(3)); } + [Fact] + public void SourceNonEmptyTakeAllButOneNotIList() + { + var source = GuaranteeNotIList(new[] { 2, 5, 9, 1 }); + int[] expected = { 2, 5, 9 }; + + Assert.Equal(expected, source.Take(3)); + } + [Fact] public void SourceNonEmptyTakeExcessive() { @@ -83,7 +157,15 @@ public void SourceNonEmptyTakeExcessive() Assert.Equal(source, source.Take(source.Length + 1)); } - + + [Fact] + public void SourceNonEmptyTakeExcessiveNotIList() + { + var source = GuaranteeNotIList(new int?[] { 2, 5, null, 9, 1 }); + + Assert.Equal(source, source.Take(source.Count() + 1)); + } + [Fact] public void ThrowsOnNullSource() { @@ -99,5 +181,275 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void ForcedToEnumeratorDoesntEnumerateIList() + { + var iterator = NumberRangeGuaranteedNotCollectionType(0, 3).ToList().Take(2); + // Don't insist on this behaviour, but check its correct if it happens + var en = iterator as IEnumerator; + Assert.False(en != null && en.MoveNext()); + } + + [Fact] + public void FollowWithTake() + { + var source = new[] { 5, 6, 7, 8 }; + var expected = new[] { 5, 6 }; + Assert.Equal(expected, source.Take(5).Take(3).Take(2).Take(40)); + } + + [Fact] + public void FollowWithTakeNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(5, 4); + var expected = new[] { 5, 6 }; + Assert.Equal(expected, source.Take(5).Take(3).Take(2)); + } + + [Fact] + public void FollowWithSkip() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var expected = new[] { 3, 4, 5 }; + Assert.Equal(expected, source.Take(5).Skip(2).Skip(-4)); + } + + [Fact] + public void FollowWithSkipNotIList() + { + var source = NumberRangeGuaranteedNotCollectionType(1, 6); + var expected = new[] { 3, 4, 5 }; + Assert.Equal(expected, source.Take(5).Skip(2).Skip(-4)); + } + + [Fact] + public void ElementAt() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var taken = source.Take(3); + Assert.Equal(1, taken.ElementAt(0)); + Assert.Equal(3, taken.ElementAt(2)); + Assert.Throws("index", () => taken.ElementAt(-1)); + Assert.Throws("index", () => taken.ElementAt(3)); + } + + [Fact] + public void ElementAtNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5, 6 }); + var taken = source.Take(3); + Assert.Equal(1, taken.ElementAt(0)); + Assert.Equal(3, taken.ElementAt(2)); + Assert.Throws("index", () => taken.ElementAt(-1)); + Assert.Throws("index", () => taken.ElementAt(3)); + } + + [Fact] + public void ElementAtOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5, 6 }; + var taken = source.Take(3); + Assert.Equal(1, taken.ElementAtOrDefault(0)); + Assert.Equal(3, taken.ElementAtOrDefault(2)); + Assert.Equal(0, taken.ElementAtOrDefault(-1)); + Assert.Equal(0, taken.ElementAtOrDefault(3)); + } + + [Fact] + public void ElementAtOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5, 6 }); + var taken = source.Take(3); + Assert.Equal(1, taken.ElementAtOrDefault(0)); + Assert.Equal(3, taken.ElementAtOrDefault(2)); + Assert.Equal(0, taken.ElementAtOrDefault(-1)); + Assert.Equal(0, taken.ElementAtOrDefault(3)); + } + + [Fact] + public void First() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Take(1).First()); + Assert.Equal(1, source.Take(4).First()); + Assert.Equal(1, source.Take(40).First()); + Assert.Throws(() => source.Take(0).First()); + Assert.Throws(() => source.Skip(5).Take(10).First()); + } + + [Fact] + public void FirstNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Take(1).First()); + Assert.Equal(1, source.Take(4).First()); + Assert.Equal(1, source.Take(40).First()); + Assert.Throws(() => source.Take(0).First()); + Assert.Throws(() => source.Skip(5).Take(10).First()); + } + + [Fact] + public void FirstOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Take(1).FirstOrDefault()); + Assert.Equal(1, source.Take(4).FirstOrDefault()); + Assert.Equal(1, source.Take(40).FirstOrDefault()); + Assert.Equal(0, source.Take(0).FirstOrDefault()); + Assert.Equal(0, source.Skip(5).Take(10).FirstOrDefault()); + } + + [Fact] + public void FirstOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Take(1).FirstOrDefault()); + Assert.Equal(1, source.Take(4).FirstOrDefault()); + Assert.Equal(1, source.Take(40).FirstOrDefault()); + Assert.Equal(0, source.Take(0).FirstOrDefault()); + Assert.Equal(0, source.Skip(5).Take(10).FirstOrDefault()); + } + + [Fact] + public void Last() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Take(1).Last()); + Assert.Equal(5, source.Take(5).Last()); + Assert.Equal(5, source.Take(40).Last()); + Assert.Throws(() => source.Take(0).Last()); + Assert.Throws(() => Array.Empty().Take(40).Last()); + } + + [Fact] + public void LastNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Take(1).Last()); + Assert.Equal(5, source.Take(5).Last()); + Assert.Equal(5, source.Take(40).Last()); + Assert.Throws(() => source.Take(0).Last()); + Assert.Throws(() => GuaranteeNotIList(Array.Empty()).Take(40).Last()); + } + + [Fact] + public void LastOrDefault() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(1, source.Take(1).LastOrDefault()); + Assert.Equal(5, source.Take(5).LastOrDefault()); + Assert.Equal(5, source.Take(40).LastOrDefault()); + Assert.Equal(0, source.Take(0).LastOrDefault()); + Assert.Equal(0, Array.Empty().Take(40).LastOrDefault()); + } + + [Fact] + public void LastOrDefaultNotIList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(1, source.Take(1).LastOrDefault()); + Assert.Equal(5, source.Take(5).LastOrDefault()); + Assert.Equal(5, source.Take(40).LastOrDefault()); + Assert.Equal(0, source.Take(0).LastOrDefault()); + Assert.Equal(0, GuaranteeNotIList(Array.Empty()).Take(40).LastOrDefault()); + } + + [Fact] + public void ToArray() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(5).ToArray()); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(40).ToArray()); + Assert.Equal(new[] { 1, 2, 3, 4 }, source.Take(4).ToArray()); + Assert.Equal(1, source.Take(1).ToArray().Single()); + Assert.Empty(source.Take(0).ToArray()); + Assert.Empty(source.Take(-10).ToArray()); + } + + [Fact] + public void ToArrayNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(5).ToArray()); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(40).ToArray()); + Assert.Equal(new[] { 1, 2, 3, 4 }, source.Take(4).ToArray()); + Assert.Equal(1, source.Take(1).ToArray().Single()); + Assert.Empty(source.Take(0).ToArray()); + Assert.Empty(source.Take(-10).ToArray()); + } + + [Fact] + public void ToList() + { + var source = new[] { 1, 2, 3, 4, 5 }; + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(5).ToList()); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(40).ToList()); + Assert.Equal(new[] { 1, 2, 3, 4 }, source.Take(4).ToList()); + Assert.Equal(1, source.Take(1).ToList().Single()); + Assert.Empty(source.Take(0).ToList()); + Assert.Empty(source.Take(-10).ToList()); + } + + [Fact] + public void ToListNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(5).ToList()); + Assert.Equal(new[] { 1, 2, 3, 4, 5 }, source.Take(40).ToList()); + Assert.Equal(new[] { 1, 2, 3, 4 }, source.Take(4).ToList()); + Assert.Equal(1, source.Take(1).ToList().Single()); + Assert.Empty(source.Take(0).ToList()); + Assert.Empty(source.Take(-10).ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + var source = new[] { 1, 2, 3, 4, 5 }; + var taken = source.Take(3); + Assert.Equal(taken, taken); + } + + [Fact] + public void RepeatEnumeratingNotList() + { + var source = GuaranteeNotIList(new[] { 1, 2, 3, 4, 5 }); + var taken = source.Take(3); + Assert.Equal(taken, taken); + } + + [Fact] + public void ArraySelectSource() + { + var source = new[] { 1, 2, 3, 4 }.Select(i => i * 2); + Assert.Equal(new[] { 2, 4 }, source.Take(2)); + Assert.Empty(source.Take(0)); + Assert.Empty(source.Take(-20)); + Assert.Equal(source, source.Take(4)); + Assert.Equal(source, source.Take(40)); + } + + [Fact] + public void ListSelectSource() + { + var source = new[] { 1, 2, 3, 4 }.ToList().Select(i => i * 2); + Assert.Equal(new[] { 2, 4 }, source.Take(2)); + Assert.Empty(source.Take(0)); + Assert.Empty(source.Take(-20)); + Assert.Equal(source, source.Take(4)); + Assert.Equal(source, source.Take(40)); + } + + [Fact] + public void IListSelectSource() + { + var source = new ReadOnlyCollection(new[] { 1, 2, 3, 4 }).Select(i => i * 2); + Assert.Equal(new[] { 2, 4 }, source.Take(2)); + Assert.Empty(source.Take(0)); + Assert.Empty(source.Take(-20)); + Assert.Equal(source, source.Take(4)); + Assert.Equal(source, source.Take(40)); + } } } diff --git a/src/System.Linq/tests/UnionTests.cs b/src/System.Linq/tests/UnionTests.cs index 6decc34defba..b16143d90647 100644 --- a/src/System.Linq/tests/UnionTests.cs +++ b/src/System.Linq/tests/UnionTests.cs @@ -205,5 +205,36 @@ public void ForcedToEnumeratorDoesntEnumerate() var en = iterator as IEnumerator; Assert.False(en != null && en.MoveNext()); } + + [Fact] + public void ToArray() + { + string[] first = { "Bob", "Robert", "Tim", "Matt", "miT" }; + string[] second = { "ttaM", "Charlie", "Bbo" }; + string[] expected = { "Bob", "Robert", "Tim", "Matt", "miT", "ttaM", "Charlie", "Bbo" }; + + Assert.Equal(expected, first.Union(second).ToArray()); + } + + [Fact] + public void ToList() + { + string[] first = { "Bob", "Robert", "Tim", "Matt", "miT" }; + string[] second = { "ttaM", "Charlie", "Bbo" }; + string[] expected = { "Bob", "Robert", "Tim", "Matt", "miT", "ttaM", "Charlie", "Bbo" }; + + Assert.Equal(expected, first.Union(second).ToList()); + } + + [Fact] + public void RepeatEnumerating() + { + string[] first = { "Bob", "Robert", "Tim", "Matt", "miT" }; + string[] second = { "ttaM", "Charlie", "Bbo" }; + + var result = first.Union(second); + + Assert.Equal(result, result); + } } }