diff --git a/src/System.Text.Encodings.Web/src/System.Text.Encodings.Web.csproj b/src/System.Text.Encodings.Web/src/System.Text.Encodings.Web.csproj index baebf8fa349b..636feac1183a 100644 --- a/src/System.Text.Encodings.Web/src/System.Text.Encodings.Web.csproj +++ b/src/System.Text.Encodings.Web/src/System.Text.Encodings.Web.csproj @@ -23,6 +23,7 @@ + diff --git a/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/DefaultJavaScriptEncoderBasicLatin.cs b/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/DefaultJavaScriptEncoderBasicLatin.cs index db25c3d6d63f..0fb8c6bb6a4e 100644 --- a/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/DefaultJavaScriptEncoderBasicLatin.cs +++ b/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/DefaultJavaScriptEncoderBasicLatin.cs @@ -73,7 +73,6 @@ public override bool WillEncode(int unicodeScalar) return NeedsEscaping((char)unicodeScalar); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public override unsafe int FindFirstCharacterToEncode(char* text, int textLength) { if (text == null) @@ -81,114 +80,261 @@ public override unsafe int FindFirstCharacterToEncode(char* text, int textLength throw new ArgumentNullException(nameof(text)); } + Debug.Assert(textLength >= 0); + + if (textLength == 0) + { + goto AllAllowed; + } + int idx = 0; + short* ptr = (short*)text; + short* end = ptr + (uint)textLength; + +#if NETCOREAPP + if (Sse2.IsSupported && textLength >= Vector128.Count) + { + goto VectorizedEntry; + } + + Sequential: +#endif + Debug.Assert(textLength > 0 && ptr < end); + + do + { + Debug.Assert(text <= ptr && ptr < (text + textLength)); + + if (NeedsEscaping(*(char*)ptr)) + { + goto Return; + } + + ptr++; + idx++; + } + while (ptr < end); + + AllAllowed: + idx = -1; + + Return: + return idx; #if NETCOREAPP - if (Sse2.IsSupported) + VectorizedEntry: + int index; + short* vectorizedEnd; + + if (textLength >= 2 * Vector128.Count) { - short* startingAddress = (short*)text; - while (textLength - 8 >= idx) + vectorizedEnd = end - 2 * Vector128.Count; + + do { - Debug.Assert(startingAddress >= text && startingAddress <= (text + textLength - 8)); + Debug.Assert(text <= ptr && ptr <= (text + textLength - 2 * Vector128.Count)); - // Load the next 8 characters. - Vector128 sourceValue = Sse2.LoadVector128(startingAddress); + // Load the next 16 characters, combine them to one byte vector. + // Chars that don't cleanly convert to ASCII bytes will get converted (saturated) to + // somewhere in the range [0x7F, 0xFF], which the NeedsEscaping method will detect. + Vector128 sourceValue = Sse2.PackSignedSaturate( + Sse2.LoadVector128(ptr), + Sse2.LoadVector128(ptr + Vector128.Count)); - // Check if any of the 8 characters need to be escaped. - Vector128 mask = Sse2Helper.CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(sourceValue); + // Check if any of the 16 characters need to be escaped. + index = NeedsEscaping(sourceValue); - int index = Sse2.MoveMask(mask.AsByte()); - // If index == 0, that means none of the 8 characters needed to be escaped. + // If index == 0, that means none of the 16 characters needed to be escaped. // TrailingZeroCount is relatively expensive, avoid it if possible. if (index != 0) { - // Found at least one character that needs to be escaped, figure out the index of - // the first one found that needed to be escaped within the 8 characters. - Debug.Assert(index > 0 && index <= 65_535); - int tzc = BitOperations.TrailingZeroCount(index); - Debug.Assert(tzc % 2 == 0 && tzc >= 0 && tzc <= 16); - idx += tzc >> 1; - goto Return; + goto VectorizedFound; } - idx += 8; - startingAddress += 8; - } - // Process the remaining characters. - Debug.Assert(textLength - idx < 8); + ptr += 2 * Vector128.Count; + } + while (ptr <= vectorizedEnd); } -#endif - for (; idx < textLength; idx++) + vectorizedEnd = end - Vector128.Count; + + Vectorized: + // PERF: JIT produces better code for do-while as for a while-loop (no spills) + if (ptr <= vectorizedEnd) { - Debug.Assert((text + idx) <= (text + textLength)); - if (NeedsEscaping(*(text + idx))) + do { - goto Return; + Debug.Assert(text <= ptr && ptr <= (text + textLength - Vector128.Count)); + + // Load the next 8 characters + a dummy known that it must not be escaped. + // Put the dummy second, so it's easier for GetIndexOfFirstNeedToEscape. + Vector128 sourceValue = Sse2.PackSignedSaturate( + Sse2.LoadVector128(ptr), + Vector128.Create((short)'A')); // max. one "iteration", so no need to cache this vector + + index = NeedsEscaping(sourceValue); + + // If index == 0, that means none of the 16 bytes needed to be escaped. + // TrailingZeroCount is relatively expensive, avoid it if possible. + if (index != 0) + { + goto VectorizedFound; + } + + ptr += Vector128.Count; } + while (ptr <= vectorizedEnd); } - idx = -1; // All characters are allowed. + // Process the remaining characters. + Debug.Assert(end - ptr < Vector128.Count); - Return: + // Process the remaining elements vectorized, only if the remaining count + // is above thresholdForRemainingVectorized, otherwise process them sequential. + // Threshold found by testing. + const int thresholdForRemainingVectorized = 5; + if (ptr < end - thresholdForRemainingVectorized) + { + ptr = vectorizedEnd; + goto Vectorized; + } + + idx = CalculateIndex(ptr, text); + + if (idx < textLength) + { + goto Sequential; + } + + goto AllAllowed; + + VectorizedFound: + idx = GetIndexOfFirstNeedToEscape(index); + idx += CalculateIndex(ptr, text); return idx; + + static int CalculateIndex(short* ptr, char* text) + { + // Subtraction with short* results in a idiv, so use byte* and shift + return (int)(((byte*)ptr - (byte*)text) >> 1); + } +#endif } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public override unsafe int FindFirstCharacterToEncodeUtf8(ReadOnlySpan utf8Text) { - fixed (byte* ptr = utf8Text) + fixed (byte* pValue = utf8Text) { + uint textLength = (uint)utf8Text.Length; + + if (textLength == 0) + { + goto AllAllowed; + } + int idx = 0; + byte* ptr = pValue; + byte* end = ptr + textLength; #if NETCOREAPP - if (Sse2.IsSupported) - { - sbyte* startingAddress = (sbyte*)ptr; - while (utf8Text.Length - 16 >= idx) - { - Debug.Assert(startingAddress >= ptr && startingAddress <= (ptr + utf8Text.Length - 16)); - - // Load the next 16 bytes. - Vector128 sourceValue = Sse2.LoadVector128(startingAddress); - - // Check if any of the 16 bytes need to be escaped. - Vector128 mask = Sse2Helper.CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(sourceValue); - - int index = Sse2.MoveMask(mask); - // If index == 0, that means none of the 16 bytes needed to be escaped. - // TrailingZeroCount is relatively expensive, avoid it if possible. - if (index != 0) - { - // Found at least one byte that needs to be escaped, figure out the index of - // the first one found that needed to be escaped within the 16 bytes. - int tzc = BitOperations.TrailingZeroCount(index); - Debug.Assert(tzc >= 0 && tzc <= 16); - idx += tzc; - goto Return; - } - idx += 16; - startingAddress += 16; - } - // Process the remaining bytes. - Debug.Assert(utf8Text.Length - idx < 16); + if (Sse2.IsSupported && textLength >= Vector128.Count) + { + goto Vectorized; } + + Sequential: #endif + Debug.Assert(textLength > 0 && ptr < end); - for (; idx < utf8Text.Length; idx++) + do { - Debug.Assert((ptr + idx) <= (ptr + utf8Text.Length)); - if (NeedsEscaping(*(ptr + idx))) + Debug.Assert(pValue <= ptr && ptr < (pValue + utf8Text.Length)); + + if (NeedsEscaping(*ptr)) { goto Return; } + + ptr++; + idx++; } + while (ptr < end); - idx = -1; // All bytes are allowed. + AllAllowed: + idx = -1; Return: return idx; + +#if NETCOREAPP + Vectorized: + byte* vectorizedEnd = end - Vector128.Count; + int index; + + do + { + Debug.Assert(pValue <= ptr && ptr <= (pValue + utf8Text.Length - Vector128.Count)); + // Load the next 16 bytes + Vector128 sourceValue = Sse2.LoadVector128((sbyte*)ptr); + + index = NeedsEscaping(sourceValue); + + // If index == 0, that means none of the 16 bytes needed to be escaped. + // TrailingZeroCount is relatively expensive, avoid it if possible. + if (index != 0) + { + goto VectorizedFound; + } + + ptr += Vector128.Count; + } + while (ptr <= vectorizedEnd); + + // Process the remaining elements. + Debug.Assert(end - ptr < Vector128.Count); + + // Process the remaining elements vectorized, only if the remaining count + // is above thresholdForRemainingVectorized, otherwise process them sequential. + const int thresholdForRemainingVectorized = 4; + if (ptr < end - thresholdForRemainingVectorized) + { + // PERF: duplicate instead of jumping at the beginning of the previous loop + // otherwise all the static data (vectors) will be re-assigned to registers, + // so they are re-used. + + Debug.Assert(pValue <= vectorizedEnd && vectorizedEnd <= (pValue + utf8Text.Length - Vector128.Count)); + + // Load the last 16 bytes + Vector128 sourceValue = Sse2.LoadVector128((sbyte*)vectorizedEnd); + + index = NeedsEscaping(sourceValue); + if (index != 0) + { + ptr = vectorizedEnd; + goto VectorizedFound; + } + + idx = -1; + goto Return; + } + + idx = CalculateIndex(ptr, pValue); + + if (idx < textLength) + { + goto Sequential; + } + + goto AllAllowed; + + VectorizedFound: + idx = GetIndexOfFirstNeedToEscape(index); + idx += CalculateIndex(ptr, pValue); + return idx; + + static int CalculateIndex(byte* ptr, byte* pValue) => (int)(ptr - pValue); +#endif } } @@ -285,5 +431,35 @@ public override unsafe bool TryEncodeUnicodeScalar(int unicodeScalar, char* buff [MethodImpl(MethodImplOptions.AggressiveInlining)] private static bool NeedsEscaping(char value) => value > LastAsciiCharacter || AllowList[value] == 0; + +#if NETCOREAPP + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int NeedsEscaping(Vector128 sourceValue) + { + Debug.Assert(Sse2.IsSupported); + + // Check if any of the 16 bytes need to be escaped. + Vector128 mask = Ssse3.IsSupported + ? Ssse3Helper.CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(sourceValue) + : Sse2Helper.CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(sourceValue); + + int index = Sse2.MoveMask(mask.AsByte()); + return index; + } + + // PERF: don't manually inline or call this method in NeedsEscaping + // as the resulting asm won't be great + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetIndexOfFirstNeedToEscape(int index) + { + // Found at least one byte that needs to be escaped, figure out the index of + // the first one found that needed to be escaped within the 16 bytes. + Debug.Assert(index > 0 && index <= 65_535); + int tzc = BitOperations.TrailingZeroCount(index); + Debug.Assert(tzc >= 0 && tzc <= 16); + + return tzc; + } +#endif } } diff --git a/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/Sse2Helper.cs b/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/Sse2Helper.cs index f36cce2a16ad..b524bd436df1 100644 --- a/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/Sse2Helper.cs +++ b/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/Sse2Helper.cs @@ -4,7 +4,6 @@ using System.Diagnostics; using System.Runtime.CompilerServices; -using System.Numerics; using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.X86; @@ -51,23 +50,6 @@ public static Vector128 CreateEscapingMask_UnsafeRelaxedJavaScriptEncoder return mask; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static Vector128 CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(Vector128 sourceValue) - { - Debug.Assert(Sse2.IsSupported); - - Vector128 mask = CreateEscapingMask_UnsafeRelaxedJavaScriptEncoder(sourceValue); - - mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_ampersandMaskInt16)); - mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_apostropheMaskInt16)); - mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_plusSignMaskInt16)); - mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_lessThanSignMaskInt16)); - mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_greaterThanSignMaskInt16)); - mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_graveAccentMaskInt16)); - - return mask; - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Vector128 CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(Vector128 sourceValue) { @@ -103,13 +85,7 @@ public static Vector128 CreateAsciiMask(Vector128 sourceValue) private static readonly Vector128 s_nullMaskInt16 = Vector128.Zero; private static readonly Vector128 s_spaceMaskInt16 = Vector128.Create((short)' '); private static readonly Vector128 s_quotationMarkMaskInt16 = Vector128.Create((short)'"'); - private static readonly Vector128 s_ampersandMaskInt16 = Vector128.Create((short)'&'); - private static readonly Vector128 s_apostropheMaskInt16 = Vector128.Create((short)'\''); - private static readonly Vector128 s_plusSignMaskInt16 = Vector128.Create((short)'+'); - private static readonly Vector128 s_lessThanSignMaskInt16 = Vector128.Create((short)'<'); - private static readonly Vector128 s_greaterThanSignMaskInt16 = Vector128.Create((short)'>'); private static readonly Vector128 s_reverseSolidusMaskInt16 = Vector128.Create((short)'\\'); - private static readonly Vector128 s_graveAccentMaskInt16 = Vector128.Create((short)'`'); private static readonly Vector128 s_tildeMaskInt16 = Vector128.Create((short)'~'); private static readonly Vector128 s_maxAsciiCharacterMaskInt16 = Vector128.Create((short)0x7F); // Delete control character diff --git a/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/Ssse3Helper.cs b/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/Ssse3Helper.cs new file mode 100644 index 000000000000..d832b6c7ee51 --- /dev/null +++ b/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/Ssse3Helper.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; + +namespace System.Text.Encodings.Web +{ + internal static class Ssse3Helper + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(Vector128 sourceValue) + { + // To check if an input byte needs to be escaped or not, we create bit-mask. + // Therefore we split the input byte into the low- and high-nibble, which will get + // the row-/column-index in the bit-mask. + // The bit-mask-matrix looks like + // high-nibble + // low-nibble 0 1 2 3 4 5 6 7 8 9 A B C D E F + // 0 1 1 0 0 0 0 1 0 1 1 1 1 1 1 1 1 + // 1 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 + // 2 1 1 1 0 0 0 0 0 1 1 1 1 1 1 1 1 + // 3 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 + // 4 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 + // 5 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 + // 6 1 1 1 0 0 0 0 0 1 1 1 1 1 1 1 1 + // 7 1 1 1 0 0 0 0 0 1 1 1 1 1 1 1 1 + // 8 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 + // 9 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 + // A 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 + // B 1 1 1 0 0 0 0 0 1 1 1 1 1 1 1 1 + // C 1 1 0 1 0 1 0 0 1 1 1 1 1 1 1 1 + // D 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 + // E 1 1 0 1 0 0 0 0 1 1 1 1 1 1 1 1 + // F 1 1 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + // + // where 1 denotes the neeed for escaping, while 0 means no escaping needed. + // For high-nibbles in the range 8..F every input needs to be escaped, so we + // can omit them in the bit-mask, thus only high-nibbles in the range 0..7 need + // to be considered, hence the entries in the bit-mask can be of type byte. + // + // In the Bitmask (see above) for each row (= low-nibble) a bit-mask for the + // high-nibbles (= columns) is created. + + Debug.Assert(Ssse3.IsSupported); + + Vector128 highNibbles = Sse2.And(Sse2.ShiftRightLogical(sourceValue.AsInt32(), 4).AsSByte(), s_nibbleMaskSByte); + Vector128 lowNibbles = Sse2.And(sourceValue, s_nibbleMaskSByte); + + Vector128 bitMask = Ssse3.Shuffle(s_bitMask, lowNibbles); + Vector128 bitPositions = Ssse3.Shuffle(s_bitPosLookup, highNibbles); + + Vector128 mask = Sse2.And(bitPositions, bitMask); + + mask = Sse2.CompareEqual(s_nullMaskSByte, Sse2.CompareEqual(s_nullMaskSByte, mask)); + return mask; + } + + private static readonly Vector128 s_nibbleMaskSByte = Vector128.Create((sbyte)0xF); + private static readonly Vector128 s_nullMaskSByte = Vector128.Zero; + + // See comment above in method CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin + // for description of the bit-mask. + private static readonly Vector128 s_bitMask = Vector128.Create( + 0b_01000011, // low-nibble 0 + 0b_00000011, // low-nibble 1 + 0b_00000111, // low-nibble 2 + 0b_00000011, // low-nibble 3 + 0b_00000011, // low-nibble 4 + 0b_00000011, // low-nibble 5 + 0b_00000111, // low-nibble 6 + 0b_00000111, // low-nibble 7 + 0b_00000011, // low-nibble 8 + 0b_00000011, // low-nibble 9 + 0b_00000011, // low-nibble A + 0b_00000111, // low-nibble B + 0b_00101011, // low-nibble C + 0b_00000011, // low-nibble D + 0b_00001011, // low-nibble E + 0b_10000011 // low-nibble F + ).AsSByte(); + + // To check if a bit in a bitmask from the Bitmask is set, in a sequential code + // we would do ((1 << bitIndex) & bitmask) != 0 + // As there is no hardware instrinic for such a shift, we use a lookup that + // stores the shifted bitpositions. + // So (1 << bitIndex) becomes BitPosLook[bitIndex], which is simd-friendly. + // + // A bitmask from the Bitmask (above) is created only for values 0..7 (one byte), + // so to avoid a explicit check for values outside 0..7, i.e. + // high nibbles 8..F, we use a bitpos that always results in escaping. + private static readonly Vector128 s_bitPosLookup = Vector128.Create( + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, // high-nibble 0..7 + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF // high-nibble 8..F + ).AsSByte(); + } +} diff --git a/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/UnsafeRelaxedJavaScriptEncoder.cs b/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/UnsafeRelaxedJavaScriptEncoder.cs index 424ee5443501..eae2ff5f5fd3 100644 --- a/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/UnsafeRelaxedJavaScriptEncoder.cs +++ b/src/System.Text.Encodings.Web/src/System/Text/Encodings/Web/UnsafeRelaxedJavaScriptEncoder.cs @@ -53,7 +53,6 @@ public override bool WillEncode(int unicodeScalar) return !_allowedCharacters.IsUnicodeScalarAllowed(unicodeScalar); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public override unsafe int FindFirstCharacterToEncode(char* text, int textLength) { if (text == null) @@ -135,7 +134,6 @@ public override unsafe int FindFirstCharacterToEncode(char* text, int textLength return idx; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public override unsafe int FindFirstCharacterToEncodeUtf8(ReadOnlySpan utf8Text) { fixed (byte* ptr = utf8Text)