Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace base64 {
using tokenizers::Error;
using tokenizers::Result;

Result<std::string> decode(const std::string_view &input);
Result<std::string> decode(const std::string_view& input);

namespace detail {

Expand Down Expand Up @@ -68,9 +68,12 @@ inline Error validate(uint32_t v) {
return Error::Ok;
}

inline Error decode(const std::string_view &input, std::string &output) {
TK_CHECK_OR_RETURN_ERROR(input.size() == 4, Base64DecodeFailure,
"input length must be 4, got %zu", input.size());
inline Error decode(const std::string_view& input, std::string& output) {
TK_CHECK_OR_RETURN_ERROR(
input.size() == 4,
Base64DecodeFailure,
"input length must be 4, got %zu",
input.size());

uint32_t val = 0;

Expand Down Expand Up @@ -100,10 +103,14 @@ inline Error decode(const std::string_view &input, std::string &output) {
return Error::Ok;
}

inline Error decode_1_padding(const std::string_view &input,
std::string &output) {
TK_CHECK_OR_RETURN_ERROR(input.size() == 3, Base64DecodeFailure,
"input length must be 3, got %zu", input.size());
inline Error decode_1_padding(
const std::string_view& input,
std::string& output) {
TK_CHECK_OR_RETURN_ERROR(
input.size() == 3,
Base64DecodeFailure,
"input length must be 3, got %zu",
input.size());

uint32_t val = 0;

Expand All @@ -127,10 +134,14 @@ inline Error decode_1_padding(const std::string_view &input,
return Error::Ok;
}

inline Error decode_2_padding(const std::string_view &input,
std::string &output) {
TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure,
"input length must be 2, got %zu", input.size());
inline Error decode_2_padding(
const std::string_view& input,
std::string& output) {
TK_CHECK_OR_RETURN_ERROR(
input.size() == 2,
Base64DecodeFailure,
"input length must be 2, got %zu",
input.size());

uint32_t val = 0;

Expand All @@ -150,12 +161,13 @@ inline Error decode_2_padding(const std::string_view &input,

} // namespace detail

inline tokenizers::Result<std::string> decode(const std::string_view &input) {
inline tokenizers::Result<std::string> decode(const std::string_view& input) {
TK_CHECK_OR_RETURN_ERROR(!input.empty(), Base64DecodeFailure, "empty input");

// Faster than `input.size() % 4`.
TK_CHECK_OR_RETURN_ERROR(
(input.size() & 3) == 0 && input.size() >= 4, Base64DecodeFailure,
(input.size() & 3) == 0 && input.size() >= 4,
Base64DecodeFailure,
"input length must be larger than 4 and is multiple of 4, got %zu",
input.size());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <memory>
#include <optional>
#include <string>
#include <unordered_set>
#include <vector>

#include <pytorch/tokenizers/map_utils.h>
Expand All @@ -24,47 +25,66 @@
namespace tokenizers {

class BPEModel : public Model {
public:
explicit BPEModel(detail::TokenMap token_map,
detail::TokenMap special_token_map,
std::optional<detail::TokenMap> merge_ranks,
std::unique_ptr<IRegex> special_token_regex,
bool byte_fallback, std::optional<uint64_t> unk_token_id,
std::optional<uint64_t> bos_token_id,
std::optional<uint64_t> eos_token_id);
public:
explicit BPEModel(
detail::TokenMap token_map,
detail::TokenMap special_token_map,
std::optional<detail::TokenMap> merge_ranks,
std::unique_ptr<IRegex> special_token_regex,
bool byte_fallback,
std::optional<uint64_t> unk_token_id,
std::optional<uint64_t> bos_token_id,
std::optional<uint64_t> eos_token_id,
std::unordered_set<std::string> rstrip_tokens = {},
std::unordered_set<std::string> lstrip_tokens = {});

~BPEModel() override = default;

Result<std::vector<uint64_t>>
tokenize(const std::string &piece) const override;
Result<std::vector<uint64_t>> tokenize(
const std::string& piece) const override;

Result<std::string> id_to_piece(uint64_t token) const override;
Result<uint64_t> piece_to_id(const std::string &token) const override;
Result<uint64_t> piece_to_id(const std::string& token) const override;

int32_t vocab_size() const override { return vocab_size_; }
int32_t vocab_size() const override {
return vocab_size_;
}

bool is_special_token(uint64_t token) const override;

bool is_loaded() const override { return initialized_; }
bool is_loaded() const override {
return initialized_;
}

std::pair<std::optional<std::string>, std::string>
split_with_allowed_special_token(const std::string &input,
size_t offset) const override;
split_with_allowed_special_token(const std::string& input, size_t offset)
const override;

uint64_t bos_token_id() const override { return bos_token_id_.value_or(0); }
bool special_token_has_rstrip(const std::string& token) const override {
return rstrip_tokens_.count(token) > 0;
}
bool special_token_has_lstrip(const std::string& token) const override {
return lstrip_tokens_.count(token) > 0;
}

uint64_t eos_token_id() const override { return eos_token_id_.value_or(0); }
uint64_t bos_token_id() const override {
return bos_token_id_.value_or(0);
}

private:
Result<std::pair<std::vector<uint64_t>, uint64_t>>
encode_with_special_token(const std::string &text) const;
uint64_t eos_token_id() const override {
return eos_token_id_.value_or(0);
}

Result<std::vector<uint64_t>>
byte_pair_encode(const std::string &piece) const;
private:
Result<std::pair<std::vector<uint64_t>, uint64_t>> encode_with_special_token(
const std::string& text) const;

std::vector<uint64_t>
byte_pair_merge(const std::string &piece, const detail::TokenMap &ranks,
std::function<uint64_t(uint64_t, uint64_t)> func) const;
Result<std::vector<uint64_t>> byte_pair_encode(const std::string& piece) const;

std::vector<uint64_t> byte_pair_merge(
const std::string& piece,
const detail::TokenMap& ranks,
std::function<uint64_t(uint64_t, uint64_t)> func) const;

// Real state
detail::TokenMap token_map_;
Expand All @@ -76,6 +96,8 @@ class BPEModel : public Model {
std::optional<uint64_t> unk_token_id_;
std::optional<uint64_t> bos_token_id_;
std::optional<uint64_t> eos_token_id_;
std::unordered_set<std::string> rstrip_tokens_;
std::unordered_set<std::string> lstrip_tokens_;

bool initialized_ = false;
int32_t vocab_size_ = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,53 +32,62 @@ namespace tokenizers {
namespace detail {

class BPETokenizerBase : public Tokenizer {
public:
Result<std::vector<uint64_t>> encode(const std::string &input, int8_t bos,
int8_t eos) const override;
public:
Result<std::vector<uint64_t>>
encode(const std::string& input, int8_t bos, int8_t eos) const override;

Result<std::string> id_to_piece(uint64_t token) const override;
Result<uint64_t> piece_to_id(const std::string &text) const override;
Result<uint64_t> piece_to_id(const std::string& text) const override;

Result<std::string> decode(uint64_t prev_token, uint64_t token,
bool skip_special_tokens = false) const override;
Result<std::string> decode(
uint64_t prev_token,
uint64_t token,
bool skip_special_tokens = false) const override;

protected:
protected:
explicit BPETokenizerBase() {}
virtual ~BPETokenizerBase() override {}

std::pair<std::optional<std::string>, std::string>
split_with_allowed_special_token_(const std::string &input,
const TokenMap &allowed_special) const;
split_with_allowed_special_token_(
const std::string& input,
const TokenMap& allowed_special) const;

std::pair<std::optional<std::string>, std::string>
split_with_allowed_special_token_(const std::string &input, size_t offset,
const TokenMap &allowed_special) const;
split_with_allowed_special_token_(
const std::string& input,
size_t offset,
const TokenMap& allowed_special) const;

Result<std::pair<std::vector<uint64_t>, uint64_t>>
encode_with_special_token_(const std::string &text,
const TokenMap &allowed_special) const;
Result<std::pair<std::vector<uint64_t>, uint64_t>> encode_with_special_token_(
const std::string& text,
const TokenMap& allowed_special) const;

virtual Result<std::vector<uint64_t>>
byte_pair_encode_(const std::string &piece, const TokenMap &encoder) const;
virtual Result<std::vector<uint64_t>> byte_pair_encode_(
const std::string& piece,
const TokenMap& encoder) const;

// Virtual method for BPE merging - can be overridden by derived classes
// The passed in `ranks` param for the base impl is just a regular token map
// and that the actual ranks are derived implicitly from the regular token
// map. This is the same implementation as Tiktoken.
virtual std::vector<uint64_t>
_byte_pair_merge(const std::string &piece, const TokenMap &ranks,
std::function<uint64_t(uint64_t, uint64_t)> func) const;
virtual std::vector<uint64_t> _byte_pair_merge(
const std::string& piece,
const TokenMap& ranks,
std::function<uint64_t(uint64_t, uint64_t)> func) const;

// Protected members that can be overloaded by other BPE tokenizers
std::unique_ptr<IRegex> special_token_regex_;
std::optional<TokenMap> token_map_;
std::optional<TokenMap> special_token_map_;

private:
virtual Error _encode(const std::string &input, std::vector<uint64_t> &ret,
uint64_t &last_piece_token_len) const = 0;
private:
virtual Error _encode(
const std::string& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const = 0;

virtual void _decode(const std::string &input, std::string &ret) const = 0;
virtual void _decode(const std::string& input, std::string& ret) const = 0;
};

} // namespace detail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,23 @@ enum class Error : error_code_t {
* @param[in] message__ Format string for the log error message.
* @param[in] ... Optional additional arguments for the format string.
*/
#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \
{ \
if (!(cond__)) { \
TK_LOG(Error, message__, ##__VA_ARGS__); \
return ::tokenizers::Error::error__; \
} \
#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \
{ \
if (!(cond__)) { \
TK_LOG(Error, message__, ##__VA_ARGS__); \
return ::tokenizers::Error::error__; \
} \
}

/**
* If error__ is not Error::Ok, return the specified Error
* @param[in] error__ Error enum value to return without the `Error::` prefix,
* like `Base64DecodeFailure`.
*/
#define TK_CHECK_OK_OR_RETURN_ERROR(error__) \
do { \
const auto et_error__ = (error__); \
if (et_error__ != ::tokenizers::Error::Ok) { \
return et_error__; \
} \
#define TK_CHECK_OK_OR_RETURN_ERROR(error__) \
do { \
const auto et_error__ = (error__); \
if (et_error__ != ::tokenizers::Error::Ok) { \
return et_error__; \
} \
} while (0)
Loading
Loading