Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 310 additions & 0 deletions src/System.Linq/src/System/Linq/Partition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -389,5 +389,315 @@ public int GetCount(bool onlyIfCheap)
return Count;
}
}

private sealed class EnumerablePartition<TSource> : Iterator<TSource>, IPartition<TSource>
{
private readonly IEnumerable<TSource> _source;
private readonly int _minIndexInclusive;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting choice to represent bounds.
Why not uint _start and uint _length ? I think the math might get simpler.
You can still use uint.MaxValue as an "unlimited" marker.

Copy link
Copy Markdown
Contributor Author

@jamesqo jamesqo Nov 24, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VSadov I originally had something like that, actually, using ints. But since all of the other partitions used _minIndexInclusive and _maxIndexInclusive I decided to follow their lead.

Why not uint _start and uint _length? I think the math might get simpler.

I used ints for the fields, since it's more idiomatic. If we know the length (Take has been called atl. once) then the count will fit in an int anyway. Also, it'd still be possible that _start could overflow during a chained Skip / Take & we'd have to wrap ourselves in another iterator.

private readonly int _maxIndexInclusive; // -1 if we want everything past _minIndexInclusive.
// If this is -1, it's impossible to set a limit on the count.
private IEnumerator<TSource> _enumerator;

internal EnumerablePartition(IEnumerable<TSource> source, int minIndexInclusive, int maxIndexInclusive)
{
Debug.Assert(source != null);
Debug.Assert(!(source is IList<TSource>), $"The caller needs to check for {nameof(IList<TSource>)}.");
Debug.Assert(minIndexInclusive >= 0);
Debug.Assert(maxIndexInclusive >= -1);
// Note that although maxIndexInclusive can't grow, it can still be int.MaxValue.
// We support partitioning enumerables with > 2B elements. For example, e.Skip(1).Take(int.MaxValue) should work.
// But if it is int.MaxValue, then minIndexInclusive must != 0. Otherwise, our count may overflow.
Debug.Assert(maxIndexInclusive == -1 || (maxIndexInclusive - minIndexInclusive < int.MaxValue), $"{nameof(Limit)} will overflow!");
Debug.Assert(maxIndexInclusive == -1 || minIndexInclusive <= maxIndexInclusive);

_source = source;
_minIndexInclusive = minIndexInclusive;
_maxIndexInclusive = maxIndexInclusive;
}

// If this is true (e.g. at least one Take call was made), then we have an upper bound
// on how many elements we can have.
private bool HasLimit => _maxIndexInclusive != -1;

private int Limit => (_maxIndexInclusive + 1) - _minIndexInclusive; // This is that upper bound.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value is only meaningful if HasLimit is true, right? Should we add a Debug.Assert(HasLimit) to the getter?


public override Iterator<TSource> Clone()
{
return new EnumerablePartition<TSource>(_source, _minIndexInclusive, _maxIndexInclusive);
}

public int GetCount(bool onlyIfCheap)
{
if (onlyIfCheap)
{
return -1;
Copy link
Copy Markdown
Member

@stephentoub stephentoub Nov 30, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: looking at a code coverage report, this is never hit. Missing a test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add one.

Copy link
Copy Markdown
Contributor Author

@jamesqo jamesqo Dec 1, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stephentoub Made the change. Something to note: While adding coverage for this, I also realized I had forgotten to override Select in this type. (Select checks for an Iterator first, then an IPartition, and the default implementation of Iterator.Select is to return a SelectEnumerableIterator rather than a SelectIPartitionIterator. Thus it's necessary to override Select here to return the latter for better perf.) So thank you for pointing that out.

}

if (!HasLimit)
{
// If HasLimit is false, we contain everything past _minIndexInclusive.
// Therefore, we have to iterate the whole enumerable.
return Math.Max(_source.Count() - _minIndexInclusive, 0);
}

using (IEnumerator<TSource> en = _source.GetEnumerator())
{
// We only want to iterate up to _maxIndexInclusive + 1.
// Past that, we know the enumerable will be able to fit this partition,
// so the count will just be _maxIndexInclusive + 1 - _minIndexInclusive.

// Note that it is possible for _maxIndexInclusive to be int.MaxValue here,
// so + 1 may result in signed integer overflow. We need to handle this.
// At the same time, however, we are guaranteed that our max count can fit
// in an int because if that is true, then _minIndexInclusive must > 0.

uint count = SkipAndCount((uint)_maxIndexInclusive + 1, en);
Debug.Assert(count != (uint)int.MaxValue + 1 || _minIndexInclusive > 0, "Our return value will be incorrect.");
return Math.Max((int)count - _minIndexInclusive, 0);
}

}

public override bool MoveNext()
{
// Cases where GetEnumerator has not been called or Dispose has already
// been called need to be handled explicitly, due to the default: clause.
int taken = _state - 3;
if (taken < -2)
{
Dispose();
return false;
}

switch (_state)
{
case 1:
_enumerator = _source.GetEnumerator();
_state = 2;
goto case 2;
case 2:
if (!SkipBeforeFirst(_enumerator))
{
// Reached the end before we finished skipping.
break;
}

_state = 3;
goto default;
default:
if ((!HasLimit || taken < Limit) && _enumerator.MoveNext())
{
if (HasLimit)
{
// If we are taking an unknown number of elements, it's important not to increment _state.
// _state - 3 may eventually end up overflowing & we'll hit the Dispose branch even though
Copy link
Copy Markdown
Member

@stephentoub stephentoub Nov 30, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How expensive would it be to have an outerloop test that overflowed to test this condition? Probably prohibitive?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it would be 2B * 3 virtual method calls (MoveNext, Current, and MoveNext on the iterator) plus all of the other logic, so yes, probably too expensive for a test.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. In a few places in our tests we do use reflection to put the object into a state that would make it hit such cases quickly, e.g.
https://github.com/dotnet/corefx/blob/master/src/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/HandleCollectorTests.cs#L63
Would it be worth doing something like that here? Or we don't think there's enough risk to warrant that kind of fragility?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worth doing something like that here? Or we don't think there's enough risk to warrant that kind of fragility?

It is low risk, but at the same time I think it is unlikely the field name of _state will change. I will add a test.

// we haven't finished enumerating.
_state++;
}
_current = _enumerator.Current;
return true;
}

break;
}

Dispose();
return false;
}

public override IEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
{
return new SelectIPartitionIterator<TSource, TResult>(this, selector);
}

public IPartition<TSource> Skip(int count)
{
int minIndex = _minIndexInclusive + count;
if (!HasLimit)
{
if (minIndex < 0)
{
// If we don't know our max count and minIndex can no longer fit in a positive int,
// then we will need to wrap ourselves in another iterator.
// This can happen, for example, during e.Skip(int.MaxValue).Skip(int.MaxValue).
return new EnumerablePartition<TSource>(this, count, -1);
}
}
else if ((uint)minIndex > (uint)_maxIndexInclusive)
{
// If minIndex overflows and we have an upper bound, we will go down this branch.
// We know our upper bound must be smaller than minIndex, since our upper bound fits in an int.
// This branch should not be taken if we don't have a bound.
return EmptyPartition<TSource>.Instance;
}

Debug.Assert(minIndex >= 0, $"We should have taken care of all cases when {nameof(minIndex)} overflows.");
return new EnumerablePartition<TSource>(_source, minIndex, _maxIndexInclusive);
}

public IPartition<TSource> Take(int count)
{
int maxIndex = _minIndexInclusive + count - 1;
if (!HasLimit)
{
if (maxIndex < 0)
{
// If we don't know our max count and maxIndex can no longer fit in a positive int,
// then we will need to wrap ourselves in another iterator.
// Note that although maxIndex may be too large, the difference between it and
// _minIndexInclusive (which is count - 1) must fit in an int.
// Example: e.Skip(50).Take(int.MaxValue).

return new EnumerablePartition<TSource>(this, 0, count - 1);
}
}
else if ((uint)maxIndex >= (uint)_maxIndexInclusive)
{
// If we don't know our max count, we can't go down this branch.
// It's always possible for us to contain more than count items, as the rest
// of the enumerable past _minIndexInclusive can be arbitrarily long.
return this;
}

Debug.Assert(maxIndex >= 0, $"We should have taken care of all cases when {nameof(maxIndex)} overflows.");
return new EnumerablePartition<TSource>(_source, _minIndexInclusive, maxIndex);
}

public TSource TryGetElementAt(int index, out bool found)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have avoided optimizing the TryGet* functions if I could have, since they take a lot of code and are not as valuable to optimize as the other functions. But, currently they come with implementing IPartition.

{
// If the index is negative or >= our max count, return early.
if (index >= 0 && (!HasLimit || index < Limit))
{
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
Debug.Assert(_minIndexInclusive + index >= 0, $"Adding {nameof(index)} caused {nameof(_minIndexInclusive)} to overflow.");

if (SkipBefore(_minIndexInclusive + index, en) && en.MoveNext())
{
found = true;
return en.Current;
}
}
}

found = false;
return default(TSource);
}

public TSource TryGetFirst(out bool found)
{
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
if (SkipBeforeFirst(en) && en.MoveNext())
{
found = true;
return en.Current;
}
}

found = false;
return default(TSource);
}

public TSource TryGetLast(out bool found)
{
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
if (SkipBeforeFirst(en) && en.MoveNext())
{
int remaining = Limit - 1; // Max number of items left, not counting the current element.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this is why you don't assert HasLimit in Limit's getter.

int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.
TSource result;

do
{
remaining--;
result = en.Current;
}
while (remaining >= comparand && en.MoveNext());

found = true;
return result;
}
}

found = false;
return default(TSource);
}

public TSource[] ToArray()
{
using (IEnumerator<TSource> en = _source.GetEnumerator())
{
if (SkipBeforeFirst(en) && en.MoveNext())
{
int remaining = Limit - 1; // Max number of items left, not counting the current element.
int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.

int maxCapacity = HasLimit ? Limit : int.MaxValue;
var builder = new LargeArrayBuilder<TSource>(maxCapacity);

do
{
remaining--;
builder.Add(en.Current);
}
while (remaining >= comparand && en.MoveNext());
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're dishing out an extra unnecessary write/branch per iteration for the loop if HasLimit is false (e.g. if only Skip has been called, and not Take). It could be specialized like this:

if (!HasLimit)
{
    do { builder.Add(en.Current); } while (en.MoveNext());
}
else
{
    do { remaining--; builder.Add(en.Current); } while (remaining >= 0 && en.MoveNext());
}

I figured it wasn't worth the extra code though.


return builder.ToArray();
}
}

return Array.Empty<TSource>();
}

