diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index d799d45a39..20cfe7f38b 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -3,8 +3,10 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; @@ -34,20 +36,21 @@ private set { _unknownToken = value; - if (value is null) + if (VocabReverse.TryGetValue(0, out string? v)) { - if (VocabReverse.TryGetValue(0, out string? v)) + if (v == value) { - VocabReverse.Remove(0); - if (_vocab.TryGetValue(v, out int id)) - { - _vocab.Remove(v); - } + return; } + + VocabReverse.Remove(0); + _vocab.Remove(new StringSpanOrdinalKey(v)); } - else + + + if (value is not null) { - _vocab[value] = 0; + _vocab[new StringSpanOrdinalKey(value)] = 0; VocabReverse[0] = value; } } @@ -68,7 +71,6 @@ private set /// public bool FuseUnknownTokens { get; } - /// /// Construct a new Bpe model object to use for text encoding. /// @@ -111,23 +113,19 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri ContinuingSubwordPrefix = continuingSubwordPrefix; EndOfWordSuffix = endOfWordSuffix; - (Dictionary? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream); - _vocab = vocab1 ?? new Dictionary(); - Cache = new Cache(); + (Dictionary? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream); + _vocab = vocab1 ?? new Dictionary(); + Cache = new StringSpanOrdinalKeyCache(); VocabReverse = new(); - foreach (KeyValuePair kvp in Vocab) + foreach (KeyValuePair kvp in _vocab) { - VocabReverse.Add(kvp.Value, kvp.Key); + VocabReverse.Add(kvp.Value, kvp.Key.Data!); } - if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken)) - { - unknownToken = unkToken; - } - UnknownToken = unknownToken; + UnknownToken = unknownToken ?? (VocabReverse.TryGetValue(0, out string? unkToken) ? unkToken : null); int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length; @@ -197,7 +195,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The text to split. /// Indicate if the token is a special token. /// The list of accumulated encoded Ids. - public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds); + public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds); /// /// Get the number of tokens that the input text will be encoded to. @@ -205,7 +203,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The text to encode. /// Indicate if the token is special token. /// The number of tokens that the input text will be encoded to. - public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null); + public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIdsWithCache(text, null); /// /// Map the token to encoded Id. @@ -213,15 +211,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The token to map to the Id. /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public override int? MapTokenToId(string token, bool considerSpecialTokens = true) - { - if (_vocab.TryGetValue(token, out int value)) - { - return value; - } - - return null; - } + public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null; /// /// Map the encoded Id to the token. @@ -242,24 +232,27 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// /// Gets the dictionary mapping tokens to Ids. /// - public IReadOnlyDictionary Vocab => _vocab; + public IReadOnlyDictionary Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); /// Read the given files to extract the vocab and merges - internal static (Dictionary?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges) + internal static (Dictionary?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges) { - Dictionary? dic = JsonSerializer.Deserialize>(vocab) as Dictionary; + JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } }; + Dictionary? dic = JsonSerializer.Deserialize>(vocab, options) as Dictionary; return (dic, ConvertMergesToHashmap(merges)); } /// The vocabulary assigns a number to each token. - private readonly Dictionary _vocab; + private readonly Dictionary _vocab; + + private Dictionary? _vocabOriginal; /// Contains the mapping between Pairs and their (rank, newId). internal Dictionary, (int, int)> Merges { get; } /// Contains the cache for optimizing the encoding step. - internal Cache? Cache { get; } + internal StringSpanOrdinalKeyCache? Cache { get; } internal static readonly int DefaultCacheCapacity = 10_000; @@ -309,9 +302,6 @@ internal static (Dictionary?, Vec<(string, string)>) ReadModelData( return merges; } - /// Reset the cache. - internal void ClearCache() => Cache?.Clear(); - private readonly Dictionary _charToString = new Dictionary(); [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -327,38 +317,68 @@ internal string CharToString(char c) return s; } - internal Word MergeWord(string w) + internal Word MergeWord(ReadOnlySpan w) { Word word = Word.WithCapacity(w.Length); (int Id, int Len)? unk = null; int i = 0; + Span buffer = stackalloc char[256]; + scoped ReadOnlySpan s; + while (i < w.Length) { int length; - string s; if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1])) { length = 2; - s = w.Substring(i, length); + s = w.Slice(i, 2); } else { length = 1; - s = CharToString(w[i]); + s = w.Slice(i, 1); } // Add the `continuing_subword_prefix` if relevant if (i > 0 && ContinuingSubwordPrefix is not null) { - s = $"{ContinuingSubwordPrefix}{s}"; + if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length) + { + ContinuingSubwordPrefix.AsSpan().CopyTo(buffer); + s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length)); + s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length); + } + else + { +#if NETCOREAPP + s = $"{ContinuingSubwordPrefix}{s}".AsSpan(); +#else + string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); + s = $"{ContinuingSubwordPrefix}{s1}".AsSpan(); +#endif + } } // Add the `end_of_word_suffix` if relevant if (i + length >= w.Length && EndOfWordSuffix is not null) { - s = $"{s}{EndOfWordSuffix}"; + if (s.Length + EndOfWordSuffix.Length <= buffer.Length) + { + s.CopyTo(buffer); + EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length)); + s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length); + } + else + { +#if NETCOREAPP + s = $"{s}{EndOfWordSuffix}".AsSpan(); +#else + string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); + s = $"{s1}{EndOfWordSuffix}".AsSpan(); +#endif + } } if (_vocab.TryGetValue(s, out int id)) @@ -419,17 +439,17 @@ internal List EncodeWithCache(string text) Word word; if (Cache is not null) { - if (Cache.TryGet(text, out word)) + if (Cache.TryGetValue(text, out word)) { return WordToTokens(ref word); } - word = MergeWord(text); + word = MergeWord(text.AsSpan()); Cache.Set(text, word); } else { - word = MergeWord(text); + word = MergeWord(text.AsSpan()); } return WordToTokens(ref word); @@ -445,19 +465,19 @@ internal int WordToIds(ref Word word, IList? accumulatedIds) return word.SymbolsCount; } - internal int EncodeToIdsWithCache(string text, IList? accumulatedIds) + internal int EncodeToIdsWithCache(ReadOnlySpan text, IList? accumulatedIds) { Word word; if (Cache is not null) { - if (Cache.TryGet(text, out Word hit)) + if (Cache.TryGetValue(text, out Word hit)) { return WordToIds(ref hit, accumulatedIds); } word = MergeWord(text); - Cache.Set(text, word); + Cache.Set(text.ToString(), word); } else { diff --git a/src/Microsoft.ML.Tokenizers/Model/Cache.cs b/src/Microsoft.ML.Tokenizers/Model/Cache.cs index b10d211ea6..065676621e 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Cache.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Cache.cs @@ -4,112 +4,53 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; namespace Microsoft.ML.Tokenizers { internal sealed class Cache where TKey : notnull where TValue : notnull { + private readonly int _capacity; + private readonly Dictionary _map; + private object SyncObj => _map; + internal Cache() : this(Bpe.DefaultCacheCapacity) { } internal Cache(int capacity) { - Capacity = capacity; - Map = new Dictionary(Capacity); + _capacity = capacity; + _map = new Dictionary(capacity); } - private readonly object _lock = new(); - - internal Dictionary Map { get; set; } - - internal int Capacity { get; set; } - - internal void Fresh() => Map = new Dictionary(Capacity); - - internal void Clear() + internal bool TryGetValue(TKey key, out TValue value) { - lock (_lock) + lock (SyncObj) { - Map.Clear(); + return _map.TryGetValue(key, out value!); } } - internal List GetValues(IEnumerable keys) - { - List values = new(); - lock (_lock) - { - foreach (TKey key in keys) - { - if (Map.TryGetValue(key, out TValue? value)) - { - values.Add(value); - } - } - } - - return values; - } - - internal bool TryGet(TKey key, out TValue value) - { - lock (_lock) - { - return Map.TryGetValue(key, out value!); - } - } - - internal void SetValues(IEnumerable<(TKey, TValue)> entries) - { - lock (_lock) - { - foreach ((TKey, TValue) entry in entries) - { - if (Capacity <= Map.Count) - { - break; - } - Map[entry.Item1] = entry.Item2; - } - } - } - - internal void Set(TKey k, TValue v) + internal TValue GetOrAdd(TKey key, TValue value) { - lock (_lock) + lock (SyncObj) { - if (Capacity > Map.Count) + if (_map.TryGetValue(key, out TValue? v)) { - Map[k] = v; + return v!; } - } - } - internal KeyValuePair[] ToArray() - { - lock (_lock) - { - return Map.ToArray(); + _map[key] = value; + return value; } } - internal TValue GetOrAdd(TKey key, TValue value) + internal void Set(TKey key, TValue value) { - lock (_lock) + lock (SyncObj) { - if (Map.TryGetValue(key, out TValue? v)) + if (_map.Count < _capacity) { - return v; + _map[key] = value; } - - if (Capacity > Map.Count) - { - Map[key] = value; - } - - return value; } } } diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index ea9fa884a8..3155c778ec 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -18,13 +18,14 @@ namespace Microsoft.ML.Tokenizers public sealed class EnglishRoberta : Model { private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence; - private readonly IReadOnlyDictionary _vocab; - private readonly SortedDictionary _vocabReverse; + private readonly Dictionary _vocab; + private Dictionary? _vocabOriginal; + private readonly SortedDictionary _vocabReverse; private readonly Cache<(string, string), int> _mergeRanks; private readonly IReadOnlyDictionary _byteToUnicode; private readonly IReadOnlyDictionary _unicodeToByte; private readonly string[] _charToString; - private readonly Cache> _cache; + private readonly StringSpanOrdinalKeyCache> _cache; /// /// Indicate if want to filter the unsupported characters during the decoding. @@ -77,7 +78,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc } _unicodeToByte = _byteToUnicode.Reverse(); - _cache = new Cache>(); + _cache = new StringSpanOrdinalKeyCache>(); } /// @@ -118,13 +119,13 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes } _unicodeToByte = _byteToUnicode.Reverse(); - _cache = new Cache>(); + _cache = new StringSpanOrdinalKeyCache>(); } /// /// Gets the dictionary mapping tokens to Ids. /// - public IReadOnlyDictionary Vocab => _vocab; + public IReadOnlyDictionary Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); // // Public Model interfaces implementation @@ -145,14 +146,15 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes if (_vocabReverse.TryGetValue(id, out var value)) { + string v = value.Data!; if (FilterUnsupportedChars) { - char[] buffer = ArrayPool.Shared.Rent(value.Length); + char[] buffer = ArrayPool.Shared.Rent(v.Length); int i = 0; - for (int j = 0; j < value.Length; j++) + for (int j = 0; j < v.Length; j++) { - if (_unicodeToByte.TryGetValue(value[j], out var c)) + if (_unicodeToByte.TryGetValue(v[j], out var c)) { buffer[i++] = c; } @@ -164,7 +166,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes } else { - return value; + return v; } } @@ -205,7 +207,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f return Array.Empty(); } - if (_cache.TryGet(text, out List? hit)) + if (_cache.TryGetValue(text, out List? hit)) { ArrayPool.Shared.Return(token); ArrayPool.Shared.Return(indexMapping); @@ -225,7 +227,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The text to split. /// Indicate if the token is a special token. /// The list of accumulated encoded Ids. - public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds); + public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds); /// /// Get the number of tokens that the input text will be encoded to. @@ -233,16 +235,16 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The text to encode. /// Indicate if the token is special token. /// The number of tokens that the input text will be encoded to. - public override int CountTokens(string text, bool isSpecialToken) => EncodeToIds(text, null); + public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIds(text, null); - private int EncodeToIds(string text, IList? accumulatedIds) + private int EncodeToIds(ReadOnlySpan text, IList? accumulatedIds) { - if (string.IsNullOrEmpty(text)) + if (text.IsEmpty) { return 0; } - if (_cache.TryGet(text, out List? hit)) + if (_cache.TryGetValue(text, out List? hit)) { if (accumulatedIds is not null) { @@ -255,17 +257,41 @@ private int EncodeToIds(string text, IList? accumulatedIds) return hit.Count; } - // If the cache doesn't have the text, then encode it and add it to the cache - IReadOnlyList tokens = Encode(text); + char[] token = ArrayPool.Shared.Rent(text.Length); + int[] indexMapping = ArrayPool.Shared.Rent(text.Length); + + int newTokenIndex = 0; + for (int i = 0; i < text.Length; i++) + { + if (_byteToUnicode.TryGetValue(text[i], out var value)) + { + token[newTokenIndex] = value; + indexMapping[newTokenIndex] = i; + newTokenIndex++; + } + } + + if (newTokenIndex == 0) + { + ArrayPool.Shared.Return(token); + ArrayPool.Shared.Return(indexMapping); + return 0; + } + + List result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping); + _cache.Set(text.ToString(), result); + ArrayPool.Shared.Return(token); + ArrayPool.Shared.Return(indexMapping); + if (accumulatedIds is not null) { - foreach (var t in tokens) + foreach (var t in result) { accumulatedIds.Add(t.Id); } } - return tokens.Count; + return result.Count; } /// @@ -274,7 +300,7 @@ private int EncodeToIds(string text, IList? accumulatedIds) /// The token to map to the Id. /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public override int? MapTokenToId(string token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out var value) ? value : null; + public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null; /// /// Convert a list of tokens Ids to highest occurrence rankings. @@ -397,12 +423,13 @@ private IReadOnlyList ModifyTokenListOffsets(IReadOnlyList tokens, private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) => HighestOccurrenceMapping.Load(highestOccurrenceMappingStream); - private Dictionary GetVocabulary(Stream vocabularyStream) + private Dictionary GetVocabulary(Stream vocabularyStream) { - Dictionary? vocab; + Dictionary? vocab; try { - vocab = JsonSerializer.Deserialize>(vocabularyStream) as Dictionary; + JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } }; + vocab = JsonSerializer.Deserialize>(vocabularyStream, options) as Dictionary; } catch (Exception e) { @@ -416,22 +443,22 @@ private Dictionary GetVocabulary(Stream vocabularyStream) if (_vocabIdToHighestOccurrence.BosWord is not null) { - vocab[_vocabIdToHighestOccurrence.BosWord] = -_vocabIdToHighestOccurrence.BosIndex; + vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.BosWord)] = -_vocabIdToHighestOccurrence.BosIndex; } if (_vocabIdToHighestOccurrence.EosWord is not null) { - vocab[_vocabIdToHighestOccurrence.EosWord] = -_vocabIdToHighestOccurrence.EosIndex; + vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.EosWord)] = -_vocabIdToHighestOccurrence.EosIndex; } if (_vocabIdToHighestOccurrence.UnkWord is not null) { - vocab[_vocabIdToHighestOccurrence.UnkWord] = -_vocabIdToHighestOccurrence.UnkIndex; + vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.UnkWord)] = -_vocabIdToHighestOccurrence.UnkIndex; } if (_vocabIdToHighestOccurrence.PadWord is not null) { - vocab[_vocabIdToHighestOccurrence.PadWord] = -_vocabIdToHighestOccurrence.PadIndex; + vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.PadWord)] = -_vocabIdToHighestOccurrence.PadIndex; } return vocab; @@ -510,7 +537,7 @@ private List EncodeToTokens(Span token, Span indexMapping) if (token.Length == 1) { string tokenValue = _charToString[token[0]]; - return new List { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], 1)) }; + return new List { new Token(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) }; } List word = new(token.Length); @@ -539,7 +566,7 @@ private List EncodeToTokens(Span token, Span indexMapping) // get the most frequent bi-gram pair var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue)); - if (!_mergeRanks.TryGet((first, second), out int _)) + if (!_mergeRanks.TryGetValue((first, second), out int _)) { break; } @@ -599,7 +626,7 @@ private List EncodeToTokens(Span token, Span indexMapping) foreach (string w in word) { - tokens.Add(new Token(_vocab[w], w, (indexMapping[index], w.Length))); + tokens.Add(new Token(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length))); index += w.Length; } diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs index 16eecc4aa4..815bd04a0b 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Model.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs @@ -31,14 +31,15 @@ public abstract class Model /// This method does the default implementation that uses the Encode method to get the token's Ids. /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. /// - public virtual void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) + public virtual void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) { if (accumulatedIds is null) { throw new ArgumentNullException(nameof(accumulatedIds)); } - var tokens = Encode(text); + // Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance. + var tokens = Encode(text.ToString()); foreach (var token in tokens) { accumulatedIds.Add(token.Id); @@ -55,7 +56,7 @@ public virtual void EncodeToIds(string text, bool isSpecialToken, IList acc /// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids. /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. /// - public virtual int CountTokens(string text, bool isSpecialToken) + public virtual int CountTokens(ReadOnlySpan text, bool isSpecialToken) { var ids = new List(); EncodeToIds(text, isSpecialToken, ids); @@ -68,7 +69,7 @@ public virtual int CountTokens(string text, bool isSpecialToken) /// The token to map to Id /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public abstract int? MapTokenToId(string token, bool considerSpecialTokens = true); + public abstract int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true); /// /// Map the encoded Id to the token. diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 0696efd9b0..60e9282a81 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -19,12 +19,14 @@ namespace Microsoft.ML.Tokenizers /// public sealed class Tiktoken : Model { - private readonly Dictionary, int> _encoder = null!; - private readonly Dictionary> _decoder = null!; - private readonly LruCache? _cache; - private readonly IReadOnlyDictionary? _specialTokensEncoder; + private readonly Dictionary, int> _encoder; + private readonly Dictionary> _decoder; + private readonly LruCache _cache; + private readonly Dictionary? _specialTokensEncoder; + private Dictionary? _specialTokensEncoderOriginal; private readonly Dictionary? _specialTokensDecoder; - private readonly Dictionary _vocab = null!; + private readonly Dictionary _vocab; + private IReadOnlyDictionary? _vocabOriginal; /// /// Create a new Tiktoken tokenizer's model object. @@ -34,7 +36,7 @@ public sealed class Tiktoken : Model /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE vocab file. - public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : + public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true) { } @@ -47,7 +49,7 @@ public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialT /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE vocab file. - public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : + public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false) { } @@ -63,9 +65,9 @@ public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTok internal Tiktoken( Dictionary, int> encoder, Dictionary> decoder, - Dictionary vocab, + Dictionary vocab, IReadOnlyDictionary? specialTokens, - int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize) + int cacheSize = LruCache.DefaultCacheSize) { _encoder = encoder ?? throw new ArgumentNullException(nameof(encoder)); _decoder = decoder ?? throw new ArgumentNullException(nameof(decoder)); @@ -73,24 +75,21 @@ internal Tiktoken( Debug.Assert(encoder.Count == decoder.Count); - _specialTokensEncoder = specialTokens; - if (_specialTokensEncoder is not null) - { - _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); - } + _encoder = encoder!; + _decoder = decoder!; + _vocab = vocab!; + _cache = new LruCache(cacheSize); + + (_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens); } - private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens, int cacheSize, bool disposeStream) : this(cacheSize) + private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens, int cacheSize, bool disposeStream) { try { + _cache = new LruCache(cacheSize); (_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult(); - - _specialTokensEncoder = specialTokens; - if (_specialTokensEncoder is not null) - { - _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); - } + (_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens); } finally { @@ -101,17 +100,15 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTo } } - private Tiktoken(int cacheSize) + private static (Dictionary?, Dictionary?) CreateEncoderDecoder(IReadOnlyDictionary? specialTokens) { - if (cacheSize < 0) + if (specialTokens is not null) { - throw new ArgumentOutOfRangeException(nameof(cacheSize)); + var encoder = specialTokens.ToDictionary(e => new StringSpanOrdinalKey(e.Key), e => e.Value); + return (encoder, encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key.Data!)); } - if (cacheSize > 0) - { - _cache = new LruCache(cacheSize); - } + return (null, null); } /// @@ -125,7 +122,7 @@ private Tiktoken(int cacheSize) public static async Task CreateAsync( Stream vocabStream, IReadOnlyDictionary? specialTokens = null, - int cacheSize = LruCache.DefaultCacheSize, + int cacheSize = LruCache.DefaultCacheSize, CancellationToken cancellationToken = default) { if (vocabStream is null) @@ -133,7 +130,7 @@ public static async Task CreateAsync( throw new ArgumentNullException(nameof(vocabStream)); } - (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) = + (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) = await LoadTikTokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false); return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize); @@ -150,7 +147,7 @@ public static async Task CreateAsync( public static async Task CreateAsync( string vocabFilePath, IReadOnlyDictionary? specialTokensEncoder = null, - int cacheSize = LruCache.DefaultCacheSize, + int cacheSize = LruCache.DefaultCacheSize, CancellationToken cancellationToken = default) { if (vocabFilePath is null) @@ -170,11 +167,11 @@ public static async Task CreateAsync( /// used to request cancellation of the operation. /// Map of byte[] to integer token id /// - internal static async ValueTask<(Dictionary, int>, Dictionary, Dictionary>)> LoadTikTokenBpeAsync( + internal static async ValueTask<(Dictionary, int>, Dictionary, Dictionary>)> LoadTikTokenBpeAsync( Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default) { var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance); - var vocab = new Dictionary(); + var vocab = new Dictionary(); var decoder = new Dictionary>(); try @@ -212,7 +209,7 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : if (decodedToken.IndexOf('\uFFFD') < 0) { - vocab[decodedToken] = rank; + vocab[new StringSpanOrdinalKey(decodedToken)] = rank; } } else @@ -230,12 +227,6 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : return (encoder, vocab, decoder); } - /// - /// Gets the dictionary mapping special tokens to Ids. - /// - /// The dictionary mapping special tokens to Ids. - public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoder; - /// /// Encode a split text string to a list of tokens. /// @@ -253,12 +244,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) if (isSpecialToken) { - if (_specialTokensEncoder is null) - { - throw new InvalidOperationException($"The tokenizer doesn't have special tokens"); - } - - if (_specialTokensEncoder.TryGetValue(text, out int id)) + if (_specialTokensEncoder?.TryGetValue(text, out int id) is true) { return new List { new(id, text, (0, text.Length)) }; } @@ -266,7 +252,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) throw new InvalidOperationException($"The special token {text} doesn't exist in the tokenizer"); } - if (_cache?.Lookup(text, out int[] ids) is true) + if (_cache.TryGetValue(text, out int[]? ids)) { tokens = new Token[ids.Length]; tokens[0] = new Token(ids[0], text, (0, text.Length)); @@ -290,7 +276,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); Debug.Assert(encodedIds.Length > 0); - _cache?.Add(text, encodedIds); + _cache.Add(text, encodedIds); tokens = new Token[encodedIds.Length]; tokens[0] = new Token(encodedIds[0], text, (0, text.Length)); @@ -305,21 +291,21 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) } /// - /// Encode a split text string to a list of Ids. + /// Encode text to a list of Ids. /// /// The text to encode. /// Indicate if the token is a special token. /// The list of accumulated Ids. - public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) + public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) { - if (string.IsNullOrEmpty(text)) + if (text.IsEmpty) { return; } if (isSpecialToken) { - if (_specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(text, out int id)) + if (_specialTokensEncoder?.TryGetValue(text, out int id) is true) { accumulatedIds.Add(id); } @@ -327,7 +313,7 @@ public override void EncodeToIds(string text, bool isSpecialToken, IList ac return; } - if (_cache?.Lookup(text, out int[] tokenIds) is true) + if (_cache.TryGetValue(text, out int[]? tokenIds)) { accumulatedIds.AddRange(tokenIds); return; @@ -340,10 +326,10 @@ public override void EncodeToIds(string text, bool isSpecialToken, IList ac } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length)); - int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray); + int encodedLength = GetUtf8Bytes(text, arrayPoolArray); int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache?.Add(text, encodedIds); + _cache.Add(text.ToString(), encodedIds); accumulatedIds.AddRange(encodedIds); @@ -354,12 +340,12 @@ public override void EncodeToIds(string text, bool isSpecialToken, IList ac /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to encode. + /// The text to tokenize. /// Indicate if the token is special token. /// The number of tokens that the input text will be encoded to. - public override int CountTokens(string text, bool isSpecialToken) + public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) { - if (string.IsNullOrEmpty(text)) + if (text.IsEmpty) { return 0; } @@ -369,7 +355,7 @@ public override int CountTokens(string text, bool isSpecialToken) return _specialTokensEncoder.TryGetValue(text, out _) ? 1 : 0; } - if (_cache?.Lookup(text, out int[] ids) is true) + if (_cache.TryGetValue(text, out int[] ids)) { return ids.Length; } @@ -380,10 +366,10 @@ public override int CountTokens(string text, bool isSpecialToken) } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length)); - int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray); + int encodedLength = GetUtf8Bytes(text, arrayPoolArray); int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache?.Add(text, encodedIds); + _cache.Add(text.ToString(), encodedIds); ArrayPool.Shared.Return(arrayPoolArray); return encodedIds.Length; @@ -395,19 +381,22 @@ public override int CountTokens(string text, bool isSpecialToken) /// The token to map to the Id. /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public override int? MapTokenToId(string token, bool considerSpecialTokens = true) + public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) { - if (string.IsNullOrEmpty(token)) + if (token.IsEmpty) { return 0; } - if (considerSpecialTokens && _specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(token, out int specialTokenId)) + if (considerSpecialTokens && _specialTokensEncoder is not null) { - return specialTokenId; + if (_specialTokensEncoder.TryGetValue(token, out int specialTokenId)) + { + return specialTokenId; + } } - if (_cache?.Lookup(token, out int[] ids) is true) + if (_cache.TryGetValue(token, out int[] ids)) { if (ids.Length == 1) { @@ -425,10 +414,10 @@ public override int CountTokens(string text, bool isSpecialToken) byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length)); try { - int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray); + int encodedLength = GetUtf8Bytes(token, arrayPoolArray); int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache?.Add(token, idsToCache); + _cache.Add(token.ToString(), idsToCache); if (idsToCache.Length == 1) { @@ -550,7 +539,12 @@ static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray, /// Gets the dictionary mapping tokens to Ids. /// /// This may not contain the full set of vocabulary tokens, use Encoder to get the full set of vocabulary. - public IReadOnlyDictionary Vocab => _vocab; + public IReadOnlyDictionary Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); + + /// + /// Gets the dictionary mapping special tokens to Ids. + /// + public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoderOriginal ??= _specialTokensEncoder?.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); /// /// Gets the dictionary mapping token bytes to Ids. diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index 0826a8b68e..c64ebf256e 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -104,7 +104,7 @@ public IReadOnlyList EncodeToIds(string text, bool considerSpecialTokens = foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens)) { - Model.EncodeToIds(split.TokenString, split.IsSpecialToken, idsList); + Model.EncodeToIds(split.TokenSpan, split.IsSpecialToken, idsList); } return idsList; @@ -130,7 +130,7 @@ public int CountTokens(string text, bool considerSpecialTokens = true) int idsCount = 0; foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens)) { - idsCount += Model.CountTokens(split.TokenString, split.IsSpecialToken); + idsCount += Model.CountTokens(split.TokenSpan, split.IsSpecialToken); } return idsCount; @@ -343,7 +343,7 @@ private static Task CreateByEncoderNameAsync( } } - private static readonly ConcurrentDictionary, int>, Dictionary, Dictionary>)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); + private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); /// /// Create tokenizer based on regex pattern, BPE rank file and special tokens @@ -371,7 +371,7 @@ private static async Task CreateTikTokenTokenizerAsync( } } - if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache)) + if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache)) { using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false)) { diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs index b64531431f..0050c63f3d 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs @@ -37,5 +37,7 @@ public static byte[] FromBase64String(string base64String, int offset, int lengt internal static bool TryParseInt32(string s, int offset, out int result) => int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result); + + internal static int GetHashCode(ReadOnlySpan span) => string.GetHashCode(span); } } diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs index 2979c99b6e..2d739e52e4 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs @@ -48,6 +48,17 @@ internal static bool TryParseInt32(string s, int offset, out int result) return true; } + + internal static int GetHashCode(ReadOnlySpan span) + { + int hash = 17; + foreach (char c in span) + { + hash = hash * 31 + c; + } + + return hash; + } } } diff --git a/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs b/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs index 9ad88e2f35..c11d79e1f5 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs @@ -2,47 +2,37 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; namespace Microsoft.ML.Tokenizers { - internal class LruCache where TKey : notnull where TValue : notnull + internal sealed class LruCache { /// /// The default LRU cache size. /// - public const int DefaultCacheSize = 8192; // 4096; + public const int DefaultCacheSize = 8192; - private readonly object _lockObject = new object(); - - private class CacheItem - { - public readonly TKey Key; - public TValue Value; - - public CacheItem(TKey key, TValue value) - { - Key = key; - Value = value; - } - } - - private readonly Dictionary> _cache; - private readonly LinkedList _lruList; + private readonly Dictionary>> _cache = new(); + private readonly LinkedList> _lruList = new(); private readonly int _cacheSize; + private object SyncObj => _cache; + /// - /// Constructs an object. + /// Constructs an object. /// /// - /// The maximum number of to mappings - /// that can be cached. This defaults to , which is set to - /// 4096. + /// The maximum number of mappings that can be cached. This defaults to , which is set to 8192. /// public LruCache(int cacheSize = DefaultCacheSize) { - _cache = new Dictionary>(); - _lruList = new LinkedList(); + if (cacheSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(cacheSize), "Cache size must be a positive number."); + } + _cacheSize = cacheSize; } @@ -54,11 +44,11 @@ public LruCache(int cacheSize = DefaultCacheSize) /// /// true if the cache contains a mapping for key, false otherwise. /// - public bool Lookup(TKey key, out TValue value) + public bool TryGetValue(string key, out TValue value) { - lock (_lockObject) + lock (SyncObj) { - if (_cache.TryGetValue(key, out LinkedListNode? cached)) + if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode>? cached)) { _lruList.Remove(cached); _lruList.AddFirst(cached); @@ -71,16 +61,31 @@ public bool Lookup(TKey key, out TValue value) } } - protected virtual void OnEviction(TValue evictedValue) { } - - private void EvictIfNeeded() + /// + /// Retrieves the value associated with the specified key /> object. + /// + /// The object to be used as a key. + /// An out parameter that is set to the value of the key if key contains a mapping in the cache. + /// + /// true if the cache contains a mapping for key, false otherwise. + /// + public unsafe bool TryGetValue(ReadOnlySpan key, out TValue value) { - while (_cache.Count >= _cacheSize) + lock (SyncObj) { - LinkedListNode? nodeToEvict = _lruList.Last; - _lruList.RemoveLast(); - _cache.Remove(nodeToEvict!.Value.Key); - OnEviction(nodeToEvict.Value.Value); + fixed (char* ptr = key) + { + if (_cache.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out LinkedListNode>? cached)) + { + _lruList.Remove(cached); + _lruList.AddFirst(cached); + value = cached.Value.Value; + return true; + } + } + + value = default!; + return false; } } @@ -89,46 +94,29 @@ private void EvictIfNeeded() /// /// The key whose mapped is to be created or replaced. /// The new value to be mapped to the . - public void Add(TKey key, TValue value) => Replace(key, value, out _); - - public bool Replace(TKey key, TValue value, out TValue oldValue) + public void Add(string key, TValue value) { - lock (_lockObject) + lock (SyncObj) { - return ReplaceInternal(key, value, out oldValue); - } - } - - private bool ReplaceInternal(TKey key, TValue value, out TValue oldValue) - { - if (_cache.TryGetValue(key, out LinkedListNode? cached)) - { - oldValue = cached.Value.Value; - cached.Value.Value = value; - _lruList.Remove(cached); - _lruList.AddFirst(cached); - return true; - } - EvictIfNeeded(); - var node = new LinkedListNode(new CacheItem(key, value)); - _cache[key] = node; - _lruList.AddFirst(node); - oldValue = default!; - return false; - } + if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode>? cached)) + { + cached.Value = new KeyValuePair(key, value); + _lruList.Remove(cached); + _lruList.AddFirst(cached); + return; + } - /// - /// The number of entries currently present in the cache. - /// - public int Count => _cache.Count; + while (_cache.Count >= _cacheSize) + { + LinkedListNode>? nodeToEvict = _lruList.Last; + _lruList.RemoveLast(); + _cache.Remove(new StringSpanOrdinalKey(nodeToEvict!.Value.Key)); + } - /// - /// Clears the contents of this cache. - /// - public void Clear() - { - _cache.Clear(); - _lruList.Clear(); + var node = new LinkedListNode>(new KeyValuePair(key, value)); + _cache[new StringSpanOrdinalKey(key)] = node; + _lruList.AddFirst(node); + } } } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs new file mode 100644 index 0000000000..3cee62e318 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs @@ -0,0 +1,132 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.ML.Tokenizers +{ + /// Used as a key in a dictionary to enable querying with either a string or a span. + /// + /// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should + /// always be used with a string. + /// + internal unsafe readonly struct StringSpanOrdinalKey : IEquatable + { + public readonly char* Ptr; + public readonly int Length; + public readonly string? Data; + + public StringSpanOrdinalKey(char* ptr, int length) + { + Ptr = ptr; + Length = length; + } + + public StringSpanOrdinalKey(string data) => + Data = data; + + private ReadOnlySpan Span => Ptr is not null ? + new ReadOnlySpan(Ptr, Length) : + Data.AsSpan(); + + public override bool Equals(object? obj) => + obj is StringSpanOrdinalKey wrapper && Equals(wrapper); + + public bool Equals(StringSpanOrdinalKey other) => + Span.SequenceEqual(other.Span); + + public override int GetHashCode() => Helpers.GetHashCode(Span); + } + + internal sealed class StringSpanOrdinalKeyCache + { + private readonly int _capacity; + private readonly Dictionary _map; + + private object SyncObj => _map; + + internal StringSpanOrdinalKeyCache() : this(Bpe.DefaultCacheCapacity) { } + + internal StringSpanOrdinalKeyCache(int capacity) + { + _capacity = capacity; + _map = new Dictionary(capacity); + } + + internal bool TryGetValue(string key, out TValue value) + { + lock (SyncObj) + { + return _map.TryGetValue(new StringSpanOrdinalKey(key), out value!); + } + } + + internal unsafe bool TryGetValue(ReadOnlySpan key, out TValue value) + { + lock (SyncObj) + { + fixed (char* ptr = key) + { + return _map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!); + } + } + } + + internal void Remove(string key) + { + lock (SyncObj) + { + _map.Remove(new StringSpanOrdinalKey(key)); + } + } + + internal void Set(string k, TValue v) + { + lock (SyncObj) + { + if (_map.Count < _capacity) + { + _map[new StringSpanOrdinalKey(k)] = v; + } + } + } + } + + /// + /// Custom JSON converter for . + /// + internal sealed class StringSpanOrdinalKeyConverter : JsonConverter + { + public static StringSpanOrdinalKeyConverter Instance { get; } = new StringSpanOrdinalKeyConverter(); + public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => + new StringSpanOrdinalKey(reader.GetString()!); + + public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => + writer.WriteStringValue(value.Data!); + + public override StringSpanOrdinalKey Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => new StringSpanOrdinalKey(reader.GetString()!); + public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!); + } + + /// + /// Extension methods for . + /// + internal static class StringSpanOrdinalKeyExtensions + { + public unsafe static bool TryGetValue(this Dictionary map, ReadOnlySpan key, out TValue value) + { + fixed (char* ptr = key) + { + return map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!); + } + } + + public static bool TryGetValue(this Dictionary map, string key, out TValue value) => + map.TryGetValue(new StringSpanOrdinalKey(key), out value!); + } +} diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs index 810862322b..2959184b5d 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs @@ -156,7 +156,7 @@ public void SimpleTestWithUnknownToken(Dictionary vocab, (string, s Assert.Equal(ids[i], encoding.Ids[i]); Assert.Equal(ids[i], idsList[i]); Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); - Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i])); + Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan())); Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i])); } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs index d23f241319..ccf0e66ef9 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs @@ -201,7 +201,7 @@ private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = Call Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false)); } - Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i])); + Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan())); } } }