Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
<MicrosoftMLTensorFlowTestModelsVersion>0.0.13-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLTestDatabasesVersion>0.0.6-test</MicrosoftMLTestDatabasesVersion>
<MicrosoftMLTestModelsVersion>0.0.7-test</MicrosoftMLTestModelsVersion>
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.25110.1</MicrosoftMLTestTokenizersVersion>
<MicrosoftMLTestTokenizersVersion>2.0.0-beta.25126.1</MicrosoftMLTestTokenizersVersion>
<SystemDataSqlClientVersion>4.9.0</SystemDataSqlClientVersion>
<SystemDataSQLiteCoreVersion>1.0.118</SystemDataSQLiteCoreVersion>
<XunitCombinatorialVersion>1.6.24</XunitCombinatorialVersion>
Expand Down
67 changes: 64 additions & 3 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,67 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool
specialTokens);
}

internal SentencePieceBaseModel(SentencePieceOptions options)
{
if (options is null)
{
throw new ArgumentNullException(nameof(options));
}

if (options.Vocabulary is null)
{
throw new ArgumentNullException(nameof(options.Vocabulary));
}

if (options.BeginningOfSentenceToken is null)
{
throw new ArgumentNullException(nameof(options.BeginningOfSentenceToken));
}

if (options.EndOfSentenceToken is null)
{
throw new ArgumentNullException(nameof(options.EndOfSentenceToken));
}

if (options.UnknownToken is null)
{
throw new ArgumentNullException(nameof(options.UnknownToken));
}

AddBeginningOfSentence = options.AddBeginningOfSentence;
AddEndOfSentence = options.AddEndOfSentence;
BeginningOfSentenceToken = options.BeginningOfSentenceToken;
EndOfSentenceToken = options.EndOfSentenceToken;
UnknownToken = options.UnknownToken;
AddDummyPrefix = options.AddDummyPrefix;
EscapeWhiteSpaces = options.EscapeWhiteSpaces;
TreatWhitespaceAsSuffix = options.TreatWhitespaceAsSuffix;
ByteFallback = options.ByteFallback;
SpecialTokens = options.SpecialTokens;

if (SpecialTokens is not null && SpecialTokens.Count > 0)
{
InternalSpecialTokens = new Dictionary<StringSpanOrdinalKey, int>();
SpecialTokensReverse = new Dictionary<int, string>();

foreach (var item in SpecialTokens)
{
InternalSpecialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value);
SpecialTokensReverse.Add(item.Value, item.Key);
}

// We create this Regex object without a timeout, as we expect the match operation to complete in O(N) time complexity. Note that `specialTokens` are treated as constants after the tokenizer is created.
SpecialTokensRegex = new Regex(string.Join("|", SpecialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
}

Normalizer = new SentencePieceNormalizer(
options.PrecompiledNormalizationData,
options.RemoveExtraWhiteSpaces,
options.AddDummyPrefix, options.EscapeWhiteSpaces,
options.TreatWhitespaceAsSuffix,
SpecialTokens);
}

internal Regex? SpecialTokensRegex { get; }

internal Dictionary<StringSpanOrdinalKey, int>? InternalSpecialTokens { get; }
Expand Down Expand Up @@ -91,11 +152,11 @@ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool

public string UnknownToken { get; }

public int BeginningOfSentenceId { get; }
public int BeginningOfSentenceId { get; set; }

public int EndOfSentenceId { get; }
public int EndOfSentenceId { get; set; }

public int UnknownId { get; }
public int UnknownId { get; set; }

public SentencePieceNormalizer? Normalizer { get; }

Expand Down
46 changes: 46 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,52 @@ internal SentencePieceBpeModel(ModelProto modelProto, bool addBos, bool addEos,
OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
}

internal SentencePieceBpeModel(SentencePieceOptions options) : base(options)
{
if (options.PrecompiledNormalizationData is not null)
{
throw new NotSupportedException("Normalization data is not supported for SentencePieceBpeModel.");
}

Debug.Assert(options.Vocabulary is not null);

int id = 0;
foreach (var item in options.Vocabulary!)
{
_vocab.Add(new StringSpanOrdinalKey(item.Token), (id, item.Score, (byte)ModelProto.Types.SentencePiece.Types.Type.Normal));
_vocabReverse.Add(id++, item.Token);
}

if (options.ByteFallback)
{
if (!_vocab.TryGetValue("<0x00>", out (int Id, float Score, byte Type) value))
{
throw new ArgumentException("'ByteFallback' is enabled but the vocabulary must include a special token for each byte value (0-255) in the format <0xNN>, where NN represents the byte's hexadecimal value.");
}

ByteCodeToIdOffset = value.Id;
OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
}

if (!_vocab.TryGetValue(options.UnknownToken, out (int Id, float Score, byte Type) unknownToken))
{
throw new ArgumentException($"The vocabulary must include the unknown token '{options.UnknownToken}'.");
}
UnknownId = unknownToken.Id;

if (!_vocab.TryGetValue(options.BeginningOfSentenceToken, out (int Id, float Score, byte Type) beginOfSentenceToken))
{
throw new ArgumentException($"The vocabulary must include the beginning of sentence token '{options.BeginningOfSentenceToken}'.");
}
BeginningOfSentenceId = beginOfSentenceToken.Id;

if (!_vocab.TryGetValue(options.EndOfSentenceToken, out (int Id, float Score, byte Type) endOfSentenceToken))
{
throw new ArgumentException($"The vocabulary must include the end of sentence token '{options.EndOfSentenceToken}'.");
}
EndOfSentenceId = endOfSentenceToken.Id;
}

public override IReadOnlyDictionary<string, int> Vocabulary
{
get
Expand Down
118 changes: 118 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;

namespace Microsoft.ML.Tokenizers
{
#pragma warning disable MSML_NoInstanceInitializers
/// <summary>
/// The type of the SentencePiece model.
/// </summary>
public enum SentencePieceModelType
{
/// <summary>
/// The model type is not defined.
/// </summary>
Undefined = 0,

/// <summary>
/// The model type is Byte Pair Encoding (Bpe) model.
/// </summary>
Bpe = 1,

/// <summary>
/// The model type is Unigram model.
/// </summary>
Unigram = 2,
}

/// <summary>
/// Options for the SentencePiece tokenizer.
/// </summary>
/// <remarks>
/// The options are used to configure the SentencePiece tokenizer. Serialization is not guaranteed for this type.
/// </remarks>
public sealed class SentencePieceOptions
{
/// <summary>
/// The type of the SentencePiece model.
/// </summary>
public SentencePieceModelType ModelType { get; set; }

/// <summary>
/// Determines whether the model uses a byte fallback strategy to encode unknown tokens as byte sequences.
/// </summary>
/// <remarks>
/// The vocabulary must include a special token for each byte value (0-255) in the format &lt;0xNN&gt;,
/// where NN represents the byte's hexadecimal value (e.g., &lt;0x41&gt; for byte value 65).
/// </remarks>
public bool ByteFallback { get; set; }

/// <summary>
/// Indicate emitting the prefix character e.g. U+2581 at the beginning of sentence token during the normalization and encoding.
/// </summary>
public bool AddDummyPrefix { get; set; }

/// <summary>
/// Indicate if the spaces should be replaced with character U+2581 during the normalization and encoding. Default value is `true`.
/// </summary>
public bool EscapeWhiteSpaces { get; set; } = true;

/// <summary>
/// Indicate emitting the character U+2581 at the end of the last sentence token instead beginning of sentence token during the normalization and encoding.
/// </summary>
public bool TreatWhitespaceAsSuffix { get; set; }

/// <summary>
/// Indicate removing extra white spaces from the original string during the normalization.
/// </summary>
public bool RemoveExtraWhiteSpaces { get; set; }

/// <summary>
/// Indicate emitting the beginning of sentence token during the encoding. Default value is `true`.
/// </summary>
public bool AddBeginningOfSentence { get; set; } = true;

/// <summary>
/// Indicate emitting the end of sentence token during the encoding.
/// </summary>
public bool AddEndOfSentence { get; set; }

/// <summary>
/// The beginning of sentence token. Default value is `&lt;s&gt;`.
/// </summary>
public string BeginningOfSentenceToken { get; set; } = "<s>";

/// <summary>
/// The end of sentence token. Default value is `&lt;/s&gt;`.
/// </summary>
public string EndOfSentenceToken { get; set; } = "</s>";

/// <summary>
/// The unknown token. Default value is `&lt;unk&gt;`.
/// </summary>
public string UnknownToken { get; set; } = "<unk>";

/// <summary>
/// The data used for string normalization.
/// </summary>
public byte[]? PrecompiledNormalizationData { get; set; }

/// <summary>
/// Represent the vocabulary.
/// The list should be sorted by token ID, with entries passed in the order that corresponds to their IDs. In other words,
/// the first entry in the list will be mapped to ID 0, the second entry to ID 1, the third to ID 2, and so on.
/// Each entry represents a token and its corresponding score.
/// </summary>
public IEnumerable<(string Token, float Score)>? Vocabulary { get; set; }

/// <summary>
/// The special tokens.
/// Special tokens remain intact during encoding and are not split into sub-tokens.
/// </summary>
public IReadOnlyDictionary<string, int>? SpecialTokens { get; set; }
}
#pragma warning restore MSML_NoInstanceInitializers
}
24 changes: 24 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos,
};
}

internal SentencePieceTokenizer(SentencePieceOptions options)
{
_model = options.ModelType switch
{
SentencePieceModelType.Bpe => new SentencePieceBpeModel(options),
SentencePieceModelType.Unigram => new SentencePieceUnigramModel(options),
_ => throw new ArgumentException($"The model type '{options.ModelType}' is not supported.", nameof(options.ModelType))
};
}

/// <summary>
/// The special tokens.
/// </summary>
Expand Down Expand Up @@ -457,5 +467,19 @@ public static SentencePieceTokenizer Create(

return new SentencePieceTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens);
}

/// <summary>
/// Creates an instance of SentencePieceTokenizer.
/// </summary>
/// <param name="options">The options to use for the sentence piece tokenizer.</param>
public static SentencePieceTokenizer Create(SentencePieceOptions options)
{
if (options is null)
{
throw new ArgumentNullException(nameof(options));
}

return new SentencePieceTokenizer(options);
}
}
}
Loading