From dccbbc49f5d7f0fc3f694f202051db390af653b2 Mon Sep 17 00:00:00 2001 From: rvorster Date: Thu, 23 Mar 2023 00:37:49 -0700 Subject: [PATCH 1/5] Make this work with current format ggml models Breaking changes in ggml: - Scores added in vocabulary - Format version added - Magic updated --- llama-rs/src/ggml.rs | 5 ++++ llama-rs/src/lib.rs | 67 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/llama-rs/src/ggml.rs b/llama-rs/src/ggml.rs index 327d04d0..c9e9a26e 100644 --- a/llama-rs/src/ggml.rs +++ b/llama-rs/src/ggml.rs @@ -6,6 +6,11 @@ use std::{ pub use ggml_raw::ggml_type as Type; +pub const FILE_MAGIC: i32 = 0x67676d66; +pub const FILE_MAGIC_UNVERSIONED: i32 = 0x67676d6c; + +pub const FORMAT_VERSION: u32 = 1; + pub const TYPE_Q4_0: ggml_raw::ggml_type = ggml_raw::GGML_TYPE_Q4_0; pub const TYPE_Q4_1: ggml_raw::ggml_type = ggml_raw::GGML_TYPE_Q4_1; pub const TYPE_I32: ggml_raw::ggml_type = ggml_raw::GGML_TYPE_I32; diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index fcf7e7a9..b1bc7fd5 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -180,11 +180,16 @@ impl Display for InferenceStats { type TokenId = i32; type Token = String; +type TokenScore = f32; pub struct Vocabulary { /// Maps every integer (index) token id to its corresponding token id_to_token: Vec, + /// Maps every integer (index) token id to corresponding score + #[allow(dead_code)] + id_to_token_score: Vec, + /// Maps a token to a token id token_to_id: HashMap, @@ -300,8 +305,12 @@ pub enum LoadError { #[error("invalid integer conversion")] InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("unversioned magic number, regenerate your ggml models")] + UnversionedMagic, #[error("invalid magic number for {path:?}")] InvalidMagic { path: PathBuf }, + #[error("invalid file format version {value}")] + InvalidFormatVersion { value: u32 }, #[error("invalid value {value} for `f16` in hyperparameters")] HyperparametersF16Invalid { value: i32 }, #[error("unknown tensor `{tensor_name}` in {path:?}")] @@ -345,6 +354,28 @@ macro_rules! mulf { }; } +trait FromLeBytes { + fn from_le_bytes(bytes: [u8; 4]) -> Self; +} + +impl FromLeBytes for u32 { + fn from_le_bytes(bytes: [u8; 4]) -> Self { + return u32::from_le_bytes(bytes); + } +} + +impl FromLeBytes for i32 { + fn from_le_bytes(bytes: [u8; 4]) -> Self { + return i32::from_le_bytes(bytes); + } +} + +impl FromLeBytes for f32 { + fn from_le_bytes(bytes: [u8; 4]) -> Self { + return f32::from_le_bytes(bytes); + } +} + impl Model { pub fn load( path: impl AsRef, @@ -364,8 +395,7 @@ impl Model { })?, ); - /// Helper function. Reads an int from the buffer and returns it. - fn read_i32(reader: &mut impl BufRead) -> Result { + fn read_int(reader: &mut impl BufRead) -> Result { let mut bytes = [0u8; 4]; reader .read_exact(&mut bytes) @@ -373,9 +403,13 @@ impl Model { source: e, bytes: bytes.len(), })?; - Ok(i32::from_le_bytes(bytes)) + Ok(T::from_le_bytes(bytes)) } + let read_i32 = read_int::; + let read_u32 = read_int::; + let read_f32 = read_int::; + /// Helper function. Reads a string from the buffer and returns it. fn read_string(reader: &mut BufReader, len: usize) -> Result { let mut buf = vec![0; len]; @@ -391,12 +425,20 @@ impl Model { // Verify magic { - let magic = read_i32(&mut reader)?; - if magic != 0x67676d6c { - return Err(LoadError::InvalidMagic { - path: main_path.to_owned(), - }); - } + match read_i32(&mut reader)? { + ggml::FILE_MAGIC => true, + ggml::FILE_MAGIC_UNVERSIONED => return Err(LoadError::UnversionedMagic), + _ => return Err(LoadError::InvalidMagic { path: main_path.to_owned() }), + }; + } + + // Load format version + { + #[allow(unused_variables)] + let version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion { value: version }), + }; } // ================= @@ -426,6 +468,7 @@ impl Model { // =============== let vocab = { let mut id_to_token = vec![]; + let mut id_to_token_score = vec![]; let mut token_to_id = HashMap::new(); let mut max_token_length = 0; @@ -441,10 +484,16 @@ impl Model { }); id_to_token.push("�".to_string()); } + + // Token score, currently unused + if let Ok(score) = read_f32(&mut reader) { + id_to_token_score.push(score); + } } Vocabulary { id_to_token, + id_to_token_score, token_to_id, max_token_length, } From 42f26eab37bdc6ceef82ef5b478f7948c7338d94 Mon Sep 17 00:00:00 2001 From: rvorster Date: Thu, 23 Mar 2023 00:45:16 -0700 Subject: [PATCH 2/5] Add some semblance of support for legacy models --- llama-rs/src/lib.rs | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index b1bc7fd5..a74dfbf4 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -424,16 +424,14 @@ impl Model { } // Verify magic - { - match read_i32(&mut reader)? { - ggml::FILE_MAGIC => true, - ggml::FILE_MAGIC_UNVERSIONED => return Err(LoadError::UnversionedMagic), - _ => return Err(LoadError::InvalidMagic { path: main_path.to_owned() }), - }; - } + let is_legacy_model: bool = match read_i32(&mut reader)? { + ggml::FILE_MAGIC => false, + ggml::FILE_MAGIC_UNVERSIONED => true, + _ => return Err(LoadError::InvalidMagic { path: main_path.to_owned() }), + }; // Load format version - { + if !is_legacy_model { #[allow(unused_variables)] let version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, @@ -486,8 +484,13 @@ impl Model { } // Token score, currently unused - if let Ok(score) = read_f32(&mut reader) { - id_to_token_score.push(score); + if !is_legacy_model { + if let Ok(score) = read_f32(&mut reader) { + id_to_token_score.push(score); + } + } else { + // Legacy model, set empty score + id_to_token_score.push(0.); } } From 2f57987df637b1536957cf140ffe60038365f636 Mon Sep 17 00:00:00 2001 From: rvorster Date: Thu, 23 Mar 2023 01:13:44 -0700 Subject: [PATCH 3/5] Cleaner handling of different types in byte reader - Generalize u32, i32 and f32 reading without a boilerplate-y trait --- llama-rs/src/lib.rs | 47 +++++++++++++++------------------------------ 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index a74dfbf4..f069e23f 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -354,28 +354,6 @@ macro_rules! mulf { }; } -trait FromLeBytes { - fn from_le_bytes(bytes: [u8; 4]) -> Self; -} - -impl FromLeBytes for u32 { - fn from_le_bytes(bytes: [u8; 4]) -> Self { - return u32::from_le_bytes(bytes); - } -} - -impl FromLeBytes for i32 { - fn from_le_bytes(bytes: [u8; 4]) -> Self { - return i32::from_le_bytes(bytes); - } -} - -impl FromLeBytes for f32 { - fn from_le_bytes(bytes: [u8; 4]) -> Self { - return f32::from_le_bytes(bytes); - } -} - impl Model { pub fn load( path: impl AsRef, @@ -395,20 +373,25 @@ impl Model { })?, ); - fn read_int(reader: &mut impl BufRead) -> Result { - let mut bytes = [0u8; 4]; + fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; reader .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: bytes.len(), - })?; - Ok(T::from_le_bytes(bytes)) + .map_err(|e| LoadError::ReadExactFailed { source: e, bytes: N })?; + Ok(bytes) } - let read_i32 = read_int::; - let read_u32 = read_int::; - let read_f32 = read_int::; + fn read_i32(reader: &mut impl BufRead) -> Result { + return Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)); + } + + fn read_u32(reader: &mut impl BufRead) -> Result { + return Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)); + } + + fn read_f32(reader: &mut impl BufRead) -> Result { + return Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)); + } /// Helper function. Reads a string from the buffer and returns it. fn read_string(reader: &mut BufReader, len: usize) -> Result { From d59f9811de57bfa9a40b3e0a87f30551b052b8d5 Mon Sep 17 00:00:00 2001 From: rvorster Date: Thu, 23 Mar 2023 11:44:52 -0700 Subject: [PATCH 4/5] Fix formatting --- llama-rs/src/lib.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index f069e23f..abfb4f20 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -377,7 +377,10 @@ impl Model { let mut bytes = [0u8; N]; reader .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { source: e, bytes: N })?; + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: N, + })?; Ok(bytes) } @@ -410,7 +413,11 @@ impl Model { let is_legacy_model: bool = match read_i32(&mut reader)? { ggml::FILE_MAGIC => false, ggml::FILE_MAGIC_UNVERSIONED => true, - _ => return Err(LoadError::InvalidMagic { path: main_path.to_owned() }), + _ => { + return Err(LoadError::InvalidMagic { + path: main_path.to_owned(), + }) + } }; // Load format version From 1fa31d7cfa3303c64b30ad2541f35cc482f7b131 Mon Sep 17 00:00:00 2001 From: setzer22 Date: Thu, 23 Mar 2023 22:14:05 +0100 Subject: [PATCH 5/5] Fix clippy lints --- llama-rs/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index abfb4f20..316eff1f 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -385,15 +385,15 @@ impl Model { } fn read_i32(reader: &mut impl BufRead) -> Result { - return Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)); + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) } fn read_u32(reader: &mut impl BufRead) -> Result { - return Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)); + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) } fn read_f32(reader: &mut impl BufRead) -> Result { - return Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)); + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) } /// Helper function. Reads a string from the buffer and returns it.