Skip to content

WordPiece and Bert Tokenizer Design review #7281

@tarekgh

Description

@tarekgh

Proposal

The proposal omitted the overridden properties and method that is defined in the abstraction we already reviewed before.

WordPiece Tokenizer

namespace Microsoft.ML.Tokenizers
{
    public partial class WordPieceTokenizer : Tokenizer
    {
        public static WordPieceTokenizer Create(
                        string vocabFilePath,
                        PreTokenizer? preTokenizer = null,
                        Normalizer? normalizer = null,
                        IReadOnlyDictionary<string, int>? specialTokens = null,
                        string unknownToken = "[UNK]",
                        string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
                        int maxInputCharsPerWord = DefaultMaxInputCharsPerWord)

        public static WordPieceTokenizer Create(
                        Stream vocabStream,
                        PreTokenizer? preTokenizer = null,
                        Normalizer? normalizer = null,
                        IReadOnlyDictionary<string, int>? specialTokens = null,
                        string unknownToken = "[UNK]",
                        string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
                        int maxInputCharsPerWord = DefaultMaxInputCharsPerWord)

        public static async Task<WordPieceTokenizer> CreateAsync(
                        Stream vocabStream,
                        PreTokenizer? preTokenizer = null,
                        Normalizer? normalizer = null,
                        IReadOnlyDictionary<string, int>? specialTokens = null,
                        string unknownToken = "[UNK]",
                        string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
                        int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
                        CancellationToken cancellationToken = default)

        /// <summary>
        /// Gets the unknown token.
        /// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
        /// </summary>
        public string UnknownToken { get; }

        /// <summary>
        /// Gets the unknown token ID.
        /// A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead.
        /// </summary>
        public int UnknownTokenId { get; }

        /// <summary>
        /// Gets the prefix to use for sub-words that are not the first part of a word.
        /// </summary>
        public string ContinuingSubwordPrefix { get; }

        /// <summary>
        /// Gets the maximum number of characters to authorize in a single word.
        /// </summary>
        public int MaxInputCharsPerWord { get; }

        /// <summary>
        /// Gets the special tokens and their corresponding ids.
        /// </summary>
        public IReadOnlyDictionary<string, int>? SpecialTokens { get; }

        /// <summary>
        /// Decode the given ids, back to a String.
        /// </summary>
        /// <param name="ids">The list of ids that we want to decode.</param>
        /// <param name="skipSpecialTokens">Indicate whether to skip the special tokens during the decoding.</param>
        /// <returns>The decoded string.</returns>
        public string Decode(IEnumerable<int> ids, bool skipSpecialTokens)

        public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool skipSpecialTokens, 
                   out int idsConsumed, out int charsWritten)
    }
}

Bert Tokenizer

namespace Microsoft.ML.Tokenizers
{
    public sealed partial class BertTokenizer : WordPieceTokenizer
    {
        public static BertTokenizer Create(
                    string vocabFilePath,
                    bool doLowerCase = true,
                    bool doBasicTokenization = true,
                    bool splitOnSpecialTokens = true,
                    string unknownToken = "[UNK]",
                    string sepToken = "[SEP]",
                    string padToken = "[PAD]",
                    string clsToken = "[CLS]",
                    string maskToken = "[MASK]",
                    bool tokenizeChineseChars = true,
                    bool stripAccents = false)    

        public static BertTokenizer Create(
                    Stream vocabStream,
                    bool doLowerCase = true,
                    bool doBasicTokenization = true,
                    bool splitOnSpecialTokens = true,
                    string unknownToken = "[UNK]",
                    string sepToken = "[SEP]",
                    string padToken = "[PAD]",
                    string clsToken = "[CLS]",
                    string maskToken = "[MASK]",
                    bool tokenizeChineseChars = true,
                    bool stripAccents = false)

        public static async Task<BertTokenizer> CreateAsync(
                    Stream vocabStream,
                    bool doLowerCase = true,
                    bool doBasicTokenization = true,
                    bool splitOnSpecialTokens = true,
                    string unknownToken = "[UNK]",
                    string sepToken = "[SEP]",
                    string padToken = "[PAD]",
                    string clsToken = "[CLS]",
                    string maskToken = "[MASK]",
                    bool tokenizeChineseChars = true,
                    bool stripAccents = false)

        /// <summary>
        /// Gets a value indicating whether the tokenizer should lowercase the input text.
        /// </summary>
        public bool DoLowerCase { get; }

        /// <summary>
        /// Gets a value indicating whether the tokenizer should do basic tokenization. Like clean text, normalize it, lowercasing, etc.
        /// </summary>
        public bool DoBasicTokenization { get; }

        /// <summary>
        /// Gets a value indicating whether the tokenizer should split on the special tokens or treat special tokens as normal text.
        /// </summary>
        public bool SplitOnSpecialTokens { get; }

