diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs
index d799d45a39..20cfe7f38b 100644
--- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs
@@ -3,8 +3,10 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Buffers;
using System.Collections.Generic;
using System.IO;
+using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
@@ -34,20 +36,21 @@ private set
{
_unknownToken = value;
- if (value is null)
+ if (VocabReverse.TryGetValue(0, out string? v))
{
- if (VocabReverse.TryGetValue(0, out string? v))
+ if (v == value)
{
- VocabReverse.Remove(0);
- if (_vocab.TryGetValue(v, out int id))
- {
- _vocab.Remove(v);
- }
+ return;
}
+
+ VocabReverse.Remove(0);
+ _vocab.Remove(new StringSpanOrdinalKey(v));
}
- else
+
+
+ if (value is not null)
{
- _vocab[value] = 0;
+ _vocab[new StringSpanOrdinalKey(value)] = 0;
VocabReverse[0] = value;
}
}
@@ -68,7 +71,6 @@ private set
///
public bool FuseUnknownTokens { get; }
-
///
/// Construct a new Bpe model object to use for text encoding.
///
@@ -111,23 +113,19 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;
- (Dictionary? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
- _vocab = vocab1 ?? new Dictionary();
- Cache = new Cache();
+ (Dictionary? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
+ _vocab = vocab1 ?? new Dictionary();
+ Cache = new StringSpanOrdinalKeyCache();
VocabReverse = new();
- foreach (KeyValuePair kvp in Vocab)
+ foreach (KeyValuePair kvp in _vocab)
{
- VocabReverse.Add(kvp.Value, kvp.Key);
+ VocabReverse.Add(kvp.Value, kvp.Key.Data!);
}
- if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
- {
- unknownToken = unkToken;
- }
- UnknownToken = unknownToken;
+ UnknownToken = unknownToken ?? (VocabReverse.TryGetValue(0, out string? unkToken) ? unkToken : null);
int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
@@ -197,7 +195,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f
/// The text to split.
/// Indicate if the token is a special token.
/// The list of accumulated encoded Ids.
- public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
+ public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
///
/// Get the number of tokens that the input text will be encoded to.
@@ -205,7 +203,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f
/// The text to encode.
/// Indicate if the token is special token.
/// The number of tokens that the input text will be encoded to.
- public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
+ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
///
/// Map the token to encoded Id.
@@ -213,15 +211,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f
/// The token to map to the Id.
/// Indicate if want to consider the special tokens during the encoding.
/// The mapped Id of the token.
- public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
- {
- if (_vocab.TryGetValue(token, out int value))
- {
- return value;
- }
-
- return null;
- }
+ public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
///
/// Map the encoded Id to the token.
@@ -242,24 +232,27 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f
///
/// Gets the dictionary mapping tokens to Ids.
///
- public IReadOnlyDictionary Vocab => _vocab;
+ public IReadOnlyDictionary Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
/// Read the given files to extract the vocab and merges
- internal static (Dictionary?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
+ internal static (Dictionary?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
{
- Dictionary? dic = JsonSerializer.Deserialize>(vocab) as Dictionary;
+ JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
+ Dictionary? dic = JsonSerializer.Deserialize>(vocab, options) as Dictionary;
return (dic, ConvertMergesToHashmap(merges));
}
/// The vocabulary assigns a number to each token.
- private readonly Dictionary _vocab;
+ private readonly Dictionary _vocab;
+
+ private Dictionary? _vocabOriginal;
/// Contains the mapping between Pairs and their (rank, newId).
internal Dictionary, (int, int)> Merges { get; }
/// Contains the cache for optimizing the encoding step.
- internal Cache? Cache { get; }
+ internal StringSpanOrdinalKeyCache? Cache { get; }
internal static readonly int DefaultCacheCapacity = 10_000;
@@ -309,9 +302,6 @@ internal static (Dictionary?, Vec<(string, string)>) ReadModelData(
return merges;
}
- /// Reset the cache.
- internal void ClearCache() => Cache?.Clear();
-
private readonly Dictionary _charToString = new Dictionary();
[MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -327,38 +317,68 @@ internal string CharToString(char c)
return s;
}
- internal Word MergeWord(string w)
+ internal Word MergeWord(ReadOnlySpan w)
{
Word word = Word.WithCapacity(w.Length);
(int Id, int Len)? unk = null;
int i = 0;
+ Span buffer = stackalloc char[256];
+ scoped ReadOnlySpan s;
+
while (i < w.Length)
{
int length;
- string s;
if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
{
length = 2;
- s = w.Substring(i, length);
+ s = w.Slice(i, 2);
}
else
{
length = 1;
- s = CharToString(w[i]);
+ s = w.Slice(i, 1);
}
// Add the `continuing_subword_prefix` if relevant
if (i > 0 && ContinuingSubwordPrefix is not null)
{
- s = $"{ContinuingSubwordPrefix}{s}";
+ if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length)
+ {
+ ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
+ s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
+ s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length);
+ }
+ else
+ {
+#if NETCOREAPP
+ s = $"{ContinuingSubwordPrefix}{s}".AsSpan();
+#else
+ string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
+ s = $"{ContinuingSubwordPrefix}{s1}".AsSpan();
+#endif
+ }
}
// Add the `end_of_word_suffix` if relevant
if (i + length >= w.Length && EndOfWordSuffix is not null)
{
- s = $"{s}{EndOfWordSuffix}";
+ if (s.Length + EndOfWordSuffix.Length <= buffer.Length)
+ {
+ s.CopyTo(buffer);
+ EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length));
+ s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length);
+ }
+ else
+ {
+#if NETCOREAPP
+ s = $"{s}{EndOfWordSuffix}".AsSpan();
+#else
+ string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
+ s = $"{s1}{EndOfWordSuffix}".AsSpan();
+#endif
+ }
}
if (_vocab.TryGetValue(s, out int id))
@@ -419,17 +439,17 @@ internal List EncodeWithCache(string text)
Word word;
if (Cache is not null)
{
- if (Cache.TryGet(text, out word))
+ if (Cache.TryGetValue(text, out word))
{
return WordToTokens(ref word);
}
- word = MergeWord(text);
+ word = MergeWord(text.AsSpan());
Cache.Set(text, word);
}
else
{
- word = MergeWord(text);
+ word = MergeWord(text.AsSpan());
}
return WordToTokens(ref word);
@@ -445,19 +465,19 @@ internal int WordToIds(ref Word word, IList? accumulatedIds)
return word.SymbolsCount;
}
- internal int EncodeToIdsWithCache(string text, IList? accumulatedIds)
+ internal int EncodeToIdsWithCache(ReadOnlySpan text, IList? accumulatedIds)
{
Word word;
if (Cache is not null)
{
- if (Cache.TryGet(text, out Word hit))
+ if (Cache.TryGetValue(text, out Word hit))
{
return WordToIds(ref hit, accumulatedIds);
}
word = MergeWord(text);
- Cache.Set(text, word);
+ Cache.Set(text.ToString(), word);
}
else
{
diff --git a/src/Microsoft.ML.Tokenizers/Model/Cache.cs b/src/Microsoft.ML.Tokenizers/Model/Cache.cs
index b10d211ea6..065676621e 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Cache.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Cache.cs
@@ -4,112 +4,53 @@
using System;
using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading;
namespace Microsoft.ML.Tokenizers
{
internal sealed class Cache where TKey : notnull where TValue : notnull
{
+ private readonly int _capacity;
+ private readonly Dictionary _map;
+ private object SyncObj => _map;
+
internal Cache() : this(Bpe.DefaultCacheCapacity) { }
internal Cache(int capacity)
{
- Capacity = capacity;
- Map = new Dictionary(Capacity);
+ _capacity = capacity;
+ _map = new Dictionary(capacity);
}
- private readonly object _lock = new();
-
- internal Dictionary Map { get; set; }
-
- internal int Capacity { get; set; }
-
- internal void Fresh() => Map = new Dictionary(Capacity);
-
- internal void Clear()
+ internal bool TryGetValue(TKey key, out TValue value)
{
- lock (_lock)
+ lock (SyncObj)
{
- Map.Clear();
+ return _map.TryGetValue(key, out value!);
}
}
- internal List GetValues(IEnumerable keys)
- {
- List values = new();
- lock (_lock)
- {
- foreach (TKey key in keys)
- {
- if (Map.TryGetValue(key, out TValue? value))
- {
- values.Add(value);
- }
- }
- }
-
- return values;
- }
-
- internal bool TryGet(TKey key, out TValue value)
- {
- lock (_lock)
- {
- return Map.TryGetValue(key, out value!);
- }
- }
-
- internal void SetValues(IEnumerable<(TKey, TValue)> entries)
- {
- lock (_lock)
- {
- foreach ((TKey, TValue) entry in entries)
- {
- if (Capacity <= Map.Count)
- {
- break;
- }
- Map[entry.Item1] = entry.Item2;
- }
- }
- }
-
- internal void Set(TKey k, TValue v)
+ internal TValue GetOrAdd(TKey key, TValue value)
{
- lock (_lock)
+ lock (SyncObj)
{
- if (Capacity > Map.Count)
+ if (_map.TryGetValue(key, out TValue? v))
{
- Map[k] = v;
+ return v!;
}
- }
- }
- internal KeyValuePair[] ToArray()
- {
- lock (_lock)
- {
- return Map.ToArray();
+ _map[key] = value;
+ return value;
}
}
- internal TValue GetOrAdd(TKey key, TValue value)
+ internal void Set(TKey key, TValue value)
{
- lock (_lock)
+ lock (SyncObj)
{
- if (Map.TryGetValue(key, out TValue? v))
+ if (_map.Count < _capacity)
{
- return v;
+ _map[key] = value;
}
-
- if (Capacity > Map.Count)
- {
- Map[key] = value;
- }
-
- return value;
}
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
index ea9fa884a8..3155c778ec 100644
--- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
@@ -18,13 +18,14 @@ namespace Microsoft.ML.Tokenizers
public sealed class EnglishRoberta : Model
{
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
- private readonly IReadOnlyDictionary _vocab;
- private readonly SortedDictionary _vocabReverse;
+ private readonly Dictionary _vocab;
+ private Dictionary? _vocabOriginal;
+ private readonly SortedDictionary _vocabReverse;
private readonly Cache<(string, string), int> _mergeRanks;
private readonly IReadOnlyDictionary _byteToUnicode;
private readonly IReadOnlyDictionary _unicodeToByte;
private readonly string[] _charToString;
- private readonly Cache> _cache;
+ private readonly StringSpanOrdinalKeyCache> _cache;
///
/// Indicate if want to filter the unsupported characters during the decoding.
@@ -77,7 +78,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc
}
_unicodeToByte = _byteToUnicode.Reverse();
- _cache = new Cache>();
+ _cache = new StringSpanOrdinalKeyCache>();
}
///
@@ -118,13 +119,13 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
}
_unicodeToByte = _byteToUnicode.Reverse();
- _cache = new Cache>();
+ _cache = new StringSpanOrdinalKeyCache>();
}
///
/// Gets the dictionary mapping tokens to Ids.
///
- public IReadOnlyDictionary Vocab => _vocab;
+ public IReadOnlyDictionary Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
//
// Public Model interfaces implementation
@@ -145,14 +146,15 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
if (_vocabReverse.TryGetValue(id, out var value))
{
+ string v = value.Data!;
if (FilterUnsupportedChars)
{
- char[] buffer = ArrayPool.Shared.Rent(value.Length);
+ char[] buffer = ArrayPool.Shared.Rent(v.Length);
int i = 0;
- for (int j = 0; j < value.Length; j++)
+ for (int j = 0; j < v.Length; j++)
{
- if (_unicodeToByte.TryGetValue(value[j], out var c))
+ if (_unicodeToByte.TryGetValue(v[j], out var c))
{
buffer[i++] = c;
}
@@ -164,7 +166,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
}
else
{
- return value;
+ return v;
}
}
@@ -205,7 +207,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f
return Array.Empty();
}
- if (_cache.TryGet(text, out List? hit))
+ if (_cache.TryGetValue(text, out List? hit))
{
ArrayPool.Shared.Return(token);
ArrayPool.Shared.Return(indexMapping);
@@ -225,7 +227,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f
/// The text to split.
/// Indicate if the token is a special token.
/// The list of accumulated encoded Ids.
- public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds);
+ public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds);
///
/// Get the number of tokens that the input text will be encoded to.
@@ -233,16 +235,16 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f
/// The text to encode.
/// Indicate if the token is special token.
/// The number of tokens that the input text will be encoded to.
- public override int CountTokens(string text, bool isSpecialToken) => EncodeToIds(text, null);
+ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIds(text, null);
- private int EncodeToIds(string text, IList? accumulatedIds)
+ private int EncodeToIds(ReadOnlySpan text, IList? accumulatedIds)
{
- if (string.IsNullOrEmpty(text))
+ if (text.IsEmpty)
{
return 0;
}
- if (_cache.TryGet(text, out List? hit))
+ if (_cache.TryGetValue(text, out List? hit))
{
if (accumulatedIds is not null)
{
@@ -255,17 +257,41 @@ private int EncodeToIds(string text, IList? accumulatedIds)
return hit.Count;
}
- // If the cache doesn't have the text, then encode it and add it to the cache
- IReadOnlyList tokens = Encode(text);
+ char[] token = ArrayPool.Shared.Rent(text.Length);
+ int[] indexMapping = ArrayPool.Shared.Rent(text.Length);
+
+ int newTokenIndex = 0;
+ for (int i = 0; i < text.Length; i++)
+ {
+ if (_byteToUnicode.TryGetValue(text[i], out var value))
+ {
+ token[newTokenIndex] = value;
+ indexMapping[newTokenIndex] = i;
+ newTokenIndex++;
+ }
+ }
+
+ if (newTokenIndex == 0)
+ {
+ ArrayPool.Shared.Return(token);
+ ArrayPool.Shared.Return(indexMapping);
+ return 0;
+ }
+
+ List result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
+ _cache.Set(text.ToString(), result);
+ ArrayPool.Shared.Return(token);
+ ArrayPool.Shared.Return(indexMapping);
+
if (accumulatedIds is not null)
{
- foreach (var t in tokens)
+ foreach (var t in result)
{
accumulatedIds.Add(t.Id);
}
}
- return tokens.Count;
+ return result.Count;
}
///
@@ -274,7 +300,7 @@ private int EncodeToIds(string text, IList? accumulatedIds)
/// The token to map to the Id.
/// Indicate if want to consider the special tokens during the encoding.
/// The mapped Id of the token.
- public override int? MapTokenToId(string token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out var value) ? value : null;
+ public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;
///
/// Convert a list of tokens Ids to highest occurrence rankings.
@@ -397,12 +423,13 @@ private IReadOnlyList ModifyTokenListOffsets(IReadOnlyList tokens,
private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) =>
HighestOccurrenceMapping.Load(highestOccurrenceMappingStream);
- private Dictionary GetVocabulary(Stream vocabularyStream)
+ private Dictionary GetVocabulary(Stream vocabularyStream)
{
- Dictionary? vocab;
+ Dictionary? vocab;
try
{
- vocab = JsonSerializer.Deserialize>(vocabularyStream) as Dictionary;
+ JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
+ vocab = JsonSerializer.Deserialize>(vocabularyStream, options) as Dictionary;
}
catch (Exception e)
{
@@ -416,22 +443,22 @@ private Dictionary GetVocabulary(Stream vocabularyStream)
if (_vocabIdToHighestOccurrence.BosWord is not null)
{
- vocab[_vocabIdToHighestOccurrence.BosWord] = -_vocabIdToHighestOccurrence.BosIndex;
+ vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.BosWord)] = -_vocabIdToHighestOccurrence.BosIndex;
}
if (_vocabIdToHighestOccurrence.EosWord is not null)
{
- vocab[_vocabIdToHighestOccurrence.EosWord] = -_vocabIdToHighestOccurrence.EosIndex;
+ vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.EosWord)] = -_vocabIdToHighestOccurrence.EosIndex;
}
if (_vocabIdToHighestOccurrence.UnkWord is not null)
{
- vocab[_vocabIdToHighestOccurrence.UnkWord] = -_vocabIdToHighestOccurrence.UnkIndex;
+ vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.UnkWord)] = -_vocabIdToHighestOccurrence.UnkIndex;
}
if (_vocabIdToHighestOccurrence.PadWord is not null)
{
- vocab[_vocabIdToHighestOccurrence.PadWord] = -_vocabIdToHighestOccurrence.PadIndex;
+ vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.PadWord)] = -_vocabIdToHighestOccurrence.PadIndex;
}
return vocab;
@@ -510,7 +537,7 @@ private List EncodeToTokens(Span token, Span indexMapping)
if (token.Length == 1)
{
string tokenValue = _charToString[token[0]];
- return new List { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], 1)) };
+ return new List { new Token(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) };
}
List word = new(token.Length);
@@ -539,7 +566,7 @@ private List EncodeToTokens(Span token, Span indexMapping)
// get the most frequent bi-gram pair
var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue));
- if (!_mergeRanks.TryGet((first, second), out int _))
+ if (!_mergeRanks.TryGetValue((first, second), out int _))
{
break;
}
@@ -599,7 +626,7 @@ private List EncodeToTokens(Span token, Span indexMapping)
foreach (string w in word)
{
- tokens.Add(new Token(_vocab[w], w, (indexMapping[index], w.Length)));
+ tokens.Add(new Token(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length)));
index += w.Length;
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs
index 16eecc4aa4..815bd04a0b 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Model.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs
@@ -31,14 +31,15 @@ public abstract class Model
/// This method does the default implementation that uses the Encode method to get the token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
///
- public virtual void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds)
+ public virtual void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds)
{
if (accumulatedIds is null)
{
throw new ArgumentNullException(nameof(accumulatedIds));
}
- var tokens = Encode(text);
+ // Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance.
+ var tokens = Encode(text.ToString());
foreach (var token in tokens)
{
accumulatedIds.Add(token.Id);
@@ -55,7 +56,7 @@ public virtual void EncodeToIds(string text, bool isSpecialToken, IList acc
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
/// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation.
///
- public virtual int CountTokens(string text, bool isSpecialToken)
+ public virtual int CountTokens(ReadOnlySpan text, bool isSpecialToken)
{
var ids = new List();
EncodeToIds(text, isSpecialToken, ids);
@@ -68,7 +69,7 @@ public virtual int CountTokens(string text, bool isSpecialToken)
/// The token to map to Id
/// Indicate if want to consider the special tokens during the encoding.
/// The mapped Id of the token.
- public abstract int? MapTokenToId(string token, bool considerSpecialTokens = true);
+ public abstract int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true);
///
/// Map the encoded Id to the token.
diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
index 0696efd9b0..60e9282a81 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
@@ -19,12 +19,14 @@ namespace Microsoft.ML.Tokenizers
///
public sealed class Tiktoken : Model
{
- private readonly Dictionary, int> _encoder = null!;
- private readonly Dictionary> _decoder = null!;
- private readonly LruCache? _cache;
- private readonly IReadOnlyDictionary? _specialTokensEncoder;
+ private readonly Dictionary, int> _encoder;
+ private readonly Dictionary> _decoder;
+ private readonly LruCache _cache;
+ private readonly Dictionary? _specialTokensEncoder;
+ private Dictionary? _specialTokensEncoderOriginal;
private readonly Dictionary? _specialTokensDecoder;
- private readonly Dictionary _vocab = null!;
+ private readonly Dictionary _vocab;
+ private IReadOnlyDictionary? _vocabOriginal;
///
/// Create a new Tiktoken tokenizer's model object.
@@ -34,7 +36,7 @@ public sealed class Tiktoken : Model
/// The size of the cache to use.
/// Thrown when is null or empty.
/// Thrown when failed to load the BPE vocab file.
- public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) :
+ public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) :
this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true)
{
}
@@ -47,7 +49,7 @@ public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialT
/// The size of the cache to use.
/// Thrown when is null or empty.
/// Thrown when failed to load the BPE vocab file.
- public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) :
+ public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) :
this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false)
{
}
@@ -63,9 +65,9 @@ public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTok
internal Tiktoken(
Dictionary, int> encoder,
Dictionary> decoder,
- Dictionary vocab,
+ Dictionary vocab,
IReadOnlyDictionary? specialTokens,
- int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize)
+ int cacheSize = LruCache.DefaultCacheSize)
{
_encoder = encoder ?? throw new ArgumentNullException(nameof(encoder));
_decoder = decoder ?? throw new ArgumentNullException(nameof(decoder));
@@ -73,24 +75,21 @@ internal Tiktoken(
Debug.Assert(encoder.Count == decoder.Count);
- _specialTokensEncoder = specialTokens;
- if (_specialTokensEncoder is not null)
- {
- _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
- }
+ _encoder = encoder!;
+ _decoder = decoder!;
+ _vocab = vocab!;
+ _cache = new LruCache(cacheSize);
+
+ (_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
}
- private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens, int cacheSize, bool disposeStream) : this(cacheSize)
+ private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens, int cacheSize, bool disposeStream)
{
try
{
+ _cache = new LruCache(cacheSize);
(_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();
-
- _specialTokensEncoder = specialTokens;
- if (_specialTokensEncoder is not null)
- {
- _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
- }
+ (_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens);
}
finally
{
@@ -101,17 +100,15 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTo
}
}
- private Tiktoken(int cacheSize)
+ private static (Dictionary?, Dictionary?) CreateEncoderDecoder(IReadOnlyDictionary? specialTokens)
{
- if (cacheSize < 0)
+ if (specialTokens is not null)
{
- throw new ArgumentOutOfRangeException(nameof(cacheSize));
+ var encoder = specialTokens.ToDictionary(e => new StringSpanOrdinalKey(e.Key), e => e.Value);
+ return (encoder, encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key.Data!));
}
- if (cacheSize > 0)
- {
- _cache = new LruCache(cacheSize);
- }
+ return (null, null);
}
///
@@ -125,7 +122,7 @@ private Tiktoken(int cacheSize)
public static async Task CreateAsync(
Stream vocabStream,
IReadOnlyDictionary? specialTokens = null,
- int cacheSize = LruCache.DefaultCacheSize,
+ int cacheSize = LruCache.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabStream is null)
@@ -133,7 +130,7 @@ public static async Task CreateAsync(
throw new ArgumentNullException(nameof(vocabStream));
}
- (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) =
+ (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) =
await LoadTikTokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);
return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize);
@@ -150,7 +147,7 @@ public static async Task CreateAsync(
public static async Task CreateAsync(
string vocabFilePath,
IReadOnlyDictionary? specialTokensEncoder = null,
- int cacheSize = LruCache.DefaultCacheSize,
+ int cacheSize = LruCache.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (vocabFilePath is null)
@@ -170,11 +167,11 @@ public static async Task CreateAsync(
/// used to request cancellation of the operation.
/// Map of byte[] to integer token id
///
- internal static async ValueTask<(Dictionary, int>, Dictionary, Dictionary>)> LoadTikTokenBpeAsync(
+ internal static async ValueTask<(Dictionary, int>, Dictionary, Dictionary>)> LoadTikTokenBpeAsync(
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
{
var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance);
- var vocab = new Dictionary();
+ var vocab = new Dictionary();
var decoder = new Dictionary>();
try
@@ -212,7 +209,7 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
if (decodedToken.IndexOf('\uFFFD') < 0)
{
- vocab[decodedToken] = rank;
+ vocab[new StringSpanOrdinalKey(decodedToken)] = rank;
}
}
else
@@ -230,12 +227,6 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
return (encoder, vocab, decoder);
}
- ///
- /// Gets the dictionary mapping special tokens to Ids.
- ///
- /// The dictionary mapping special tokens to Ids.
- public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoder;
-
///
/// Encode a split text string to a list of tokens.
///
@@ -253,12 +244,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken)
if (isSpecialToken)
{
- if (_specialTokensEncoder is null)
- {
- throw new InvalidOperationException($"The tokenizer doesn't have special tokens");
- }
-
- if (_specialTokensEncoder.TryGetValue(text, out int id))
+ if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
{
return new List { new(id, text, (0, text.Length)) };
}
@@ -266,7 +252,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken)
throw new InvalidOperationException($"The special token {text} doesn't exist in the tokenizer");
}
- if (_cache?.Lookup(text, out int[] ids) is true)
+ if (_cache.TryGetValue(text, out int[]? ids))
{
tokens = new Token[ids.Length];
tokens[0] = new Token(ids[0], text, (0, text.Length));
@@ -290,7 +276,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken)
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
Debug.Assert(encodedIds.Length > 0);
- _cache?.Add(text, encodedIds);
+ _cache.Add(text, encodedIds);
tokens = new Token[encodedIds.Length];
tokens[0] = new Token(encodedIds[0], text, (0, text.Length));
@@ -305,21 +291,21 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken)
}
///
- /// Encode a split text string to a list of Ids.
+ /// Encode text to a list of Ids.
///
/// The text to encode.
/// Indicate if the token is a special token.
/// The list of accumulated Ids.
- public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds)
+ public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds)
{
- if (string.IsNullOrEmpty(text))
+ if (text.IsEmpty)
{
return;
}
if (isSpecialToken)
{
- if (_specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(text, out int id))
+ if (_specialTokensEncoder?.TryGetValue(text, out int id) is true)
{
accumulatedIds.Add(id);
}
@@ -327,7 +313,7 @@ public override void EncodeToIds(string text, bool isSpecialToken, IList ac
return;
}
- if (_cache?.Lookup(text, out int[] tokenIds) is true)
+ if (_cache.TryGetValue(text, out int[]? tokenIds))
{
accumulatedIds.AddRange(tokenIds);
return;
@@ -340,10 +326,10 @@ public override void EncodeToIds(string text, bool isSpecialToken, IList ac
}
byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
- int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray);
+ int encodedLength = GetUtf8Bytes(text, arrayPoolArray);
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
- _cache?.Add(text, encodedIds);
+ _cache.Add(text.ToString(), encodedIds);
accumulatedIds.AddRange(encodedIds);
@@ -354,12 +340,12 @@ public override void EncodeToIds(string text, bool isSpecialToken, IList ac
///
/// Get the number of tokens that the input text will be encoded to.
///
- /// The text to encode.
+ /// The text to tokenize.
/// Indicate if the token is special token.
/// The number of tokens that the input text will be encoded to.
- public override int CountTokens(string text, bool isSpecialToken)
+ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken)
{
- if (string.IsNullOrEmpty(text))
+ if (text.IsEmpty)
{
return 0;
}
@@ -369,7 +355,7 @@ public override int CountTokens(string text, bool isSpecialToken)
return _specialTokensEncoder.TryGetValue(text, out _) ? 1 : 0;
}
- if (_cache?.Lookup(text, out int[] ids) is true)
+ if (_cache.TryGetValue(text, out int[] ids))
{
return ids.Length;
}
@@ -380,10 +366,10 @@ public override int CountTokens(string text, bool isSpecialToken)
}
byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length));
- int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray);
+ int encodedLength = GetUtf8Bytes(text, arrayPoolArray);
int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
- _cache?.Add(text, encodedIds);
+ _cache.Add(text.ToString(), encodedIds);
ArrayPool.Shared.Return(arrayPoolArray);
return encodedIds.Length;
@@ -395,19 +381,22 @@ public override int CountTokens(string text, bool isSpecialToken)
/// The token to map to the Id.
/// Indicate if want to consider the special tokens during the encoding.
/// The mapped Id of the token.
- public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
+ public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true)
{
- if (string.IsNullOrEmpty(token))
+ if (token.IsEmpty)
{
return 0;
}
- if (considerSpecialTokens && _specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(token, out int specialTokenId))
+ if (considerSpecialTokens && _specialTokensEncoder is not null)
{
- return specialTokenId;
+ if (_specialTokensEncoder.TryGetValue(token, out int specialTokenId))
+ {
+ return specialTokenId;
+ }
}
- if (_cache?.Lookup(token, out int[] ids) is true)
+ if (_cache.TryGetValue(token, out int[] ids))
{
if (ids.Length == 1)
{
@@ -425,10 +414,10 @@ public override int CountTokens(string text, bool isSpecialToken)
byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length));
try
{
- int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray);
+ int encodedLength = GetUtf8Bytes(token, arrayPoolArray);
int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
- _cache?.Add(token, idsToCache);
+ _cache.Add(token.ToString(), idsToCache);
if (idsToCache.Length == 1)
{
@@ -550,7 +539,12 @@ static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray,
/// Gets the dictionary mapping tokens to Ids.
///
/// This may not contain the full set of vocabulary tokens, use Encoder to get the full set of vocabulary.
- public IReadOnlyDictionary Vocab => _vocab;
+ public IReadOnlyDictionary Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
+
+ ///
+ /// Gets the dictionary mapping special tokens to Ids.
+ ///
+ public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoderOriginal ??= _specialTokensEncoder?.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);
///
/// Gets the dictionary mapping token bytes to Ids.
diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
index 0826a8b68e..c64ebf256e 100644
--- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
@@ -104,7 +104,7 @@ public IReadOnlyList EncodeToIds(string text, bool considerSpecialTokens =
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
{
- Model.EncodeToIds(split.TokenString, split.IsSpecialToken, idsList);
+ Model.EncodeToIds(split.TokenSpan, split.IsSpecialToken, idsList);
}
return idsList;
@@ -130,7 +130,7 @@ public int CountTokens(string text, bool considerSpecialTokens = true)
int idsCount = 0;
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
{
- idsCount += Model.CountTokens(split.TokenString, split.IsSpecialToken);
+ idsCount += Model.CountTokens(split.TokenSpan, split.IsSpecialToken);
}
return idsCount;
@@ -343,7 +343,7 @@ private static Task CreateByEncoderNameAsync(
}
}
- private static readonly ConcurrentDictionary, int>, Dictionary, Dictionary>)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
+ private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
///
/// Create tokenizer based on regex pattern, BPE rank file and special tokens
@@ -371,7 +371,7 @@ private static async Task CreateTikTokenTokenizerAsync(
}
}
- if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache))
+ if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache))
{
using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
{
diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs
index b64531431f..0050c63f3d 100644
--- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs
+++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs
@@ -37,5 +37,7 @@ public static byte[] FromBase64String(string base64String, int offset, int lengt
internal static bool TryParseInt32(string s, int offset, out int result)
=> int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result);
+
+ internal static int GetHashCode(ReadOnlySpan span) => string.GetHashCode(span);
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs
index 2979c99b6e..2d739e52e4 100644
--- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs
+++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs
@@ -48,6 +48,17 @@ internal static bool TryParseInt32(string s, int offset, out int result)
return true;
}
+
+ internal static int GetHashCode(ReadOnlySpan span)
+ {
+ int hash = 17;
+ foreach (char c in span)
+ {
+ hash = hash * 31 + c;
+ }
+
+ return hash;
+ }
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs b/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs
index 9ad88e2f35..c11d79e1f5 100644
--- a/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs
+++ b/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs
@@ -2,47 +2,37 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
using System.Collections.Generic;
namespace Microsoft.ML.Tokenizers
{
- internal class LruCache where TKey : notnull where TValue : notnull
+ internal sealed class LruCache
{
///
/// The default LRU cache size.
///
- public const int DefaultCacheSize = 8192; // 4096;
+ public const int DefaultCacheSize = 8192;
- private readonly object _lockObject = new object();
-
- private class CacheItem
- {
- public readonly TKey Key;
- public TValue Value;
-
- public CacheItem(TKey key, TValue value)
- {
- Key = key;
- Value = value;
- }
- }
-
- private readonly Dictionary> _cache;
- private readonly LinkedList _lruList;
+ private readonly Dictionary>> _cache = new();
+ private readonly LinkedList> _lruList = new();
private readonly int _cacheSize;
+ private object SyncObj => _cache;
+
///
- /// Constructs an object.
+ /// Constructs an object.
///
///
- /// The maximum number of to mappings
- /// that can be cached. This defaults to , which is set to
- /// 4096.
+ /// The maximum number of mappings that can be cached. This defaults to , which is set to 8192.
///
public LruCache(int cacheSize = DefaultCacheSize)
{
- _cache = new Dictionary>();
- _lruList = new LinkedList();
+ if (cacheSize <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(cacheSize), "Cache size must be a positive number.");
+ }
+
_cacheSize = cacheSize;
}
@@ -54,11 +44,11 @@ public LruCache(int cacheSize = DefaultCacheSize)
///
/// true if the cache contains a mapping for key, false otherwise.
///
- public bool Lookup(TKey key, out TValue value)
+ public bool TryGetValue(string key, out TValue value)
{
- lock (_lockObject)
+ lock (SyncObj)
{
- if (_cache.TryGetValue(key, out LinkedListNode? cached))
+ if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode>? cached))
{
_lruList.Remove(cached);
_lruList.AddFirst(cached);
@@ -71,16 +61,31 @@ public bool Lookup(TKey key, out TValue value)
}
}
- protected virtual void OnEviction(TValue evictedValue) { }
-
- private void EvictIfNeeded()
+ ///
+ /// Retrieves the value associated with the specified key /> object.
+ ///
+ /// The object to be used as a key.
+ /// An out parameter that is set to the value of the key if key contains a mapping in the cache.
+ ///
+ /// true if the cache contains a mapping for key, false otherwise.
+ ///
+ public unsafe bool TryGetValue(ReadOnlySpan key, out TValue value)
{
- while (_cache.Count >= _cacheSize)
+ lock (SyncObj)
{
- LinkedListNode? nodeToEvict = _lruList.Last;
- _lruList.RemoveLast();
- _cache.Remove(nodeToEvict!.Value.Key);
- OnEviction(nodeToEvict.Value.Value);
+ fixed (char* ptr = key)
+ {
+ if (_cache.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out LinkedListNode>? cached))
+ {
+ _lruList.Remove(cached);
+ _lruList.AddFirst(cached);
+ value = cached.Value.Value;
+ return true;
+ }
+ }
+
+ value = default!;
+ return false;
}
}
@@ -89,46 +94,29 @@ private void EvictIfNeeded()
///
/// The key whose mapped is to be created or replaced.
/// The new value to be mapped to the .
- public void Add(TKey key, TValue value) => Replace(key, value, out _);
-
- public bool Replace(TKey key, TValue value, out TValue oldValue)
+ public void Add(string key, TValue value)
{
- lock (_lockObject)
+ lock (SyncObj)
{
- return ReplaceInternal(key, value, out oldValue);
- }
- }
-
- private bool ReplaceInternal(TKey key, TValue value, out TValue oldValue)
- {
- if (_cache.TryGetValue(key, out LinkedListNode? cached))
- {
- oldValue = cached.Value.Value;
- cached.Value.Value = value;
- _lruList.Remove(cached);
- _lruList.AddFirst(cached);
- return true;
- }
- EvictIfNeeded();
- var node = new LinkedListNode(new CacheItem(key, value));
- _cache[key] = node;
- _lruList.AddFirst(node);
- oldValue = default!;
- return false;
- }
+ if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode>? cached))
+ {
+ cached.Value = new KeyValuePair(key, value);
+ _lruList.Remove(cached);
+ _lruList.AddFirst(cached);
+ return;
+ }
- ///
- /// The number of entries currently present in the cache.
- ///
- public int Count => _cache.Count;
+ while (_cache.Count >= _cacheSize)
+ {
+ LinkedListNode>? nodeToEvict = _lruList.Last;
+ _lruList.RemoveLast();
+ _cache.Remove(new StringSpanOrdinalKey(nodeToEvict!.Value.Key));
+ }
- ///
- /// Clears the contents of this cache.
- ///
- public void Clear()
- {
- _cache.Clear();
- _lruList.Clear();
+ var node = new LinkedListNode>(new KeyValuePair(key, value));
+ _cache[new StringSpanOrdinalKey(key)] = node;
+ _lruList.AddFirst(node);
+ }
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs
new file mode 100644
index 0000000000..3cee62e318
--- /dev/null
+++ b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs
@@ -0,0 +1,132 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace Microsoft.ML.Tokenizers
+{
+ /// Used as a key in a dictionary to enable querying with either a string or a span.
+ ///
+ /// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should
+ /// always be used with a string.
+ ///
+ internal unsafe readonly struct StringSpanOrdinalKey : IEquatable
+ {
+ public readonly char* Ptr;
+ public readonly int Length;
+ public readonly string? Data;
+
+ public StringSpanOrdinalKey(char* ptr, int length)
+ {
+ Ptr = ptr;
+ Length = length;
+ }
+
+ public StringSpanOrdinalKey(string data) =>
+ Data = data;
+
+ private ReadOnlySpan Span => Ptr is not null ?
+ new ReadOnlySpan(Ptr, Length) :
+ Data.AsSpan();
+
+ public override bool Equals(object? obj) =>
+ obj is StringSpanOrdinalKey wrapper && Equals(wrapper);
+
+ public bool Equals(StringSpanOrdinalKey other) =>
+ Span.SequenceEqual(other.Span);
+
+ public override int GetHashCode() => Helpers.GetHashCode(Span);
+ }
+
+ internal sealed class StringSpanOrdinalKeyCache
+ {
+ private readonly int _capacity;
+ private readonly Dictionary _map;
+
+ private object SyncObj => _map;
+
+ internal StringSpanOrdinalKeyCache() : this(Bpe.DefaultCacheCapacity) { }
+
+ internal StringSpanOrdinalKeyCache(int capacity)
+ {
+ _capacity = capacity;
+ _map = new Dictionary(capacity);
+ }
+
+ internal bool TryGetValue(string key, out TValue value)
+ {
+ lock (SyncObj)
+ {
+ return _map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
+ }
+ }
+
+ internal unsafe bool TryGetValue(ReadOnlySpan key, out TValue value)
+ {
+ lock (SyncObj)
+ {
+ fixed (char* ptr = key)
+ {
+ return _map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
+ }
+ }
+ }
+
+ internal void Remove(string key)
+ {
+ lock (SyncObj)
+ {
+ _map.Remove(new StringSpanOrdinalKey(key));
+ }
+ }
+
+ internal void Set(string k, TValue v)
+ {
+ lock (SyncObj)
+ {
+ if (_map.Count < _capacity)
+ {
+ _map[new StringSpanOrdinalKey(k)] = v;
+ }
+ }
+ }
+ }
+
+ ///
+ /// Custom JSON converter for .
+ ///
+ internal sealed class StringSpanOrdinalKeyConverter : JsonConverter
+ {
+ public static StringSpanOrdinalKeyConverter Instance { get; } = new StringSpanOrdinalKeyConverter();
+ public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) =>
+ new StringSpanOrdinalKey(reader.GetString()!);
+
+ public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) =>
+ writer.WriteStringValue(value.Data!);
+
+ public override StringSpanOrdinalKey Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => new StringSpanOrdinalKey(reader.GetString()!);
+ public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!);
+ }
+
+ ///
+ /// Extension methods for .
+ ///
+ internal static class StringSpanOrdinalKeyExtensions
+ {
+ public unsafe static bool TryGetValue(this Dictionary map, ReadOnlySpan key, out TValue value)
+ {
+ fixed (char* ptr = key)
+ {
+ return map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!);
+ }
+ }
+
+ public static bool TryGetValue(this Dictionary map, string key, out TValue value) =>
+ map.TryGetValue(new StringSpanOrdinalKey(key), out value!);
+ }
+}
diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
index 810862322b..2959184b5d 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
@@ -156,7 +156,7 @@ public void SimpleTestWithUnknownToken(Dictionary vocab, (string, s
Assert.Equal(ids[i], encoding.Ids[i]);
Assert.Equal(ids[i], idsList[i]);
Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i]));
- Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i]));
+ Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));
Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i]));
}
}
diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
index d23f241319..ccf0e66ef9 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
@@ -201,7 +201,7 @@ private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = Call
Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false));
}
- Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i]));
+ Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan()));
}
}
}