Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions examples/models/llama2/tokenizer/bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,7 @@ BPETokenizer::~BPETokenizer() {
* token.
*/
Result<std::string> BPETokenizer::decode(uint64_t prev_token, uint64_t token) {
if (!initialized_) {
ET_LOG(Error, "Tokenizer not initialized");
return Error::NotSupported;
}
ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token));
const char* piece = vocab_[token];
// following BOS token, sentencepiece decoder strips any leading
// whitespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ TEST_F(TokenizerExtensionTest, DecodeWithoutLoadFails) {
EXPECT_EQ(result.error(), Error::NotSupported);
}

TEST_F(TokenizerExtensionTest, DecodeOutOfRangeFails) {
Error res = tokenizer_->load(modelPath_.c_str());
EXPECT_EQ(res, Error::Ok);
auto result = tokenizer_->decode(0, 64000);
// The vocab size is 32000, and token 64000 is out of vocab range.
EXPECT_EQ(result.error(), Error::NotSupported);
}

TEST_F(TokenizerExtensionTest, TokenizerVocabSizeIsExpected) {
Error res = tokenizer_->load(modelPath_.c_str());
EXPECT_EQ(res, Error::Ok);
Expand Down
9 changes: 9 additions & 0 deletions examples/models/llama2/tokenizer/test/test_tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,14 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) {
}
}

TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
Error res = tokenizer_->load(modelPath_.c_str());
EXPECT_EQ(res, Error::Ok);
// The vocab size is 128256, addes 256 just so the token is out of vocab
// range.
Result<std::string> out = tokenizer_->decode(0, 128256 + 256);
EXPECT_EQ(out.error(), Error::NotSupported);
}

} // namespace executor
} // namespace torch
4 changes: 1 addition & 3 deletions examples/models/llama2/tokenizer/tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) {

Result<std::string> Tiktoken::decode(uint64_t prev, uint64_t cur) {
(void)prev;
if (!initialized_) {
return Error::NotSupported;
}
ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur));
std::string ret;

std::string token_bytes;
Expand Down
16 changes: 16 additions & 0 deletions examples/models/llama2/tokenizer/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ class Tokenizer {
virtual Result<std::vector<uint64_t>>
encode(const std::string& input, int8_t bos, int8_t eos) = 0;

Error decode_verify(uint64_t token) const {
if (!initialized_) {
ET_LOG(Error, "Tokenizer not initialized");
return Error::NotSupported;
}
if (token >= vocab_size_) {
ET_LOG(
Error,
"token %" PRIu64 " is out side of vacab range %d",
token,
vocab_size_);
return Error::NotSupported;
}
return Error::Ok;
}

virtual Result<std::string> decode(uint64_t prev_token, uint64_t token) = 0;

// getters
Expand Down