From 40dc491582446901e1e4de1f9ef1e43ab9a76479 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 6 Apr 2023 18:22:44 +0000 Subject: [PATCH 01/10] Add loader stub for GGJT --- ggml/src/lib.rs | 6 +- llama-rs/src/lib.rs | 322 +++++------------------------------------ llama-rs/src/loader.rs | 293 +++++++++++++++++++++++++++++++++++++ 3 files changed, 333 insertions(+), 288 deletions(-) create mode 100644 llama-rs/src/loader.rs diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 22c9eee8..71016770 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -14,8 +14,10 @@ use std::{ sync::{Arc, Weak}, }; -/// Magic constant for `ggml` files (versioned). -pub const FILE_MAGIC: u32 = 0x67676d66; +/// Magic constant for `ggml` files (versioned, ggmf). +pub const FILE_MAGIC_GGMF: u32 = 0x67676d66; +/// Magic constant for `ggml` files (versioned, ggjt). +pub const FILE_MAGIC_GGJT: u32 = 0x67676a74; /// Magic constant for `ggml` files (unversioned). pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c; diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index d5ef2a23..4be5cb20 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,6 +1,8 @@ #![deny(missing_docs)] //! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model. +mod loader; + use core::slice; use std::{ collections::HashMap, @@ -586,6 +588,7 @@ impl Model { n_context_tokens: usize, load_progress_callback: impl Fn(LoadProgress), ) -> Result<(Model, Vocabulary), LoadError> { + use loader::*; use std::fs::File; use std::io::BufReader; @@ -599,46 +602,11 @@ impl Model { })?, ); - fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) - } - - fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - let mut buf = vec![0; len]; - reader - .read_exact(&mut buf) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: buf.len(), - })?; - let s = String::from_utf8(buf)?; - Ok(s) - } - // Verify magic - let is_legacy_model: bool = match read_u32(&mut reader)? { - ggml::FILE_MAGIC => false, - ggml::FILE_MAGIC_UNVERSIONED => true, + let model_type: ModelType = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ModelType::GGMF, + ggml::FILE_MAGIC_GGJT => ModelType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ModelType::Unversioned, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -647,12 +615,14 @@ impl Model { }; // Load format version - if !is_legacy_model { - #[allow(unused_variables)] - let version: u32 = match read_u32(&mut reader)? { - ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, - version => return Err(LoadError::InvalidFormatVersion { value: version }), - }; + match model_type { + ModelType::GGMF | ModelType::GGJT => { + let _version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion { value: version }), + }; + } + ModelType::Unversioned => {} } // ================= @@ -687,8 +657,12 @@ impl Model { let mut max_token_length = 0; for i in 0..hparams.n_vocab { - let len = read_i32(&mut reader)?; - if let Ok(word) = read_string(&mut reader, len as usize) { + let len = match model_type { + // `read_i32` maybe a typo + ModelType::GGMF | ModelType::Unversioned => read_i32(&mut reader)? as usize, + ModelType::GGJT => read_u32(&mut reader)? as usize, + }; + if let Ok(word) = read_string(&mut reader, len) { max_token_length = max_token_length.max(word.len()); id_to_token.push(word.clone()); token_to_id.insert(word, TokenId::try_from(i)?); @@ -698,13 +672,16 @@ impl Model { } // Token score, currently unused - if !is_legacy_model { - if let Ok(score) = read_f32(&mut reader) { - id_to_token_score.push(score); + match model_type { + ModelType::GGMF | ModelType::GGJT => { + if let Ok(score) = read_f32(&mut reader) { + id_to_token_score.push(score); + } + } + ModelType::Unversioned => { + // Legacy model, set empty score + id_to_token_score.push(0.); } - } else { - // Legacy model, set empty score - id_to_token_score.push(0.); } } @@ -831,240 +808,13 @@ impl Model { } }; - // Close the file, but keep its offset. That way we know how to skip the - // metadata when loading the parts. - let file_offset = reader.stream_position()?; - drop(reader); - - let paths = util::find_all_model_files(main_path)?; - let n_parts = paths.len(); - - for (i, part_path) in paths.into_iter().enumerate() { - let part_id = i; - - load_progress_callback(LoadProgress::PartLoading { - file: &part_path, - current_part: i, - total_parts: n_parts, - }); - - let mut part_reader = BufReader::new(File::open(&part_path)?); - - // Skip metadata - part_reader.seek(SeekFrom::Start(file_offset))?; - - let mut total_size = 0; - let mut n_tensors = 0; - - // Load weights - loop { - // NOTE: Implementation from #![feature(buf_read_has_data_left)] - let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; - - if is_eof { - break; - } - - let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; - let length = read_i32(&mut part_reader)?; - let ftype = read_u32(&mut part_reader)?; - - let mut nelements = 1; - let mut ne = [1i64, 1i64]; - - #[allow(clippy::needless_range_loop)] - for i in 0..n_dims { - ne[i] = read_i32(&mut part_reader)? as i64; - nelements *= usize::try_from(ne[i])?; - } - - let tensor_name = read_string(&mut part_reader, length as usize)?; - - let Some(tensor) = model.tensors.get(&tensor_name) - else { - return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); - }; - - // split_type = 0: split by columns - // split_type = 1: split by rows - // - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - #[allow(clippy::if_same_then_else)] - let split_type = if tensor_name.contains("tok_embeddings") { - 0 - } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { - 0 - } else { - 1 - } - } else if tensor_name.contains("output") { - 1 - } else { - 0 - }; - - if n_dims == 1 { - if tensor.nelements() != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.nelements() / n_parts != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if n_dims == 1 { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if split_type == 0 { - if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] - || tensor.get_ne()[1] != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.get_ne()[0] != ne[0] - || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - let bpe = match ftype { - 0 => ggml::type_size(ggml::Type::F32), - 1 => ggml::type_size(ggml::Type::F16), - 2 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_0) - } - 3 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_1) - } - _ => { - return Err(LoadError::InvalidFtype { - tensor_name, - ftype, - path: part_path, - }) - } - }; - - if n_dims == 1 || n_parts == 1 { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if part_id == 0 { - // SAFETY: yolo, same as original code - let slice = unsafe { - let data = tensor.data(); - std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) - }; - part_reader.read_exact(slice)?; - } else { - part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; - } - - total_size += tensor.nbytes(); - } else { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) - != tensor.nbytes() / n_parts - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if split_type == 0 { - let np0 = ne[0]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - assert_eq!(row_size, tensor.get_nb()[1]); - - for i1 in 0..ne[1] { - let offset_row = i1 as usize * row_size; - let offset = offset_row - + ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset); - let slice = std::slice::from_raw_parts_mut( - ptr as *mut u8, - row_size / n_parts, - ); - part_reader.read_exact(slice)?; - } - } - } else { - let np1 = ne[1]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - for i1 in 0..ne[1] { - let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset_row); - let slice = - std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); - part_reader.read_exact(slice)?; - } - } - } - - total_size += tensor.nbytes() / n_parts; - } - - n_tensors += 1; - load_progress_callback(LoadProgress::PartTensorLoaded { - file: &part_path, - current_tensor: n_tensors.try_into()?, - tensor_count: model.tensors.len(), - }); + match model_type { + ModelType::GGMF | ModelType::Unversioned => { + load_weights_ggmf_or_unversioned(reader, main_path, load_progress_callback, &model)? + } + ModelType::GGJT => { + load_weights_ggjt(reader, main_path, load_progress_callback, &model)? } - - load_progress_callback(LoadProgress::PartLoaded { - file: &part_path, - byte_size: total_size, - tensor_count: n_tensors.try_into()?, - }); } Ok((model, vocab)) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs new file mode 100644 index 00000000..4269538d --- /dev/null +++ b/llama-rs/src/loader.rs @@ -0,0 +1,293 @@ +use std::{fs::File, io::BufReader}; + +use crate::*; + +pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: N, + })?; + Ok(bytes) +} + +pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +/// Helper function. Reads a string from the buffer and returns it. +pub(crate) fn read_string(reader: &mut BufReader, len: usize) -> Result { + let mut buf = vec![0; len]; + reader + .read_exact(&mut buf) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: buf.len(), + })?; + let s = String::from_utf8(buf)?; + Ok(s) +} + +#[derive(PartialEq)] +pub(crate) enum ModelType { + GGMF, + GGJT, + Unversioned, +} + +pub(crate) fn load_weights_ggmf_or_unversioned( + mut reader: std::io::BufReader, + main_path: &Path, + load_progress_callback: impl Fn(LoadProgress), + model: &Model, +) -> Result<(), LoadError> { + let file_offset = reader.stream_position()?; + drop(reader); + + let paths = util::find_all_model_files(main_path)?; + + let n_parts = paths.len(); + Ok(for (i, part_path) in paths.into_iter().enumerate() { + let part_id = i; + + load_progress_callback(LoadProgress::PartLoading { + file: &part_path, + current_part: i, + total_parts: n_parts, + }); + + let mut part_reader = BufReader::new(File::open(&part_path)?); + + // Skip metadata + part_reader.seek(SeekFrom::Start(file_offset))?; + + let mut total_size = 0; + let mut n_tensors = 0; + + // Load weights + loop { + // NOTE: Implementation from #![feature(buf_read_has_data_left)] + let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; + + if is_eof { + break; + } + + let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; + let length = read_i32(&mut part_reader)?; + let ftype = read_u32(&mut part_reader)?; + + let mut nelements = 1; + let mut ne = [1i64, 1i64]; + + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + ne[i] = read_i32(&mut part_reader)? as i64; + nelements *= usize::try_from(ne[i])?; + } + + let tensor_name = read_string(&mut part_reader, length as usize)?; + + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); + }; + + // split_type = 0: split by columns + // split_type = 1: split by rows + // + // split_type = 0: + // regex: + // - tok_embeddings.* + // - layers.*.attention.wo.weight + // - layers.*.feed_forward.w2.weight + + // split_type = 1: + // regex: + // - output.* + // - layers.*.attention.wq.weight + // - layers.*.attention.wk.weight + // - layers.*.attention.wv.weight + // - layers.*.feed_forward.w1.weight + // - layers.*.feed_forward.w3.weight + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if split_type == 0 { + if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] + || tensor.get_ne()[1] != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + } else if tensor.get_ne()[0] != ne[0] + || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + let bpe = match ftype { + 0 => ggml::type_size(ggml::Type::F32), + 1 => ggml::type_size(ggml::Type::F16), + 2 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::Type::Q4_0) + } + 3 => { + assert_eq!(ne[0] % 64, 0); + ggml::type_size(ggml::Type::Q4_1) + } + _ => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: part_path, + }) + } + }; + + if n_dims == 1 || n_parts == 1 { + if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if part_id == 0 { + // SAFETY: yolo, same as original code + let slice = unsafe { + let data = tensor.data(); + std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) + }; + part_reader.read_exact(slice)?; + } else { + part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; + } + + total_size += tensor.nbytes(); + } else { + if (nelements * bpe) / ggml::blck_size(tensor.get_type()) + != tensor.nbytes() / n_parts + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if split_type == 0 { + let np0 = ne[0]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + assert_eq!(row_size, tensor.get_nb()[1]); + + for i1 in 0..ne[1] { + let offset_row = i1 as usize * row_size; + let offset = offset_row + + ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset); + let slice = + std::slice::from_raw_parts_mut(ptr as *mut u8, row_size / n_parts); + part_reader.read_exact(slice)?; + } + } + } else { + let np1 = ne[1]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + for i1 in 0..ne[1] { + let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset_row); + let slice = std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); + part_reader.read_exact(slice)?; + } + } + } + + total_size += tensor.nbytes() / n_parts; + } + + n_tensors += 1; + load_progress_callback(LoadProgress::PartTensorLoaded { + file: &part_path, + current_tensor: n_tensors.try_into()?, + tensor_count: model.tensors.len(), + }); + } + + load_progress_callback(LoadProgress::PartLoaded { + file: &part_path, + byte_size: total_size, + tensor_count: n_tensors.try_into()?, + }); + }) +} + +pub(crate) fn load_weights_ggjt( + mut reader: std::io::BufReader, + main_path: &Path, + load_progress_callback: impl Fn(LoadProgress), + model: &Model, +) -> Result<(), LoadError> { + todo!("GGJT load weights"); +} From 53ba1a97d4a9792001e33d4b3331487eaa002196 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 6 Apr 2023 19:52:14 +0000 Subject: [PATCH 02/10] Add loading code for ggjt Now it can load the model, but it's not working --- Cargo.lock | 313 +++++++++++++++++++++++++-------------- ggml/src/lib.rs | 10 +- llama-rs/Cargo.toml | 3 +- llama-rs/src/lib.rs | 25 +++- llama-rs/src/loader.rs | 323 +++++++++++++++++++++++++---------------- 5 files changed, 435 insertions(+), 239 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ac1053c..38bd0498 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,46 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstream" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-wincon", + "concolor-override", + "concolor-query", + "is-terminal", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2" + +[[package]] +name = "anstyle-parse" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-wincon" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa" +dependencies = [ + "anstyle", + "windows-sys 0.45.0", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -103,40 +143,45 @@ dependencies = [ [[package]] name = "clap" -version = "4.1.8" +version = "4.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" +checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" dependencies = [ - "bitflags", + "clap_builder", "clap_derive", - "clap_lex", - "is-terminal", "once_cell", +] + +[[package]] +name = "clap_builder" +version = "4.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" +dependencies = [ + "anstream", + "anstyle", + "bitflags", + "clap_lex", "strsim", - "termcolor", ] [[package]] name = "clap_derive" -version = "4.1.8" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" +checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" dependencies = [ "heck", - "proc-macro-error", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.13", ] [[package]] name = "clap_lex" -version = "0.3.2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09" -dependencies = [ - "os_str_bytes", -] +checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" [[package]] name = "clipboard-win" @@ -149,6 +194,21 @@ dependencies = [ "winapi", ] +[[package]] +name = "concolor-override" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f" + +[[package]] +name = "concolor-query" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf" +dependencies = [ + "windows-sys 0.45.0", +] + [[package]] name = "crossbeam-channel" version = "0.5.7" @@ -261,13 +321,13 @@ dependencies = [ [[package]] name = "errno" -version = "0.2.8" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0" dependencies = [ "errno-dragonfly", "libc", - "winapi", + "windows-sys 0.45.0", ] [[package]] @@ -292,13 +352,13 @@ dependencies = [ [[package]] name = "fd-lock" -version = "3.0.10" +version = "3.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ef1a30ae415c3a691a4f41afddc2dbcd6d70baf338368d85ebc1e8ed92cedb9" +checksum = "39ae6b3d9530211fb3b12a95374b8b0823be812f53d09e18c5675c0146b09642" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -310,9 +370,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", @@ -378,24 +438,25 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "io-lifetimes" -version = "1.0.6" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfa919a82ea574332e2de6e74b4c36e74d41982b335080fa59d4ef31be20fdf3" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" dependencies = [ + "hermit-abi 0.3.1", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "is-terminal" -version = "0.4.4" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857" +checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" dependencies = [ "hermit-abi 0.3.1", "io-lifetimes", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -436,9 +497,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.140" +version = "0.2.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" +checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" [[package]] name = "libloading" @@ -452,9 +513,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.1.4" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" [[package]] name = "llama-cli" @@ -479,6 +540,7 @@ dependencies = [ "bincode", "bytemuck", "ggml", + "memmap2", "partial_sort", "protobuf", "rand", @@ -510,6 +572,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memmap2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.8.0" @@ -572,12 +643,6 @@ version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" -[[package]] -name = "os_str_bytes" -version = "6.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" - [[package]] name = "partial_sort" version = "0.2.0" @@ -602,35 +667,11 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" -version = "1.0.52" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d0e1ae9e836cc3beddd63db0df682593d7e2d3d891ae8c9083d2113e1744224" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" dependencies = [ "unicode-ident", ] @@ -643,9 +684,9 @@ checksum = "8e86d370532557ae7573551a1ec8235a0f8d6cb276c7c9e6aa490b511c447485" [[package]] name = "quote" -version = "1.0.25" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5308e8208729c3e1504a6cfad0d5daacc4614c9a2e65d1ea312a34b5cb00fe84" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -734,9 +775,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" dependencies = [ "aho-corasick", "memchr", @@ -745,9 +786,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.28" +version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "rust_tokenizers" @@ -776,16 +817,16 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.36.9" +version = "0.37.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd5c6ff11fecd55b40746d1995a02f2eb375bf8c00d192d521ee09f42bef37bc" +checksum = "1aef160324be24d31a62147fae491c14d2204a3865c7ca8c3b0d7f7bcb3ea635" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -831,9 +872,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "771d4d9c4163ee138805e12c710dd365e4f44be8be0503cb1bb9eb989425d9c9" +checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" dependencies = [ "serde_derive", ] @@ -849,20 +890,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e801c1712f48475582b7696ac71e0ca34ebb30e09338425384269d9717c62cad" +checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" dependencies = [ "proc-macro2", "quote", - "syn 2.0.10", + "syn 2.0.13", ] [[package]] name = "serde_json" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" +checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" dependencies = [ "itoa", "ryu", @@ -945,9 +986,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.10" +version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aad1363ed6d37b84299588d62d3a7d95b5a5c2d9aad5c85609fda12afaa1f40" +checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec" dependencies = [ "proc-macro2", "quote", @@ -965,22 +1006,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.13", ] [[package]] @@ -1040,12 +1081,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1100,7 +1135,16 @@ version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" dependencies = [ - "windows-targets", + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", ] [[package]] @@ -1109,13 +1153,28 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", ] [[package]] @@ -1124,42 +1183,84 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + [[package]] name = "windows_i686_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + [[package]] name = "windows_i686_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + [[package]] name = "zstd" version = "0.12.3+zstd.1.5.2" @@ -1171,9 +1272,9 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "6.0.4+zstd.1.5.4" +version = "6.0.5+zstd.1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7afb4b54b8910cf5447638cb54bf4e8a65cbedd783af98b98c62ffe91f185543" +checksum = "d56d9e60b4b1758206c238a10165fbcae3ca37b01744e394c463463f6529d23b" dependencies = [ "libc", "zstd-sys", @@ -1181,9 +1282,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.7+zstd.1.5.4" +version = "2.0.8+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" +checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" dependencies = [ "cc", "libc", diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 71016770..4e11625e 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -354,13 +354,21 @@ impl Tensor { /// # Safety /// /// The data must not be mutated while being read from. - pub unsafe fn data(&self) -> *mut c_void { + pub unsafe fn data(&self) -> *const c_void { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive unsafe { *self.ptr.as_ptr() }.data }) } + /// Set the tensor's data pointer (useful for mmap-ed data) + pub unsafe fn set_data(&self, data_ptr: *mut c_void) { + self.with_alive_ctx(|| { + // SAFETY: The with_alive_call guarantees the context is alive + unsafe { *self.ptr.as_ptr() }.data = data_ptr; + }) + } + /// Number of elements in this tensor. pub fn nelements(&self) -> usize { self.with_alive_ctx(|| { diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 0e48cd58..b2e3aa15 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -16,6 +16,7 @@ rand = { workspace = true } serde = { version = "1.0.156", features = ["derive"] } serde_bytes = "0.11" bincode = "1.3.3" +memmap2 = "0.5.10" # Used for the `convert` feature serde_json = { version = "1.0.94", optional = true } @@ -23,4 +24,4 @@ protobuf = { version = "= 2.14.0", optional = true } rust_tokenizers = { version = "3.1.2", optional = true } [features] -convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] \ No newline at end of file +convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 4be5cb20..d5e70a4f 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -14,6 +14,7 @@ use std::{ }; use serde::Deserialize; +use memmap2::Mmap; use thiserror::Error; use partial_sort::PartialSort; @@ -73,6 +74,8 @@ pub struct Model { tensors: HashMap, + mmap: Option, + // Must be kept alive for the model _context: ggml::Context, } @@ -511,7 +514,7 @@ pub enum LoadError { /// The name of the tensor. tensor_name: String, /// The format type that was encountered. - ftype: u32, + ftype: i32, /// The path that failed. path: PathBuf, }, @@ -594,12 +597,13 @@ impl Model { let main_path = path.as_ref(); - let mut reader = - BufReader::new( - File::open(main_path).map_err(|e| LoadError::OpenFileFailed { + let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed { source: e, path: main_path.to_owned(), - })?, + })?; + let mut reader = + BufReader::new( + &file, ); // Verify magic @@ -741,7 +745,7 @@ impl Model { // Initialize the context let context = ggml::Context::init(ctx_size); - let model = { + let mut model = { let mut tensors = HashMap::new(); let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); @@ -805,15 +809,20 @@ impl Model { layers, tensors, _context: context, + mmap: None, } }; match model_type { ModelType::GGMF | ModelType::Unversioned => { - load_weights_ggmf_or_unversioned(reader, main_path, load_progress_callback, &model)? + let file_offset = reader.stream_position()?; + drop(reader); + load_weights_ggmf_or_unversioned(file_offset, main_path, load_progress_callback, &model)? } ModelType::GGJT => { - load_weights_ggjt(reader, main_path, load_progress_callback, &model)? + let mmap = unsafe { Mmap::map(&file)? }; + load_weights_ggjt(&mut reader, &mmap, main_path, load_progress_callback, &model)?; + model.mmap = Some(mmap); } } diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 4269538d..da65ec6c 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -26,7 +26,7 @@ pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { } /// Helper function. Reads a string from the buffer and returns it. -pub(crate) fn read_string(reader: &mut BufReader, len: usize) -> Result { +pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result { let mut buf = vec![0; len]; reader .read_exact(&mut buf) @@ -38,6 +38,11 @@ pub(crate) fn read_string(reader: &mut BufReader, len: usize) -> Result Result { + reader.fill_buf().map(|b| !b.is_empty()) +} + #[derive(PartialEq)] pub(crate) enum ModelType { GGMF, @@ -46,14 +51,11 @@ pub(crate) enum ModelType { } pub(crate) fn load_weights_ggmf_or_unversioned( - mut reader: std::io::BufReader, + file_offset: u64, main_path: &Path, load_progress_callback: impl Fn(LoadProgress), model: &Model, ) -> Result<(), LoadError> { - let file_offset = reader.stream_position()?; - drop(reader); - let paths = util::find_all_model_files(main_path)?; let n_parts = paths.len(); @@ -76,125 +78,23 @@ pub(crate) fn load_weights_ggmf_or_unversioned( // Load weights loop { - // NOTE: Implementation from #![feature(buf_read_has_data_left)] - let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; - - if is_eof { + if !has_data_left(&mut part_reader)? { break; } let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; let length = read_i32(&mut part_reader)?; - let ftype = read_u32(&mut part_reader)?; - - let mut nelements = 1; - let mut ne = [1i64, 1i64]; - - #[allow(clippy::needless_range_loop)] - for i in 0..n_dims { - ne[i] = read_i32(&mut part_reader)? as i64; - nelements *= usize::try_from(ne[i])?; - } - - let tensor_name = read_string(&mut part_reader, length as usize)?; - - let Some(tensor) = model.tensors.get(&tensor_name) - else { - return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); - }; - - // split_type = 0: split by columns - // split_type = 1: split by rows - // - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - #[allow(clippy::if_same_then_else)] - let split_type = if tensor_name.contains("tok_embeddings") { - 0 - } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { - 0 - } else { - 1 - } - } else if tensor_name.contains("output") { - 1 - } else { - 0 - }; - - if n_dims == 1 { - if tensor.nelements() != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.nelements() / n_parts != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if n_dims == 1 { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if split_type == 0 { - if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] - || tensor.get_ne()[1] != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.get_ne()[0] != ne[0] - || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - let bpe = match ftype { - 0 => ggml::type_size(ggml::Type::F32), - 1 => ggml::type_size(ggml::Type::F16), - 2 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_0) - } - 3 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_1) - } - _ => { - return Err(LoadError::InvalidFtype { - tensor_name, - ftype, - path: part_path, - }) - } - }; + let ftype = read_i32(&mut part_reader)?; + + let (nelements, ne, tensor_name, tensor, split_type, bpe) = load_tensor_header_ggmf( + n_dims, + &mut part_reader, + length, + model, + &part_path, + n_parts, + ftype, + )?; if n_dims == 1 || n_parts == 1 { if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { @@ -283,11 +183,188 @@ pub(crate) fn load_weights_ggmf_or_unversioned( }) } +fn load_tensor_header_ggmf<'a>( + n_dims: usize, + reader: &mut BufReader, + length: i32, + model: &'a Model, + path: &Path, + n_parts: usize, + ftype: i32, +) -> Result<(usize, [i64; 2], String, &'a ggml::Tensor, i32, usize), LoadError> { + let mut nelements = 1; + let mut ne = [1i64, 1i64]; + assert!(n_dims <= ne.len()); + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + ne[i] = read_i32(reader)? as i64; + nelements *= usize::try_from(ne[i])?; + } + let tensor_name = read_string(reader, length as usize)?; + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); + }; + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if split_type == 0 { + if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + let bpe = tensor_type_size(ftype, ne); + let bpe = match bpe { + Some(x) => x, + None => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: path.to_owned(), + }); + } + }; + Ok((nelements, ne, tensor_name, tensor, split_type, bpe)) +} + +fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { + let bpe = match ftype { + 0 => Some(ggml::type_size(ggml::Type::F32)), + 1 => Some(ggml::type_size(ggml::Type::F16)), + 2 => { + assert_eq!(ne[0] % 64, 0); + Some(ggml::type_size(ggml::Type::Q4_0)) + } + 3 => { + assert_eq!(ne[0] % 64, 0); + Some(ggml::type_size(ggml::Type::Q4_1)) + } + _ => None, + }; + bpe +} + pub(crate) fn load_weights_ggjt( - mut reader: std::io::BufReader, - main_path: &Path, + reader: &mut std::io::BufReader<&File>, + mmap: &Mmap, + path: &Path, load_progress_callback: impl Fn(LoadProgress), model: &Model, -) -> Result<(), LoadError> { - todo!("GGJT load weights"); +) -> Result<(), LoadError> +// where R: std::io::Read +{ + let mut loop_i = 0; + let mut total_loaded_bytes = 0; + load_progress_callback(LoadProgress::PartLoading { + file: path, + current_part: 0, + total_parts: 1, + }); + + loop { + if !has_data_left(reader)? { + break; + } + + let n_dims = read_i32(reader)? as usize; + let length = read_i32(reader)?; + let ftype = read_i32(reader)?; + + let mut nelements: usize = 1; + let mut ne = [1i64, 1]; + assert!(n_dims <= ne.len()); + for i in 0..n_dims { + let dim = read_i32(reader)? as usize; + ne[i] = dim as i64; + nelements *= dim; + } + let tensor_name = read_string(reader, length as usize)?; + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); + }; + + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + let tensor_ne = tensor.get_ne(); + if tensor_ne[0] != ne[0] || tensor_ne[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + + _ = tensor_type_size(ftype, ne); + + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & (31 ^ u64::MAX); + unsafe { + let ptr = mmap.as_ptr().offset(offset_aligned as isize); + tensor.set_data(ptr as *mut std::ffi::c_void); + } + let tensor_data_size = tensor.nbytes() as u64; + reader.seek(SeekFrom::Start(offset_aligned + tensor_data_size))?; + total_loaded_bytes += tensor_data_size; + + load_progress_callback(LoadProgress::PartTensorLoaded { + file: path, + current_tensor: loop_i, + tensor_count: model.tensors.len(), + }); + + loop_i += 1; + } + + load_progress_callback(LoadProgress::PartLoaded { + file: path, + byte_size: total_loaded_bytes as usize, + tensor_count: loop_i, + }); + + return Ok(()); } From 071612ecce849cb60eecca09e99dbb6a89f347f0 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 6 Apr 2023 20:09:40 +0000 Subject: [PATCH 03/10] code cleanup that doesn't change anything --- llama-rs/src/lib.rs | 46 ++++++++++++++++++++++++++++-------------- llama-rs/src/loader.rs | 11 +++++++++- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index d5e70a4f..5e44979c 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -598,13 +598,10 @@ impl Model { let main_path = path.as_ref(); let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: main_path.to_owned(), - })?; - let mut reader = - BufReader::new( - &file, - ); + source: e, + path: main_path.to_owned(), + })?; + let mut reader = BufReader::new(&file); // Verify magic let model_type: ModelType = match read_u32(&mut reader)? { @@ -666,13 +663,21 @@ impl Model { ModelType::GGMF | ModelType::Unversioned => read_i32(&mut reader)? as usize, ModelType::GGJT => read_u32(&mut reader)? as usize, }; - if let Ok(word) = read_string(&mut reader, len) { - max_token_length = max_token_length.max(word.len()); - id_to_token.push(word.clone()); - token_to_id.insert(word, TokenId::try_from(i)?); + let maybe_word = if len > 0 { + read_string(&mut reader, len) } else { - load_progress_callback(LoadProgress::BadToken { index: i }); - id_to_token.push("�".to_string()); + Ok("".into()) + }; + match maybe_word { + Ok(word) => { + max_token_length = max_token_length.max(word.len()); + id_to_token.push(word.clone()); + token_to_id.insert(word, TokenId::try_from(i)?); + } + Err(_e) => { + load_progress_callback(LoadProgress::BadToken { index: i }); + id_to_token.push("�".to_string()); + } } // Token score, currently unused @@ -817,11 +822,22 @@ impl Model { ModelType::GGMF | ModelType::Unversioned => { let file_offset = reader.stream_position()?; drop(reader); - load_weights_ggmf_or_unversioned(file_offset, main_path, load_progress_callback, &model)? + load_weights_ggmf_or_unversioned( + file_offset, + main_path, + load_progress_callback, + &model, + )? } ModelType::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; - load_weights_ggjt(&mut reader, &mmap, main_path, load_progress_callback, &model)?; + load_weights_ggjt( + &mut reader, + &mmap, + main_path, + load_progress_callback, + &model, + )?; model.mmap = Some(mmap); } } diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index da65ec6c..e7326e13 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -339,7 +339,16 @@ pub(crate) fn load_weights_ggjt( }); } - _ = tensor_type_size(ftype, ne); + match tensor_type_size(ftype, ne) { + Some(_) => {}, + None => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: path.to_owned(), + }); + } + }; let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & (31 ^ u64::MAX); From 3b9e3fed10c952d53c14a5c6be41241b31ae606e Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Thu, 6 Apr 2023 20:12:55 +0000 Subject: [PATCH 04/10] more code cleanup --- llama-rs/src/lib.rs | 36 ++++++++++++++++++++++++------------ llama-rs/src/loader.rs | 7 ------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 5e44979c..2c57c8db 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -60,6 +60,15 @@ struct Layer { w3: ggml::Tensor, } + +/// Model Version +#[derive(Debug, PartialEq, Clone, Copy)] +pub(crate) enum ModelVersion { + GGMF, + GGJT, + Unversioned, +} + /// The weights for the LLaMA model. All the mutable state is split into a /// separate struct `InferenceSession`. pub struct Model { @@ -75,6 +84,8 @@ pub struct Model { tensors: HashMap, mmap: Option, + + version: ModelVersion, // Must be kept alive for the model _context: ggml::Context, @@ -604,10 +615,10 @@ impl Model { let mut reader = BufReader::new(&file); // Verify magic - let model_type: ModelType = match read_u32(&mut reader)? { - ggml::FILE_MAGIC_GGMF => ModelType::GGMF, - ggml::FILE_MAGIC_GGJT => ModelType::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ModelType::Unversioned, + let model_type: ModelVersion = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ModelVersion::GGMF, + ggml::FILE_MAGIC_GGJT => ModelVersion::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ModelVersion::Unversioned, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -617,13 +628,13 @@ impl Model { // Load format version match model_type { - ModelType::GGMF | ModelType::GGJT => { + ModelVersion::GGMF | ModelVersion::GGJT => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion { value: version }), }; } - ModelType::Unversioned => {} + ModelVersion::Unversioned => {} } // ================= @@ -660,8 +671,8 @@ impl Model { for i in 0..hparams.n_vocab { let len = match model_type { // `read_i32` maybe a typo - ModelType::GGMF | ModelType::Unversioned => read_i32(&mut reader)? as usize, - ModelType::GGJT => read_u32(&mut reader)? as usize, + ModelVersion::GGMF | ModelVersion::Unversioned => read_i32(&mut reader)? as usize, + ModelVersion::GGJT => read_u32(&mut reader)? as usize, }; let maybe_word = if len > 0 { read_string(&mut reader, len) @@ -682,12 +693,12 @@ impl Model { // Token score, currently unused match model_type { - ModelType::GGMF | ModelType::GGJT => { + ModelVersion::GGMF | ModelVersion::GGJT => { if let Ok(score) = read_f32(&mut reader) { id_to_token_score.push(score); } } - ModelType::Unversioned => { + ModelVersion::Unversioned => { // Legacy model, set empty score id_to_token_score.push(0.); } @@ -815,11 +826,12 @@ impl Model { tensors, _context: context, mmap: None, + version: model_type, } }; match model_type { - ModelType::GGMF | ModelType::Unversioned => { + ModelVersion::GGMF | ModelVersion::Unversioned => { let file_offset = reader.stream_position()?; drop(reader); load_weights_ggmf_or_unversioned( @@ -829,7 +841,7 @@ impl Model { &model, )? } - ModelType::GGJT => { + ModelVersion::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; load_weights_ggjt( &mut reader, diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index e7326e13..602bd080 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -43,13 +43,6 @@ fn has_data_left(reader: &mut impl BufRead) -> Result { reader.fill_buf().map(|b| !b.is_empty()) } -#[derive(PartialEq)] -pub(crate) enum ModelType { - GGMF, - GGJT, - Unversioned, -} - pub(crate) fn load_weights_ggmf_or_unversioned( file_offset: u64, main_path: &Path, From f6f0aa0bba04aa6f562a4a41552fb927093ba038 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 13:54:45 +0000 Subject: [PATCH 05/10] minor change in math, tensor loading --- llama-rs/src/loader.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 602bd080..fa5ce73a 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -333,7 +333,7 @@ pub(crate) fn load_weights_ggjt( } match tensor_type_size(ftype, ne) { - Some(_) => {}, + Some(_) => {} None => { return Err(LoadError::InvalidFtype { tensor_name, @@ -344,7 +344,7 @@ pub(crate) fn load_weights_ggjt( }; let offset_curr = reader.stream_position()?; - let offset_aligned: u64 = (offset_curr + 31) & (31 ^ u64::MAX); + let offset_aligned: u64 = (offset_curr + 31) & !31; unsafe { let ptr = mmap.as_ptr().offset(offset_aligned as isize); tensor.set_data(ptr as *mut std::ffi::c_void); From 264920eb5d5e5a253526e811f41cdbc3c0f902c3 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 19:23:07 +0000 Subject: [PATCH 06/10] Add non-mmap loader for GGJT --- ggml/src/lib.rs | 2 +- llama-rs/Cargo.toml | 5 ++++- llama-rs/src/lib.rs | 44 +++++++++++++++++++++++++++--------------- llama-rs/src/loader.rs | 39 +++++++++++++++++++++++++++---------- 4 files changed, 62 insertions(+), 28 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 4e11625e..a081548d 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -354,7 +354,7 @@ impl Tensor { /// # Safety /// /// The data must not be mutated while being read from. - pub unsafe fn data(&self) -> *const c_void { + pub unsafe fn data(&self) -> *mut c_void { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive unsafe { *self.ptr.as_ptr() }.data diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index b2e3aa15..302a1389 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -16,7 +16,7 @@ rand = { workspace = true } serde = { version = "1.0.156", features = ["derive"] } serde_bytes = "0.11" bincode = "1.3.3" -memmap2 = "0.5.10" +memmap2 = { version = "0.5.10", optional = true } # Used for the `convert` feature serde_json = { version = "1.0.94", optional = true } @@ -25,3 +25,6 @@ rust_tokenizers = { version = "3.1.2", optional = true } [features] convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] + +# broken atm, see https://github.com/rustformers/llama-rs/pull/114#issuecomment-1500337463 +mmap = ["dep:memmap2"] diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 2c57c8db..d8c98fcc 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -2,6 +2,9 @@ //! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model. mod loader; +mod util; +#[cfg(feature = "convert")] +pub mod convert; use core::slice; use std::{ @@ -14,18 +17,31 @@ use std::{ }; use serde::Deserialize; -use memmap2::Mmap; use thiserror::Error; - use partial_sort::PartialSort; use rand::{distributions::WeightedIndex, prelude::Distribution}; - pub use ggml::Type as ElementType; -#[cfg(feature = "convert")] -pub mod convert; +#[cfg(feature = "mmap")] +use memmap2::Mmap; + +/// dummy struct +#[cfg(not(feature = "mmap"))] +pub(crate) struct Mmap; + +/// dummy impl +#[cfg(not(feature = "mmap"))] +impl Mmap { + pub(crate) unsafe fn map(_: &std::fs::File) -> Result { + Ok(Mmap) + } + pub(crate) fn as_ptr(&self) -> *const u8 { + std::ptr::null() + } +} +// map + -mod util; /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) @@ -60,7 +76,6 @@ struct Layer { w3: ggml::Tensor, } - /// Model Version #[derive(Debug, PartialEq, Clone, Copy)] pub(crate) enum ModelVersion { @@ -84,7 +99,7 @@ pub struct Model { tensors: HashMap, mmap: Option, - + version: ModelVersion, // Must be kept alive for the model @@ -671,7 +686,9 @@ impl Model { for i in 0..hparams.n_vocab { let len = match model_type { // `read_i32` maybe a typo - ModelVersion::GGMF | ModelVersion::Unversioned => read_i32(&mut reader)? as usize, + ModelVersion::GGMF | ModelVersion::Unversioned => { + read_i32(&mut reader)? as usize + } ModelVersion::GGJT => read_u32(&mut reader)? as usize, }; let maybe_word = if len > 0 { @@ -843,14 +860,9 @@ impl Model { } ModelVersion::GGJT => { let mmap = unsafe { Mmap::map(&file)? }; - load_weights_ggjt( - &mut reader, - &mmap, - main_path, - load_progress_callback, - &model, - )?; + let ptr = mmap.as_ptr(); model.mmap = Some(mmap); + load_weights_ggjt(&mut reader, ptr, main_path, load_progress_callback, &model)?; } } diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index fa5ce73a..9155f4e8 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -280,7 +280,7 @@ fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { pub(crate) fn load_weights_ggjt( reader: &mut std::io::BufReader<&File>, - mmap: &Mmap, + mmap_base: *const u8, path: &Path, load_progress_callback: impl Fn(LoadProgress), model: &Model, @@ -343,15 +343,9 @@ pub(crate) fn load_weights_ggjt( } }; - let offset_curr = reader.stream_position()?; - let offset_aligned: u64 = (offset_curr + 31) & !31; - unsafe { - let ptr = mmap.as_ptr().offset(offset_aligned as isize); - tensor.set_data(ptr as *mut std::ffi::c_void); - } - let tensor_data_size = tensor.nbytes() as u64; - reader.seek(SeekFrom::Start(offset_aligned + tensor_data_size))?; - total_loaded_bytes += tensor_data_size; + load_tensor(reader, mmap_base, tensor)?; + + total_loaded_bytes += tensor.nbytes() as u64; load_progress_callback(LoadProgress::PartTensorLoaded { file: path, @@ -370,3 +364,28 @@ pub(crate) fn load_weights_ggjt( return Ok(()); } + +#[cfg(feature = "mmap")] +fn load_tensor(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &ggml::Tensor) -> Result<(), LoadError> { + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + unsafe { + let ptr = mmap_base.offset(offset_aligned as isize); + tensor.set_data(ptr as *mut std::ffi::c_void); + } + reader.seek(SeekFrom::Start(offset_aligned + tensor.nbytes() as u8))?; + Ok(()) +} + +#[cfg(not(feature = "mmap"))] +fn load_tensor<'a>(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &'a ggml::Tensor) -> Result<(), LoadError> { + _ = mmap_base; + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + reader.seek(SeekFrom::Start(offset_aligned))?; + + let buf: &'a mut [u8] = unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; + reader.read_exact(buf)?; + + Ok(()) +} From 6a92d0eb4b3982b67baec01f365fb46f122fa144 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 19:37:36 +0000 Subject: [PATCH 07/10] Prefer traits in loader.rs --- llama-rs/src/loader.rs | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 9155f4e8..78228ce4 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -1,5 +1,3 @@ -use std::{fs::File, io::BufReader}; - use crate::*; pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { @@ -49,6 +47,8 @@ pub(crate) fn load_weights_ggmf_or_unversioned( load_progress_callback: impl Fn(LoadProgress), model: &Model, ) -> Result<(), LoadError> { + use std::{fs::File, io::BufReader}; + let paths = util::find_all_model_files(main_path)?; let n_parts = paths.len(); @@ -178,7 +178,7 @@ pub(crate) fn load_weights_ggmf_or_unversioned( fn load_tensor_header_ggmf<'a>( n_dims: usize, - reader: &mut BufReader, + reader: &mut impl BufRead, length: i32, model: &'a Model, path: &Path, @@ -279,7 +279,7 @@ fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { } pub(crate) fn load_weights_ggjt( - reader: &mut std::io::BufReader<&File>, + reader: &mut (impl BufRead + Seek), mmap_base: *const u8, path: &Path, load_progress_callback: impl Fn(LoadProgress), @@ -344,7 +344,7 @@ pub(crate) fn load_weights_ggjt( }; load_tensor(reader, mmap_base, tensor)?; - + total_loaded_bytes += tensor.nbytes() as u64; load_progress_callback(LoadProgress::PartTensorLoaded { @@ -366,7 +366,11 @@ pub(crate) fn load_weights_ggjt( } #[cfg(feature = "mmap")] -fn load_tensor(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &ggml::Tensor) -> Result<(), LoadError> { +fn load_tensor( + reader: &mut (impl BufRead + Seek), + mmap_base: *const u8, + tensor: &ggml::Tensor, +) -> Result<(), LoadError> { let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & !31; unsafe { @@ -378,13 +382,18 @@ fn load_tensor(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &ggm } #[cfg(not(feature = "mmap"))] -fn load_tensor<'a>(reader: &mut BufReader<&File>, mmap_base: *const u8, tensor: &'a ggml::Tensor) -> Result<(), LoadError> { +fn load_tensor<'a>( + reader: &mut (impl BufRead + Seek), + mmap_base: *const u8, + tensor: &'a ggml::Tensor, +) -> Result<(), LoadError> { _ = mmap_base; let offset_curr = reader.stream_position()?; let offset_aligned: u64 = (offset_curr + 31) & !31; reader.seek(SeekFrom::Start(offset_aligned))?; - let buf: &'a mut [u8] = unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; + let buf: &'a mut [u8] = + unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; reader.read_exact(buf)?; Ok(()) From 4548aa242d3d1a54ee4526a2be49cf8297645206 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 19:51:49 +0000 Subject: [PATCH 08/10] cargo fmt --- llama-rs/src/lib.rs | 12 +++++------- llama-rs/src/loader.rs | 6 +++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index d8c98fcc..e13603d5 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,10 +1,10 @@ #![deny(missing_docs)] //! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model. -mod loader; -mod util; #[cfg(feature = "convert")] pub mod convert; +mod loader; +mod util; use core::slice; use std::{ @@ -16,11 +16,11 @@ use std::{ time, }; -use serde::Deserialize; -use thiserror::Error; +pub use ggml::Type as ElementType; use partial_sort::PartialSort; use rand::{distributions::WeightedIndex, prelude::Distribution}; -pub use ggml::Type as ElementType; +use serde::Deserialize; +use thiserror::Error; #[cfg(feature = "mmap")] use memmap2::Mmap; @@ -41,8 +41,6 @@ impl Mmap { } // map - - /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 78228ce4..912f924c 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -343,7 +343,7 @@ pub(crate) fn load_weights_ggjt( } }; - load_tensor(reader, mmap_base, tensor)?; + load_tensor_ggjt(reader, mmap_base, tensor)?; total_loaded_bytes += tensor.nbytes() as u64; @@ -366,7 +366,7 @@ pub(crate) fn load_weights_ggjt( } #[cfg(feature = "mmap")] -fn load_tensor( +fn load_tensor_ggjt( reader: &mut (impl BufRead + Seek), mmap_base: *const u8, tensor: &ggml::Tensor, @@ -382,7 +382,7 @@ fn load_tensor( } #[cfg(not(feature = "mmap"))] -fn load_tensor<'a>( +fn load_tensor_ggjt<'a>( reader: &mut (impl BufRead + Seek), mmap_base: *const u8, tensor: &'a ggml::Tensor, From eab72357259b1e37058d11df91f3d1b08b1127dc Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 20:33:02 +0000 Subject: [PATCH 09/10] cargo clippy --fix --- ggml/src/lib.rs | 6 +++++- llama-rs/src/convert.rs | 4 ++-- llama-rs/src/lib.rs | 5 +++-- llama-rs/src/loader.rs | 15 +++++++++------ 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index a081548d..03ed698a 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -353,7 +353,7 @@ impl Tensor { /// /// # Safety /// - /// The data must not be mutated while being read from. + /// Only `std::slice::from_raw_parts_mut(tensor.data(), tensor.nbytes())` is safe to mutate. pub unsafe fn data(&self) -> *mut c_void { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive @@ -362,6 +362,10 @@ impl Tensor { } /// Set the tensor's data pointer (useful for mmap-ed data) + /// + /// # Safety + /// + /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. pub unsafe fn set_data(&self, data_ptr: *mut c_void) { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index 285cf3c0..f4f55996 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -28,12 +28,12 @@ pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) { let model_files = util::find_all_model_files(model_directory).unwrap(); for (i, _file) in model_files.iter().enumerate() { - let fname_out = model_directory.join(format!("rust-model-{}.bin", element_type)); + let fname_out = model_directory.join(format!("rust-model-{element_type}.bin")); let mut file = File::create(fname_out).expect("Unable to create file"); write_header(file.borrow_mut(), &hparams).unwrap(); write_tokens(file.borrow_mut(), &vocab).unwrap(); - let _fname_model = model_directory.join(format!("consolidated.0{}.pth", i)); + let _fname_model = model_directory.join(format!("consolidated.0{i}.pth")); // Todo process and write variables } } diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index e13603d5..ddffce02 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -76,6 +76,7 @@ struct Layer { /// Model Version #[derive(Debug, PartialEq, Clone, Copy)] +#[allow(clippy::upper_case_acronyms)] pub(crate) enum ModelVersion { GGMF, GGJT, @@ -98,7 +99,7 @@ pub struct Model { mmap: Option, - version: ModelVersion, + _version: ModelVersion, // Must be kept alive for the model _context: ggml::Context, @@ -841,7 +842,7 @@ impl Model { tensors, _context: context, mmap: None, - version: model_type, + _version: model_type, } }; diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 912f924c..fc3b2c61 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -52,7 +52,7 @@ pub(crate) fn load_weights_ggmf_or_unversioned( let paths = util::find_all_model_files(main_path)?; let n_parts = paths.len(); - Ok(for (i, part_path) in paths.into_iter().enumerate() { + for (i, part_path) in paths.into_iter().enumerate() { let part_id = i; load_progress_callback(LoadProgress::PartLoading { @@ -173,9 +173,11 @@ pub(crate) fn load_weights_ggmf_or_unversioned( byte_size: total_size, tensor_count: n_tensors.try_into()?, }); - }) + }; + Ok(()) } +#[allow(clippy::type_complexity)] fn load_tensor_header_ggmf<'a>( n_dims: usize, reader: &mut impl BufRead, @@ -262,7 +264,8 @@ fn load_tensor_header_ggmf<'a>( } fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { - let bpe = match ftype { + + match ftype { 0 => Some(ggml::type_size(ggml::Type::F32)), 1 => Some(ggml::type_size(ggml::Type::F16)), 2 => { @@ -274,8 +277,7 @@ fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { Some(ggml::type_size(ggml::Type::Q4_1)) } _ => None, - }; - bpe + } } pub(crate) fn load_weights_ggjt( @@ -307,6 +309,7 @@ pub(crate) fn load_weights_ggjt( let mut nelements: usize = 1; let mut ne = [1i64, 1]; assert!(n_dims <= ne.len()); + #[allow(clippy::needless_range_loop)] for i in 0..n_dims { let dim = read_i32(reader)? as usize; ne[i] = dim as i64; @@ -362,7 +365,7 @@ pub(crate) fn load_weights_ggjt( tensor_count: loop_i, }); - return Ok(()); + Ok(()) } #[cfg(feature = "mmap")] From 32925e71e2d8322b4754b6aa7f8b174cbee6c797 Mon Sep 17 00:00:00 2001 From: Locria Cyber <74560659+iacore@users.noreply.github.com> Date: Fri, 7 Apr 2023 20:36:09 +0000 Subject: [PATCH 10/10] Remove ggml::Tensor::set_data --- ggml/src/lib.rs | 22 +++++++++++----------- llama-rs/src/loader.rs | 3 +-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 03ed698a..89a6a32f 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -361,17 +361,17 @@ impl Tensor { }) } - /// Set the tensor's data pointer (useful for mmap-ed data) - /// - /// # Safety - /// - /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. - pub unsafe fn set_data(&self, data_ptr: *mut c_void) { - self.with_alive_ctx(|| { - // SAFETY: The with_alive_call guarantees the context is alive - unsafe { *self.ptr.as_ptr() }.data = data_ptr; - }) - } + // /// Set the tensor's data pointer (useful for mmap-ed data) + // /// + // /// # Safety + // /// + // /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. + // pub unsafe fn set_data(&self, data_ptr: *mut c_void) { + // self.with_alive_ctx(|| { + // // SAFETY: The with_alive_call guarantees the context is alive + // unsafe { *self.ptr.as_ptr() }.data = data_ptr; + // }) + // } /// Number of elements in this tensor. pub fn nelements(&self) -> usize { diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index fc3b2c61..e05483c8 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -173,7 +173,7 @@ pub(crate) fn load_weights_ggmf_or_unversioned( byte_size: total_size, tensor_count: n_tensors.try_into()?, }); - }; + } Ok(()) } @@ -264,7 +264,6 @@ fn load_tensor_header_ggmf<'a>( } fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { - match ftype { 0 => Some(ggml::type_size(ggml::Type::F32)), 1 => Some(ggml::type_size(ggml::Type::F16)),