        /// <summary>
        /// Gets the separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification 
        /// or for a text and a question for question answering.
        /// It is also used as the last token of a sequence built with special tokens.
        /// </summary>
        public string SepToken { get; }

        /// <summary>
        /// Gets the separator token Id
        /// </summary>
        public int SepTokenId { get; }

        /// <summary>
        /// Gets the token used for padding, for example when batching sequences of different lengths
        /// </summary>
        public string PadToken { get; }

        /// <summary>
        /// Gets padding token Id
        /// </summary>
        public int PadTokenId { get; }

        /// <summary>
        /// Gets the classifier token which is used when doing sequence classification (classification of the whole sequence 
        /// instead of per-token classification).
        /// It is the first token of the sequence when built with special tokens.
        /// </summary>
        public string ClsToken { get; }

        /// <summary>
        /// Gets the classifier token Id
        /// </summary>
        public int ClsTokenId { get; }

        /// <summary>
        /// Gets the mask token used for masking values. This is the token used when training this model with masked language modeling.
        /// This is the token which the model will try to predict.
        /// </summary>
        public string MaskToken { get; }

        /// <summary>
        /// Gets the mask token Id
        /// </summary>
        public int MaskTokenId { get; }

        /// <summary>
        /// Gets a value indicating whether the tokenizer should split the Chinese characters into tokens.
        /// </summary>
        public bool TokenizeChineseChars { get; }

        /// <summary>
        /// Gets a value indicating whether the tokenizer should strip accents characters.
        /// </summary>
        public bool StripAccents { get; }

        public IReadOnlyList<int> EncodeToIds(string text, bool addSpecialTokens, 
                      bool considerPreTokenization = true, bool considerNormalization = true)

        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, bool addSpecialTokens, 
                      bool considerPreTokenization = true, bool considerNormalization = true)

        public IReadOnlyList<int> EncodeToIds(string text, int maxTokenCount, bool addSpecialTokens, out string? normalizedText, 
                       out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)

        public IReadOnlyList<int> EncodeToIds(ReadOnlySpan<char> text, int maxTokenCount, bool addSpecialTokens, 
                        out string? normalizedText,  out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)

        /// <summary>
        /// Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and 
        /// adding special tokens. A BERT sequence has the following format:
        ///     - single sequence: `[CLS] tokenIds0 [SEP]`
        ///     - pair of sequences: `[CLS] tokenIds0 [SEP] tokenIds1 [SEP]`
        /// </summary>
        /// <param name="tokenIds0">List of IDs to which the special tokens will be added.</param>
        /// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
        /// <returns>The list of IDs with special tokens added.</returns>
        /// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
        public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null)

        public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, 
                             IEnumerable<int>? tokenIds1 = null)

        /// <summary>
        /// Retrieve sequence tokens mask from a IDs list.
        /// </summary>
        /// <param name="tokenIds0">List of IDs.</param>
        /// <param name="tokenIds1">Optional second list of IDs for sequence pairs.</param>
        /// <param name="alreadyHasSpecialTokens">Indicate whether or not the token list is already formatted with special tokens 
        /// for the model.</param>
        /// <returns>A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.</returns>
        /// <exception cref="ArgumentNullException"></exception>
        public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null, 
                    bool alreadyHasSpecialTokens = false)

        public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, 
                     IEnumerable<int>? tokenIds1 = null, bool alreadyHasSpecialTokens = false)

        /// <summary>
        /// Create a mask from the two sequences passed to be used in a sequence-pair classification task. 
        /// A BERT sequence pair mask has the following format:
        ///         0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        ///         | first sequence    | second sequence |
        /// If <paramref name="tokenIds1"/> is null, this method only returns the first portion of the type ids (0s).
        /// </summary>
        /// <param name="tokenIds0">List of token IDs for the first sequence.</param>
        /// <param name="tokenIds1">Optional list of token IDs for the second sequence.</param>
        /// <returns>List of token type IDs according to the given sequence(s).</returns>
        /// <exception cref="ArgumentNullException">When <paramref name="tokenIds0"/> is null.</exception>
        public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds0, IEnumerable<int>? tokenIds1 = null)

        public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds0, Span<int> buffer, out int written, 
                         IEnumerable<int>? tokenIds1 = null)
    }
}

PreTokenizer Factory methods

namespace Microsoft.ML.Tokenizers
{
    public abstract partial class PreTokenizer
    {
        // @"\w+|[\p{P}]"
        public static PreTokenizer CreateWhiteSpaceOrPunctuationPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)

        // @"\w+|[^\w\s]+"
        public static PreTokenizer CreateWordOrNonWordPreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)

        // @"\S+"
        public static PreTokenizer CreateWhiteSpacePreTokenizer(IReadOnlyDictionary<string, int>? specialTokensEncoder = null)

    }
}

Metadata

Metadata

Assignees

Labels

Tokenizersapi-approvedAPI was approved in API review, it can be implementedblockingMarks issues that we want to fast track in order to unblock other important work

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions