diff --git a/llama-rs/src/ggml.rs b/llama-rs/src/ggml.rs index 327d04d0..c9e9a26e 100644 --- a/llama-rs/src/ggml.rs +++ b/llama-rs/src/ggml.rs @@ -6,6 +6,11 @@ use std::{ pub use ggml_raw::ggml_type as Type; +pub const FILE_MAGIC: i32 = 0x67676d66; +pub const FILE_MAGIC_UNVERSIONED: i32 = 0x67676d6c; + +pub const FORMAT_VERSION: u32 = 1; + pub const TYPE_Q4_0: ggml_raw::ggml_type = ggml_raw::GGML_TYPE_Q4_0; pub const TYPE_Q4_1: ggml_raw::ggml_type = ggml_raw::GGML_TYPE_Q4_1; pub const TYPE_I32: ggml_raw::ggml_type = ggml_raw::GGML_TYPE_I32; diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index fcf7e7a9..316eff1f 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -180,11 +180,16 @@ impl Display for InferenceStats { type TokenId = i32; type Token = String; +type TokenScore = f32; pub struct Vocabulary { /// Maps every integer (index) token id to its corresponding token id_to_token: Vec, + /// Maps every integer (index) token id to corresponding score + #[allow(dead_code)] + id_to_token_score: Vec, + /// Maps a token to a token id token_to_id: HashMap, @@ -300,8 +305,12 @@ pub enum LoadError { #[error("invalid integer conversion")] InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("unversioned magic number, regenerate your ggml models")] + UnversionedMagic, #[error("invalid magic number for {path:?}")] InvalidMagic { path: PathBuf }, + #[error("invalid file format version {value}")] + InvalidFormatVersion { value: u32 }, #[error("invalid value {value} for `f16` in hyperparameters")] HyperparametersF16Invalid { value: i32 }, #[error("unknown tensor `{tensor_name}` in {path:?}")] @@ -364,16 +373,27 @@ impl Model { })?, ); - /// Helper function. Reads an int from the buffer and returns it. - fn read_i32(reader: &mut impl BufRead) -> Result { - let mut bytes = [0u8; 4]; + fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; reader .read_exact(&mut bytes) .map_err(|e| LoadError::ReadExactFailed { source: e, - bytes: bytes.len(), + bytes: N, })?; - Ok(i32::from_le_bytes(bytes)) + Ok(bytes) + } + + fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) + } + + fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) + } + + fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) } /// Helper function. Reads a string from the buffer and returns it. @@ -390,13 +410,23 @@ impl Model { } // Verify magic - { - let magic = read_i32(&mut reader)?; - if magic != 0x67676d6c { + let is_legacy_model: bool = match read_i32(&mut reader)? { + ggml::FILE_MAGIC => false, + ggml::FILE_MAGIC_UNVERSIONED => true, + _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), - }); + }) } + }; + + // Load format version + if !is_legacy_model { + #[allow(unused_variables)] + let version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion { value: version }), + }; } // ================= @@ -426,6 +456,7 @@ impl Model { // =============== let vocab = { let mut id_to_token = vec![]; + let mut id_to_token_score = vec![]; let mut token_to_id = HashMap::new(); let mut max_token_length = 0; @@ -441,10 +472,21 @@ impl Model { }); id_to_token.push("�".to_string()); } + + // Token score, currently unused + if !is_legacy_model { + if let Ok(score) = read_f32(&mut reader) { + id_to_token_score.push(score); + } + } else { + // Legacy model, set empty score + id_to_token_score.push(0.); + } } Vocabulary { id_to_token, + id_to_token_score, token_to_id, max_token_length, }