From 25e8936595e3502946f004f9489a9f8b3759e265 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Wed, 26 Feb 2025 18:45:07 -0800 Subject: [PATCH 1/2] Create SentencePieceTokenizer from options object --- eng/Versions.props | 2 +- .../Model/SentencePieceBaseModel.cs | 67 +++- .../Model/SentencePieceBpeModel.cs | 46 +++ .../Model/SentencePieceOptions.cs | 113 +++++++ .../Model/SentencePieceTokenizer.cs | 24 ++ .../Model/SentencePieceUnigramModel.cs | 105 ++++-- .../UnigramTests.cs | 312 +++++++++++++++++- 7 files changed, 638 insertions(+), 31 deletions(-) create mode 100644 src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs diff --git a/eng/Versions.props b/eng/Versions.props index b6de8327b2..522e5f3489 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -100,7 +100,7 @@ 0.0.13-test 0.0.6-test 0.0.7-test - 2.0.0-beta.25110.1 + 2.0.0-beta.25126.1 4.9.0 1.0.118 1.6.24 diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs index a1553ec4cd..0f8a76ddfc 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs @@ -59,6 +59,67 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool specialTokens); } + internal SentencePieceBaseModel(SentencePieceOptions options) + { + if (options is null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (options.Vocabulary is null) + { + throw new ArgumentNullException(nameof(options.Vocabulary)); + } + + if (options.BeginningOfSentenceToken is null) + { + throw new ArgumentNullException(nameof(options.BeginningOfSentenceToken)); + } + + if (options.EndOfSentenceToken is null) + { + throw new ArgumentNullException(nameof(options.EndOfSentenceToken)); + } + + if (options.UnknownToken is null) + { + throw new ArgumentNullException(nameof(options.UnknownToken)); + } + + AddBeginningOfSentence = options.AddBeginningOfSentence; + AddEndOfSentence = options.AddEndOfSentence; + BeginningOfSentenceToken = options.BeginningOfSentenceToken; + EndOfSentenceToken = options.EndOfSentenceToken; + UnknownToken = options.UnknownToken; + AddDummyPrefix = options.AddDummyPrefix; + EscapeWhiteSpaces = options.EscapeWhiteSpaces; + TreatWhitespaceAsSuffix = options.TreatWhitespaceAsSuffix; + ByteFallback = options.ByteFallback; + SpecialTokens = options.SpecialTokens; + + if (SpecialTokens is not null && SpecialTokens.Count > 0) + { + InternalSpecialTokens = new Dictionary(); + SpecialTokensReverse = new Dictionary(); + + foreach (var item in SpecialTokens) + { + InternalSpecialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value); + SpecialTokensReverse.Add(item.Value, item.Key); + } + + // We create this Regex object without a timeout, as we expect the match operation to complete in O(N) time complexity. Note that `specialTokens` are treated as constants after the tokenizer is created. + SpecialTokensRegex = new Regex(string.Join("|", SpecialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled); + } + + Normalizer = new SentencePieceNormalizer( + options.PrecompiledNormalizationData, + options.RemoveExtraWhiteSpaces, + options.AddDummyPrefix, options.EscapeWhiteSpaces, + options.TreatWhitespaceAsSuffix, + SpecialTokens); + } + internal Regex? SpecialTokensRegex { get; } internal Dictionary? InternalSpecialTokens { get; } @@ -91,11 +152,11 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool public string UnknownToken { get; } - public int BeginningOfSentenceId { get; } + public int BeginningOfSentenceId { get; set; } - public int EndOfSentenceId { get; } + public int EndOfSentenceId { get; set; } - public int UnknownId { get; } + public int UnknownId { get; set; } public SentencePieceNormalizer? Normalizer { get; } diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs index 85c85c1677..2823226bbc 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs @@ -41,6 +41,52 @@ internal SentencePieceBpeModel(ModelProto modelProto, bool addBos, bool addEos, OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character. } + internal SentencePieceBpeModel(SentencePieceOptions options) : base(options) + { + if (options.PrecompiledNormalizationData is not null) + { + throw new NotSupportedException("Normalization data is not supported for SentencePieceBpeModel."); + } + + Debug.Assert(options.Vocabulary is not null); + + int id = 0; + foreach (var item in options.Vocabulary!) + { + _vocab.Add(new StringSpanOrdinalKey(item.Key), (id, item.Value, (byte)ModelProto.Types.SentencePiece.Types.Type.Normal)); + _vocabReverse.Add(id++, item.Key); + } + + if (options.ByteFallback) + { + if (!_vocab.TryGetValue("<0x00>", out (int Id, float Score, byte Type) value)) + { + throw new ArgumentException("'ByteFallback' is enabled but the vocabulary must include a special token for each byte value (0-255) in the format <0xNN>, where NN represents the byte's hexadecimal value."); + } + + ByteCodeToIdOffset = value.Id; + OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character. + } + + if (!_vocab.TryGetValue(options.UnknownToken, out (int Id, float Score, byte Type) unknownToken)) + { + throw new ArgumentException($"The vocabulary must include the unknown token '{options.UnknownToken}'."); + } + UnknownId = unknownToken.Id; + + if (!_vocab.TryGetValue(options.BeginningOfSentenceToken, out (int Id, float Score, byte Type) beginOfSentenceToken)) + { + throw new ArgumentException($"The vocabulary must include the beginning of sentence token '{options.BeginningOfSentenceToken}'."); + } + BeginningOfSentenceId = beginOfSentenceToken.Id; + + if (!_vocab.TryGetValue(options.EndOfSentenceToken, out (int Id, float Score, byte Type) endOfSentenceToken)) + { + throw new ArgumentException($"The vocabulary must include the end of sentence token '{options.EndOfSentenceToken}'."); + } + EndOfSentenceId = endOfSentenceToken.Id; + } + public override IReadOnlyDictionary Vocabulary { get diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs new file mode 100644 index 0000000000..d4a073fb23 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs @@ -0,0 +1,113 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; + +namespace Microsoft.ML.Tokenizers +{ +#pragma warning disable MSML_NoInstanceInitializers + /// + /// The type of the SentencePiece model. + /// + public enum SentencePieceModelType + { + /// + /// The model type is not defined. + /// + Undefined, + + /// + /// The model type is Byte Pair Encoding (Bpe) model. + /// + Bpe, + + /// + /// The model type is Unigram model. + /// + Unigram, + } + + /// + /// Options for the SentencePiece tokenizer. + /// + public class SentencePieceOptions + { + /// + /// The type of the SentencePiece model. + /// + public SentencePieceModelType ModelType { get; set; } + + /// + /// Determines whether the model uses a byte fallback strategy to encode unknown tokens as byte sequences. + /// + /// + /// The vocabulary must include a special token for each byte value (0-255) in the format <0xNN>, + /// where NN represents the byte's hexadecimal value (e.g., <0x41> for byte value 65). + /// + public bool ByteFallback { get; set; } + + /// + /// Indicate emitting the prefix character e.g. U+2581 at the beginning of sentence token during the normalization and encoding. + /// + public bool AddDummyPrefix { get; set; } + + /// + /// Indicate if the spaces should be replaced with character U+2581 during the normalization and encoding. + /// + public bool EscapeWhiteSpaces { get; set; } = true; + + /// + /// Indicate emitting the character U+2581 at the end of the last sentence token instead beginning of sentence token during the normalization and encoding. + /// + public bool TreatWhitespaceAsSuffix { get; set; } + + /// + /// Indicate removing extra white spaces from the original string during the normalization. + /// + public bool RemoveExtraWhiteSpaces { get; set; } + + /// + /// Indicate emitting the beginning of sentence token during the encoding. + /// + public bool AddBeginningOfSentence { get; set; } = true; + + /// + /// Indicate emitting the end of sentence token during the encoding. + /// + public bool AddEndOfSentence { get; set; } + + /// + /// The beginning of sentence token. + /// + public string BeginningOfSentenceToken { get; set; } = ""; + + /// + /// The end of sentence token. + /// + public string EndOfSentenceToken { get; set; } = ""; + + /// + /// The unknown token. + /// + public string UnknownToken { get; set; } = ""; + + /// + /// The data used for string normalization. + /// + public byte[]? PrecompiledNormalizationData { get; set; } + + /// + /// Represent the vocabulary. + /// The list should be sorted by the token id. Every entry represents a token and its score. + /// + public IEnumerable>? Vocabulary { get; set; } + + /// + /// The special tokens. + /// Special tokens remain intact during encoding and are not split into sub-tokens. + /// + public IReadOnlyDictionary? SpecialTokens { get; set; } + } +#pragma warning restore MSML_NoInstanceInitializers +} \ No newline at end of file diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs index f41516e270..9090bfa9fe 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs @@ -30,6 +30,16 @@ internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos, }; } + internal SentencePieceTokenizer(SentencePieceOptions options) + { + _model = options.ModelType switch + { + SentencePieceModelType.Bpe => new SentencePieceBpeModel(options), + SentencePieceModelType.Unigram => new SentencePieceUnigramModel(options), + _ => throw new ArgumentException($"The model type '{options.ModelType}' is not supported.", nameof(options.ModelType)) + }; + } + /// /// The special tokens. /// @@ -457,5 +467,19 @@ public static SentencePieceTokenizer Create( return new SentencePieceTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens); } + + /// + /// Creates an instance of SentencePieceTokenizer. + /// + /// The options to use for the sentence piece tokenizer. + public static SentencePieceTokenizer Create(SentencePieceOptions options) + { + if (options is null) + { + throw new ArgumentNullException(nameof(options)); + } + + return new SentencePieceTokenizer(options); + } } } diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs index a578362f73..dcd9c29f47 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs @@ -66,6 +66,64 @@ public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos _vocabReverse[EndOfSentenceId] = (EndOfSentenceToken, 0f, 0); } + public SentencePieceUnigramModel(SentencePieceOptions options) : base(options) + { + _vocab = new SortedDictionary(OrdinalUtf8StringComparer.Instance); + // _vocabReverse = new (string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)[]; + + // 250_000 using big number to avoid reallocation during the initialization. + List<(string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)> vocabReverse = new(250_000); + + _minScore = float.MaxValue; + _maxScore = float.MinValue; + + int id = 0; + foreach (KeyValuePair kvp in options.Vocabulary!) + { + _vocab.Add(kvp.Key, id++); + vocabReverse.Add((kvp.Key, kvp.Value, ModelProto.Types.SentencePiece.Types.Type.Normal)); + _minScore = Math.Min(_minScore, kvp.Value); + _maxScore = Math.Max(_maxScore, kvp.Value); + } + + _vocabReverse = vocabReverse.ToArray(); + + if (options.ByteFallback) + { + if (!_vocab.TryGetValue("<0x00>", out id)) + { + throw new ArgumentException("'ByteFallback' is enabled but the vocabulary must include a special token for each byte value (0-255) in the format <0xNN>, where NN represents the byte's hexadecimal value."); + } + + ByteCodeToIdOffset = id; + OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character. + MaxIdByteFallbackId = ByteCodeToIdOffset + 0xFF; // from <0x00> to <0xFF>. + } + + _trie = new DoubleArrayTrie(_vocab); + + _vocabReverse[BeginningOfSentenceId] = (BeginningOfSentenceToken, 0f, 0); + _vocabReverse[EndOfSentenceId] = (EndOfSentenceToken, 0f, 0); + + if (!_vocab.TryGetValue(options.UnknownToken, out int unknownToken)) + { + throw new ArgumentException($"The vocabulary must include the unknown token '{options.UnknownToken}'."); + } + UnknownId = unknownToken; + + if (!_vocab.TryGetValue(options.BeginningOfSentenceToken, out int beginOfSentenceToken)) + { + throw new ArgumentException($"The vocabulary must include the beginning of sentence token '{options.BeginningOfSentenceToken}'."); + } + BeginningOfSentenceId = beginOfSentenceToken; + + if (!_vocab.TryGetValue(options.EndOfSentenceToken, out int endOfSentenceToken)) + { + throw new ArgumentException($"The vocabulary must include the end of sentence token '{options.EndOfSentenceToken}'."); + } + EndOfSentenceId = endOfSentenceToken; + } + public override IReadOnlyDictionary Vocabulary => new ReadOnlyDictionary(_vocab); public int MaxIdByteFallbackId { get; } @@ -114,39 +172,39 @@ public override bool TryMapIdToToken(int id, out string? token) return true; } - private void StoreNormalizedTextFromEnd(ReadOnlySpan text, ref char[] normalizedString, ref int normalizedStringIndexFromEnd) + private void StoreNormalizedTextFromEnd(ReadOnlySpan text, ref char[] normalizedString, ref int normalizedStringCountFromEnd) { - int remainingLength = normalizedString.Length - normalizedStringIndexFromEnd; + int remainingLength = normalizedString.Length - normalizedStringCountFromEnd; if (text.Length > remainingLength) { char[] utf16NormalizedString = ArrayPool.Shared.Rent(normalizedString.Length << 1); - normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringIndexFromEnd)); + normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringCountFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringCountFromEnd)); ArrayPool.Shared.Return(normalizedString); normalizedString = utf16NormalizedString; } - text.CopyTo(normalizedString.AsSpan(normalizedString.Length - normalizedStringIndexFromEnd - text.Length)); - normalizedStringIndexFromEnd += text.Length; + text.CopyTo(normalizedString.AsSpan(normalizedString.Length - normalizedStringCountFromEnd - text.Length)); + normalizedStringCountFromEnd += text.Length; } - private void StoreNormalizedTextFromEnd(ReadOnlySpan utf8Bytes, ref char[] normalizedString, ref int normalizedStringIndexFromEnd) + private void StoreNormalizedTextFromEnd(ReadOnlySpan utf8Bytes, ref char[] normalizedString, ref int normalizedStringCountFromEnd) { - int remainingLength = normalizedString.Length - normalizedStringIndexFromEnd; + int remainingLength = normalizedString.Length - normalizedStringCountFromEnd; int expectedCount = Helpers.GetUtf16LengthFromUtf8Bytes(utf8Bytes); if (expectedCount > remainingLength) { char[] utf16NormalizedString = ArrayPool.Shared.Rent(normalizedString.Length << 1); - normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringIndexFromEnd)); + normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringCountFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringCountFromEnd)); ArrayPool.Shared.Return(normalizedString); normalizedString = utf16NormalizedString; } - bool res = Helpers.ConvertUtf8ToUtf16(utf8Bytes, normalizedString.AsSpan(normalizedString.Length - normalizedStringIndexFromEnd - expectedCount), out int bytesConsumed, out int charsWritten); + bool res = Helpers.ConvertUtf8ToUtf16(utf8Bytes, normalizedString.AsSpan(normalizedString.Length - normalizedStringCountFromEnd - expectedCount), out int bytesConsumed, out int charsWritten); Debug.Assert(res); Debug.Assert(bytesConsumed == utf8Bytes.Length); Debug.Assert(charsWritten == expectedCount); - normalizedStringIndexFromEnd += expectedCount; + normalizedStringCountFromEnd += expectedCount; } private void StoreNormalizedText(ReadOnlySpan text, ref char[] normalizedString, ref int normalizedStringIndex) @@ -1251,14 +1309,15 @@ private void GetIndexByTokenCountFromEndWithSpecialTokens( Debug.Assert(maxTokenCount > 0); charConsumedFromEnd = 0; - int normalizedStringIndexFromEnd = 0; + int normalizedStringCountFromEnd = 0; (int Offset, int Length)[] splits = PreTokenizer.SplitText(text, SpecialTokensRegex!).ToArray(); if (splits.Length == 0) { - GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount); - normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - charConsumedFromEnd).ToString() : null; + GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringCountFromEnd, ref charConsumedFromEnd, maxTokenCount); + normalizedText = normalizedString is not null ? normalizedString.AsSpan(normalizedString.Length - normalizedStringCountFromEnd).ToString() : null; + return; } (int Offset, int Length) current = splits[splits.Length - 1]; @@ -1266,7 +1325,7 @@ private void GetIndexByTokenCountFromEndWithSpecialTokens( // Last part is not a special token if (current.Offset + current.Length < text.Length) { - GetIndexByTokenCountFromEndInternal(text.Slice(current.Offset + current.Length), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount); + GetIndexByTokenCountFromEndInternal(text.Slice(current.Offset + current.Length), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringCountFromEnd, ref charConsumedFromEnd, maxTokenCount); } for (int i = splits.Length - 1; i >= 0; i--) @@ -1285,17 +1344,17 @@ private void GetIndexByTokenCountFromEndWithSpecialTokens( if (normalizedString is not null) { - StoreNormalizedTextFromEnd(text.Slice(current.Offset, current.Length), ref normalizedString, ref normalizedStringIndexFromEnd); + StoreNormalizedTextFromEnd(text.Slice(current.Offset, current.Length), ref normalizedString, ref normalizedStringCountFromEnd); } if (current.Offset > 0) { int start = i > 0 ? splits[i - 1].Offset + splits[i - 1].Length : 0; - GetIndexByTokenCountFromEndInternal(text.Slice(start, current.Offset - start), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount); + GetIndexByTokenCountFromEndInternal(text.Slice(start, current.Offset - start), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringCountFromEnd, ref charConsumedFromEnd, maxTokenCount); } } - normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).ToString() : null; + normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringCountFromEnd).ToString() : null; } private void GetIndexByTokenCountFromEndWithoutSpecialTokens( @@ -1309,11 +1368,11 @@ private void GetIndexByTokenCountFromEndWithoutSpecialTokens( int maxTokenCount) { charConsumedFromEnd = 0; - int normalizedStringIndexFromEnd = 0; + int normalizedStringCountFromEnd = 0; - GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount); + GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringCountFromEnd, ref charConsumedFromEnd, maxTokenCount); - normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).ToString() : null; + normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringCountFromEnd).ToString() : null; } private void GetIndexByTokenCountFromEndInternal( @@ -1322,7 +1381,7 @@ private void GetIndexByTokenCountFromEndInternal( ref int tokenCount, int[] buffer, ref char[]? normalizedString, - ref int normalizedStringIndexFromEnd, + ref int normalizedStringCountFromEnd, ref int charConsumedFromEnd, int maxTokenCount) { @@ -1381,11 +1440,11 @@ private void GetIndexByTokenCountFromEndInternal( { if (considerNormalization) { - StoreNormalizedTextFromEnd(normalizationSpan, ref normalizedString, ref normalizedStringIndexFromEnd); + StoreNormalizedTextFromEnd(normalizationSpan, ref normalizedString, ref normalizedStringCountFromEnd); } else { - StoreNormalizedTextFromEnd(text, ref normalizedString, ref normalizedStringIndexFromEnd); + StoreNormalizedTextFromEnd(text, ref normalizedString, ref normalizedStringCountFromEnd); } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs index b90ab7a414..fc6714f89d 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs @@ -8,6 +8,8 @@ using System.IO; using System.Linq; using System.Reflection.Metadata; +using System.Text; +using System.Text.Json; using Microsoft.ML.Tokenizers; using Xunit; @@ -17,6 +19,7 @@ public class UnigramTests { private static SentencePieceTokenizer _unigramTokenizer = CreateUnigramTokenizer(); private static SentencePieceTokenizer _unigramTokenizerWithSpecialTokens = CreateUnigramTokenizerWithSpecialTokens(); + private static SentencePieceTokenizer _unigramTokenizerFromJson = CreateUnigramTokenizerFromJson(); private static SentencePieceTokenizer CreateUnigramTokenizer() { @@ -25,6 +28,72 @@ private static SentencePieceTokenizer CreateUnigramTokenizer() return SentencePieceTokenizer.Create(remoteStream); } + private static SentencePieceTokenizer CreateUnigramTokenizerFromJson() + { + // @"https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/resolve/main/tokenizer.json?download=true"; + using Stream jsonModelStream = File.OpenRead(Path.Combine(@"Paraphrase-multilingual-MiniLM-L12-v2", "tokenizer.json")); + using var reader = new StreamReader(jsonModelStream, Encoding.UTF8); + string json = reader.ReadToEnd(); + using JsonDocument doc = JsonDocument.Parse(json); + JsonElement root = doc.RootElement; + + SentencePieceOptions options = new SentencePieceOptions(); + options.ModelType = SentencePieceModelType.Unigram; + options.EscapeWhiteSpaces = true; + options.AddDummyPrefix = true; + + options.BeginningOfSentenceToken = ""; + options.EndOfSentenceToken = ""; + options.UnknownToken = ""; + + options.SpecialTokens = new Dictionary + { + { "", 0 }, + { "", 1 }, + { "", 2 }, + { "", 3 }, + { "", 250001 } + }; + + if (root.TryGetProperty("normalizer", out JsonElement normalizerElement) && normalizerElement.GetProperty("type").GetString() == "Precompiled") + { + string? precompiledCharsMap = normalizerElement.GetProperty("precompiled_charsmap").GetString(); + if (precompiledCharsMap is not null) + { + byte[] bytes = Convert.FromBase64String(precompiledCharsMap); + options.PrecompiledNormalizationData = bytes; + } + } + + options.Vocabulary = GetVocabulary(root); + return SentencePieceTokenizer.Create(options); + } + + private static IEnumerable> GetVocabulary(JsonElement root) + { + if (root.TryGetProperty("model", out JsonElement modelElement) && + modelElement.TryGetProperty("vocab", out JsonElement vocabElement) && + vocabElement.ValueKind == JsonValueKind.Array) + { + foreach (JsonElement token in vocabElement.EnumerateArray()) + { + if (token.ValueKind == JsonValueKind.Array && token.GetArrayLength() == 2) + { + string? tokenString = token[0].GetString(); + if (tokenString is null) + { + throw new InvalidOperationException("Invalid model vocabulary format"); + } + yield return new KeyValuePair(tokenString, token[1].GetSingle()); + } + } + } + else + { + throw new InvalidOperationException("Invalid model vocabulary format"); + } + } + private static SentencePieceTokenizer CreateUnigramTokenizerWithSpecialTokens() { // @"https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/resolve/main/sentencepiece.bpe.model?download=true"; @@ -163,7 +232,7 @@ public static IEnumerable UnigramTestData() yield return new object[] { - "xyz東京", // Latin-Japanese + "xyz東京", // Latin-Japanese "▁xyz東京", "xyz東京", new int[] { 1021, 32188, 22887 }, @@ -274,35 +343,92 @@ private void Validate((IEnumerable Ids, IEnumerable Tokens, IEnumer Assert.Equal(offsets, extracted.Offsets); } + /// + /// _unigramTokenizerFromJson, the tokenizer created from the json file has the ids shifted by 1 compared to the tokenizer created from tokenizer.bpe.model file. + /// + /// + /// + private int[] GetShiftedIds(int[] ids) + { + int[] shiftedIds = new int[ids.Length]; + foreach (int i in Enumerable.Range(0, ids.Length)) + { + if (ids[i] == _unigramTokenizer.UnknownId) + { + shiftedIds[i] = _unigramTokenizerFromJson.UnknownId; + } + else if (ids[i] == _unigramTokenizer.BeginningOfSentenceId) + { + shiftedIds[i] = _unigramTokenizerFromJson.BeginningOfSentenceId; + } + else if (ids[i] == _unigramTokenizer.EndOfSentenceId) + { + shiftedIds[i] = _unigramTokenizerFromJson.EndOfSentenceId; + } + else + { + shiftedIds[i] = ids[i] + 1; + } + } + + return shiftedIds; + } + [Theory] [MemberData(nameof(UnigramTestData))] public void EncodeToTokensTest(string inputText, string normalizedText, string decodedString, int[] ids, string[] tokens, Range[] offsets) { + int[] shiftedIds = GetShiftedIds(ids); + Assert.True(decodedString is not null); // to make the compiler happy IReadOnlyList result = _unigramTokenizer.EncodeToTokens(inputText, out string? normalized); (IEnumerable Ids, IEnumerable Tokens, IEnumerable Offsets) extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, _unigramTokenizer.AddBeginningOfSentence, _unigramTokenizer.AddEndOfSentence); Validate(extracted, ids, tokens, offsets); + result = _unigramTokenizerFromJson.EncodeToTokens(inputText, out normalized); + extracted = ExtractedIds(_unigramTokenizerFromJson, result, normalizedText, _unigramTokenizerFromJson.AddBeginningOfSentence, _unigramTokenizerFromJson.AddEndOfSentence); + Validate(extracted, shiftedIds, tokens, offsets); + result = _unigramTokenizer.EncodeToTokens(inputText.AsSpan(), out normalized); extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, _unigramTokenizer.AddBeginningOfSentence, _unigramTokenizer.AddEndOfSentence); Validate(extracted, ids, tokens, offsets); + result = _unigramTokenizerFromJson.EncodeToTokens(inputText.AsSpan(), out normalized); + extracted = ExtractedIds(_unigramTokenizerFromJson, result, normalizedText, _unigramTokenizerFromJson.AddBeginningOfSentence, _unigramTokenizerFromJson.AddEndOfSentence); + Validate(extracted, shiftedIds, tokens, offsets); + result = _unigramTokenizer.EncodeToTokens(inputText, out normalized, addBeginningOfSentence: true, addEndOfSentence: false); extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, true, false); Validate(extracted, ids, tokens, offsets); + result = _unigramTokenizerFromJson.EncodeToTokens(inputText, out normalized, addBeginningOfSentence: true, addEndOfSentence: false); + extracted = ExtractedIds(_unigramTokenizerFromJson, result, normalizedText, true, false); + Validate(extracted, shiftedIds, tokens, offsets); + result = _unigramTokenizer.EncodeToTokens(inputText.AsSpan(), out normalized, addBeginningOfSentence: true, addEndOfSentence: false); extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, true, false); Validate(extracted, ids, tokens, offsets); + result = _unigramTokenizerFromJson.EncodeToTokens(inputText.AsSpan(), out normalized, addBeginningOfSentence: true, addEndOfSentence: false); + extracted = ExtractedIds(_unigramTokenizerFromJson, result, normalizedText, true, false); + Validate(extracted, shiftedIds, tokens, offsets); + result = _unigramTokenizer.EncodeToTokens(inputText, out normalized, addBeginningOfSentence: true, addEndOfSentence: true); extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, true, true); Validate(extracted, ids, tokens, offsets); + result = _unigramTokenizerFromJson.EncodeToTokens(inputText, out normalized, addBeginningOfSentence: true, addEndOfSentence: true); + extracted = ExtractedIds(_unigramTokenizerFromJson, result, normalizedText, true, true); + Validate(extracted, shiftedIds, tokens, offsets); + result = _unigramTokenizer.EncodeToTokens(inputText.AsSpan(), out normalized, addBeginningOfSentence: true, addEndOfSentence: true); extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, true, true); Validate(extracted, ids, tokens, offsets); + result = _unigramTokenizerFromJson.EncodeToTokens(inputText.AsSpan(), out normalized, addBeginningOfSentence: true, addEndOfSentence: true); + extracted = ExtractedIds(_unigramTokenizerFromJson, result, normalizedText, true, true); + Validate(extracted, shiftedIds, tokens, offsets); + string newString = $"{_unigramTokenizer.BeginningOfSentenceToken}{inputText}{inputText}{_unigramTokenizer.EndOfSentenceToken}"; result = _unigramTokenizerWithSpecialTokens.EncodeToTokens(newString, out normalized, addBeginningOfSentence: false, addEndOfSentence: false); extracted = ExtractedIds(_unigramTokenizerWithSpecialTokens, result, normalizedText, false, false); @@ -322,12 +448,34 @@ public void EncodeToTokensTest(string inputText, string normalizedText, string d Array.Copy(tokens, 0, expectedTokens, tokens.Length + 2, tokens.Length); expectedTokens[tokens.Length * 2 + 2] = _unigramTokenizerWithSpecialTokens.EndOfSentenceToken; Assert.Equal(expectedTokens, extracted.Tokens); + + newString = $"{_unigramTokenizerFromJson.BeginningOfSentenceToken}{inputText}{inputText}{_unigramTokenizerFromJson.EndOfSentenceToken}"; + result = _unigramTokenizerFromJson.EncodeToTokens(newString, out normalized, addBeginningOfSentence: false, addEndOfSentence: false); + extracted = ExtractedIds(_unigramTokenizerFromJson, result, normalizedText, false, false); + + expectedIds = new int[ids.Length * 2 + 3]; + expectedIds[0] = _unigramTokenizerFromJson.BeginningOfSentenceId; + Array.Copy(shiftedIds, 0, expectedIds, 1, shiftedIds.Length); + expectedIds[shiftedIds.Length + 1] = _unigramTokenizerFromJson.SpecialTokens![""]; + Array.Copy(shiftedIds, 0, expectedIds, shiftedIds.Length + 2, shiftedIds.Length); + expectedIds[shiftedIds.Length * 2 + 2] = _unigramTokenizerFromJson.EndOfSentenceId; + Assert.Equal(expectedIds, extracted.Ids); + + expectedTokens = new string[tokens.Length * 2 + 3]; + expectedTokens[0] = _unigramTokenizerFromJson.BeginningOfSentenceToken; + Array.Copy(tokens, 0, expectedTokens, 1, tokens.Length); + expectedTokens[tokens.Length + 1] = ""; + Array.Copy(tokens, 0, expectedTokens, tokens.Length + 2, tokens.Length); + expectedTokens[tokens.Length * 2 + 2] = _unigramTokenizerFromJson.EndOfSentenceToken; + Assert.Equal(expectedTokens, extracted.Tokens); } [Theory] [MemberData(nameof(UnigramTestData))] public void EncodeToIdsTest(string inputText, string normalizedText, string decodedString, int[] ids, string[] tokens, Range[] offsets) { + int[] shiftedIds = GetShiftedIds(ids); + Assert.True(decodedString is not null); // to make the compiler happy Assert.True(tokens is not null); // to make the compiler happy Assert.True(offsets is not null); // to make the compiler happy @@ -337,6 +485,11 @@ public void EncodeToIdsTest(string inputText, string normalizedText, string deco result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false); Assert.Equal(ids, result); + result = _unigramTokenizerFromJson.EncodeToIds(inputText, addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(shiftedIds, result); + result = _unigramTokenizerFromJson.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(shiftedIds, result); + result = _unigramTokenizer.EncodeToIds(inputText, addBeginningOfSentence: true, addEndOfSentence: false); List ints = result is List list ? list : result.ToList(); if (ints.Count > 0) @@ -345,6 +498,14 @@ public void EncodeToIdsTest(string inputText, string normalizedText, string deco } Assert.Equal(ids, ints); + result = _unigramTokenizerFromJson.EncodeToIds(inputText, addBeginningOfSentence: true, addEndOfSentence: false); + ints = result is List list1 ? list1 : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + } + Assert.Equal(shiftedIds, ints); + result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: false); ints = result is List ? (List)result : result.ToList(); if (ints.Count > 0) @@ -353,6 +514,14 @@ public void EncodeToIdsTest(string inputText, string normalizedText, string deco } Assert.Equal(ids, ints); + result = _unigramTokenizerFromJson.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: false); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + } + Assert.Equal(shiftedIds, ints); + result = _unigramTokenizer.EncodeToIds(inputText, addBeginningOfSentence: true, addEndOfSentence: true); ints = result is List ? (List)result : result.ToList(); if (ints.Count > 0) @@ -362,6 +531,15 @@ public void EncodeToIdsTest(string inputText, string normalizedText, string deco } Assert.Equal(ids, ints); + result = _unigramTokenizerFromJson.EncodeToIds(inputText, addBeginningOfSentence: true, addEndOfSentence: true); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + ints.RemoveAt(ints.Count - 1); + } + Assert.Equal(shiftedIds, ints); + result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true); ints = result is List ? (List)result : result.ToList(); if (ints.Count > 0) @@ -371,16 +549,33 @@ public void EncodeToIdsTest(string inputText, string normalizedText, string deco } Assert.Equal(ids, ints); + result = _unigramTokenizerFromJson.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + ints.RemoveAt(ints.Count - 1); + } + Assert.Equal(shiftedIds, ints); + for (int i = 1; i <= ids.Length; i++) { result = _unigramTokenizer.EncodeToIds(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out string? normalized, out int charConsumed); Assert.Equal(ids.Take(i), result); Assert.Equal(normalizedText, normalized); + result = _unigramTokenizerFromJson.EncodeToIds(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out normalized, out charConsumed); + Assert.Equal(shiftedIds.Take(i), result); + Assert.Equal(normalizedText, normalized); + result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out normalized, out charConsumed); Assert.Equal(ids.Take(i), result); Assert.Equal(normalizedText, normalized); + result = _unigramTokenizerFromJson.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out normalized, out charConsumed); + Assert.Equal(shiftedIds.Take(i), result); + Assert.Equal(normalizedText, normalized); + result = _unigramTokenizer.EncodeToIds(inputText, addBeginningOfSentence: true, addEndOfSentence: true, maxTokenCount: i, out normalized, out charConsumed); ints = result is List ? (List)result : result.ToList(); if (ints.Count > 0) @@ -397,6 +592,22 @@ public void EncodeToIdsTest(string inputText, string normalizedText, string deco Assert.Equal(normalizedText, normalized); } + result = _unigramTokenizerFromJson.EncodeToIds(inputText, addBeginningOfSentence: true, addEndOfSentence: true, maxTokenCount: i, out normalized, out charConsumed); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + } + if (ints.Count > shiftedIds.Length) + { + ints.RemoveAt(ints.Count - 1); + } + Assert.Equal(shiftedIds.Take(i - 1), ints); // Exclude the counted BoS token + if (normalized is not null) + { + Assert.Equal(normalizedText, normalized); + } + result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, maxTokenCount: i, out normalized, out charConsumed); ints = result is List ? (List)result : result.ToList(); if (ints.Count > 0) @@ -412,6 +623,22 @@ public void EncodeToIdsTest(string inputText, string normalizedText, string deco { Assert.Equal(normalizedText, normalized); } + + result = _unigramTokenizerFromJson.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, maxTokenCount: i, out normalized, out charConsumed); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + } + if (ints.Count > shiftedIds.Length) + { + ints.RemoveAt(ints.Count - 1); + } + Assert.Equal(shiftedIds.Take(i - 1), ints); // Exclude the counted BoS token + if (normalized is not null) + { + Assert.Equal(normalizedText, normalized); + } } inputText = $"{_unigramTokenizerWithSpecialTokens.BeginningOfSentenceToken}{inputText}{inputText}{_unigramTokenizerWithSpecialTokens.EndOfSentenceToken}"; @@ -433,6 +660,25 @@ public void EncodeToIdsTest(string inputText, string normalizedText, string deco Assert.Equal(expectedIds.Take(i), result); Assert.Equal(expectedNormalized, normalized); } + + expectedIds = new int[shiftedIds.Length * 2 + 3]; + expectedIds[0] = _unigramTokenizerFromJson.BeginningOfSentenceId; + Array.Copy(shiftedIds, 0, expectedIds, 1, shiftedIds.Length); + expectedIds[shiftedIds.Length + 1] = _unigramTokenizerFromJson.SpecialTokens![""]; + Array.Copy(shiftedIds, 0, expectedIds, shiftedIds.Length + 2, shiftedIds.Length); + expectedIds[shiftedIds.Length * 2 + 2] = _unigramTokenizerFromJson.EndOfSentenceId; + expectedNormalized = $"{_unigramTokenizerFromJson.BeginningOfSentenceToken}{normalizedText}{normalizedText}{_unigramTokenizerFromJson.EndOfSentenceToken}"; + + for (int i = 1; i <= expectedIds.Length; i++) + { + result = _unigramTokenizerFromJson.EncodeToIds(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out string? normalized, out int charConsumed); + Assert.Equal(expectedIds.Take(i), result); + Assert.Equal(expectedNormalized, normalized); + + result = _unigramTokenizerFromJson.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out normalized, out charConsumed); + Assert.Equal(expectedIds.Take(i), result); + Assert.Equal(expectedNormalized, normalized); + } } [Theory] @@ -443,6 +689,7 @@ public void GetIndexByTokenCountTest(string inputText, string normalizedText, st Assert.True(tokens is not null); // to make the compiler happy Assert.True(offsets is not null); // to make the compiler happy + int[] shiftedIds = GetShiftedIds(ids); int totalTokens = ids.Length; for (int i = 1; i <= totalTokens; i++) @@ -453,23 +700,47 @@ public void GetIndexByTokenCountTest(string inputText, string normalizedText, st IReadOnlyList ids2 = index < normalized.Length ? _unigramTokenizer.EncodeToIds(normalized!.Substring(index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); Assert.Equal(ids, ids1.Concat(ids2).ToList()); + index = _unigramTokenizerFromJson.GetIndexByTokenCount(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, out normalized, out charConsumed); + Assert.Equal(normalizedText, normalized); + ids1 = _unigramTokenizerFromJson.EncodeToIds(normalized!.Substring(0, index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); + ids2 = index < normalized.Length ? _unigramTokenizerFromJson.EncodeToIds(normalized!.Substring(index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); + Assert.Equal(shiftedIds, ids1.Concat(ids2).ToList()); + index = _unigramTokenizer.GetIndexByTokenCount(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, out normalized, out charConsumed); Assert.Equal(normalizedText, normalized); ids1 = _unigramTokenizer.EncodeToIds(normalized!.Substring(0, index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); ids2 = index < normalized.Length ? _unigramTokenizer.EncodeToIds(normalized!.Substring(index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); Assert.Equal(ids, ids1.Concat(ids2).ToList()); + index = _unigramTokenizerFromJson.GetIndexByTokenCount(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, out normalized, out charConsumed); + Assert.Equal(normalizedText, normalized); + ids1 = _unigramTokenizerFromJson.EncodeToIds(normalized!.Substring(0, index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); + ids2 = index < normalized.Length ? _unigramTokenizerFromJson.EncodeToIds(normalized!.Substring(index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); + Assert.Equal(shiftedIds, ids1.Concat(ids2).ToList()); + index = _unigramTokenizer.GetIndexByTokenCountFromEnd(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, considerNormalization: true, out normalized, out charConsumed); Assert.Equal(normalizedText, normalized); ids1 = _unigramTokenizer.EncodeToIds(normalized!.Substring(0, index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); ids2 = index < normalized.Length ? _unigramTokenizer.EncodeToIds(normalized!.Substring(index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); Assert.Equal(ids, ids1.Concat(ids2).ToList()); + index = _unigramTokenizerFromJson.GetIndexByTokenCountFromEnd(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, considerNormalization: true, out normalized, out charConsumed); + Assert.Equal(normalizedText, normalized); + ids1 = _unigramTokenizerFromJson.EncodeToIds(normalized!.Substring(0, index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); + ids2 = index < normalized.Length ? _unigramTokenizerFromJson.EncodeToIds(normalized!.Substring(index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); + Assert.Equal(shiftedIds, ids1.Concat(ids2).ToList()); + index = _unigramTokenizer.GetIndexByTokenCountFromEnd(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, considerNormalization: true, out normalized, out charConsumed); Assert.Equal(normalizedText, normalized); ids1 = _unigramTokenizer.EncodeToIds(normalized!.Substring(0, index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); ids2 = index < normalized.Length ? _unigramTokenizer.EncodeToIds(normalized!.Substring(index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); Assert.Equal(ids, ids1.Concat(ids2).ToList()); + + index = _unigramTokenizerFromJson.GetIndexByTokenCountFromEnd(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, considerNormalization: true, out normalized, out charConsumed); + Assert.Equal(normalizedText, normalized); + ids1 = _unigramTokenizerFromJson.EncodeToIds(normalized!.Substring(0, index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); + ids2 = index < normalized.Length ? _unigramTokenizerFromJson.EncodeToIds(normalized!.Substring(index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); + Assert.Equal(shiftedIds, ids1.Concat(ids2).ToList()); } } @@ -482,19 +753,25 @@ public void DecodeTest(string inputText, string normalizedText, string decodedSt Assert.True(inputText is not null); // to make the compiler happy Assert.True(normalizedText is not null); // to make the compiler happy - string result = _unigramTokenizer.Decode(ids, considerSpecialTokens: false); + DecodeWithTokenizerTest(_unigramTokenizer, decodedString, ids); + DecodeWithTokenizerTest(_unigramTokenizerFromJson, decodedString, GetShiftedIds(ids)); + } + + private static void DecodeWithTokenizerTest(SentencePieceTokenizer tokenizer, string decodedString, int[] ids) + { + string result = tokenizer.Decode(ids, considerSpecialTokens: false); Assert.Equal(decodedString, result); char[] buffer = new char[decodedString.Length]; - OperationStatus status = _unigramTokenizer.Decode(ids, buffer, considerSpecialTokens: false, out int idsConsumed, out int charsWritten); + OperationStatus status = tokenizer.Decode(ids, buffer, considerSpecialTokens: false, out int idsConsumed, out int charsWritten); Assert.Equal(OperationStatus.Done, status); Assert.Equal(ids.Length, idsConsumed); Assert.Equal(decodedString, buffer.AsSpan().Slice(0, charsWritten).ToString()); for (int i = 0; i < decodedString.Length - 1; i++) { - status = _unigramTokenizer.Decode(ids, buffer.AsSpan().Slice(0, i), considerSpecialTokens: false, out idsConsumed, out charsWritten); + status = tokenizer.Decode(ids, buffer.AsSpan().Slice(0, i), considerSpecialTokens: false, out idsConsumed, out charsWritten); Assert.Equal(OperationStatus.DestinationTooSmall, status); Assert.Equal(decodedString.AsSpan().Slice(0, charsWritten).ToString(), buffer.AsSpan().Slice(0, charsWritten).ToString()); } @@ -510,5 +787,32 @@ public void SpecialTokensTest() Assert.Equal("", _unigramTokenizer.EndOfSentenceToken); Assert.Equal(2, _unigramTokenizer.EndOfSentenceId); } + + [Fact] + public void JsonTokenizerSpecialTokensTest() + { + Assert.Equal("", _unigramTokenizerFromJson.UnknownToken); + Assert.Equal(3, _unigramTokenizerFromJson.UnknownId); + Assert.Equal("", _unigramTokenizerFromJson.BeginningOfSentenceToken); + Assert.Equal(0, _unigramTokenizerFromJson.BeginningOfSentenceId); + Assert.Equal("", _unigramTokenizerFromJson.EndOfSentenceToken); + Assert.Equal(2, _unigramTokenizerFromJson.EndOfSentenceId); + + var specialTokens = new Dictionary + { + { "", 0 }, + { "", 1 }, + { "", 2 }, + { "", 3 }, + { "", 250001 } + }; + + Assert.Equal(specialTokens, _unigramTokenizerFromJson.SpecialTokens); + Assert.Equal(0, _unigramTokenizerFromJson.Vocabulary[""]); + Assert.Equal(1, _unigramTokenizerFromJson.Vocabulary[""]); + Assert.Equal(2, _unigramTokenizerFromJson.Vocabulary[""]); + Assert.Equal(3, _unigramTokenizerFromJson.Vocabulary[""]); + Assert.Equal(250001, _unigramTokenizerFromJson.Vocabulary[""]); + } } } From f97b7831d7ca150606e22df5e0ecfd698de2473d Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Thu, 27 Feb 2025 11:19:10 -0800 Subject: [PATCH 2/2] Address the feedback --- .../Model/SentencePieceBpeModel.cs | 4 +-- .../Model/SentencePieceOptions.cs | 27 +++++++++++-------- .../Model/SentencePieceUnigramModel.cs | 10 +++---- .../UnigramTests.cs | 4 +-- 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs index 2823226bbc..d33dede6c8 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs @@ -53,8 +53,8 @@ internal SentencePieceBpeModel(SentencePieceOptions options) : base(options) int id = 0; foreach (var item in options.Vocabulary!) { - _vocab.Add(new StringSpanOrdinalKey(item.Key), (id, item.Value, (byte)ModelProto.Types.SentencePiece.Types.Type.Normal)); - _vocabReverse.Add(id++, item.Key); + _vocab.Add(new StringSpanOrdinalKey(item.Token), (id, item.Score, (byte)ModelProto.Types.SentencePiece.Types.Type.Normal)); + _vocabReverse.Add(id++, item.Token); } if (options.ByteFallback) diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs index d4a073fb23..87285d59b7 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs @@ -15,23 +15,26 @@ public enum SentencePieceModelType /// /// The model type is not defined. /// - Undefined, + Undefined = 0, /// /// The model type is Byte Pair Encoding (Bpe) model. /// - Bpe, + Bpe = 1, /// /// The model type is Unigram model. /// - Unigram, + Unigram = 2, } /// /// Options for the SentencePiece tokenizer. /// - public class SentencePieceOptions + /// + /// The options are used to configure the SentencePiece tokenizer. Serialization is not guaranteed for this type. + /// + public sealed class SentencePieceOptions { /// /// The type of the SentencePiece model. @@ -53,7 +56,7 @@ public class SentencePieceOptions public bool AddDummyPrefix { get; set; } /// - /// Indicate if the spaces should be replaced with character U+2581 during the normalization and encoding. + /// Indicate if the spaces should be replaced with character U+2581 during the normalization and encoding. Default value is `true`. /// public bool EscapeWhiteSpaces { get; set; } = true; @@ -68,7 +71,7 @@ public class SentencePieceOptions public bool RemoveExtraWhiteSpaces { get; set; } /// - /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the beginning of sentence token during the encoding. Default value is `true`. /// public bool AddBeginningOfSentence { get; set; } = true; @@ -78,17 +81,17 @@ public class SentencePieceOptions public bool AddEndOfSentence { get; set; } /// - /// The beginning of sentence token. + /// The beginning of sentence token. Default value is `<s>`. /// public string BeginningOfSentenceToken { get; set; } = ""; /// - /// The end of sentence token. + /// The end of sentence token. Default value is `</s>`. /// public string EndOfSentenceToken { get; set; } = ""; /// - /// The unknown token. + /// The unknown token. Default value is `<unk>`. /// public string UnknownToken { get; set; } = ""; @@ -99,9 +102,11 @@ public class SentencePieceOptions /// /// Represent the vocabulary. - /// The list should be sorted by the token id. Every entry represents a token and its score. + /// The list should be sorted by token ID, with entries passed in the order that corresponds to their IDs. In other words, + /// the first entry in the list will be mapped to ID 0, the second entry to ID 1, the third to ID 2, and so on. + /// Each entry represents a token and its corresponding score. /// - public IEnumerable>? Vocabulary { get; set; } + public IEnumerable<(string Token, float Score)>? Vocabulary { get; set; } /// /// The special tokens. diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs index dcd9c29f47..dca346caea 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs @@ -78,12 +78,12 @@ public SentencePieceUnigramModel(SentencePieceOptions options) : base(options) _maxScore = float.MinValue; int id = 0; - foreach (KeyValuePair kvp in options.Vocabulary!) + foreach ((string Token, float Score) item in options.Vocabulary!) { - _vocab.Add(kvp.Key, id++); - vocabReverse.Add((kvp.Key, kvp.Value, ModelProto.Types.SentencePiece.Types.Type.Normal)); - _minScore = Math.Min(_minScore, kvp.Value); - _maxScore = Math.Max(_maxScore, kvp.Value); + _vocab.Add(item.Token, id++); + vocabReverse.Add((item.Token, item.Score, ModelProto.Types.SentencePiece.Types.Type.Normal)); + _minScore = Math.Min(_minScore, item.Score); + _maxScore = Math.Max(_maxScore, item.Score); } _vocabReverse = vocabReverse.ToArray(); diff --git a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs index fc6714f89d..2e948a36ea 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs @@ -69,7 +69,7 @@ private static SentencePieceTokenizer CreateUnigramTokenizerFromJson() return SentencePieceTokenizer.Create(options); } - private static IEnumerable> GetVocabulary(JsonElement root) + private static IEnumerable<(string Token, float Score)> GetVocabulary(JsonElement root) { if (root.TryGetProperty("model", out JsonElement modelElement) && modelElement.TryGetProperty("vocab", out JsonElement vocabElement) && @@ -84,7 +84,7 @@ private static IEnumerable> GetVocabulary(JsonElemen { throw new InvalidOperationException("Invalid model vocabulary format"); } - yield return new KeyValuePair(tokenString, token[1].GetSingle()); + yield return (tokenString, token[1].GetSingle()); } } }