From 9857e37c02aeccbeb6170ffce799513a57c273d6 Mon Sep 17 00:00:00 2001 From: Jon Hanna Date: Mon, 2 May 2016 04:09:28 +0100 Subject: [PATCH] Optimise multiple Append and Prepend calls. Use a more compact form for multiple calls to Append and Prepend, that shares a collection of appended and/or prepended elements between instances. --- src/System.Linq/src/System.Linq.csproj | 1 + .../src/System/Linq/AppendPrepend.cs | 430 ++++++++++++++++++ .../src/System/Linq/Concatenate.cs | 40 -- src/System.Linq/tests/AppendPrependTests.cs | 142 +++++- 4 files changed, 572 insertions(+), 41 deletions(-) create mode 100644 src/System.Linq/src/System/Linq/AppendPrepend.cs diff --git a/src/System.Linq/src/System.Linq.csproj b/src/System.Linq/src/System.Linq.csproj index bac6913d1036..ed4ef9447403 100644 --- a/src/System.Linq/src/System.Linq.csproj +++ b/src/System.Linq/src/System.Linq.csproj @@ -26,6 +26,7 @@ + diff --git a/src/System.Linq/src/System/Linq/AppendPrepend.cs b/src/System.Linq/src/System/Linq/AppendPrepend.cs new file mode 100644 index 000000000000..d0298e67d55f --- /dev/null +++ b/src/System.Linq/src/System/Linq/AppendPrepend.cs @@ -0,0 +1,430 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable Append(this IEnumerable source, TSource element) + { + if (source == null) + { + throw Error.ArgumentNull(nameof(source)); + } + + AppendPrependIterator appendable = source as AppendPrependIterator; + if (appendable != null) + { + return appendable.Append(element); + } + + return new AppendPrepend1Iterator(source, element, true); + } + + public static IEnumerable Prepend(this IEnumerable source, TSource element) + { + if (source == null) + { + throw Error.ArgumentNull(nameof(source)); + } + + AppendPrependIterator appendable = source as AppendPrependIterator; + if (appendable != null) + { + return appendable.Prepend(element); + } + + return new AppendPrepend1Iterator(source, element, false); + } + + private abstract class AppendPrependIterator : Iterator, IIListProvider + { + protected readonly IEnumerable _source; + protected IEnumerator _enumerator; + + protected AppendPrependIterator(IEnumerable source) + { + Debug.Assert(source != null); + _source = source; + } + + protected void GetSourceEnumerator() + { + Debug.Assert(_enumerator == null); + _enumerator = _source.GetEnumerator(); + } + + public abstract AppendPrependIterator Append(TSource item); + + public abstract AppendPrependIterator Prepend(TSource item); + + protected bool LoadFromEnumerator() + { + if (_enumerator.MoveNext()) + { + _current = _enumerator.Current; + return true; + } + + Dispose(); + return false; + } + + public override void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + + base.Dispose(); + } + + public abstract TSource[] ToArray(); + + public abstract List ToList(); + + public abstract int GetCount(bool onlyIfCheap); + } + + private class AppendPrepend1Iterator : AppendPrependIterator + { + private readonly TSource _item; + private readonly bool _appending; + + public AppendPrepend1Iterator(IEnumerable source, TSource item, bool appending) + : base(source) + { + _item = item; + _appending = appending; + } + + public override Iterator Clone() => new AppendPrepend1Iterator(_source, _item, _appending); + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _state = 2; + if (!_appending) + { + _current = _item; + return true; + } + + goto case 2; + case 2: + GetSourceEnumerator(); + _state = 3; + goto case 3; + case 3: + if (LoadFromEnumerator()) + { + return true; + } + + if (_appending) + { + _current = _item; + return true; + } + + break; + } + + Dispose(); + return false; + } + + public override AppendPrependIterator Append(TSource item) + { + if (_appending) + { + return new AppendPrependN(_source, null, new SingleLinkedNode(_item, item)); + } + else + { + return new AppendPrependN(_source, new SingleLinkedNode(_item), new SingleLinkedNode(item)); + } + } + + public override AppendPrependIterator Prepend(TSource item) + { + if (_appending) + { + return new AppendPrependN(_source, new SingleLinkedNode(item), new SingleLinkedNode(_item)); + } + else + { + return new AppendPrependN(_source, new SingleLinkedNode(_item, item), null); + } + } + + public override TSource[] ToArray() + { + int count = GetCount(onlyIfCheap: true); + if (count == -1) + { + return EnumerableHelpers.ToArray(this); + } + + TSource[] array = new TSource[count]; + int index; + if (_appending) + { + index = 0; + } + else + { + array[0] = _item; + index = 1; + } + + ICollection sourceCollection = _source as ICollection; + if (sourceCollection != null) + { + sourceCollection.CopyTo(array, index); + } + else + { + foreach (TSource item in _source) + { + array[index] = item; + ++index; + } + } + + if (_appending) + { + array[array.Length - 1] = _item; + } + + return array; + } + + public override List ToList() + { + int count = GetCount(onlyIfCheap: true); + List list = count == -1 ? new List() : new List(count); + if (!_appending) + { + list.Add(_item); + } + + list.AddRange(_source); + if (_appending) + { + list.Add(_item); + } + + return list; + } + + public override int GetCount(bool onlyIfCheap) + { + IIListProvider listProv = _source as IIListProvider; + if (listProv != null) + { + int count = listProv.GetCount(onlyIfCheap); + return count == -1 ? -1 : count + 1; + } + + return !onlyIfCheap || _source is ICollection ? _source.Count() + 1 : -1; + } + } + + private sealed class SingleLinkedNode + { + public SingleLinkedNode(TSource first, TSource second) + { + Linked = new SingleLinkedNode(first); + Item = second; + Count = 2; + } + + public SingleLinkedNode(TSource item) + { + Item = item; + Count = 1; + } + + private SingleLinkedNode(SingleLinkedNode linked, TSource item) + { + Debug.Assert(linked != null); + Linked = linked; + Item = item; + Count = linked.Count + 1; + } + + public TSource Item { get; } + + public SingleLinkedNode Linked { get; } + + public int Count { get; } + + public SingleLinkedNode Add(TSource item) => new SingleLinkedNode(this, item); + + public IEnumerator GetEnumerator() + { + TSource[] array = new TSource[Count]; + int index = Count; + for (SingleLinkedNode node = this; node != null; node = node.Linked) + { + --index; + array[index] = node.Item; + } + + Debug.Assert(index == 0); + return ((IEnumerable)array).GetEnumerator(); + } + } + + private class AppendPrependN : AppendPrependIterator + { + private readonly SingleLinkedNode _prepended; + private readonly SingleLinkedNode _appended; + private SingleLinkedNode _node; + + public AppendPrependN(IEnumerable source, SingleLinkedNode prepended, SingleLinkedNode appended) + : base(source) + { + Debug.Assert(prepended != null || appended != null); + _prepended = prepended; + _appended = appended; + } + + public override Iterator Clone() => new AppendPrependN(_source, _prepended, _appended); + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _node = _prepended; + _state = 2; + goto case 2; + case 2: + if (_node != null) + { + _current = _node.Item; + _node = _node.Linked; + return true; + } + + GetSourceEnumerator(); + _state = 3; + goto case 3; + case 3: + if (LoadFromEnumerator()) + { + return true; + } + + if (_appended == null) + { + return false; + } + + _enumerator = _appended.GetEnumerator(); + _state = 4; + goto case 4; + case 4: + return LoadFromEnumerator(); + } + + Dispose(); + return false; + } + + public override AppendPrependIterator Append(TSource item) + { + return new AppendPrependN(_source, _prepended, _appended != null ? _appended.Add(item) : new SingleLinkedNode(item)); + } + + public override AppendPrependIterator Prepend(TSource item) + { + return new AppendPrependN(_source, _prepended != null ? _prepended.Add(item) : new SingleLinkedNode(item), _appended); + } + + public override TSource[] ToArray() + { + int count = GetCount(onlyIfCheap: true); + if (count == -1) + { + return EnumerableHelpers.ToArray(this); + } + + TSource[] array = new TSource[count]; + int index = 0; + for (SingleLinkedNode node = _prepended; node != null; node = node.Linked) + { + array[index] = node.Item; + ++index; + } + + ICollection sourceCollection = _source as ICollection; + if (sourceCollection != null) + { + sourceCollection.CopyTo(array, index); + } + else + { + foreach (TSource item in _source) + { + array[index] = item; + ++index; + } + } + + index = array.Length; + for (SingleLinkedNode node = _appended; node != null; node = node.Linked) + { + --index; + array[index] = node.Item; + } + + return array; + } + + public override List ToList() + { + int count = GetCount(onlyIfCheap: true); + List list = count == -1 ? new List() : new List(count); + for (SingleLinkedNode node = _prepended; node != null; node = node.Linked) + { + list.Add(node.Item); + } + + list.AddRange(_source); + if (_appended != null) + { + IEnumerator e = _appended.GetEnumerator(); + while (e.MoveNext()) + { + list.Add(e.Current); + } + } + + return list; + } + + public override int GetCount(bool onlyIfCheap) + { + IIListProvider listProv = _source as IIListProvider; + if (listProv != null) + { + int count = listProv.GetCount(onlyIfCheap); + return count == -1 ? -1 : count + (_appended == null ? 0 : _appended.Count) + (_prepended == null ? 0 : _prepended.Count); + } + + return !onlyIfCheap || _source is ICollection ? _source.Count() + (_appended == null ? 0 : _appended.Count) + (_prepended == null ? 0 : _prepended.Count) : -1; + } + } + } +} diff --git a/src/System.Linq/src/System/Linq/Concatenate.cs b/src/System.Linq/src/System/Linq/Concatenate.cs index 40bd1f5f32e8..319aa69af133 100644 --- a/src/System.Linq/src/System/Linq/Concatenate.cs +++ b/src/System.Linq/src/System/Linq/Concatenate.cs @@ -9,46 +9,6 @@ namespace System.Linq { public static partial class Enumerable { - public static IEnumerable Append(this IEnumerable source, TSource element) - { - if (source == null) - { - throw Error.ArgumentNull(nameof(source)); - } - - return AppendIterator(source, element); - } - - private static IEnumerable AppendIterator(IEnumerable source, TSource element) - { - foreach (TSource e1 in source) - { - yield return e1; - } - - yield return element; - } - - public static IEnumerable Prepend(this IEnumerable source, TSource element) - { - if (source == null) - { - throw Error.ArgumentNull(nameof(source)); - } - - return PrependIterator(source, element); - } - - private static IEnumerable PrependIterator(IEnumerable source, TSource element) - { - yield return element; - - foreach (TSource e1 in source) - { - yield return e1; - } - } - public static IEnumerable Concat(this IEnumerable first, IEnumerable second) { if (first == null) diff --git a/src/System.Linq/tests/AppendPrependTests.cs b/src/System.Linq/tests/AppendPrependTests.cs index eaa2663e8e32..0f22a286d95d 100644 --- a/src/System.Linq/tests/AppendPrependTests.cs +++ b/src/System.Linq/tests/AppendPrependTests.cs @@ -15,7 +15,7 @@ public void SameResultsRepeatCallsIntQueryAppend() { var q1 = from x1 in new int?[] { 2, 3, null, 2, null, 4, 5 } select x1; - + Assert.Equal(q1.Append(42), q1.Append(42)); Assert.Equal(q1.Append(42), q1.Concat(new int?[] { 42 })); } @@ -50,6 +50,15 @@ public void SameResultsRepeatCallsStringQueryPrepend() Assert.Equal(q1.Prepend("hi"), (new string[] { "hi" }).Concat(q1)); } + [Fact] + public void RepeatIteration() + { + var q = Enumerable.Range(3, 4).Append(12); + Assert.Equal(q, q); + q = q.Append(14); + Assert.Equal(q, q); + } + [Fact] public void EmptyAppend() { @@ -93,6 +102,15 @@ public void ForcedToEnumeratorDoesntEnumerateAppend() Assert.False(en != null && en.MoveNext()); } + [Fact] + public void ForcedToEnumeratorDoesntEnumerateMultipleAppendsAndPrepends() + { + var iterator = NumberRangeGuaranteedNotCollectionType(0, 3).Append(4).Append(5).Prepend(-1).Prepend(-2); + // Don't insist on this behaviour, but check it's correct if it happens + var en = iterator as IEnumerator; + Assert.False(en != null && en.MoveNext()); + } + [Fact] public void SourceNull() { @@ -115,5 +133,127 @@ public void Combined() Assert.Equal(v2.ToArray(), "dcba".ToArray()); } + + [Fact] + public void AppendCombinations() + { + var source = Enumerable.Range(0, 3).Append(3).Append(4); + var app0a = source.Append(5); + var app0b = source.Append(6); + var app1aa = app0a.Append(7); + var app1ab = app0a.Append(8); + var app1ba = app0b.Append(9); + var app1bb = app0b.Append(10); + + Assert.Equal(new[] { 0, 1, 2, 3, 4, 5 }, app0a); + Assert.Equal(new[] { 0, 1, 2, 3, 4, 6 }, app0b); + Assert.Equal(new[] { 0, 1, 2, 3, 4, 5, 7 }, app1aa); + Assert.Equal(new[] { 0, 1, 2, 3, 4, 5, 8 }, app1ab); + Assert.Equal(new[] { 0, 1, 2, 3, 4, 6, 9 }, app1ba); + Assert.Equal(new[] { 0, 1, 2, 3, 4, 6, 10 }, app1bb); + } + + [Fact] + public void PrependCombinations() + { + var source = Enumerable.Range(2, 2).Prepend(1).Prepend(0); + var pre0a = source.Prepend(5); + var pre0b = source.Prepend(6); + var pre1aa = pre0a.Prepend(7); + var pre1ab = pre0a.Prepend(8); + var pre1ba = pre0b.Prepend(9); + var pre1bb = pre0b.Prepend(10); + + Assert.Equal(new[] { 5, 0, 1, 2, 3 }, pre0a); + Assert.Equal(new[] { 6, 0, 1, 2, 3 }, pre0b); + Assert.Equal(new[] { 7, 5, 0, 1, 2, 3 }, pre1aa); + Assert.Equal(new[] { 8, 5, 0, 1, 2, 3 }, pre1ab); + Assert.Equal(new[] { 9, 6, 0, 1, 2, 3 }, pre1ba); + Assert.Equal(new[] { 10, 6, 0, 1, 2, 3 }, pre1bb); + } + + [Fact] + public void Append1ToArrayToList() + { + var source = Enumerable.Range(0, 2).Append(2); + Assert.Equal(Enumerable.Range(0, 3), source.ToList()); + Assert.Equal(Enumerable.Range(0, 3), source.ToArray()); + + source = Enumerable.Range(0, 2).ToList().Append(2); + Assert.Equal(Enumerable.Range(0, 3), source.ToList()); + Assert.Equal(Enumerable.Range(0, 3), source.ToArray()); + + source = NumberRangeGuaranteedNotCollectionType(0, 2).Append(2); + Assert.Equal(Enumerable.Range(0, 3), source.ToList()); + Assert.Equal(Enumerable.Range(0, 3), source.ToArray()); + } + + [Fact] + public void Prepend1ToArrayToList() + { + var source = Enumerable.Range(1, 2).Prepend(0); + Assert.Equal(Enumerable.Range(0, 3), source.ToList()); + Assert.Equal(Enumerable.Range(0, 3), source.ToArray()); + + source = Enumerable.Range(1, 2).ToList().Prepend(0); + Assert.Equal(Enumerable.Range(0, 3), source.ToList()); + Assert.Equal(Enumerable.Range(0, 3), source.ToArray()); + + source = NumberRangeGuaranteedNotCollectionType(1, 2).Prepend(0); + Assert.Equal(Enumerable.Range(0, 3), source.ToList()); + Assert.Equal(Enumerable.Range(0, 3), source.ToArray()); + } + + [Fact] + public void AppendNToArrayToList() + { + var source = Enumerable.Range(0, 2).Append(2).Append(3); + Assert.Equal(Enumerable.Range(0, 4), source.ToList()); + Assert.Equal(Enumerable.Range(0, 4), source.ToArray()); + + source = Enumerable.Range(0, 2).ToList().Append(2).Append(3); + Assert.Equal(Enumerable.Range(0, 4), source.ToList()); + Assert.Equal(Enumerable.Range(0, 4), source.ToArray()); + + source = NumberRangeGuaranteedNotCollectionType(0, 2).Append(2).Append(3); + Assert.Equal(Enumerable.Range(0, 4), source.ToList()); + Assert.Equal(Enumerable.Range(0, 4), source.ToArray()); + } + + [Fact] + public void PrependNToArrayToList() + { + var source = Enumerable.Range(2, 2).Prepend(1).Prepend(0); + Assert.Equal(Enumerable.Range(0, 4), source.ToList()); + Assert.Equal(Enumerable.Range(0, 4), source.ToArray()); + + source = Enumerable.Range(2, 2).ToList().Prepend(1).Prepend(0); + Assert.Equal(Enumerable.Range(0, 4), source.ToList()); + Assert.Equal(Enumerable.Range(0, 4), source.ToArray()); + + source = NumberRangeGuaranteedNotCollectionType(2, 2).Prepend(1).Prepend(0); + Assert.Equal(Enumerable.Range(0, 4), source.ToList()); + Assert.Equal(Enumerable.Range(0, 4), source.ToArray()); + } + + [Fact] + public void AppendPrependToArrayToList() + { + var source = Enumerable.Range(2, 2).Prepend(1).Append(4).Prepend(0).Append(5); + Assert.Equal(Enumerable.Range(0, 6), source.ToList()); + Assert.Equal(Enumerable.Range(0, 6), source.ToArray()); + + source = Enumerable.Range(2, 2).ToList().Prepend(1).Append(4).Prepend(0).Append(5); + Assert.Equal(Enumerable.Range(0, 6), source.ToList()); + Assert.Equal(Enumerable.Range(0, 6), source.ToArray()); + + source = NumberRangeGuaranteedNotCollectionType(2, 2).Append(4).Prepend(1).Append(5).Prepend(0); + Assert.Equal(Enumerable.Range(0, 6), source.ToList()); + Assert.Equal(Enumerable.Range(0, 6), source.ToArray()); + + source = NumberRangeGuaranteedNotCollectionType(2, 2).Prepend(1).Prepend(0).Append(4).Append(5); + Assert.Equal(Enumerable.Range(0, 6), source.ToList()); + Assert.Equal(Enumerable.Range(0, 6), source.ToArray()); + } } }