diff --git a/src/libraries/System.Linq/src/System/Linq/AppendPrepend.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/AppendPrepend.SpeedOpt.cs index d549f093260543..b59f063cbebed5 100644 --- a/src/libraries/System.Linq/src/System/Linq/AppendPrepend.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/AppendPrepend.SpeedOpt.cs @@ -124,10 +124,10 @@ public override int GetCount(bool onlyIfCheap) if (_source is Iterator iterator) { int count = iterator.GetCount(onlyIfCheap); - return count == -1 ? -1 : count + 1; + return count == -1 ? -1 : checked(count + 1); } - return !onlyIfCheap || _source is ICollection ? _source.Count() + 1 : -1; + return !onlyIfCheap || _source is ICollection ? checked(_source.Count() + 1) : -1; } public override TSource? TryGetFirst(out bool found) @@ -277,10 +277,10 @@ public override int GetCount(bool onlyIfCheap) if (_source is Iterator iterator) { int count = iterator.GetCount(onlyIfCheap); - return count == -1 ? -1 : count + _appendCount + _prependCount; + return count == -1 ? -1 : checked(count + _appendCount + _prependCount); } - return !onlyIfCheap || _source is ICollection ? _source.Count() + _appendCount + _prependCount : -1; + return !onlyIfCheap || _source is ICollection ? checked(_source.Count() + _appendCount + _prependCount) : -1; } public override bool Contains(TSource value) diff --git a/src/libraries/System.Linq/tests/AppendPrependTests.cs b/src/libraries/System.Linq/tests/AppendPrependTests.cs index 3a7a2e5ed699a3..d54846b2055bfa 100644 --- a/src/libraries/System.Linq/tests/AppendPrependTests.cs +++ b/src/libraries/System.Linq/tests/AppendPrependTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections; using System.Collections.Generic; using Xunit; @@ -9,6 +10,25 @@ namespace System.Linq.Tests { public class AppendPrependTests : EnumerableTests { + // Mock collection for testing overflow without allocating memory + private sealed class MockCollection : ICollection + { + private readonly int _count; + + public MockCollection(int count) => _count = count; + + public int Count => _count; + public bool IsReadOnly => true; + + public void Add(T item) => throw new NotSupportedException(); + public void Clear() => throw new NotSupportedException(); + public bool Contains(T item) => false; + public void CopyTo(T[] array, int arrayIndex) { } + public bool Remove(T item) => throw new NotSupportedException(); + public IEnumerator GetEnumerator() => Enumerable.Empty().GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + [Fact] public void SameResultsRepeatCallsIntQueryAppend() { @@ -263,5 +283,59 @@ public void AppendPrepend_First_Last_ElementAt() Assert.Equal(84, NumberRangeGuaranteedNotCollectionType(42, 1).Append(84).ElementAt(1)); Assert.Equal(84, NumberRangeGuaranteedNotCollectionType(84, 1).Prepend(42).ElementAt(1)); } + + [Fact] + public void AppendOverflowCount() + { + // AppendPrepend1Iterator overflow when source has GetCount optimization + var source = Enumerable.Repeat(0, int.MaxValue); + var appended = source.Append(1); + Assert.Throws(() => appended.Count()); + } + + [Fact] + public void PrependOverflowCount() + { + // AppendPrepend1Iterator overflow when source has GetCount optimization + var source = Enumerable.Repeat(0, int.MaxValue); + var prepended = source.Prepend(1); + Assert.Throws(() => prepended.Count()); + } + + [Fact] + public void AppendPrependNOverflowCount() + { + // AppendPrependN overflow when source has GetCount optimization + var source = Enumerable.Repeat(0, int.MaxValue); + var result = source.Append(1).Prepend(2); + Assert.Throws(() => result.Count()); + } + + [Fact] + public void AppendOverflowCountWithICollection() + { + // AppendPrepend1Iterator overflow when source is ICollection + var source = new MockCollection(int.MaxValue); + var appended = source.Append(1); + Assert.Throws(() => appended.Count()); + } + + [Fact] + public void PrependOverflowCountWithICollection() + { + // AppendPrepend1Iterator overflow when source is ICollection + var source = new MockCollection(int.MaxValue); + var prepended = source.Prepend(1); + Assert.Throws(() => prepended.Count()); + } + + [Fact] + public void AppendPrependNOverflowCountWithICollection() + { + // AppendPrependN overflow when source is ICollection + var source = new MockCollection(int.MaxValue); + var result = source.Append(1).Prepend(2); + Assert.Throws(() => result.Count()); + } } }