diff --git a/eng/Versions.props b/eng/Versions.props
index b6de8327b2..522e5f3489 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -100,7 +100,7 @@
0.0.13-test
0.0.6-test
0.0.7-test
- 2.0.0-beta.25110.1
+ 2.0.0-beta.25126.1
4.9.0
1.0.118
1.6.24
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
index a1553ec4cd..0f8a76ddfc 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
@@ -59,6 +59,67 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool
specialTokens);
}
+ internal SentencePieceBaseModel(SentencePieceOptions options)
+ {
+ if (options is null)
+ {
+ throw new ArgumentNullException(nameof(options));
+ }
+
+ if (options.Vocabulary is null)
+ {
+ throw new ArgumentNullException(nameof(options.Vocabulary));
+ }
+
+ if (options.BeginningOfSentenceToken is null)
+ {
+ throw new ArgumentNullException(nameof(options.BeginningOfSentenceToken));
+ }
+
+ if (options.EndOfSentenceToken is null)
+ {
+ throw new ArgumentNullException(nameof(options.EndOfSentenceToken));
+ }
+
+ if (options.UnknownToken is null)
+ {
+ throw new ArgumentNullException(nameof(options.UnknownToken));
+ }
+
+ AddBeginningOfSentence = options.AddBeginningOfSentence;
+ AddEndOfSentence = options.AddEndOfSentence;
+ BeginningOfSentenceToken = options.BeginningOfSentenceToken;
+ EndOfSentenceToken = options.EndOfSentenceToken;
+ UnknownToken = options.UnknownToken;
+ AddDummyPrefix = options.AddDummyPrefix;
+ EscapeWhiteSpaces = options.EscapeWhiteSpaces;
+ TreatWhitespaceAsSuffix = options.TreatWhitespaceAsSuffix;
+ ByteFallback = options.ByteFallback;
+ SpecialTokens = options.SpecialTokens;
+
+ if (SpecialTokens is not null && SpecialTokens.Count > 0)
+ {
+ InternalSpecialTokens = new Dictionary();
+ SpecialTokensReverse = new Dictionary();
+
+ foreach (var item in SpecialTokens)
+ {
+ InternalSpecialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value);
+ SpecialTokensReverse.Add(item.Value, item.Key);
+ }
+
+ // We create this Regex object without a timeout, as we expect the match operation to complete in O(N) time complexity. Note that `specialTokens` are treated as constants after the tokenizer is created.
+ SpecialTokensRegex = new Regex(string.Join("|", SpecialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
+ }
+
+ Normalizer = new SentencePieceNormalizer(
+ options.PrecompiledNormalizationData,
+ options.RemoveExtraWhiteSpaces,
+ options.AddDummyPrefix, options.EscapeWhiteSpaces,
+ options.TreatWhitespaceAsSuffix,
+ SpecialTokens);
+ }
+
internal Regex? SpecialTokensRegex { get; }
internal Dictionary? InternalSpecialTokens { get; }
@@ -91,11 +152,11 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool
public string UnknownToken { get; }
- public int BeginningOfSentenceId { get; }
+ public int BeginningOfSentenceId { get; set; }
- public int EndOfSentenceId { get; }
+ public int EndOfSentenceId { get; set; }
- public int UnknownId { get; }
+ public int UnknownId { get; set; }
public SentencePieceNormalizer? Normalizer { get; }
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs
index 85c85c1677..d33dede6c8 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs
@@ -41,6 +41,52 @@ internal SentencePieceBpeModel(ModelProto modelProto, bool addBos, bool addEos,
OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
}
+ internal SentencePieceBpeModel(SentencePieceOptions options) : base(options)
+ {
+ if (options.PrecompiledNormalizationData is not null)
+ {
+ throw new NotSupportedException("Normalization data is not supported for SentencePieceBpeModel.");
+ }
+
+ Debug.Assert(options.Vocabulary is not null);
+
+ int id = 0;
+ foreach (var item in options.Vocabulary!)
+ {
+ _vocab.Add(new StringSpanOrdinalKey(item.Token), (id, item.Score, (byte)ModelProto.Types.SentencePiece.Types.Type.Normal));
+ _vocabReverse.Add(id++, item.Token);
+ }
+
+ if (options.ByteFallback)
+ {
+ if (!_vocab.TryGetValue("<0x00>", out (int Id, float Score, byte Type) value))
+ {
+ throw new ArgumentException("'ByteFallback' is enabled but the vocabulary must include a special token for each byte value (0-255) in the format <0xNN>, where NN represents the byte's hexadecimal value.");
+ }
+
+ ByteCodeToIdOffset = value.Id;
+ OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
+ }
+
+ if (!_vocab.TryGetValue(options.UnknownToken, out (int Id, float Score, byte Type) unknownToken))
+ {
+ throw new ArgumentException($"The vocabulary must include the unknown token '{options.UnknownToken}'.");
+ }
+ UnknownId = unknownToken.Id;
+
+ if (!_vocab.TryGetValue(options.BeginningOfSentenceToken, out (int Id, float Score, byte Type) beginOfSentenceToken))
+ {
+ throw new ArgumentException($"The vocabulary must include the beginning of sentence token '{options.BeginningOfSentenceToken}'.");
+ }
+ BeginningOfSentenceId = beginOfSentenceToken.Id;
+
+ if (!_vocab.TryGetValue(options.EndOfSentenceToken, out (int Id, float Score, byte Type) endOfSentenceToken))
+ {
+ throw new ArgumentException($"The vocabulary must include the end of sentence token '{options.EndOfSentenceToken}'.");
+ }
+ EndOfSentenceId = endOfSentenceToken.Id;
+ }
+
public override IReadOnlyDictionary Vocabulary
{
get
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs
new file mode 100644
index 0000000000..87285d59b7
--- /dev/null
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs
@@ -0,0 +1,118 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+
+namespace Microsoft.ML.Tokenizers
+{
+#pragma warning disable MSML_NoInstanceInitializers
+ ///
+ /// The type of the SentencePiece model.
+ ///
+ public enum SentencePieceModelType
+ {
+ ///
+ /// The model type is not defined.
+ ///
+ Undefined = 0,
+
+ ///
+ /// The model type is Byte Pair Encoding (Bpe) model.
+ ///
+ Bpe = 1,
+
+ ///
+ /// The model type is Unigram model.
+ ///
+ Unigram = 2,
+ }
+
+ ///
+ /// Options for the SentencePiece tokenizer.
+ ///
+ ///
+ /// The options are used to configure the SentencePiece tokenizer. Serialization is not guaranteed for this type.
+ ///
+ public sealed class SentencePieceOptions
+ {
+ ///
+ /// The type of the SentencePiece model.
+ ///
+ public SentencePieceModelType ModelType { get; set; }
+
+ ///
+ /// Determines whether the model uses a byte fallback strategy to encode unknown tokens as byte sequences.
+ ///
+ ///
+ /// The vocabulary must include a special token for each byte value (0-255) in the format <0xNN>,
+ /// where NN represents the byte's hexadecimal value (e.g., <0x41> for byte value 65).
+ ///
+ public bool ByteFallback { get; set; }
+
+ ///
+ /// Indicate emitting the prefix character e.g. U+2581 at the beginning of sentence token during the normalization and encoding.
+ ///
+ public bool AddDummyPrefix { get; set; }
+
+ ///
+ /// Indicate if the spaces should be replaced with character U+2581 during the normalization and encoding. Default value is `true`.
+ ///
+ public bool EscapeWhiteSpaces { get; set; } = true;
+
+ ///
+ /// Indicate emitting the character U+2581 at the end of the last sentence token instead beginning of sentence token during the normalization and encoding.
+ ///
+ public bool TreatWhitespaceAsSuffix { get; set; }
+
+ ///
+ /// Indicate removing extra white spaces from the original string during the normalization.
+ ///
+ public bool RemoveExtraWhiteSpaces { get; set; }
+
+ ///
+ /// Indicate emitting the beginning of sentence token during the encoding. Default value is `true`.
+ ///
+ public bool AddBeginningOfSentence { get; set; } = true;
+
+ ///
+ /// Indicate emitting the end of sentence token during the encoding.
+ ///
+ public bool AddEndOfSentence { get; set; }
+
+ ///
+ /// The beginning of sentence token. Default value is `<s>`.
+ ///
+ public string BeginningOfSentenceToken { get; set; } = "";
+
+ ///
+ /// The end of sentence token. Default value is `</s>`.
+ ///
+ public string EndOfSentenceToken { get; set; } = "";
+
+ ///
+ /// The unknown token. Default value is `<unk>`.
+ ///
+ public string UnknownToken { get; set; } = "";
+
+ ///
+ /// The data used for string normalization.
+ ///
+ public byte[]? PrecompiledNormalizationData { get; set; }
+
+ ///
+ /// Represent the vocabulary.
+ /// The list should be sorted by token ID, with entries passed in the order that corresponds to their IDs. In other words,
+ /// the first entry in the list will be mapped to ID 0, the second entry to ID 1, the third to ID 2, and so on.
+ /// Each entry represents a token and its corresponding score.
+ ///
+ public IEnumerable<(string Token, float Score)>? Vocabulary { get; set; }
+
+ ///
+ /// The special tokens.
+ /// Special tokens remain intact during encoding and are not split into sub-tokens.
+ ///
+ public IReadOnlyDictionary? SpecialTokens { get; set; }
+ }
+#pragma warning restore MSML_NoInstanceInitializers
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
index f41516e270..9090bfa9fe 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
@@ -30,6 +30,16 @@ internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos,
};
}
+ internal SentencePieceTokenizer(SentencePieceOptions options)
+ {
+ _model = options.ModelType switch
+ {
+ SentencePieceModelType.Bpe => new SentencePieceBpeModel(options),
+ SentencePieceModelType.Unigram => new SentencePieceUnigramModel(options),
+ _ => throw new ArgumentException($"The model type '{options.ModelType}' is not supported.", nameof(options.ModelType))
+ };
+ }
+
///
/// The special tokens.
///
@@ -457,5 +467,19 @@ public static SentencePieceTokenizer Create(
return new SentencePieceTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens);
}
+
+ ///
+ /// Creates an instance of SentencePieceTokenizer.
+ ///
+ /// The options to use for the sentence piece tokenizer.
+ public static SentencePieceTokenizer Create(SentencePieceOptions options)
+ {
+ if (options is null)
+ {
+ throw new ArgumentNullException(nameof(options));
+ }
+
+ return new SentencePieceTokenizer(options);
+ }
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs
index a578362f73..dca346caea 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs
@@ -66,6 +66,64 @@ public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos
_vocabReverse[EndOfSentenceId] = (EndOfSentenceToken, 0f, 0);
}
+ public SentencePieceUnigramModel(SentencePieceOptions options) : base(options)
+ {
+ _vocab = new SortedDictionary(OrdinalUtf8StringComparer.Instance);
+ // _vocabReverse = new (string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)[];
+
+ // 250_000 using big number to avoid reallocation during the initialization.
+ List<(string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)> vocabReverse = new(250_000);
+
+ _minScore = float.MaxValue;
+ _maxScore = float.MinValue;
+
+ int id = 0;
+ foreach ((string Token, float Score) item in options.Vocabulary!)
+ {
+ _vocab.Add(item.Token, id++);
+ vocabReverse.Add((item.Token, item.Score, ModelProto.Types.SentencePiece.Types.Type.Normal));
+ _minScore = Math.Min(_minScore, item.Score);
+ _maxScore = Math.Max(_maxScore, item.Score);
+ }
+
+ _vocabReverse = vocabReverse.ToArray();
+
+ if (options.ByteFallback)
+ {
+ if (!_vocab.TryGetValue("<0x00>", out id))
+ {
+ throw new ArgumentException("'ByteFallback' is enabled but the vocabulary must include a special token for each byte value (0-255) in the format <0xNN>, where NN represents the byte's hexadecimal value.");
+ }
+
+ ByteCodeToIdOffset = id;
+ OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
+ MaxIdByteFallbackId = ByteCodeToIdOffset + 0xFF; // from <0x00> to <0xFF>.
+ }
+
+ _trie = new DoubleArrayTrie(_vocab);
+
+ _vocabReverse[BeginningOfSentenceId] = (BeginningOfSentenceToken, 0f, 0);
+ _vocabReverse[EndOfSentenceId] = (EndOfSentenceToken, 0f, 0);
+
+ if (!_vocab.TryGetValue(options.UnknownToken, out int unknownToken))
+ {
+ throw new ArgumentException($"The vocabulary must include the unknown token '{options.UnknownToken}'.");
+ }
+ UnknownId = unknownToken;
+
+ if (!_vocab.TryGetValue(options.BeginningOfSentenceToken, out int beginOfSentenceToken))
+ {
+ throw new ArgumentException($"The vocabulary must include the beginning of sentence token '{options.BeginningOfSentenceToken}'.");
+ }
+ BeginningOfSentenceId = beginOfSentenceToken;
+
+ if (!_vocab.TryGetValue(options.EndOfSentenceToken, out int endOfSentenceToken))
+ {
+ throw new ArgumentException($"The vocabulary must include the end of sentence token '{options.EndOfSentenceToken}'.");
+ }
+ EndOfSentenceId = endOfSentenceToken;
+ }
+
public override IReadOnlyDictionary Vocabulary => new ReadOnlyDictionary(_vocab);
public int MaxIdByteFallbackId { get; }
@@ -114,39 +172,39 @@ public override bool TryMapIdToToken(int id, out string? token)
return true;
}
- private void StoreNormalizedTextFromEnd(ReadOnlySpan text, ref char[] normalizedString, ref int normalizedStringIndexFromEnd)
+ private void StoreNormalizedTextFromEnd(ReadOnlySpan text, ref char[] normalizedString, ref int normalizedStringCountFromEnd)
{
- int remainingLength = normalizedString.Length - normalizedStringIndexFromEnd;
+ int remainingLength = normalizedString.Length - normalizedStringCountFromEnd;
if (text.Length > remainingLength)
{
char[] utf16NormalizedString = ArrayPool.Shared.Rent(normalizedString.Length << 1);
- normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringIndexFromEnd));
+ normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringCountFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringCountFromEnd));
ArrayPool.Shared.Return(normalizedString);
normalizedString = utf16NormalizedString;
}
- text.CopyTo(normalizedString.AsSpan(normalizedString.Length - normalizedStringIndexFromEnd - text.Length));
- normalizedStringIndexFromEnd += text.Length;
+ text.CopyTo(normalizedString.AsSpan(normalizedString.Length - normalizedStringCountFromEnd - text.Length));
+ normalizedStringCountFromEnd += text.Length;
}
- private void StoreNormalizedTextFromEnd(ReadOnlySpan utf8Bytes, ref char[] normalizedString, ref int normalizedStringIndexFromEnd)
+ private void StoreNormalizedTextFromEnd(ReadOnlySpan utf8Bytes, ref char[] normalizedString, ref int normalizedStringCountFromEnd)
{
- int remainingLength = normalizedString.Length - normalizedStringIndexFromEnd;
+ int remainingLength = normalizedString.Length - normalizedStringCountFromEnd;
int expectedCount = Helpers.GetUtf16LengthFromUtf8Bytes(utf8Bytes);
if (expectedCount > remainingLength)
{
char[] utf16NormalizedString = ArrayPool.Shared.Rent(normalizedString.Length << 1);
- normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringIndexFromEnd));
+ normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringCountFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringCountFromEnd));
ArrayPool.Shared.Return(normalizedString);
normalizedString = utf16NormalizedString;
}
- bool res = Helpers.ConvertUtf8ToUtf16(utf8Bytes, normalizedString.AsSpan(normalizedString.Length - normalizedStringIndexFromEnd - expectedCount), out int bytesConsumed, out int charsWritten);
+ bool res = Helpers.ConvertUtf8ToUtf16(utf8Bytes, normalizedString.AsSpan(normalizedString.Length - normalizedStringCountFromEnd - expectedCount), out int bytesConsumed, out int charsWritten);
Debug.Assert(res);
Debug.Assert(bytesConsumed == utf8Bytes.Length);
Debug.Assert(charsWritten == expectedCount);
- normalizedStringIndexFromEnd += expectedCount;
+ normalizedStringCountFromEnd += expectedCount;
}
private void StoreNormalizedText(ReadOnlySpan text, ref char[] normalizedString, ref int normalizedStringIndex)
@@ -1251,14 +1309,15 @@ private void GetIndexByTokenCountFromEndWithSpecialTokens(
Debug.Assert(maxTokenCount > 0);
charConsumedFromEnd = 0;
- int normalizedStringIndexFromEnd = 0;
+ int normalizedStringCountFromEnd = 0;
(int Offset, int Length)[] splits = PreTokenizer.SplitText(text, SpecialTokensRegex!).ToArray();
if (splits.Length == 0)
{
- GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount);
- normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - charConsumedFromEnd).ToString() : null;
+ GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringCountFromEnd, ref charConsumedFromEnd, maxTokenCount);
+ normalizedText = normalizedString is not null ? normalizedString.AsSpan(normalizedString.Length - normalizedStringCountFromEnd).ToString() : null;
+ return;
}
(int Offset, int Length) current = splits[splits.Length - 1];
@@ -1266,7 +1325,7 @@ private void GetIndexByTokenCountFromEndWithSpecialTokens(
// Last part is not a special token
if (current.Offset + current.Length < text.Length)
{
- GetIndexByTokenCountFromEndInternal(text.Slice(current.Offset + current.Length), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount);
+ GetIndexByTokenCountFromEndInternal(text.Slice(current.Offset + current.Length), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringCountFromEnd, ref charConsumedFromEnd, maxTokenCount);
}
for (int i = splits.Length - 1; i >= 0; i--)
@@ -1285,17 +1344,17 @@ private void GetIndexByTokenCountFromEndWithSpecialTokens(
if (normalizedString is not null)
{
- StoreNormalizedTextFromEnd(text.Slice(current.Offset, current.Length), ref normalizedString, ref normalizedStringIndexFromEnd);
+ StoreNormalizedTextFromEnd(text.Slice(current.Offset, current.Length), ref normalizedString, ref normalizedStringCountFromEnd);
}
if (current.Offset > 0)
{
int start = i > 0 ? splits[i - 1].Offset + splits[i - 1].Length : 0;
- GetIndexByTokenCountFromEndInternal(text.Slice(start, current.Offset - start), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount);
+ GetIndexByTokenCountFromEndInternal(text.Slice(start, current.Offset - start), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringCountFromEnd, ref charConsumedFromEnd, maxTokenCount);
}
}
- normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).ToString() : null;
+ normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringCountFromEnd).ToString() : null;
}
private void GetIndexByTokenCountFromEndWithoutSpecialTokens(
@@ -1309,11 +1368,11 @@ private void GetIndexByTokenCountFromEndWithoutSpecialTokens(
int maxTokenCount)
{
charConsumedFromEnd = 0;
- int normalizedStringIndexFromEnd = 0;
+ int normalizedStringCountFromEnd = 0;
- GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount);
+ GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringCountFromEnd, ref charConsumedFromEnd, maxTokenCount);
- normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).ToString() : null;
+ normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringCountFromEnd).ToString() : null;
}
private void GetIndexByTokenCountFromEndInternal(
@@ -1322,7 +1381,7 @@ private void GetIndexByTokenCountFromEndInternal(
ref int tokenCount,
int[] buffer,
ref char[]? normalizedString,
- ref int normalizedStringIndexFromEnd,
+ ref int normalizedStringCountFromEnd,
ref int charConsumedFromEnd,
int maxTokenCount)
{
@@ -1381,11 +1440,11 @@ private void GetIndexByTokenCountFromEndInternal(
{
if (considerNormalization)
{
- StoreNormalizedTextFromEnd(normalizationSpan, ref normalizedString, ref normalizedStringIndexFromEnd);
+ StoreNormalizedTextFromEnd(normalizationSpan, ref normalizedString, ref normalizedStringCountFromEnd);
}
else
{
- StoreNormalizedTextFromEnd(text, ref normalizedString, ref normalizedStringIndexFromEnd);
+ StoreNormalizedTextFromEnd(text, ref normalizedString, ref normalizedStringCountFromEnd);
}
}
diff --git a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs
index b90ab7a414..2e948a36ea 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs
@@ -8,6 +8,8 @@
using System.IO;
using System.Linq;
using System.Reflection.Metadata;
+using System.Text;
+using System.Text.Json;
using Microsoft.ML.Tokenizers;
using Xunit;
@@ -17,6 +19,7 @@ public class UnigramTests
{
private static SentencePieceTokenizer _unigramTokenizer = CreateUnigramTokenizer();
private static SentencePieceTokenizer _unigramTokenizerWithSpecialTokens = CreateUnigramTokenizerWithSpecialTokens();
+ private static SentencePieceTokenizer _unigramTokenizerFromJson = CreateUnigramTokenizerFromJson();
private static SentencePieceTokenizer CreateUnigramTokenizer()
{
@@ -25,6 +28,72 @@ private static SentencePieceTokenizer CreateUnigramTokenizer()
return SentencePieceTokenizer.Create(remoteStream);
}
+ private static SentencePieceTokenizer CreateUnigramTokenizerFromJson()
+ {
+ // @"https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/resolve/main/tokenizer.json?download=true";
+ using Stream jsonModelStream = File.OpenRead(Path.Combine(@"Paraphrase-multilingual-MiniLM-L12-v2", "tokenizer.json"));
+ using var reader = new StreamReader(jsonModelStream, Encoding.UTF8);
+ string json = reader.ReadToEnd();
+ using JsonDocument doc = JsonDocument.Parse(json);
+ JsonElement root = doc.RootElement;
+
+ SentencePieceOptions options = new SentencePieceOptions();
+ options.ModelType = SentencePieceModelType.Unigram;
+ options.EscapeWhiteSpaces = true;
+ options.AddDummyPrefix = true;
+
+ options.BeginningOfSentenceToken = "";
+ options.EndOfSentenceToken = "";
+ options.UnknownToken = "";
+
+ options.SpecialTokens = new Dictionary
+ {
+ { "", 0 },
+ { "", 1 },
+ { "", 2 },
+ { "", 3 },
+ { "", 250001 }
+ };
+
+ if (root.TryGetProperty("normalizer", out JsonElement normalizerElement) && normalizerElement.GetProperty("type").GetString() == "Precompiled")
+ {
+ string? precompiledCharsMap = normalizerElement.GetProperty("precompiled_charsmap").GetString();
+ if (precompiledCharsMap is not null)
+ {
+ byte[] bytes = Convert.FromBase64String(precompiledCharsMap);
+ options.PrecompiledNormalizationData = bytes;
+ }
+ }
+
+ options.Vocabulary = GetVocabulary(root);
+ return SentencePieceTokenizer.Create(options);
+ }
+
+ private static IEnumerable<(string Token, float Score)> GetVocabulary(JsonElement root)
+ {
+ if (root.TryGetProperty("model", out JsonElement modelElement) &&
+ modelElement.TryGetProperty("vocab", out JsonElement vocabElement) &&
+ vocabElement.ValueKind == JsonValueKind.Array)
+ {
+ foreach (JsonElement token in vocabElement.EnumerateArray())
+ {
+ if (token.ValueKind == JsonValueKind.Array && token.GetArrayLength() == 2)
+ {
+ string? tokenString = token[0].GetString();
+ if (tokenString is null)
+ {
+ throw new InvalidOperationException("Invalid model vocabulary format");
+ }
+ yield return (tokenString, token[1].GetSingle());
+ }
+ }
+ }
+ else
+ {
+ throw new InvalidOperationException("Invalid model vocabulary format");
+ }
+ }
+
private static SentencePieceTokenizer CreateUnigramTokenizerWithSpecialTokens()
{
// @"https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/resolve/main/sentencepiece.bpe.model?download=true";
@@ -163,7 +232,7 @@ public static IEnumerable