diff --git a/src/libraries/Common/src/System/Text/ValueUtf8Converter.cs b/src/libraries/Common/src/System/Text/ValueUtf8Converter.cs index 6765ecbaad6a6d..f589aec0cac222 100644 --- a/src/libraries/Common/src/System/Text/ValueUtf8Converter.cs +++ b/src/libraries/Common/src/System/Text/ValueUtf8Converter.cs @@ -23,7 +23,7 @@ public ValueUtf8Converter(Span initialBuffer) public Span ConvertAndTerminateString(ReadOnlySpan value) { - int maxSize = Encoding.UTF8.GetMaxByteCount(value.Length) + 1; + int maxSize = checked(Encoding.UTF8.GetMaxByteCount(value.Length) + 1); if (_bytes.Length < maxSize) { Dispose(); diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs index bbc79c19effdc2..907fa5c9b7ed3c 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.cs @@ -993,7 +993,7 @@ private static unsafe IntPtr StringToHGlobalUTF8(string? s) int nb = Encoding.UTF8.GetMaxByteCount(s.Length); - IntPtr ptr = AllocHGlobal(nb + 1); + IntPtr ptr = AllocHGlobal(checked(nb + 1)); int nbWritten; byte* pbMem = (byte*)ptr; @@ -1040,7 +1040,7 @@ public static unsafe IntPtr StringToCoTaskMemUTF8(string? s) int nb = Encoding.UTF8.GetMaxByteCount(s.Length); - IntPtr ptr = AllocCoTaskMem(nb + 1); + IntPtr ptr = AllocCoTaskMem(checked(nb + 1)); int nbWritten; byte* pbMem = (byte*)ptr; diff --git a/src/libraries/System.Private.CoreLib/src/System/Text/UTF8Encoding.Sealed.cs b/src/libraries/System.Private.CoreLib/src/System/Text/UTF8Encoding.Sealed.cs index 8a0041028df210..c76c5782d4940c 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Text/UTF8Encoding.Sealed.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Text/UTF8Encoding.Sealed.cs @@ -69,6 +69,50 @@ private unsafe byte[] GetBytesForSmallInput(string s) return new Span(ref *pDestination, bytesWritten).ToArray(); // this overload of Span ctor doesn't validate length } + public override int GetMaxByteCount(int charCount) + { + // This is a specialization of UTF8Encoding.GetMaxByteCount + // with the assumption that the default replacement fallback + // emits 3 fallback bytes ([ EF BF BD ] = '\uFFFD') per + // malformed input char in the worst case. + + if ((uint)charCount > (int.MaxValue / MaxUtf8BytesPerChar) - 1) + { + // Move the throw out of the hot path to allow for inlining. + ThrowArgumentException(charCount); + static void ThrowArgumentException(int charCount) + { + throw new ArgumentOutOfRangeException( + paramName: nameof(charCount), + message: (charCount < 0) ? SR.ArgumentOutOfRange_NeedNonNegNum : SR.ArgumentOutOfRange_GetByteCountOverflow); + } + } + + return (charCount * MaxUtf8BytesPerChar) + MaxUtf8BytesPerChar; + } + + public override int GetMaxCharCount(int byteCount) + { + // This is a specialization of UTF8Encoding.GetMaxCharCount + // with the assumption that the default replacement fallback + // emits one fallback char ('\uFFFD') per malformed input + // byte in the worst case. + + if ((uint)byteCount > int.MaxValue - 1) + { + // Move the throw out of the hot path to allow for inlining. + ThrowArgumentException(byteCount); + static void ThrowArgumentException(int byteCount) + { + throw new ArgumentOutOfRangeException( + paramName: nameof(byteCount), + message: (byteCount < 0) ? SR.ArgumentOutOfRange_NeedNonNegNum : SR.ArgumentOutOfRange_GetCharCountOverflow); + } + } + + return byteCount + 1; + } + public override string GetString(byte[] bytes) { // This method is short and can be inlined, meaning that the null check below diff --git a/src/libraries/System.Private.CoreLib/src/System/Text/UTF8Encoding.cs b/src/libraries/System.Private.CoreLib/src/System/Text/UTF8Encoding.cs index 88ec661423ee58..7a2c577479d266 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Text/UTF8Encoding.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Text/UTF8Encoding.cs @@ -761,8 +761,20 @@ public override int GetMaxByteCount(int charCount) throw new ArgumentOutOfRangeException(nameof(charCount), SR.ArgumentOutOfRange_NeedNonNegNum); - // Characters would be # of characters + 1 in case left over high surrogate is ? * max fallback - long byteCount = (long)charCount + 1; + // GetMaxByteCount assumes that the caller might have a stateful Encoder instance. If the + // Encoder instance already has a captured high surrogate, then one of two things will + // happen: + // + // - The next char is a low surrogate, at which point the two chars together result in 4 + // UTF-8 bytes in the output; or + // - The next char is not a low surrogate (or the input reaches EOF), at which point the + // standalone captured surrogate will go through the fallback routine. + // + // The second case is the worst-case scenario for expansion, so it's what we use for any + // pessimistic "max byte count" calculation: assume there's a captured surrogate and that + // it must fall back. + + long byteCount = (long)charCount + 1; // +1 to account for captured surrogate, per above if (EncoderFallback.MaxCharCount > 1) byteCount *= EncoderFallback.MaxCharCount; @@ -782,8 +794,23 @@ public override int GetMaxCharCount(int byteCount) throw new ArgumentOutOfRangeException(nameof(byteCount), SR.ArgumentOutOfRange_NeedNonNegNum); - // Figure out our length, 1 char per input byte + 1 char if 1st byte is last byte of 4 byte surrogate pair - long charCount = ((long)byteCount + 1); + // GetMaxCharCount assumes that the caller might have a stateful Decoder instance. If the + // Decoder instance already has a captured partial UTF-8 subsequence, then one of two + // thngs will happen: + // + // - The next byte(s) won't complete the subsequence but will instead be consumed into + // the Decoder's internal state, resulting in no character output; or + // - The next byte(s) will complete the subsequence, and the previously captured + // subsequence and the next byte(s) will result in 1 - 2 chars output; or + // - The captured subsequence will be treated as a singular ill-formed subsequence, at + // which point the captured subsequence will go through the fallback routine. + // (See The Unicode Standard, Sec. 3.9 for more information on this.) + // + // The third case is the worst-case scenario for expansion, since it means 0 bytes of + // new input could cause any existing captured state to expand via fallback. So it's + // what we'll use for any pessimistic "max char count" calculation. + + long charCount = ((long)byteCount + 1); // +1 to account for captured subsequence, as above // Non-shortest form would fall back, so get max count from fallback. // So would 11... followed by 11..., so you could fall back every byte diff --git a/src/libraries/System.Text.Encoding/tests/UTF8Encoding/UTF8EncodingGetMaxByteCount.cs b/src/libraries/System.Text.Encoding/tests/UTF8Encoding/UTF8EncodingGetMaxByteCount.cs index 6c4c2da47acb8f..7379fa4ae8b6b6 100644 --- a/src/libraries/System.Text.Encoding/tests/UTF8Encoding/UTF8EncodingGetMaxByteCount.cs +++ b/src/libraries/System.Text.Encoding/tests/UTF8Encoding/UTF8EncodingGetMaxByteCount.cs @@ -15,10 +15,24 @@ public class UTF8EncodingGetMaxByteCount public void GetMaxByteCount(int charCount) { int expected = (charCount + 1) * 3; + Assert.Equal(expected, Encoding.UTF8.GetMaxByteCount(charCount)); Assert.Equal(expected, new UTF8Encoding(true, true).GetMaxByteCount(charCount)); Assert.Equal(expected, new UTF8Encoding(true, false).GetMaxByteCount(charCount)); Assert.Equal(expected, new UTF8Encoding(false, true).GetMaxByteCount(charCount)); Assert.Equal(expected, new UTF8Encoding(false, false).GetMaxByteCount(charCount)); } + + [Theory] + [InlineData(-1)] + [InlineData(int.MinValue)] + [InlineData(-1_000_000_000)] + [InlineData(-1_300_000_000)] // yields positive result when *3 + [InlineData(int.MaxValue / 3)] + [InlineData(int.MaxValue)] + public void GetMaxByteCount_NegativeTests(int charCount) + { + Assert.Throws(nameof(charCount), () => Encoding.UTF8.GetMaxByteCount(charCount)); + Assert.Throws(nameof(charCount), () => new UTF8Encoding().GetMaxByteCount(charCount)); + } } } diff --git a/src/libraries/System.Text.Encoding/tests/UTF8Encoding/UTF8EncodingGetMaxCharCount.cs b/src/libraries/System.Text.Encoding/tests/UTF8Encoding/UTF8EncodingGetMaxCharCount.cs index 7450d4c698e56e..3b043ae2859bfc 100644 --- a/src/libraries/System.Text.Encoding/tests/UTF8Encoding/UTF8EncodingGetMaxCharCount.cs +++ b/src/libraries/System.Text.Encoding/tests/UTF8Encoding/UTF8EncodingGetMaxCharCount.cs @@ -15,10 +15,22 @@ public class UTF8EncodingGetMaxCharCount public void GetMaxCharCount(int byteCount) { int expected = byteCount + 1; + Assert.Equal(expected, Encoding.UTF8.GetMaxCharCount(byteCount)); Assert.Equal(expected, new UTF8Encoding(true, true).GetMaxCharCount(byteCount)); Assert.Equal(expected, new UTF8Encoding(true, false).GetMaxCharCount(byteCount)); Assert.Equal(expected, new UTF8Encoding(false, true).GetMaxCharCount(byteCount)); Assert.Equal(expected, new UTF8Encoding(false, false).GetMaxCharCount(byteCount)); } + + [Theory] + [InlineData(-1)] + [InlineData(int.MinValue)] + [InlineData(-1_000_000_000)] + [InlineData(int.MaxValue)] + public void GetMaxCharCount_NegativeTests(int byteCount) + { + Assert.Throws(nameof(byteCount), () => Encoding.UTF8.GetMaxCharCount(byteCount)); + Assert.Throws(nameof(byteCount), () => new UTF8Encoding().GetMaxCharCount(byteCount)); + } } }