From 00d9e602f972f19058ee3873dd2dd869c19259f8 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Wed, 8 May 2024 10:38:21 -0700 Subject: [PATCH] Error out when token is outside of vocab size (#3535) Summary: Ideally it shouldn't happen, but if we post process the weight somehow too much it might happen. In Android, it just seg fault directly if it's outside of the range without error message. After this change, it's clearer: ``` E 00:00:00.180911 executorch:bpe_tokenizer.cpp:155] token 18446744073709551615 is out side of vacab range 512 Aborted ``` Reviewed By: larryliu0820 Differential Revision: D57057026 --- .../models/llama2/tokenizer/bpe_tokenizer.cpp | 5 +---- .../llama2/tokenizer/test/test_bpe_tokenizer.cpp | 8 ++++++++ .../llama2/tokenizer/test/test_tiktoken.cpp | 9 +++++++++ examples/models/llama2/tokenizer/tiktoken.cpp | 4 +--- examples/models/llama2/tokenizer/tokenizer.h | 16 ++++++++++++++++ 5 files changed, 35 insertions(+), 7 deletions(-) diff --git a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp index ed7d34aca4d..7af2357d9be 100644 --- a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp +++ b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp @@ -146,10 +146,7 @@ BPETokenizer::~BPETokenizer() { * token. */ Result BPETokenizer::decode(uint64_t prev_token, uint64_t token) { - if (!initialized_) { - ET_LOG(Error, "Tokenizer not initialized"); - return Error::NotSupported; - } + ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token)); const char* piece = vocab_[token]; // following BOS token, sentencepiece decoder strips any leading // whitespace diff --git a/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp b/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp index 1d1f83065cf..e9eada338d5 100644 --- a/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp +++ b/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp @@ -39,6 +39,14 @@ TEST_F(TokenizerExtensionTest, DecodeWithoutLoadFails) { EXPECT_EQ(result.error(), Error::NotSupported); } +TEST_F(TokenizerExtensionTest, DecodeOutOfRangeFails) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + auto result = tokenizer_->decode(0, 64000); + // The vocab size is 32000, and token 64000 is out of vocab range. + EXPECT_EQ(result.error(), Error::NotSupported); +} + TEST_F(TokenizerExtensionTest, TokenizerVocabSizeIsExpected) { Error res = tokenizer_->load(modelPath_.c_str()); EXPECT_EQ(res, Error::Ok); diff --git a/examples/models/llama2/tokenizer/test/test_tiktoken.cpp b/examples/models/llama2/tokenizer/test/test_tiktoken.cpp index 2f08e2a1aa7..6130a9e858a 100644 --- a/examples/models/llama2/tokenizer/test/test_tiktoken.cpp +++ b/examples/models/llama2/tokenizer/test/test_tiktoken.cpp @@ -77,5 +77,14 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) { } } +TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + // The vocab size is 128256, addes 256 just so the token is out of vocab + // range. + Result out = tokenizer_->decode(0, 128256 + 256); + EXPECT_EQ(out.error(), Error::NotSupported); +} + } // namespace executor } // namespace torch diff --git a/examples/models/llama2/tokenizer/tiktoken.cpp b/examples/models/llama2/tokenizer/tiktoken.cpp index 849a2ff1e8d..79b61e5eb64 100644 --- a/examples/models/llama2/tokenizer/tiktoken.cpp +++ b/examples/models/llama2/tokenizer/tiktoken.cpp @@ -364,9 +364,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) { Result Tiktoken::decode(uint64_t prev, uint64_t cur) { (void)prev; - if (!initialized_) { - return Error::NotSupported; - } + ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur)); std::string ret; std::string token_bytes; diff --git a/examples/models/llama2/tokenizer/tokenizer.h b/examples/models/llama2/tokenizer/tokenizer.h index 5e9f0925823..7ad3b32bbb8 100644 --- a/examples/models/llama2/tokenizer/tokenizer.h +++ b/examples/models/llama2/tokenizer/tokenizer.h @@ -40,6 +40,22 @@ class Tokenizer { virtual Result> encode(const std::string& input, int8_t bos, int8_t eos) = 0; + Error decode_verify(uint64_t token) const { + if (!initialized_) { + ET_LOG(Error, "Tokenizer not initialized"); + return Error::NotSupported; + } + if (token >= vocab_size_) { + ET_LOG( + Error, + "token %" PRIu64 " is out side of vacab range %d", + token, + vocab_size_); + return Error::NotSupported; + } + return Error::Ok; + } + virtual Result decode(uint64_t prev_token, uint64_t token) = 0; // getters