Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llama-rs/src/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ use std::{

pub use ggml_raw::ggml_type as Type;

pub const FILE_MAGIC: i32 = 0x67676d66;
pub const FILE_MAGIC_UNVERSIONED: i32 = 0x67676d6c;

pub const FORMAT_VERSION: u32 = 1;

pub const TYPE_Q4_0: ggml_raw::ggml_type = ggml_raw::GGML_TYPE_Q4_0;
pub const TYPE_Q4_1: ggml_raw::ggml_type = ggml_raw::GGML_TYPE_Q4_1;
pub const TYPE_I32: ggml_raw::ggml_type = ggml_raw::GGML_TYPE_I32;
Expand Down
60 changes: 51 additions & 9 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,16 @@ impl Display for InferenceStats {

type TokenId = i32;
type Token = String;
type TokenScore = f32;

pub struct Vocabulary {
/// Maps every integer (index) token id to its corresponding token
id_to_token: Vec<Token>,

/// Maps every integer (index) token id to corresponding score
#[allow(dead_code)]
id_to_token_score: Vec<TokenScore>,

/// Maps a token to a token id
token_to_id: HashMap<Token, TokenId>,

Expand Down Expand Up @@ -300,8 +305,12 @@ pub enum LoadError {
#[error("invalid integer conversion")]
InvalidIntegerConversion(#[from] std::num::TryFromIntError),

#[error("unversioned magic number, regenerate your ggml models")]
UnversionedMagic,
#[error("invalid magic number for {path:?}")]
InvalidMagic { path: PathBuf },
#[error("invalid file format version {value}")]
InvalidFormatVersion { value: u32 },
#[error("invalid value {value} for `f16` in hyperparameters")]
HyperparametersF16Invalid { value: i32 },
#[error("unknown tensor `{tensor_name}` in {path:?}")]
Expand Down Expand Up @@ -364,16 +373,27 @@ impl Model {
})?,
);

/// Helper function. Reads an int from the buffer and returns it.
fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
let mut bytes = [0u8; 4];
fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
reader
.read_exact(&mut bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: bytes.len(),
bytes: N,
})?;
Ok(i32::from_le_bytes(bytes))
Ok(bytes)
}

fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

fn read_u32(reader: &mut impl BufRead) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

fn read_f32(reader: &mut impl BufRead) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

/// Helper function. Reads a string from the buffer and returns it.
Expand All @@ -390,13 +410,23 @@ impl Model {
}

// Verify magic
{
let magic = read_i32(&mut reader)?;
if magic != 0x67676d6c {
let is_legacy_model: bool = match read_i32(&mut reader)? {
ggml::FILE_MAGIC => false,
ggml::FILE_MAGIC_UNVERSIONED => true,
_ => {
return Err(LoadError::InvalidMagic {
path: main_path.to_owned(),
});
})
}
};

// Load format version
if !is_legacy_model {
#[allow(unused_variables)]
let version: u32 = match read_u32(&mut reader)? {
ggml::FORMAT_VERSION => ggml::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion { value: version }),
};
}

// =================
Expand Down Expand Up @@ -426,6 +456,7 @@ impl Model {
// ===============
let vocab = {
let mut id_to_token = vec![];
let mut id_to_token_score = vec![];
let mut token_to_id = HashMap::new();
let mut max_token_length = 0;

Expand All @@ -441,10 +472,21 @@ impl Model {
});
id_to_token.push("�".to_string());
}

// Token score, currently unused
if !is_legacy_model {
if let Ok(score) = read_f32(&mut reader) {
id_to_token_score.push(score);
}
} else {
// Legacy model, set empty score
id_to_token_score.push(0.);
}
}

Vocabulary {
id_to_token,
id_to_token_score,
token_to_id,
max_token_length,
}
Expand Down