public List<TSource> ToList()
{
var list = new List<TSource>();

using (IEnumerator<TSource> en = _source.GetEnumerator())
{
if (SkipBeforeFirst(en) && en.MoveNext())
{
int remaining = Limit - 1; // Max number of items left, not counting the current element.
int comparand = HasLimit ? 0 : int.MinValue; // If we don't have an upper bound, have the comparison always return true.

do
{
remaining--;
list.Add(en.Current);
}
while (remaining >= comparand && en.MoveNext());
}
}

return list;
}

private bool SkipBeforeFirst(IEnumerator<TSource> en) => SkipBefore(_minIndexInclusive, en);

private static bool SkipBefore(int index, IEnumerator<TSource> en) => SkipAndCount(index, en) == index;

private static int SkipAndCount(int index, IEnumerator<TSource> en)
{
Debug.Assert(index >= 0);
return (int)SkipAndCount((uint)index, en);
}

private static uint SkipAndCount(uint index, IEnumerator<TSource> en)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to add a uint overload of this function to support scenarios where _maxIndexInclusive is exactly int.MaxValue & we call Count(), for example e.Skip(1).Take(int.MaxValue).Count(). We don't lose any validation at callsites other than GetCount though, since the signed wrapper asserts.

{
Debug.Assert(en != null);

for (uint i = 0; i < index; i++)
{
if (!en.MoveNext())
{
return i;
}
}

return index;
}
}
}
}
21 changes: 1 addition & 20 deletions src/System.Linq/src/System/Linq/Skip.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,7 @@ public static IEnumerable<TSource> Skip<TSource>(this IEnumerable<TSource> sourc
return new ListPartition<TSource>(sourceList, count, int.MaxValue);
}

return SkipIterator(source, count);
}

private static IEnumerable<TSource> SkipIterator<TSource>(IEnumerable<TSource> source, int count)
{
using (IEnumerator<TSource> e = source.GetEnumerator())
{
while (count > 0 && e.MoveNext())
{
count--;
}

if (count <= 0)
{
while (e.MoveNext())
{
yield return e.Current;
}
}
}
return new EnumerablePartition<TSource>(source, count, -1);
}

public static IEnumerable<TSource> SkipWhile<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
Expand Down
14 changes: 1 addition & 13 deletions src/System.Linq/src/System/Linq/Take.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,7 @@ public static IEnumerable<TSource> Take<TSource>(this IEnumerable<TSource> sourc
return new ListPartition<TSource>(sourceList, 0, count - 1);
}

return TakeIterator(source, count);
}

private static IEnumerable<TSource> TakeIterator<TSource>(IEnumerable<TSource> source, int count)
{
foreach (TSource element in source)
{
yield return element;
if (--count == 0)
{
break;
}
}
return new EnumerablePartition<TSource>(source, 0, count - 1);
}

public static IEnumerable<TSource> TakeWhile<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
Expand Down
Loading