diff --git a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp index ed7d34aca4d..7af2357d9be 100644 --- a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp +++ b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp @@ -146,10 +146,7 @@ BPETokenizer::~BPETokenizer() { * token. */ Result BPETokenizer::decode(uint64_t prev_token, uint64_t token) { - if (!initialized_) { - ET_LOG(Error, "Tokenizer not initialized"); - return Error::NotSupported; - } + ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token)); const char* piece = vocab_[token]; // following BOS token, sentencepiece decoder strips any leading // whitespace diff --git a/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp b/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp index 1d1f83065cf..e9eada338d5 100644 --- a/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp +++ b/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp @@ -39,6 +39,14 @@ TEST_F(TokenizerExtensionTest, DecodeWithoutLoadFails) { EXPECT_EQ(result.error(), Error::NotSupported); } +TEST_F(TokenizerExtensionTest, DecodeOutOfRangeFails) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + auto result = tokenizer_->decode(0, 64000); + // The vocab size is 32000, and token 64000 is out of vocab range. + EXPECT_EQ(result.error(), Error::NotSupported); +} + TEST_F(TokenizerExtensionTest, TokenizerVocabSizeIsExpected) { Error res = tokenizer_->load(modelPath_.c_str()); EXPECT_EQ(res, Error::Ok); diff --git a/examples/models/llama2/tokenizer/test/test_tiktoken.cpp b/examples/models/llama2/tokenizer/test/test_tiktoken.cpp index 2f08e2a1aa7..6130a9e858a 100644 --- a/examples/models/llama2/tokenizer/test/test_tiktoken.cpp +++ b/examples/models/llama2/tokenizer/test/test_tiktoken.cpp @@ -77,5 +77,14 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) { } } +TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + // The vocab size is 128256, addes 256 just so the token is out of vocab + // range. + Result out = tokenizer_->decode(0, 128256 + 256); + EXPECT_EQ(out.error(), Error::NotSupported); +} + } // namespace executor } // namespace torch diff --git a/examples/models/llama2/tokenizer/tiktoken.cpp b/examples/models/llama2/tokenizer/tiktoken.cpp index 849a2ff1e8d..79b61e5eb64 100644 --- a/examples/models/llama2/tokenizer/tiktoken.cpp +++ b/examples/models/llama2/tokenizer/tiktoken.cpp @@ -364,9 +364,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) { Result Tiktoken::decode(uint64_t prev, uint64_t cur) { (void)prev; - if (!initialized_) { - return Error::NotSupported; - } + ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur)); std::string ret; std::string token_bytes; diff --git a/examples/models/llama2/tokenizer/tokenizer.h b/examples/models/llama2/tokenizer/tokenizer.h index 5e9f0925823..7ad3b32bbb8 100644 --- a/examples/models/llama2/tokenizer/tokenizer.h +++ b/examples/models/llama2/tokenizer/tokenizer.h @@ -40,6 +40,22 @@ class Tokenizer { virtual Result> encode(const std::string& input, int8_t bos, int8_t eos) = 0; + Error decode_verify(uint64_t token) const { + if (!initialized_) { + ET_LOG(Error, "Tokenizer not initialized"); + return Error::NotSupported; + } + if (token >= vocab_size_) { + ET_LOG( + Error, + "token %" PRIu64 " is out side of vacab range %d", + token, + vocab_size_); + return Error::NotSupported; + } + return Error::Ok; + } + virtual Result decode(uint64_t prev_token, uint64_t token) = 0; // getters