From 1af69ec16029f5f5b499b9a37ebb338f75934dff Mon Sep 17 00:00:00 2001 From: Guy Garcia Date: Mon, 27 Mar 2023 00:51:30 -0400 Subject: [PATCH 01/16] initial quantize implementation --- Cargo.lock | 16 ++ ggml-raw/ggml/ggml.c | 110 +++++++++++++ ggml-raw/ggml/ggml.h | 8 + ggml-raw/src/lib.rs | 20 +++ llama-rs/Cargo.toml | 1 + llama-rs/src/ggml.rs | 48 +++++- llama-rs/src/lib.rs | 11 +- llama-rs/src/quantize.rs | 343 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 554 insertions(+), 3 deletions(-) create mode 100644 llama-rs/src/quantize.rs diff --git a/Cargo.lock b/Cargo.lock index 5fce7727..8f08c65d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "dirs-next" version = "2.0.0" @@ -192,6 +198,15 @@ dependencies = [ "cc", ] +[[package]] +name = "half" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" +dependencies = [ + "crunchy", +] + [[package]] name = "heck" version = "0.4.1" @@ -281,6 +296,7 @@ dependencies = [ "bincode", "bytemuck", "ggml-raw", + "half", "partial_sort", "rand", "serde", diff --git a/ggml-raw/ggml/ggml.c b/ggml-raw/ggml/ggml.c index 4eeaf28d..e636e725 100644 --- a/ggml-raw/ggml/ggml.c +++ b/ggml-raw/ggml/ggml.c @@ -397,6 +397,53 @@ static inline __m128i packNibbles( __m256i bytes ) } #endif +// method 5 +// blocks of QK elements +// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) + +// reference implementation for deterministic creation of model files +static void quantize_row_q4_0_reference(const float * restrict x, void * restrict y, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + const size_t bs = sizeof(float) + QK/2; + + uint8_t * restrict pd = ((uint8_t *)y + 0*bs); + uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float)); + + uint8_t pp[QK/2]; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 3) - 1); + const float id = d ? 1.0f/d : 0.0f; + + *(float *)pd = d; + pd += bs; + + for (int l = 0; l < QK; l += 2) { + const float v0 = x[i*QK + l + 0]*id; + const float v1 = x[i*QK + l + 1]*id; + + const uint8_t vi0 = ((int8_t) (round(v0))) + 8; + const uint8_t vi1 = ((int8_t) (round(v1))) + 8; + + assert(vi0 >= 0 && vi0 < 16); + assert(vi1 >= 0 && vi1 < 16); + + pp[l/2] = vi0 | (vi1 << 4); + } + + memcpy(pb, pp, sizeof(pp)); + pb += bs; + } +} // method 5 // blocks of QK elements @@ -10630,6 +10677,69 @@ enum ggml_opt_result ggml_opt( return result; } +//////////////////////////////////////////////////////////////////////////////// + +size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist) { + const int nb = k / qk; + const size_t bs = (sizeof(float) + sizeof(uint8_t)*qk/2); + const size_t row_size = nb*bs; + + assert(k % qk == 0); + + char * pdst = (char *) dst; + + for (int j = 0; j < n; j += k) { + uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs); + uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float)); + + quantize_row_q4_0_reference(src + j, pd, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < qk; l += 2) { + const uint8_t vi0 = pb[l/2] & 0xF; + const uint8_t vi1 = pb[l/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + pb += bs; + } + } + + return (n/k)*row_size; +} + +size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist) { + const int nb = k / qk; + const size_t bs = (2*sizeof(float) + sizeof(uint8_t)*qk/2); + const size_t row_size = nb*bs; + + assert(k % qk == 0); + + char * pdst = (char *) dst; + + for (int j = 0; j < n; j += k) { + uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs); + uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float)); + + quantize_row_q4_1(src + j, pd, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < qk; l += 2) { + const uint8_t vi0 = pb[l/2] & 0xF; + const uint8_t vi1 = pb[l/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + pb += bs; + } + } + + return (n/k)*row_size; +} + + //////////////////////////////////////////////////////////////////////////////// int ggml_cpu_has_avx(void) { diff --git a/ggml-raw/ggml/ggml.h b/ggml-raw/ggml/ggml.h index bac4fe65..e79e028f 100644 --- a/ggml-raw/ggml/ggml.h +++ b/ggml-raw/ggml/ggml.h @@ -741,6 +741,14 @@ enum ggml_opt_result ggml_opt( struct ggml_opt_params params, struct ggml_tensor * f); +// +// quantization +// + +size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist); +size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist); + + // // system info // diff --git a/ggml-raw/src/lib.rs b/ggml-raw/src/lib.rs index 3c5d49d0..e4f5e418 100644 --- a/ggml-raw/src/lib.rs +++ b/ggml-raw/src/lib.rs @@ -228,4 +228,24 @@ extern "C" { pub fn ggml_build_forward_expand(cgraph: *mut ggml_cgraph, tensor: *mut ggml_tensor); pub fn ggml_graph_compute(ctx: *mut ggml_context, cgraph: *mut ggml_cgraph); + + pub fn ggml_quantize_q4_0( + src: *mut f32, + work: *mut c_void, + n: i32, + k: i32, + qk: i32, + hist: *mut i64, + ) -> usize; + + pub fn ggml_quantize_q4_1( + src: *mut f32, + work: *mut c_void, + n: i32, + k: i32, + qk: i32, + hist: *mut i64, + ) -> usize; + + pub fn ggml_fp16_to_fp32(x: u16) -> f32; } diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index a91b13eb..e005d9f8 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -8,6 +8,7 @@ rust-version = "1.65" [dependencies] bytemuck = "1.13.1" +half = "2.2.1" ggml-raw = { path = "../ggml-raw" } partial_sort = "0.2.0" thiserror = "1.0" diff --git a/llama-rs/src/ggml.rs b/llama-rs/src/ggml.rs index c9e9a26e..5e9b182b 100644 --- a/llama-rs/src/ggml.rs +++ b/llama-rs/src/ggml.rs @@ -6,8 +6,8 @@ use std::{ pub use ggml_raw::ggml_type as Type; -pub const FILE_MAGIC: i32 = 0x67676d66; -pub const FILE_MAGIC_UNVERSIONED: i32 = 0x67676d6c; +pub const FILE_MAGIC: u32 = 0x67676d66; +pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c; pub const FORMAT_VERSION: u32 = 1; @@ -291,3 +291,47 @@ pub fn type_sizef(x: ggml_raw::ggml_type) -> f64 { pub fn blck_size(t: Type) -> i32 { unsafe { ggml_raw::ggml_blck_size(t) } } + +pub fn quantize_q4_0( + src: &mut Vec, + work: &mut Vec, + n: i32, + k: i32, + qk: i32, + hist: &mut Vec, +) -> usize { + unsafe { + ggml_raw::ggml_quantize_q4_0( + src.as_mut_ptr(), + work.as_mut_ptr() as *mut c_void, + n, + k, + qk, + hist.as_mut_ptr(), + ) + } +} + +pub fn quantize_q4_1( + src: &mut Vec, + work: &mut Vec, + n: i32, + k: i32, + qk: i32, + hist: &mut Vec, +) -> usize { + unsafe { + ggml_raw::ggml_quantize_q4_1( + src.as_mut_ptr(), + work.as_mut_ptr() as *mut c_void, + n, + k, + qk, + hist.as_mut_ptr(), + ) + } +} + +pub fn ggml_fp16_to_fp32(x: u16) -> f32 { + unsafe { ggml_raw::ggml_fp16_to_fp32(x) } +} diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 4a5616ff..ac292d66 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,4 +1,5 @@ mod ggml; +mod quantize; use std::{ collections::{HashMap, VecDeque}, @@ -14,6 +15,7 @@ use thiserror::Error; use partial_sort::PartialSort; use rand::{distributions::WeightedIndex, prelude::Distribution}; +pub use quantize::llama_model_quantize; pub const EOD_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) #[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)] @@ -345,6 +347,11 @@ pub enum LoadError { source: std::io::Error, path: PathBuf, }, + #[error("could not create file {path:?}")] + CreateFileFailed { + source: std::io::Error, + path: PathBuf, + }, #[error("no parent path for {path:?}")] NoParentPath { path: PathBuf }, #[error("unable to read exactly {bytes} bytes")] @@ -374,6 +381,8 @@ pub enum LoadError { TensorWrongSize { tensor_name: String, path: PathBuf }, #[error("invalid ftype {ftype} in {path:?}")] InvalidFtype { ftype: i32, path: PathBuf }, + #[error("itype supplied was invalid: {0}")] + InvalidItype(u8), } #[derive(Error, Debug)] @@ -465,7 +474,7 @@ impl Model { } // Verify magic - let is_legacy_model: bool = match read_i32(&mut reader)? { + let is_legacy_model: bool = match read_u32(&mut reader)? { ggml::FILE_MAGIC => false, ggml::FILE_MAGIC_UNVERSIONED => true, _ => { diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs new file mode 100644 index 00000000..b6f93f08 --- /dev/null +++ b/llama-rs/src/quantize.rs @@ -0,0 +1,343 @@ +use crate::ggml::{ + quantize_q4_0, quantize_q4_1, FILE_MAGIC, FILE_MAGIC_UNVERSIONED, FORMAT_VERSION, TYPE_Q4_0, + TYPE_Q4_1, +}; +use crate::{Hyperparameters, LoadError, Vocabulary}; +use half::f16; +use std::fs::File; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::path::Path; +use thiserror::Error; + +const FTYPE_STR: [&str; 4] = ["f32", "f16", "q4_0", "q4_1"]; + +pub fn llama_model_quantize( + file_name_in: impl AsRef, + file_name_out: impl AsRef, + itype: u8, + qk: u8, +) -> Result { + let mut otype = TYPE_Q4_1; + + match itype { + 2 => otype = TYPE_Q4_0, + 3 => otype = TYPE_Q4_1, + _ => { + return Err(LoadError::InvalidItype(itype)); + } + }; + + let file_in = file_name_in.as_ref(); + let mut finp = BufReader::new(File::open(file_in).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: file_in.to_owned(), + })?); + + let file_out = file_name_out.as_ref(); + let mut fout = + BufWriter::new( + File::create(file_out).map_err(|e| LoadError::CreateFileFailed { + source: e, + path: file_out.to_owned(), + })?, + ); + + // Verify magic + { + let mut magic_buffer: [u8; 4] = [0; 4]; + finp.read_exact(&mut magic_buffer).unwrap(); + + let magic = u32::from_le_bytes(magic_buffer); + if magic == FILE_MAGIC_UNVERSIONED { + return Err(LoadError::UnversionedMagic); + } + if magic != FILE_MAGIC { + return Err(LoadError::InvalidMagic { + path: file_in.to_owned(), + }); + } + + fout.write(&magic_buffer).unwrap(); + + let mut version_buffer: [u8; 4] = [0; 4]; + finp.read_exact(&mut version_buffer).unwrap(); + + let format_version = u32::from_le_bytes(version_buffer); + + if format_version != FORMAT_VERSION { + return Err(LoadError::InvalidFormatVersion { + value: format_version, + }); + } + + fout.write(&version_buffer).unwrap(); + } + + let mut hparams = Hyperparameters::default(); + + // Load parameters + { + let mut buffer: [u8; 4] = [0; 4]; + finp.read_exact(&mut buffer).unwrap(); + hparams.n_vocab = i32::from_le_bytes(buffer); + println!("n_vocab: {}", hparams.n_vocab); + fout.write(&buffer).unwrap(); + + finp.read_exact(&mut buffer).unwrap(); + hparams.n_embd = i32::from_le_bytes(buffer); + println!("n_embd: {}", hparams.n_embd); + fout.write(&buffer).unwrap(); + + finp.read_exact(&mut buffer).unwrap(); + hparams.n_mult = i32::from_le_bytes(buffer); + println!("n_mult: {}", hparams.n_mult); + fout.write(&buffer).unwrap(); + + finp.read_exact(&mut buffer).unwrap(); + hparams.n_head = i32::from_le_bytes(buffer); + println!("n_head: {}", hparams.n_head); + fout.write(&buffer).unwrap(); + + finp.read_exact(&mut buffer).unwrap(); + hparams.n_layer = i32::from_le_bytes(buffer); + println!("n_layer: {}", hparams.n_layer); + fout.write(&buffer).unwrap(); + + finp.read_exact(&mut buffer).unwrap(); + hparams.n_rot = i32::from_le_bytes(buffer); + println!("n_rot: {}", hparams.n_rot); + fout.write(&buffer).unwrap(); + + finp.read_exact(&mut buffer).unwrap(); + hparams.f16_ = i32::from_le_bytes(buffer); + println!("f16_: {}", hparams.f16_); + fout.write(&(itype as i32).to_le_bytes()).unwrap(); + } + + // load vocab + let mut vocab = Vocabulary { + id_to_token: vec![], + id_to_token_score: vec![], + token_to_id: Default::default(), + max_token_length: 0, + }; + + { + let n_vocab = hparams.n_vocab; + + for i in 0..n_vocab { + let mut len_buffer = [0u8; 4]; + finp.read_exact(&mut len_buffer).unwrap(); + fout.write(&len_buffer).unwrap(); + let len = u32::from_le_bytes(len_buffer) as usize; + + let mut word_buffer = vec![0u8; len]; + finp.read_exact(word_buffer.as_mut_slice()).unwrap(); + fout.write(&word_buffer).unwrap(); + + let word = String::from_utf8_lossy(&word_buffer).to_string(); + + let mut score_buffer = [0u8; 4]; + finp.read_exact(&mut score_buffer).unwrap(); + fout.write(&score_buffer).unwrap(); + let score = f32::from_le_bytes(score_buffer); + + vocab.token_to_id.insert(word.clone(), i); + + vocab.id_to_token.push(word); + vocab.id_to_token_score.push(score); + } + } + + // Load weights + { + let mut total_size_org: usize = 0; + let mut total_size_new: usize = 0; + + let mut work: Vec = vec![]; + + let mut data_u8: Vec = vec![]; + let mut data_f16: Vec = vec![]; + let mut data_f32: Vec = vec![]; + + let mut hist_all: Vec = vec![0; 16]; + + loop { + let mut buffer = [0u8; 4]; + if finp.read_exact(&mut buffer).is_err() { + break; + }; + let n_dims = i32::from_le_bytes(buffer); + + if finp.read_exact(&mut buffer).is_err() { + break; + }; + let length = i32::from_le_bytes(buffer) as usize; + + if finp.read_exact(&mut buffer).is_err() { + break; + }; + let mut ftype = i32::from_le_bytes(buffer) as usize; + + println!("n_dims: {}, length: {}, ftype: {} ", n_dims, length, ftype); + + let mut nelements = 1i32; + let mut ne = [1i32, 1i32]; + for i in 0..n_dims { + finp.read_exact(&mut buffer).unwrap(); + ne[i as usize] = i32::from_le_bytes(buffer); + nelements *= ne[i as usize]; + } + + let mut name_buffer = vec![0u8; length]; + finp.read_exact(&mut name_buffer).unwrap(); + let name = String::from_utf8(name_buffer).unwrap(); + println!("Nelements: {}", nelements); + print!( + "{:>48} - [{:>5}, {:>5}], type = {:>6}", + format!("'{}'", name), + ne[0], + ne[1], + FTYPE_STR[ftype] + ); + + // Quantize only 2D tensors + let mut quantize = name.find("weight").is_some() && n_dims == 2; + + if quantize { + if ftype != 0 && ftype != 1 { + return Err(LoadError::InvalidFtype { + ftype: ftype as i32, + path: file_in.to_owned(), + }); + } + + data_f32.resize(nelements as usize, 0.0); + if ftype == 1 { + data_f16.resize(nelements as usize, 0); + + let mut buffer = vec![0u8; (nelements * 2) as usize]; + finp.read_exact(&mut buffer).unwrap(); + // Compute buffer + for (index, chunk) in buffer.chunks(2).enumerate() { + let i = u16::from_le_bytes([chunk[0], chunk[1]]); + data_f16[index] = i; + + //data_f32[index] = ggml_fp16_to_fp32(i); + data_f32[index] = f16::from_bits(i).to_f32(); + } + } else { + let mut buffer = vec![0u8; (nelements * 4) as usize]; + finp.read_exact(&mut buffer).unwrap(); + + for (index, chunk) in buffer.chunks(4).enumerate() { + data_f32[index] = + f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + } + } + + ftype = itype as usize; + } else { + // Determines the total bytes were dealing with + let bpe = (nelements * if ftype == 0 { 4 } else { 2 }) as usize; + + data_u8.resize(bpe, 0); + finp.read_exact(&mut data_u8).unwrap(); + } + + // Write data + fout.write(&n_dims.to_le_bytes()).unwrap(); + fout.write(&(length as i32).to_le_bytes()).unwrap(); + println!(" new ftype: {}", ftype); + println!("{:?}", name.as_bytes()); + fout.write(&(ftype as i32).to_le_bytes()).unwrap(); + + for i in 0..n_dims { + fout.write(&ne[i as usize].to_le_bytes()).unwrap(); + } + fout.write(name.as_bytes()).unwrap(); + + if quantize { + print!("quantizing .. "); + work.resize(nelements as usize, 0.0); + + let curr_size; + let mut hist_cur = vec![0; 16]; + + match otype { + TYPE_Q4_0 => { + curr_size = quantize_q4_0( + &mut data_f32, + &mut work, + nelements, + ne[0], + qk as i32, + &mut hist_cur, + ) + } + TYPE_Q4_1 => { + curr_size = quantize_q4_1( + &mut data_f32, + &mut work, + nelements, + ne[0], + qk as i32, + &mut hist_cur, + ) + } + _ => { + println!("Unsupported type"); + return Ok(false); + } + } + + // We divide curr size by 4 + for i in 0..curr_size / 4 { + fout.write(&work[i].to_le_bytes()).unwrap(); + } + + total_size_new += curr_size; + + print!( + "size = {:>8.2} MB -> {:>8.2} MB | hist: ", + nelements as f32 * 4.0 / 1024.0 / 1024.0, + curr_size as f32 / 1024.0 / 1024.0 + ); + + for (i, val) in hist_cur.iter().enumerate() { + hist_all[i] += val; + print!("{:>5.3} ", *val as f32 / nelements as f32); + } + println!(); + } else { + fout.write(&data_u8).unwrap(); + println!("size = {:>8.3} MB", data_u8.len() as f64 / 1024.0 / 1024.0); + total_size_new += data_u8.len(); + } + + total_size_org += (nelements * 4) as usize; + } + + println!( + "model size: {:>8.2}", + total_size_org as f32 / 1024.0 / 1024.0 + ); + + println!( + "quant size: {:>8.2}", + total_size_new as f32 / 1024.0 / 1024.0 + ); + + { + let sum_all: i64 = hist_all.iter().sum(); + + print!("hist: "); + for hist in hist_all { + print!("{:>5.3} ", hist as f32 / sum_all as f32); + } + println!(); + } + } + + return Ok(true); +} From a43f07acae91bdb48a41c4769a77eeb2b077a0c2 Mon Sep 17 00:00:00 2001 From: Guy Garcia Date: Mon, 27 Mar 2023 18:33:21 -0400 Subject: [PATCH 02/16] removed unwraps --- llama-rs/src/quantize.rs | 74 ++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs index b6f93f08..f7ee98fa 100644 --- a/llama-rs/src/quantize.rs +++ b/llama-rs/src/quantize.rs @@ -45,7 +45,7 @@ pub fn llama_model_quantize( // Verify magic { let mut magic_buffer: [u8; 4] = [0; 4]; - finp.read_exact(&mut magic_buffer).unwrap(); + finp.read_exact(&mut magic_buffer)?; let magic = u32::from_le_bytes(magic_buffer); if magic == FILE_MAGIC_UNVERSIONED { @@ -57,10 +57,10 @@ pub fn llama_model_quantize( }); } - fout.write(&magic_buffer).unwrap(); + fout.write(&magic_buffer)?; let mut version_buffer: [u8; 4] = [0; 4]; - finp.read_exact(&mut version_buffer).unwrap(); + finp.read_exact(&mut version_buffer)?; let format_version = u32::from_le_bytes(version_buffer); @@ -70,7 +70,7 @@ pub fn llama_model_quantize( }); } - fout.write(&version_buffer).unwrap(); + fout.write(&version_buffer)?; } let mut hparams = Hyperparameters::default(); @@ -78,40 +78,40 @@ pub fn llama_model_quantize( // Load parameters { let mut buffer: [u8; 4] = [0; 4]; - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; hparams.n_vocab = i32::from_le_bytes(buffer); println!("n_vocab: {}", hparams.n_vocab); - fout.write(&buffer).unwrap(); + fout.write(&buffer)?; - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; hparams.n_embd = i32::from_le_bytes(buffer); println!("n_embd: {}", hparams.n_embd); - fout.write(&buffer).unwrap(); + fout.write(&buffer)?; - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; hparams.n_mult = i32::from_le_bytes(buffer); println!("n_mult: {}", hparams.n_mult); - fout.write(&buffer).unwrap(); + fout.write(&buffer)?; - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; hparams.n_head = i32::from_le_bytes(buffer); println!("n_head: {}", hparams.n_head); - fout.write(&buffer).unwrap(); + fout.write(&buffer)?; - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; hparams.n_layer = i32::from_le_bytes(buffer); println!("n_layer: {}", hparams.n_layer); - fout.write(&buffer).unwrap(); + fout.write(&buffer)?; - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; hparams.n_rot = i32::from_le_bytes(buffer); println!("n_rot: {}", hparams.n_rot); - fout.write(&buffer).unwrap(); + fout.write(&buffer)?; - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; hparams.f16_ = i32::from_le_bytes(buffer); println!("f16_: {}", hparams.f16_); - fout.write(&(itype as i32).to_le_bytes()).unwrap(); + fout.write(&(itype as i32).to_le_bytes())?; } // load vocab @@ -127,19 +127,19 @@ pub fn llama_model_quantize( for i in 0..n_vocab { let mut len_buffer = [0u8; 4]; - finp.read_exact(&mut len_buffer).unwrap(); - fout.write(&len_buffer).unwrap(); + finp.read_exact(&mut len_buffer)?; + fout.write(&len_buffer)?; let len = u32::from_le_bytes(len_buffer) as usize; let mut word_buffer = vec![0u8; len]; - finp.read_exact(word_buffer.as_mut_slice()).unwrap(); - fout.write(&word_buffer).unwrap(); + finp.read_exact(word_buffer.as_mut_slice())?; + fout.write(&word_buffer)?; let word = String::from_utf8_lossy(&word_buffer).to_string(); let mut score_buffer = [0u8; 4]; - finp.read_exact(&mut score_buffer).unwrap(); - fout.write(&score_buffer).unwrap(); + finp.read_exact(&mut score_buffer)?; + fout.write(&score_buffer)?; let score = f32::from_le_bytes(score_buffer); vocab.token_to_id.insert(word.clone(), i); @@ -184,14 +184,14 @@ pub fn llama_model_quantize( let mut nelements = 1i32; let mut ne = [1i32, 1i32]; for i in 0..n_dims { - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; ne[i as usize] = i32::from_le_bytes(buffer); nelements *= ne[i as usize]; } let mut name_buffer = vec![0u8; length]; - finp.read_exact(&mut name_buffer).unwrap(); - let name = String::from_utf8(name_buffer).unwrap(); + finp.read_exact(&mut name_buffer)?; + let name = String::from_utf8(name_buffer)?; println!("Nelements: {}", nelements); print!( "{:>48} - [{:>5}, {:>5}], type = {:>6}", @@ -217,7 +217,7 @@ pub fn llama_model_quantize( data_f16.resize(nelements as usize, 0); let mut buffer = vec![0u8; (nelements * 2) as usize]; - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; // Compute buffer for (index, chunk) in buffer.chunks(2).enumerate() { let i = u16::from_le_bytes([chunk[0], chunk[1]]); @@ -228,7 +228,7 @@ pub fn llama_model_quantize( } } else { let mut buffer = vec![0u8; (nelements * 4) as usize]; - finp.read_exact(&mut buffer).unwrap(); + finp.read_exact(&mut buffer)?; for (index, chunk) in buffer.chunks(4).enumerate() { data_f32[index] = @@ -242,20 +242,20 @@ pub fn llama_model_quantize( let bpe = (nelements * if ftype == 0 { 4 } else { 2 }) as usize; data_u8.resize(bpe, 0); - finp.read_exact(&mut data_u8).unwrap(); + finp.read_exact(&mut data_u8)?; } // Write data - fout.write(&n_dims.to_le_bytes()).unwrap(); - fout.write(&(length as i32).to_le_bytes()).unwrap(); + fout.write(&n_dims.to_le_bytes())?; + fout.write(&(length as i32).to_le_bytes())?; println!(" new ftype: {}", ftype); println!("{:?}", name.as_bytes()); - fout.write(&(ftype as i32).to_le_bytes()).unwrap(); + fout.write(&(ftype as i32).to_le_bytes())?; for i in 0..n_dims { - fout.write(&ne[i as usize].to_le_bytes()).unwrap(); + fout.write(&ne[i as usize].to_le_bytes())?; } - fout.write(name.as_bytes()).unwrap(); + fout.write(name.as_bytes())?; if quantize { print!("quantizing .. "); @@ -293,7 +293,7 @@ pub fn llama_model_quantize( // We divide curr size by 4 for i in 0..curr_size / 4 { - fout.write(&work[i].to_le_bytes()).unwrap(); + fout.write(&work[i].to_le_bytes())?; } total_size_new += curr_size; @@ -310,7 +310,7 @@ pub fn llama_model_quantize( } println!(); } else { - fout.write(&data_u8).unwrap(); + fout.write(&data_u8)?; println!("size = {:>8.3} MB", data_u8.len() as f64 / 1024.0 / 1024.0); total_size_new += data_u8.len(); } From 918956ee4a1162eacd20cf6f7370927da8be1648 Mon Sep 17 00:00:00 2001 From: Guy Garcia Date: Wed, 29 Mar 2023 21:53:23 -0400 Subject: [PATCH 03/16] removed unused functions --- ggml-raw/src/lib.rs | 2 -- llama-rs/src/ggml.rs | 4 ---- 2 files changed, 6 deletions(-) diff --git a/ggml-raw/src/lib.rs b/ggml-raw/src/lib.rs index e4f5e418..5b14ccdb 100644 --- a/ggml-raw/src/lib.rs +++ b/ggml-raw/src/lib.rs @@ -246,6 +246,4 @@ extern "C" { qk: i32, hist: *mut i64, ) -> usize; - - pub fn ggml_fp16_to_fp32(x: u16) -> f32; } diff --git a/llama-rs/src/ggml.rs b/llama-rs/src/ggml.rs index 5e9b182b..da5905a2 100644 --- a/llama-rs/src/ggml.rs +++ b/llama-rs/src/ggml.rs @@ -331,7 +331,3 @@ pub fn quantize_q4_1( ) } } - -pub fn ggml_fp16_to_fp32(x: u16) -> f32 { - unsafe { ggml_raw::ggml_fp16_to_fp32(x) } -} From d3f22deef7d88e50856a6d72b7f0d8b9c6ac9726 Mon Sep 17 00:00:00 2001 From: Guy Garcia Date: Wed, 29 Mar 2023 21:54:05 -0400 Subject: [PATCH 04/16] minor fixes with clippy --- llama-rs/src/quantize.rs | 108 +++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 62 deletions(-) diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs index f7ee98fa..a2b5b69e 100644 --- a/llama-rs/src/quantize.rs +++ b/llama-rs/src/quantize.rs @@ -1,13 +1,11 @@ use crate::ggml::{ - quantize_q4_0, quantize_q4_1, FILE_MAGIC, FILE_MAGIC_UNVERSIONED, FORMAT_VERSION, TYPE_Q4_0, - TYPE_Q4_1, + quantize_q4_0, quantize_q4_1, FILE_MAGIC, FILE_MAGIC_UNVERSIONED, FORMAT_VERSION, }; use crate::{Hyperparameters, LoadError, Vocabulary}; use half::f16; use std::fs::File; use std::io::{BufReader, BufWriter, Read, Write}; use std::path::Path; -use thiserror::Error; const FTYPE_STR: [&str; 4] = ["f32", "f16", "q4_0", "q4_1"]; @@ -16,16 +14,10 @@ pub fn llama_model_quantize( file_name_out: impl AsRef, itype: u8, qk: u8, -) -> Result { - let mut otype = TYPE_Q4_1; - - match itype { - 2 => otype = TYPE_Q4_0, - 3 => otype = TYPE_Q4_1, - _ => { - return Err(LoadError::InvalidItype(itype)); - } - }; +) -> Result<(), LoadError> { + if itype != 2 && itype != 3 { + return Err(LoadError::InvalidItype(itype)); + } let file_in = file_name_in.as_ref(); let mut finp = BufReader::new(File::open(file_in).map_err(|e| LoadError::OpenFileFailed { @@ -57,7 +49,7 @@ pub fn llama_model_quantize( }); } - fout.write(&magic_buffer)?; + fout.write_all(&magic_buffer)?; let mut version_buffer: [u8; 4] = [0; 4]; finp.read_exact(&mut version_buffer)?; @@ -70,7 +62,7 @@ pub fn llama_model_quantize( }); } - fout.write(&version_buffer)?; + fout.write_all(&version_buffer)?; } let mut hparams = Hyperparameters::default(); @@ -81,37 +73,37 @@ pub fn llama_model_quantize( finp.read_exact(&mut buffer)?; hparams.n_vocab = i32::from_le_bytes(buffer); println!("n_vocab: {}", hparams.n_vocab); - fout.write(&buffer)?; + fout.write_all(&buffer)?; finp.read_exact(&mut buffer)?; hparams.n_embd = i32::from_le_bytes(buffer); println!("n_embd: {}", hparams.n_embd); - fout.write(&buffer)?; + fout.write_all(&buffer)?; finp.read_exact(&mut buffer)?; hparams.n_mult = i32::from_le_bytes(buffer); println!("n_mult: {}", hparams.n_mult); - fout.write(&buffer)?; + fout.write_all(&buffer)?; finp.read_exact(&mut buffer)?; hparams.n_head = i32::from_le_bytes(buffer); println!("n_head: {}", hparams.n_head); - fout.write(&buffer)?; + fout.write_all(&buffer)?; finp.read_exact(&mut buffer)?; hparams.n_layer = i32::from_le_bytes(buffer); println!("n_layer: {}", hparams.n_layer); - fout.write(&buffer)?; + fout.write_all(&buffer)?; finp.read_exact(&mut buffer)?; hparams.n_rot = i32::from_le_bytes(buffer); println!("n_rot: {}", hparams.n_rot); - fout.write(&buffer)?; + fout.write_all(&buffer)?; finp.read_exact(&mut buffer)?; hparams.f16_ = i32::from_le_bytes(buffer); println!("f16_: {}", hparams.f16_); - fout.write(&(itype as i32).to_le_bytes())?; + fout.write_all(&(itype as i32).to_le_bytes())?; } // load vocab @@ -128,18 +120,18 @@ pub fn llama_model_quantize( for i in 0..n_vocab { let mut len_buffer = [0u8; 4]; finp.read_exact(&mut len_buffer)?; - fout.write(&len_buffer)?; + fout.write_all(&len_buffer)?; let len = u32::from_le_bytes(len_buffer) as usize; let mut word_buffer = vec![0u8; len]; finp.read_exact(word_buffer.as_mut_slice())?; - fout.write(&word_buffer)?; + fout.write_all(&word_buffer)?; let word = String::from_utf8_lossy(&word_buffer).to_string(); let mut score_buffer = [0u8; 4]; finp.read_exact(&mut score_buffer)?; - fout.write(&score_buffer)?; + fout.write_all(&score_buffer)?; let score = f32::from_le_bytes(score_buffer); vocab.token_to_id.insert(word.clone(), i); @@ -202,7 +194,7 @@ pub fn llama_model_quantize( ); // Quantize only 2D tensors - let mut quantize = name.find("weight").is_some() && n_dims == 2; + let quantize = name.contains("weight") && n_dims == 2; if quantize { if ftype != 0 && ftype != 1 { @@ -246,54 +238,46 @@ pub fn llama_model_quantize( } // Write data - fout.write(&n_dims.to_le_bytes())?; - fout.write(&(length as i32).to_le_bytes())?; + fout.write_all(&n_dims.to_le_bytes())?; + fout.write_all(&(length as i32).to_le_bytes())?; println!(" new ftype: {}", ftype); println!("{:?}", name.as_bytes()); - fout.write(&(ftype as i32).to_le_bytes())?; + fout.write_all(&(ftype as i32).to_le_bytes())?; for i in 0..n_dims { - fout.write(&ne[i as usize].to_le_bytes())?; + fout.write_all(&ne[i as usize].to_le_bytes())?; } - fout.write(name.as_bytes())?; + fout.write_all(name.as_bytes())?; if quantize { print!("quantizing .. "); work.resize(nelements as usize, 0.0); - let curr_size; let mut hist_cur = vec![0; 16]; - match otype { - TYPE_Q4_0 => { - curr_size = quantize_q4_0( - &mut data_f32, - &mut work, - nelements, - ne[0], - qk as i32, - &mut hist_cur, - ) - } - TYPE_Q4_1 => { - curr_size = quantize_q4_1( - &mut data_f32, - &mut work, - nelements, - ne[0], - qk as i32, - &mut hist_cur, - ) - } - _ => { - println!("Unsupported type"); - return Ok(false); - } - } + let curr_size = if itype == 2 { + quantize_q4_0( + &mut data_f32, + &mut work, + nelements, + ne[0], + qk as i32, + &mut hist_cur, + ) + } else { + quantize_q4_1( + &mut data_f32, + &mut work, + nelements, + ne[0], + qk as i32, + &mut hist_cur, + ) + }; // We divide curr size by 4 - for i in 0..curr_size / 4 { - fout.write(&work[i].to_le_bytes())?; + for i in work.iter().take(curr_size / 4) { + fout.write_all(&i.to_le_bytes())?; } total_size_new += curr_size; @@ -310,7 +294,7 @@ pub fn llama_model_quantize( } println!(); } else { - fout.write(&data_u8)?; + fout.write_all(&data_u8)?; println!("size = {:>8.3} MB", data_u8.len() as f64 / 1024.0 / 1024.0); total_size_new += data_u8.len(); } @@ -339,5 +323,5 @@ pub fn llama_model_quantize( } } - return Ok(true); + Ok(()) } From fea5b22915356c11bd0f3b4caabc72d50b5f5927 Mon Sep 17 00:00:00 2001 From: Guy Garcia Date: Wed, 29 Mar 2023 21:59:40 -0400 Subject: [PATCH 05/16] removed redundant checks --- ggml-raw/ggml/ggml.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-raw/ggml/ggml.c b/ggml-raw/ggml/ggml.c index e636e725..96b63070 100644 --- a/ggml-raw/ggml/ggml.c +++ b/ggml-raw/ggml/ggml.c @@ -434,8 +434,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric const uint8_t vi0 = ((int8_t) (round(v0))) + 8; const uint8_t vi1 = ((int8_t) (round(v1))) + 8; - assert(vi0 >= 0 && vi0 < 16); - assert(vi1 >= 0 && vi1 < 16); + assert(vi0 < 16); + assert(vi1 < 16); pp[l/2] = vi0 | (vi1 << 4); } @@ -687,8 +687,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { const uint8_t vi0 = round(v0); const uint8_t vi1 = round(v1); - assert(vi0 >= 0 && vi0 < 16); - assert(vi1 >= 0 && vi1 < 16); + assert(vi0 < 16); + assert(vi1 < 16); pp[l/2] = vi0 | (vi1 << 4); } From 69d7ddce2b1c09028c34eee5c61027048b2fc2ff Mon Sep 17 00:00:00 2001 From: Guy Garcia Date: Sat, 1 Apr 2023 01:31:57 -0400 Subject: [PATCH 06/16] fixed according to comments --- llama-rs/src/file.rs | 82 ++++++++++++++ llama-rs/src/lib.rs | 40 +------ llama-rs/src/quantize.rs | 228 ++++++++++++++++----------------------- 3 files changed, 179 insertions(+), 171 deletions(-) create mode 100644 llama-rs/src/file.rs diff --git a/llama-rs/src/file.rs b/llama-rs/src/file.rs new file mode 100644 index 00000000..851d88f5 --- /dev/null +++ b/llama-rs/src/file.rs @@ -0,0 +1,82 @@ +use crate::LoadError; +pub use std::fs::File; +pub use std::io::{BufRead, BufReader, BufWriter, Read, Write}; + +fn read(reader: &mut impl BufRead, bytes: &mut [u8]) -> Result<(), LoadError> { + reader + .read_exact(bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: bytes.len(), + }) +} + +fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; + read(reader, &mut bytes)?; + Ok(bytes) +} + +fn rw( + reader: &mut impl BufRead, + writer: &mut impl Write, +) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; + read(reader, &mut bytes)?; + writer.write_all(&bytes)?; + Ok(bytes) +} + +pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { + Ok(i32::from_le_bytes(rw::<4>(reader, writer)?)) +} + +pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn rw_u32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { + Ok(u32::from_le_bytes(rw::<4>(reader, writer)?)) +} + +pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn rw_f32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { + Ok(f32::from_le_bytes(rw::<4>(reader, writer)?)) +} + +/// Helper function. Reads a string from the buffer and returns it. +pub(crate) fn read_string(reader: &mut BufReader, len: usize) -> Result { + let mut buf = vec![0; len]; + reader + .read_exact(&mut buf) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: buf.len(), + })?; + let s = String::from_utf8(buf)?; + Ok(s) +} + +pub(crate) fn rw_string( + reader: &mut impl BufRead, + writer: &mut impl Write, + len: usize, +) -> Result { + let mut buf = vec![0; len]; + reader + .read_exact(&mut buf) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: buf.len(), + })?; + writer.write_all(&buf)?; + let s = String::from_utf8_lossy(&buf); + Ok(s.into_owned()) +} diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index ac292d66..299644b0 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,3 +1,4 @@ +mod file; mod ggml; mod quantize; @@ -15,7 +16,7 @@ use thiserror::Error; use partial_sort::PartialSort; use rand::{distributions::WeightedIndex, prelude::Distribution}; -pub use quantize::llama_model_quantize; +pub use quantize::{quantize, QuantizeLoadProgress}; pub const EOD_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) #[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)] @@ -424,6 +425,7 @@ impl Model { n_ctx: i32, load_progress_callback: impl Fn(LoadProgress), ) -> Result<(Model, Vocabulary), LoadError> { + use crate::file::{read_f32, read_i32, read_string, read_u32}; use std::fs::File; use std::io::BufReader; @@ -437,42 +439,6 @@ impl Model { })?, ); - fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) - } - - fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - let mut buf = vec![0; len]; - reader - .read_exact(&mut buf) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: buf.len(), - })?; - let s = String::from_utf8(buf)?; - Ok(s) - } - // Verify magic let is_legacy_model: bool = match read_u32(&mut reader)? { ggml::FILE_MAGIC => false, diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs index a2b5b69e..fcc0999a 100644 --- a/llama-rs/src/quantize.rs +++ b/llama-rs/src/quantize.rs @@ -3,18 +3,44 @@ use crate::ggml::{ }; use crate::{Hyperparameters, LoadError, Vocabulary}; use half::f16; -use std::fs::File; -use std::io::{BufReader, BufWriter, Read, Write}; use std::path::Path; const FTYPE_STR: [&str; 4] = ["f32", "f16", "q4_0", "q4_1"]; -pub fn llama_model_quantize( +#[derive(Clone, PartialEq, PartialOrd, Debug)] +pub enum QuantizeLoadProgress<'a> { + HyperparametersLoaded(&'a Hyperparameters), + LoadingWeight { + name: &'a str, + size: [i32; 2], + elements: i32, + ftype: &'a str, + }, + Quantizing, + Quantized { + original_size: f32, + reduced_size: f32, + history: Vec, + }, + Skipped { + size: f32, + }, + Finished { + original_size: f32, + reduced_size: f32, + history: Vec, + }, +} + +pub fn quantize( file_name_in: impl AsRef, file_name_out: impl AsRef, itype: u8, qk: u8, + load_progress_callback: impl Fn(QuantizeLoadProgress), ) -> Result<(), LoadError> { + use crate::file::*; + if itype != 2 && itype != 3 { return Err(LoadError::InvalidItype(itype)); } @@ -36,10 +62,7 @@ pub fn llama_model_quantize( // Verify magic { - let mut magic_buffer: [u8; 4] = [0; 4]; - finp.read_exact(&mut magic_buffer)?; - - let magic = u32::from_le_bytes(magic_buffer); + let magic = rw_u32(&mut finp, &mut fout)?; if magic == FILE_MAGIC_UNVERSIONED { return Err(LoadError::UnversionedMagic); } @@ -49,62 +72,27 @@ pub fn llama_model_quantize( }); } - fout.write_all(&magic_buffer)?; - - let mut version_buffer: [u8; 4] = [0; 4]; - finp.read_exact(&mut version_buffer)?; - - let format_version = u32::from_le_bytes(version_buffer); - + let format_version = rw_u32(&mut finp, &mut fout)?; if format_version != FORMAT_VERSION { return Err(LoadError::InvalidFormatVersion { value: format_version, }); } - - fout.write_all(&version_buffer)?; } let mut hparams = Hyperparameters::default(); // Load parameters { - let mut buffer: [u8; 4] = [0; 4]; - finp.read_exact(&mut buffer)?; - hparams.n_vocab = i32::from_le_bytes(buffer); - println!("n_vocab: {}", hparams.n_vocab); - fout.write_all(&buffer)?; - - finp.read_exact(&mut buffer)?; - hparams.n_embd = i32::from_le_bytes(buffer); - println!("n_embd: {}", hparams.n_embd); - fout.write_all(&buffer)?; - - finp.read_exact(&mut buffer)?; - hparams.n_mult = i32::from_le_bytes(buffer); - println!("n_mult: {}", hparams.n_mult); - fout.write_all(&buffer)?; - - finp.read_exact(&mut buffer)?; - hparams.n_head = i32::from_le_bytes(buffer); - println!("n_head: {}", hparams.n_head); - fout.write_all(&buffer)?; - - finp.read_exact(&mut buffer)?; - hparams.n_layer = i32::from_le_bytes(buffer); - println!("n_layer: {}", hparams.n_layer); - fout.write_all(&buffer)?; - - finp.read_exact(&mut buffer)?; - hparams.n_rot = i32::from_le_bytes(buffer); - println!("n_rot: {}", hparams.n_rot); - fout.write_all(&buffer)?; - - finp.read_exact(&mut buffer)?; - hparams.f16_ = i32::from_le_bytes(buffer); - println!("f16_: {}", hparams.f16_); - fout.write_all(&(itype as i32).to_le_bytes())?; + hparams.n_vocab = rw_i32(&mut finp, &mut fout)?; + hparams.n_embd = rw_i32(&mut finp, &mut fout)?; + hparams.n_mult = rw_i32(&mut finp, &mut fout)?; + hparams.n_head = rw_i32(&mut finp, &mut fout)?; + hparams.n_layer = rw_i32(&mut finp, &mut fout)?; + hparams.n_rot = rw_i32(&mut finp, &mut fout)?; + hparams.f16_ = rw_i32(&mut finp, &mut fout)?; } + load_progress_callback(QuantizeLoadProgress::HyperparametersLoaded(&hparams)); // load vocab let mut vocab = Vocabulary { @@ -114,31 +102,14 @@ pub fn llama_model_quantize( max_token_length: 0, }; - { - let n_vocab = hparams.n_vocab; + for i in 0..hparams.n_vocab { + let len = rw_u32(&mut finp, &mut fout)? as usize; + let word = rw_string(&mut finp, &mut fout, len)?; + let score = rw_f32(&mut finp, &mut fout)?; - for i in 0..n_vocab { - let mut len_buffer = [0u8; 4]; - finp.read_exact(&mut len_buffer)?; - fout.write_all(&len_buffer)?; - let len = u32::from_le_bytes(len_buffer) as usize; - - let mut word_buffer = vec![0u8; len]; - finp.read_exact(word_buffer.as_mut_slice())?; - fout.write_all(&word_buffer)?; - - let word = String::from_utf8_lossy(&word_buffer).to_string(); - - let mut score_buffer = [0u8; 4]; - finp.read_exact(&mut score_buffer)?; - fout.write_all(&score_buffer)?; - let score = f32::from_le_bytes(score_buffer); - - vocab.token_to_id.insert(word.clone(), i); - - vocab.id_to_token.push(word); - vocab.id_to_token_score.push(score); - } + vocab.token_to_id.insert(word.clone(), i); + vocab.id_to_token.push(word); + vocab.id_to_token_score.push(score); } // Load weights @@ -155,43 +126,42 @@ pub fn llama_model_quantize( let mut hist_all: Vec = vec![0; 16]; loop { - let mut buffer = [0u8; 4]; - if finp.read_exact(&mut buffer).is_err() { + let n_dims: i32; + if let Ok(r) = read_i32(&mut finp) { + n_dims = r; + } else { break; - }; - let n_dims = i32::from_le_bytes(buffer); + } - if finp.read_exact(&mut buffer).is_err() { + let length: usize; + if let Ok(r) = read_i32(&mut finp) { + length = r as usize; + } else { break; - }; - let length = i32::from_le_bytes(buffer) as usize; + } - if finp.read_exact(&mut buffer).is_err() { + let mut ftype: i32; + if let Ok(r) = read_i32(&mut finp) { + ftype = r; + } else { break; - }; - let mut ftype = i32::from_le_bytes(buffer) as usize; - - println!("n_dims: {}, length: {}, ftype: {} ", n_dims, length, ftype); + } let mut nelements = 1i32; let mut ne = [1i32, 1i32]; for i in 0..n_dims { - finp.read_exact(&mut buffer)?; - ne[i as usize] = i32::from_le_bytes(buffer); + ne[i as usize] = read_i32(&mut finp)?; nelements *= ne[i as usize]; } - let mut name_buffer = vec![0u8; length]; - finp.read_exact(&mut name_buffer)?; - let name = String::from_utf8(name_buffer)?; - println!("Nelements: {}", nelements); - print!( - "{:>48} - [{:>5}, {:>5}], type = {:>6}", - format!("'{}'", name), - ne[0], - ne[1], - FTYPE_STR[ftype] - ); + let name = read_string(&mut finp, length)?; + + load_progress_callback(QuantizeLoadProgress::LoadingWeight { + name: &name, + size: ne, + elements: nelements, + ftype: FTYPE_STR[ftype as usize], + }); // Quantize only 2D tensors let quantize = name.contains("weight") && n_dims == 2; @@ -199,7 +169,7 @@ pub fn llama_model_quantize( if quantize { if ftype != 0 && ftype != 1 { return Err(LoadError::InvalidFtype { - ftype: ftype as i32, + ftype, path: file_in.to_owned(), }); } @@ -228,7 +198,7 @@ pub fn llama_model_quantize( } } - ftype = itype as usize; + ftype = itype as i32; } else { // Determines the total bytes were dealing with let bpe = (nelements * if ftype == 0 { 4 } else { 2 }) as usize; @@ -240,9 +210,7 @@ pub fn llama_model_quantize( // Write data fout.write_all(&n_dims.to_le_bytes())?; fout.write_all(&(length as i32).to_le_bytes())?; - println!(" new ftype: {}", ftype); - println!("{:?}", name.as_bytes()); - fout.write_all(&(ftype as i32).to_le_bytes())?; + fout.write_all(&(ftype).to_le_bytes())?; for i in 0..n_dims { fout.write_all(&ne[i as usize].to_le_bytes())?; @@ -250,7 +218,7 @@ pub fn llama_model_quantize( fout.write_all(name.as_bytes())?; if quantize { - print!("quantizing .. "); + load_progress_callback(QuantizeLoadProgress::Quantizing); work.resize(nelements as usize, 0.0); let mut hist_cur = vec![0; 16]; @@ -282,45 +250,37 @@ pub fn llama_model_quantize( total_size_new += curr_size; - print!( - "size = {:>8.2} MB -> {:>8.2} MB | hist: ", - nelements as f32 * 4.0 / 1024.0 / 1024.0, - curr_size as f32 / 1024.0 / 1024.0 - ); - + let mut new_hist = vec![]; for (i, val) in hist_cur.iter().enumerate() { hist_all[i] += val; - print!("{:>5.3} ", *val as f32 / nelements as f32); + new_hist.push(*val as f32 / nelements as f32); } - println!(); + + load_progress_callback(QuantizeLoadProgress::Quantized { + original_size: nelements as f32 * 4.0 / 1024.0 / 1024.0, + reduced_size: curr_size as f32 / 1024.0 / 1024.0, + history: new_hist, + }); } else { fout.write_all(&data_u8)?; - println!("size = {:>8.3} MB", data_u8.len() as f64 / 1024.0 / 1024.0); + load_progress_callback(QuantizeLoadProgress::Skipped { + size: data_u8.len() as f32 / 1024.0 / 1024.0, + }); total_size_new += data_u8.len(); } total_size_org += (nelements * 4) as usize; } - println!( - "model size: {:>8.2}", - total_size_org as f32 / 1024.0 / 1024.0 - ); - - println!( - "quant size: {:>8.2}", - total_size_new as f32 / 1024.0 / 1024.0 - ); - - { - let sum_all: i64 = hist_all.iter().sum(); - - print!("hist: "); - for hist in hist_all { - print!("{:>5.3} ", hist as f32 / sum_all as f32); - } - println!(); - } + let sum_all: i64 = hist_all.iter().sum(); + load_progress_callback(QuantizeLoadProgress::Finished { + original_size: total_size_org as f32 / 1024.0 / 1024.0, + reduced_size: total_size_new as f32 / 1024.0 / 1024.0, + history: hist_all + .iter() + .map(|hist| *hist as f32 / sum_all as f32) + .collect(), + }) } Ok(()) From 600de36f61b11ae59ea26898e6f8dcb35fb96e30 Mon Sep 17 00:00:00 2001 From: Philpax Date: Fri, 7 Apr 2023 00:53:17 +0200 Subject: [PATCH 07/16] feat: wire up quantize for CLI --- .vscode/settings.json | 2 +- llama-cli/Cargo.toml | 2 +- llama-cli/src/cli_args.rs | 16 +++++++++++++++- llama-cli/src/main.rs | 13 +++++++++++++ llama-rs/Cargo.toml | 7 +++++-- llama-rs/src/file.rs | 2 ++ llama-rs/src/lib.rs | 4 +--- llama-rs/src/quantize.rs | 10 +++++----- 8 files changed, 43 insertions(+), 13 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index ddda313c..ba494b2e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,3 @@ { - "rust-analyzer.cargo.features": ["convert"] + "rust-analyzer.cargo.features": ["convert", "quantize"] } diff --git a/llama-cli/Cargo.toml b/llama-cli/Cargo.toml index 28dc4f7d..d3fd6f98 100644 --- a/llama-cli/Cargo.toml +++ b/llama-cli/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -llama-rs = { path = "../llama-rs", features = ["convert"] } +llama-rs = { path = "../llama-rs", features = ["convert", "quantize"] } rand = { workspace = true } diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index 86820ef3..e04f3c83 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -39,6 +39,9 @@ pub enum Args { /// /// For reference, see [the PR](https://github.com/rustformers/llama-rs/pull/83). Convert(Box), + + /// Quantize a GGML model to 4-bit. + Quantize(Box), } #[derive(Parser, Debug)] @@ -244,7 +247,7 @@ fn parse_bias(s: &str) -> Result { pub struct ModelLoad { /// Where to load the model path from #[arg(long, short = 'm')] - pub model_path: String, + pub model_path: PathBuf, /// Sets the size of the context (in tokens). Allows feeding longer prompts. /// Note that this affects memory. @@ -367,6 +370,17 @@ pub struct Convert { pub element_type: ElementType, } +#[derive(Parser, Debug)] +pub struct Quantize { + /// The path to the model to quantize + #[arg()] + pub source: PathBuf, + + /// The path to save the quantized model to + #[arg()] + pub destination: PathBuf, +} + #[derive(Parser, Debug, ValueEnum, Clone, Copy)] pub enum ElementType { /// Quantized 4-bit (type 0). diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 631cf733..dade09d6 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -24,6 +24,7 @@ fn main() { Args::Repl(args) => interactive(&args, false), Args::ChatExperimental(args) => interactive(&args, true), Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.element_type.into()), + Args::Quantize(args) => quantize(&args), } } @@ -191,6 +192,18 @@ fn interactive( } } +fn quantize(args: &cli_args::Quantize) { + llama_rs::quantize::quantize( + &args.source, + &args.destination, + llama_rs::ElementType::Q4_0, + |p| { + println!("{p:?}"); + }, + ) + .unwrap(); +} + fn load_prompt_file_with_prompt( prompt_file: &cli_args::PromptFile, prompt: Option<&str>, diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index c0089d63..a3357eb7 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -10,7 +10,6 @@ rust-version = "1.65" ggml = { path = "../ggml" } bytemuck = "1.13.1" -half = "2.2.1" partial_sort = "0.2.0" thiserror = "1.0" rand = { workspace = true } @@ -23,5 +22,9 @@ serde_json = { version = "1.0.94", optional = true } protobuf = { version = "= 2.14.0", optional = true } rust_tokenizers = { version = "3.1.2", optional = true } +# Used for the `quantize` feature +half = { version = "2.2.1", optional = true } + [features] -convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] \ No newline at end of file +convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] +quantize = ["dep:half"] \ No newline at end of file diff --git a/llama-rs/src/file.rs b/llama-rs/src/file.rs index 851d88f5..70305b26 100644 --- a/llama-rs/src/file.rs +++ b/llama-rs/src/file.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use crate::LoadError; pub use std::fs::File; pub use std::io::{BufRead, BufReader, BufWriter, Read, Write}; diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 9c5e8e43..20ada9e3 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -21,6 +21,7 @@ pub use ggml::Type as ElementType; #[cfg(feature = "convert")] pub mod convert; +#[cfg(feature = "quantize")] pub mod quantize; mod file; @@ -523,9 +524,6 @@ pub enum LoadError { /// The path that failed. path: PathBuf, }, - /// An invalid `itype` was encountered. - #[error("itype supplied was invalid: {0}")] - InvalidItype(u8), } #[derive(Error, Debug)] diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs index 98d5316c..d089fa09 100644 --- a/llama-rs/src/quantize.rs +++ b/llama-rs/src/quantize.rs @@ -55,13 +55,13 @@ pub enum QuantizeProgress<'a> { pub fn quantize( file_name_in: impl AsRef, file_name_out: impl AsRef, - itype: u8, + ty: crate::ElementType, progress_callback: impl Fn(QuantizeProgress), ) -> Result<(), LoadError> { use crate::file::*; - if itype != 2 && itype != 3 { - return Err(LoadError::InvalidItype(itype)); + if !matches!(ty, crate::ElementType::Q4_0 | crate::ElementType::Q4_1) { + todo!("Unsupported quantization format. This should be an error.") } let file_in = file_name_in.as_ref(); @@ -218,7 +218,7 @@ pub fn quantize( } } - ftype = itype as u32; + ftype = ty.into(); } else { // Determines the total bytes were dealing with let bpe = (nelements * if ftype == 0 { 4 } else { 2 }) as usize; @@ -243,7 +243,7 @@ pub fn quantize( let mut hist_cur = vec![0; 16]; - let curr_size = if itype == 2 { + let curr_size = if matches!(ty, crate::ElementType::Q4_0) { unsafe { quantize_q4_0(&data_f32, &mut work, nelements, ne[0], &mut hist_cur) } } else { unsafe { quantize_q4_1(&data_f32, &mut work, nelements, ne[0], &mut hist_cur) } From de518301b450e6fef0a9c69e8cc2c4a4873a0718 Mon Sep 17 00:00:00 2001 From: Guy Garcia Date: Sun, 9 Apr 2023 01:45:35 -0400 Subject: [PATCH 08/16] Fixed merge related bugs --- llama-rs/src/quantize.rs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs index d089fa09..ff34bf9c 100644 --- a/llama-rs/src/quantize.rs +++ b/llama-rs/src/quantize.rs @@ -1,7 +1,9 @@ //! Implements quantization of weights. use crate::{Hyperparameters, LoadError, Vocabulary}; -use ggml::{quantize_q4_0, quantize_q4_1, FILE_MAGIC, FILE_MAGIC_UNVERSIONED, FORMAT_VERSION}; +use ggml::{ + quantize_q4_0, quantize_q4_1, Type, FILE_MAGIC, FILE_MAGIC_UNVERSIONED, FORMAT_VERSION, +}; use half::f16; use std::path::Path; @@ -60,9 +62,11 @@ pub fn quantize( ) -> Result<(), LoadError> { use crate::file::*; - if !matches!(ty, crate::ElementType::Q4_0 | crate::ElementType::Q4_1) { - todo!("Unsupported quantization format. This should be an error.") - } + let itype: i32 = match ty { + Type::Q4_0 => 2, + Type::Q4_1 => 3, + _ => todo!("Unsupported quantization format. This should be an error."), + }; let file_in = file_name_in.as_ref(); let mut finp = BufReader::new(File::open(file_in).map_err(|e| LoadError::OpenFileFailed { @@ -109,8 +113,10 @@ pub fn quantize( hparams.n_head = rw_i32(&mut finp, &mut fout)?.try_into()?; hparams.n_layer = rw_i32(&mut finp, &mut fout)?.try_into()?; hparams.n_rot = rw_i32(&mut finp, &mut fout)?.try_into()?; - hparams.f16_ = rw_i32(&mut finp, &mut fout)?.try_into()?; + hparams.f16_ = read_i32(&mut finp)?.try_into()?; + fout.write_all(&itype.to_le_bytes())?; } + progress_callback(QuantizeProgress::HyperparametersLoaded(&hparams)); // load vocab @@ -122,7 +128,7 @@ pub fn quantize( }; for i in 0..hparams.n_vocab { - let len = rw_u32(&mut finp, &mut fout)? as usize; + let len = rw_u32(&mut finp, &mut fout)?.try_into()?; let word = rw_string(&mut finp, &mut fout, len)?; let score = rw_f32(&mut finp, &mut fout)?; @@ -218,7 +224,7 @@ pub fn quantize( } } - ftype = ty.into(); + ftype = itype.try_into()?; } else { // Determines the total bytes were dealing with let bpe = (nelements * if ftype == 0 { 4 } else { 2 }) as usize; @@ -230,7 +236,7 @@ pub fn quantize( // Write data fout.write_all(&n_dims.to_le_bytes())?; fout.write_all(&(length as i32).to_le_bytes())?; - fout.write_all(&(ftype).to_le_bytes())?; + fout.write_all(&(ftype as i32).to_le_bytes())?; for i in 0..n_dims { fout.write_all(&ne[i as usize].to_le_bytes())?; @@ -249,7 +255,7 @@ pub fn quantize( unsafe { quantize_q4_1(&data_f32, &mut work, nelements, ne[0], &mut hist_cur) } }; - // We divide curr size by 4 + // We divide curr size by 4 since size refers to bytes for i in work.iter().take(curr_size / 4) { fout.write_all(&i.to_le_bytes())?; } From 5bc1e12af99c9a8a45ac964ed4dff9cd7ff1efd1 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 25 Apr 2023 04:23:56 +0200 Subject: [PATCH 09/16] ggml-loader -> ggml-format --- Cargo.lock | 4 +- Cargo.toml | 2 +- {ggml-loader => ggml-format}/Cargo.toml | 2 +- ggml-format/src/lib.rs | 36 +++++++++++++++ .../src/lib.rs => ggml-format/src/loader.rs | 45 +++++-------------- {ggml-loader => ggml-format}/src/util.rs | 44 ++++++++++++++---- llama-rs/Cargo.toml | 2 +- llama-rs/src/file.rs | 40 ----------------- llama-rs/src/lib.rs | 1 - llama-rs/src/loader.rs | 6 ++- llama-rs/src/loader2.rs | 40 ++++++++--------- llama-rs/src/loader_common.rs | 2 +- llama-rs/src/quantize.rs | 13 ++++-- 13 files changed, 121 insertions(+), 116 deletions(-) rename {ggml-loader => ggml-format}/Cargo.toml (90%) create mode 100644 ggml-format/src/lib.rs rename ggml-loader/src/lib.rs => ggml-format/src/loader.rs (88%) rename {ggml-loader => ggml-format}/src/util.rs (50%) delete mode 100644 llama-rs/src/file.rs diff --git a/Cargo.lock b/Cargo.lock index ad5040c1..9e3a7a33 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -446,7 +446,7 @@ dependencies = [ ] [[package]] -name = "ggml-loader" +name = "ggml-format" version = "0.1.0" dependencies = [ "ggml", @@ -629,7 +629,7 @@ version = "0.1.0" dependencies = [ "bytemuck", "ggml", - "ggml-loader", + "ggml-format", "half", "memmap2", "partial_sort", diff --git a/Cargo.toml b/Cargo.toml index f579b1c6..20eca429 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ members = [ "ggml-sys", "ggml", - "ggml-loader", + "ggml-format", "llama-rs", "llama-cli", "generate-ggml-bindings" diff --git a/ggml-loader/Cargo.toml b/ggml-format/Cargo.toml similarity index 90% rename from ggml-loader/Cargo.toml rename to ggml-format/Cargo.toml index 2d088758..99ba7216 100644 --- a/ggml-loader/Cargo.toml +++ b/ggml-format/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "ggml-loader" +name = "ggml-format" version = "0.1.0" edition = "2021" diff --git a/ggml-format/src/lib.rs b/ggml-format/src/lib.rs new file mode 100644 index 00000000..4e509090 --- /dev/null +++ b/ggml-format/src/lib.rs @@ -0,0 +1,36 @@ +//! standalone model loader +//! +//! Only the hyperparameter is llama-specific. Everything else can be reused for other LLM. +#![allow(clippy::nonminimal_bool)] + +pub mod util; + +mod loader; + +pub use loader::{ + load_model_from_reader, LoadError, LoadHandler, PartialHyperparameters, TensorDataTreatment, + TensorInfo, +}; + +pub type ElementType = ggml::Type; + +/// the format of the file containing the model +#[derive(Debug, PartialEq, Clone, Copy)] +#[allow(clippy::upper_case_acronyms)] +pub enum ContainerType { + /// legacy format, oldest ggml tensor file format + GGML, + /// also legacy format, newer than GGML, older than GGJT + GGMF, + /// mmap-able format + GGJT, +} +impl ContainerType { + pub fn support_mmap(&self) -> bool { + match self { + ContainerType::GGML => false, + ContainerType::GGMF => false, + ContainerType::GGJT => true, + } + } +} diff --git a/ggml-loader/src/lib.rs b/ggml-format/src/loader.rs similarity index 88% rename from ggml-loader/src/lib.rs rename to ggml-format/src/loader.rs index 00416902..c3da081c 100644 --- a/ggml-loader/src/lib.rs +++ b/ggml-format/src/loader.rs @@ -1,35 +1,14 @@ -//! standalone model loader -//! -//! Only the hyperparameter is llama-specific. Everything else can be reused for other LLM. -#![allow(clippy::nonminimal_bool)] - -pub mod util; - -use std::ops::ControlFlow; -use util::*; - -pub type ElementType = ggml::Type; - -/// the format of the file containing the model -#[derive(Debug, PartialEq, Clone, Copy)] -#[allow(clippy::upper_case_acronyms)] -pub enum ContainerType { - /// legacy format, oldest ggml tensor file format - GGML, - /// also legacy format, newer than GGML, older than GGJT - GGMF, - /// mmap-able format - GGJT, -} -impl ContainerType { - pub fn support_mmap(&self) -> bool { - match self { - ContainerType::GGML => false, - ContainerType::GGMF => false, - ContainerType::GGJT => true, - } - } -} +use std::{ + io::{BufRead, Seek, SeekFrom}, + ops::ControlFlow, +}; + +use crate::{ + util::{ + controlflow_to_result, has_data_left, read_bytes_with_len, read_f32, read_i32, read_u32, + }, + ContainerType, ElementType, +}; #[derive(Debug, thiserror::Error)] pub enum LoadError { @@ -169,7 +148,7 @@ pub fn load_model_from_reader( /// /// `align` /// align to 4 bytes before reading tensor weights -pub fn load_weights( +fn load_weights( reader: &mut R, handler: &mut impl LoadHandler, align: bool, diff --git a/ggml-loader/src/util.rs b/ggml-format/src/util.rs similarity index 50% rename from ggml-loader/src/util.rs rename to ggml-format/src/util.rs index 9a759aac..20c8ce05 100644 --- a/ggml-loader/src/util.rs +++ b/ggml-format/src/util.rs @@ -1,4 +1,5 @@ -pub use std::io::{BufRead, Seek, SeekFrom}; +pub use std::fs::File; +pub use std::io::{BufRead, BufReader, BufWriter, Read, Seek, SeekFrom, Write}; use std::ops::ControlFlow; use crate::LoadError; @@ -30,21 +31,46 @@ pub fn read_bytes_with_len( Ok(bytes) } +pub fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { + Ok(i32::from_le_bytes(rw::<4>(reader, writer)?)) +} + +pub fn rw_u32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { + Ok(u32::from_le_bytes(rw::<4>(reader, writer)?)) +} + +pub fn rw_f32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { + Ok(f32::from_le_bytes(rw::<4>(reader, writer)?)) +} + +pub fn rw_bytes_with_len( + reader: &mut impl BufRead, + writer: &mut impl Write, + len: usize, +) -> Result, std::io::Error> { + let mut buf = vec![0; len]; + reader.read_exact(&mut buf)?; + writer.write_all(&buf)?; + Ok(buf) +} + +fn rw( + reader: &mut impl BufRead, + writer: &mut impl Write, +) -> Result<[u8; N], std::io::Error> { + let bytes: [u8; N] = read_bytes(reader)?; + writer.write_all(&bytes)?; + Ok(bytes) +} + // NOTE: Implementation from #![feature(buf_read_has_data_left)] pub fn has_data_left(reader: &mut impl BufRead) -> Result { reader.fill_buf().map(|b| !b.is_empty()) } -pub fn controlflow_to_result(x: ControlFlow) -> Result> { +pub(crate) fn controlflow_to_result(x: ControlFlow) -> Result> { match x { ControlFlow::Continue(x) => Ok(x), ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)), } } - -pub fn result_to_controlflow>(x: Result) -> ControlFlow { - match x { - Ok(x) => ControlFlow::Continue(x), - Err(y) => ControlFlow::Break(y.into()), - } -} diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index c9af7d10..7b3e2b6e 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -8,7 +8,7 @@ rust-version = "1.65" [dependencies] ggml = { path = "../ggml" } -ggml-loader = { path = "../ggml-loader" } +ggml-format = { path = "../ggml-format" } rand = { workspace = true } diff --git a/llama-rs/src/file.rs b/llama-rs/src/file.rs deleted file mode 100644 index bf2edf10..00000000 --- a/llama-rs/src/file.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::LoadError; -pub use std::fs::File; -pub use std::io::{BufRead, BufReader, BufWriter, Read, Write}; - -pub fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { - Ok(i32::from_le_bytes(rw::<4>(reader, writer)?)) -} - -pub fn rw_u32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { - Ok(u32::from_le_bytes(rw::<4>(reader, writer)?)) -} - -pub fn rw_f32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { - Ok(f32::from_le_bytes(rw::<4>(reader, writer)?)) -} - -pub fn rw_bytes_with_len( - reader: &mut impl BufRead, - writer: &mut impl Write, - len: usize, -) -> Result, LoadError> { - let mut buf = vec![0; len]; - reader - .read_exact(&mut buf) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: buf.len(), - })?; - writer.write_all(&buf)?; - Ok(buf) -} - -fn rw( - reader: &mut impl BufRead, - writer: &mut impl Write, -) -> Result<[u8; N], LoadError> { - let bytes: [u8; N] = ggml_loader::util::read_bytes(reader)?; - writer.write_all(&bytes)?; - Ok(bytes) -} diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 30059eb3..802229ac 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -8,7 +8,6 @@ pub mod convert; #[cfg(feature = "quantize")] pub mod quantize; -mod file; mod inference_session; mod loader; mod loader2; diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 3a872090..b0ba2380 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -12,8 +12,10 @@ use crate::{ LoadError, LoadProgress, Model, TokenId, Vocabulary, }; use crate::{ElementType, Hyperparameters}; -use ggml_loader::util::*; -use ggml_loader::ContainerType; +use ggml_format::{ + util::{has_data_left, read_bytes_with_len, read_f32, read_i32, read_u32}, + ContainerType, +}; use memmap2::Mmap; pub(crate) fn load( diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 89f2c963..160f2601 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -1,11 +1,12 @@ -use ggml_loader::util::*; -use ggml_loader::*; +use ggml_format::{ + util::read_i32, ContainerType, PartialHyperparameters, TensorDataTreatment, TensorInfo, +}; use memmap2::Mmap; use std::{ collections::HashMap, fs::File, - io::{BufRead, BufReader, Read, Seek}, + io::{BufRead, BufReader, Read, Seek, SeekFrom}, ops::ControlFlow, path::{Path, PathBuf}, }; @@ -16,25 +17,25 @@ use crate::{ }; impl LoadError { - pub(crate) fn from_ggml_loader_error( - value: ggml_loader::LoadError, + pub(crate) fn from_format_error( + value: ggml_format::LoadError, path: PathBuf, ) -> Self { match value { - ggml_loader::LoadError::InvalidMagic(_magic) => LoadError::InvalidMagic { path }, - ggml_loader::LoadError::InvalidFormatVersion(container_type, version) => { + ggml_format::LoadError::InvalidMagic(_magic) => LoadError::InvalidMagic { path }, + ggml_format::LoadError::InvalidFormatVersion(container_type, version) => { LoadError::InvalidFormatVersion { container_type, version, } } - ggml_loader::LoadError::Io(err) => LoadError::Io(err), - ggml_loader::LoadError::FailedCast(err) => LoadError::InvalidIntegerConversion(err), - ggml_loader::LoadError::UserInterrupted(err) => err, - ggml_loader::LoadError::UnsupportedElementType(ty) => { + ggml_format::LoadError::Io(err) => LoadError::Io(err), + ggml_format::LoadError::FailedCast(err) => LoadError::InvalidIntegerConversion(err), + ggml_format::LoadError::UserInterrupted(err) => err, + ggml_format::LoadError::UnsupportedElementType(ty) => { LoadError::HyperparametersF16Invalid { ftype: ty } } - ggml_loader::LoadError::InvariantBroken(invariant) => { + ggml_format::LoadError::InvariantBroken(invariant) => { LoadError::InvariantBroken { path, invariant } } } @@ -76,8 +77,8 @@ pub(crate) fn load( ); let use_mmap = loader.mmap_active(); - ggml_loader::load_model_from_reader(&mut reader, &mut loader) - .map_err(|err| LoadError::from_ggml_loader_error(err, path.clone()))?; + ggml_format::load_model_from_reader(&mut reader, &mut loader) + .map_err(|err| LoadError::from_format_error(err, path.clone()))?; let Loader { hyperparameters, @@ -222,7 +223,7 @@ impl Loader { } } -impl ggml_loader::LoadHandler> for Loader { +impl ggml_format::LoadHandler> for Loader { fn load_hyper_parameters( &mut self, reader: &mut BufReader<&File>, @@ -230,10 +231,7 @@ impl ggml_loader::LoadHandler t, Err(err) => { - return ControlFlow::Break(LoadError::from_ggml_loader_error( - err, - self.path.clone(), - )) + return ControlFlow::Break(LoadError::from_format_error(err, self.path.clone())) } }; self.hyperparameters = hyperparameters; @@ -278,7 +276,7 @@ impl Loader { fn load_hyperparameters( reader: &mut R, n_ctx: usize, -) -> Result<(Hyperparameters, PartialHyperparameters), ggml_loader::LoadError> { +) -> Result<(Hyperparameters, PartialHyperparameters), ggml_format::LoadError> { // NOTE: Field order matters! Data is laid out in the file exactly in this order. let hparams = Hyperparameters { n_vocab: read_i32(reader)?.try_into()?, @@ -290,7 +288,7 @@ fn load_hyperparameters( file_type: { let ftype = read_i32(reader)?; FileType::try_from(ftype).map_err(|_| { - ggml_loader::LoadError::UserInterrupted(LoadError::UnsupportedFileType(ftype)) + ggml_format::LoadError::UserInterrupted(LoadError::UnsupportedFileType(ftype)) })? }, n_ctx, diff --git a/llama-rs/src/loader_common.rs b/llama-rs/src/loader_common.rs index bd0cff35..716fd9fa 100644 --- a/llama-rs/src/loader_common.rs +++ b/llama-rs/src/loader_common.rs @@ -3,7 +3,7 @@ use std::{ path::{Path, PathBuf}, }; -use ggml_loader::ContainerType; +use ggml_format::ContainerType; use thiserror::Error; use crate::{util::FindAllModelFilesError, Hyperparameters}; diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs index 474a7cdc..e7725c96 100644 --- a/llama-rs/src/quantize.rs +++ b/llama-rs/src/quantize.rs @@ -4,9 +4,16 @@ use crate::{loader::read_string, FileType, Hyperparameters, LoadError, Vocabular use ggml::{ quantize_q4_0, quantize_q4_1, Type, FILE_MAGIC_GGMF, FILE_MAGIC_UNVERSIONED, FORMAT_VERSION, }; -use ggml_loader::{util::read_i32, ContainerType}; +use ggml_format::{ + util::{read_i32, rw_bytes_with_len, rw_f32, rw_i32, rw_u32}, + ContainerType, +}; use half::f16; -use std::path::{Path, PathBuf}; +use std::{ + fs::File, + io::{BufReader, BufWriter, Read, Write}, + path::{Path, PathBuf}, +}; use thiserror::Error; const FTYPE_STR: [&str; 4] = ["f32", "f16", "q4_0", "q4_1"]; @@ -87,8 +94,6 @@ pub fn quantize( ty: crate::ElementType, progress_callback: impl Fn(QuantizeProgress), ) -> Result<(), QuantizeError> { - use crate::file::*; - let itype: i32 = match ty { Type::Q4_0 => 2, Type::Q4_1 => 3, From 224321141c8eefe5599d8a38b9224ef24d1831ee Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 25 Apr 2023 04:30:39 +0200 Subject: [PATCH 10/16] refactor(ggml-format): remove TensorDataTreatment --- ggml-format/src/lib.rs | 3 +-- ggml-format/src/loader.rs | 29 ++++------------------------- llama-rs/src/loader2.rs | 8 +++----- 3 files changed, 8 insertions(+), 32 deletions(-) diff --git a/ggml-format/src/lib.rs b/ggml-format/src/lib.rs index 4e509090..f79aa18d 100644 --- a/ggml-format/src/lib.rs +++ b/ggml-format/src/lib.rs @@ -8,8 +8,7 @@ pub mod util; mod loader; pub use loader::{ - load_model_from_reader, LoadError, LoadHandler, PartialHyperparameters, TensorDataTreatment, - TensorInfo, + load_model_from_reader, LoadError, LoadHandler, PartialHyperparameters, TensorInfo, }; pub type ElementType = ggml::Type; diff --git a/ggml-format/src/loader.rs b/ggml-format/src/loader.rs index c3da081c..d89b8b05 100644 --- a/ggml-format/src/loader.rs +++ b/ggml-format/src/loader.rs @@ -63,11 +63,6 @@ pub struct PartialHyperparameters { pub n_vocab: usize, } -pub enum TensorDataTreatment<'a> { - CopyInto(&'a mut [u8]), - Skip, -} - #[allow(unused_variables)] pub trait LoadHandler { fn got_container_type(&mut self, container_type: ContainerType) -> ControlFlow { @@ -80,13 +75,8 @@ pub trait LoadHandler { fn load_hyper_parameters(&mut self, reader: &mut R) -> ControlFlow; - /// callback to get tensor buffer to populate - /// - /// # Returns - /// - /// `None` to skip copying - /// `Some(buf)` to provide a buffer for copying weights into - fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow; + /// Called when a new tensor is found. + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow; } #[test] @@ -204,19 +194,8 @@ fn load_weights( start_offset: offset_aligned, }; let n_bytes = tensor_info.calc_size(); - - match controlflow_to_result(handler.tensor_buffer(tensor_info))? { - TensorDataTreatment::CopyInto(buf) => { - if align { - reader.seek(SeekFrom::Start(offset_aligned))?; - } - reader.read_exact(buf)?; - } - TensorDataTreatment::Skip => { - // skip if no buffer is given - reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?; - } - } + controlflow_to_result(handler.tensor_buffer(tensor_info))?; + reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?; } Ok(()) diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 160f2601..d4bf5404 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -1,6 +1,4 @@ -use ggml_format::{ - util::read_i32, ContainerType, PartialHyperparameters, TensorDataTreatment, TensorInfo, -}; +use ggml_format::{util::read_i32, ContainerType, PartialHyperparameters, TensorInfo}; use memmap2::Mmap; use std::{ @@ -255,14 +253,14 @@ impl ggml_format::LoadHandler ControlFlow { + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow { let tensor_name = match String::from_utf8(info.name.clone()) { Ok(n) => n, Err(err) => return ControlFlow::Break(LoadError::InvalidUtf8(err)), }; self.tensors.insert(tensor_name, info); - ControlFlow::Continue(TensorDataTreatment::Skip) + ControlFlow::Continue(()) } } From a8b8cd2561742aae9fc37165d3255c6d71bd4542 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 25 Apr 2023 04:51:06 +0200 Subject: [PATCH 11/16] refactor(ggml-format): no ControlFlow, revise IF --- ggml-format/src/loader.rs | 64 +++++++++--------- ggml-format/src/util.rs | 10 --- llama-rs/src/loader2.rs | 136 +++++++++++++++----------------------- 3 files changed, 87 insertions(+), 123 deletions(-) diff --git a/ggml-format/src/loader.rs b/ggml-format/src/loader.rs index d89b8b05..2cd1d32a 100644 --- a/ggml-format/src/loader.rs +++ b/ggml-format/src/loader.rs @@ -1,17 +1,15 @@ use std::{ + error::Error, io::{BufRead, Seek, SeekFrom}, - ops::ControlFlow, }; use crate::{ - util::{ - controlflow_to_result, has_data_left, read_bytes_with_len, read_f32, read_i32, read_u32, - }, + util::{has_data_left, read_bytes_with_len, read_f32, read_i32, read_u32}, ContainerType, ElementType, }; #[derive(Debug, thiserror::Error)] -pub enum LoadError { +pub enum LoadError { #[error("invalid file magic number: {0}")] InvalidMagic(u32), @@ -24,9 +22,8 @@ pub enum LoadError { #[error("{0}")] FailedCast(#[from] std::num::TryFromIntError), - /// return `ControlFlow::Break` from any of the `cb_*` function to trigger this error - #[error("user requested interrupt: {0}")] - UserInterrupted(T), + #[error("implementation returned error: {0}")] + ImplementationError(E), #[error("unsupported tensor dtype/f16_: {0}")] UnsupportedElementType(i32), @@ -63,20 +60,16 @@ pub struct PartialHyperparameters { pub n_vocab: usize, } -#[allow(unused_variables)] -pub trait LoadHandler { - fn got_container_type(&mut self, container_type: ContainerType) -> ControlFlow { - ControlFlow::Continue(()) - } - - fn got_vocab_token(&mut self, i: usize, token: Vec, score: f32) -> ControlFlow { - ControlFlow::Continue(()) - } - - fn load_hyper_parameters(&mut self, reader: &mut R) -> ControlFlow; - +pub trait LoadHandler { + /// Called when the container type is read. + fn container_type(&mut self, container_type: ContainerType) -> Result<(), E>; + /// Called when a vocabulary token is read. + fn vocabulary_token(&mut self, i: usize, token: Vec, score: f32) -> Result<(), E>; + /// Called when the hyperparameters need to be read. + /// You must read the hyperparameters for your model here. + fn read_hyperparameters(&mut self, reader: &mut R) -> Result; /// Called when a new tensor is found. - fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow; + fn tensor_buffer(&mut self, info: TensorInfo) -> Result<(), E>; } #[test] @@ -85,10 +78,10 @@ fn can_be_vtable() { let _a: MaybeUninit>> = MaybeUninit::uninit(); } -pub fn load_model_from_reader( +pub fn load_model_from_reader( reader: &mut R, - handler: &mut impl LoadHandler, -) -> Result<(), LoadError> { + handler: &mut impl LoadHandler, +) -> Result<(), LoadError> { // Verify magic let container_type: ContainerType = match read_u32(reader)? { ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, @@ -96,7 +89,9 @@ pub fn load_model_from_reader( ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, magic => return Err(LoadError::InvalidMagic(magic)), }; - controlflow_to_result(handler.got_container_type(container_type))?; + handler + .container_type(container_type) + .map_err(LoadError::ImplementationError)?; // Load format version match container_type { @@ -110,7 +105,9 @@ pub fn load_model_from_reader( } // Load hyper params - let hparams = controlflow_to_result(handler.load_hyper_parameters(reader))?; + let hparams = handler + .read_hyperparameters(reader) + .map_err(LoadError::ImplementationError)?; let n_vocab = hparams.n_vocab; // Load vocabulary @@ -124,7 +121,9 @@ pub fn load_model_from_reader( 0. } }; - controlflow_to_result(handler.got_vocab_token(i, token, token_score))?; + handler + .vocabulary_token(i, token, token_score) + .map_err(LoadError::ImplementationError)?; } // Load tensor data @@ -138,11 +137,11 @@ pub fn load_model_from_reader( /// /// `align` /// align to 4 bytes before reading tensor weights -fn load_weights( +fn load_weights( reader: &mut R, - handler: &mut impl LoadHandler, + handler: &mut impl LoadHandler, align: bool, -) -> Result<(), LoadError> { +) -> Result<(), LoadError> { while has_data_left(reader)? { // load tensor header let n_dims: usize = read_i32(reader)?.try_into()?; @@ -157,6 +156,7 @@ fn load_weights( if !(n_dims <= ne_len) { return Err(LoadError::InvariantBroken(format!("{n_dims} <= {ne_len}"))); } + #[allow(clippy::needless_range_loop)] for i in 0..n_dims { let dim: usize = read_i32(reader)?.try_into()?; @@ -194,7 +194,9 @@ fn load_weights( start_offset: offset_aligned, }; let n_bytes = tensor_info.calc_size(); - controlflow_to_result(handler.tensor_buffer(tensor_info))?; + handler + .tensor_buffer(tensor_info) + .map_err(LoadError::ImplementationError)?; reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?; } diff --git a/ggml-format/src/util.rs b/ggml-format/src/util.rs index 20c8ce05..ec3fb31d 100644 --- a/ggml-format/src/util.rs +++ b/ggml-format/src/util.rs @@ -1,8 +1,5 @@ pub use std::fs::File; pub use std::io::{BufRead, BufReader, BufWriter, Read, Seek, SeekFrom, Write}; -use std::ops::ControlFlow; - -use crate::LoadError; pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> { let mut bytes = [0u8; N]; @@ -67,10 +64,3 @@ fn rw( pub fn has_data_left(reader: &mut impl BufRead) -> Result { reader.fill_buf().map(|b| !b.is_empty()) } - -pub(crate) fn controlflow_to_result(x: ControlFlow) -> Result> { - match x { - ControlFlow::Continue(x) => Ok(x), - ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)), - } -} diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index d4bf5404..a2a4a623 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -1,11 +1,12 @@ -use ggml_format::{util::read_i32, ContainerType, PartialHyperparameters, TensorInfo}; +use ggml_format::{ + util::read_i32, ContainerType, LoadError as FormatLoadError, PartialHyperparameters, TensorInfo, +}; use memmap2::Mmap; use std::{ collections::HashMap, fs::File, - io::{BufRead, BufReader, Read, Seek, SeekFrom}, - ops::ControlFlow, + io::{BufReader, Read, Seek, SeekFrom}, path::{Path, PathBuf}, }; @@ -15,25 +16,22 @@ use crate::{ }; impl LoadError { - pub(crate) fn from_format_error( - value: ggml_format::LoadError, - path: PathBuf, - ) -> Self { + pub(crate) fn from_format_error(value: FormatLoadError, path: PathBuf) -> Self { match value { - ggml_format::LoadError::InvalidMagic(_magic) => LoadError::InvalidMagic { path }, - ggml_format::LoadError::InvalidFormatVersion(container_type, version) => { + FormatLoadError::InvalidMagic(_magic) => LoadError::InvalidMagic { path }, + FormatLoadError::InvalidFormatVersion(container_type, version) => { LoadError::InvalidFormatVersion { container_type, version, } } - ggml_format::LoadError::Io(err) => LoadError::Io(err), - ggml_format::LoadError::FailedCast(err) => LoadError::InvalidIntegerConversion(err), - ggml_format::LoadError::UserInterrupted(err) => err, - ggml_format::LoadError::UnsupportedElementType(ty) => { + FormatLoadError::Io(err) => LoadError::Io(err), + FormatLoadError::FailedCast(err) => LoadError::InvalidIntegerConversion(err), + FormatLoadError::ImplementationError(err) => err, + FormatLoadError::UnsupportedElementType(ty) => { LoadError::HyperparametersF16Invalid { ftype: ty } } - ggml_format::LoadError::InvariantBroken(invariant) => { + FormatLoadError::InvariantBroken(invariant) => { LoadError::InvariantBroken { path, invariant } } } @@ -67,12 +65,7 @@ pub(crate) fn load( total_parts: 1, }); - let mut loader = Loader::new( - path.clone(), - n_context_tokens, - prefer_mmap, - load_progress_callback, - ); + let mut loader = Loader::new(n_context_tokens, prefer_mmap, load_progress_callback); let use_mmap = loader.mmap_active(); ggml_format::load_model_from_reader(&mut reader, &mut loader) @@ -194,7 +187,6 @@ pub(crate) fn load( struct Loader { // Input - path: PathBuf, n_ctx: usize, prefer_mmap: bool, load_progress_callback: F, @@ -206,9 +198,8 @@ struct Loader { tensors: HashMap, } impl Loader { - fn new(path: PathBuf, n_ctx: usize, prefer_mmap: bool, load_progress_callback: F) -> Self { + fn new(n_ctx: usize, prefer_mmap: bool, load_progress_callback: F) -> Self { Self { - path, n_ctx, prefer_mmap, load_progress_callback, @@ -219,80 +210,61 @@ impl Loader { tensors: HashMap::default(), } } -} - -impl ggml_format::LoadHandler> for Loader { - fn load_hyper_parameters( - &mut self, - reader: &mut BufReader<&File>, - ) -> ControlFlow { - let (hyperparameters, partial) = match load_hyperparameters(reader, self.n_ctx) { - Ok(t) => t, - Err(err) => { - return ControlFlow::Break(LoadError::from_format_error(err, self.path.clone())) - } - }; - self.hyperparameters = hyperparameters; - (self.load_progress_callback)(LoadProgress::HyperparametersLoaded(&self.hyperparameters)); - ControlFlow::Continue(partial) + fn mmap_active(&mut self) -> bool { + self.prefer_mmap && self.container_type.support_mmap() } - - fn got_container_type(&mut self, t: ContainerType) -> ControlFlow { - self.container_type = t; - ControlFlow::Continue(()) +} +impl ggml_format::LoadHandler> for Loader { + fn container_type(&mut self, container_type: ContainerType) -> Result<(), LoadError> { + self.container_type = container_type; + Ok(()) } - fn got_vocab_token(&mut self, i: usize, token: Vec, score: f32) -> ControlFlow { + fn vocabulary_token(&mut self, i: usize, token: Vec, score: f32) -> Result<(), LoadError> { let id = match TokenId::try_from(i) { Ok(id) => id, - Err(err) => return ControlFlow::Break(LoadError::InvalidIntegerConversion(err)), + Err(err) => return Err(LoadError::InvalidIntegerConversion(err)), }; self.vocabulary.push_token(id, token, score); - ControlFlow::Continue(()) + Ok(()) + } + + fn read_hyperparameters( + &mut self, + reader: &mut BufReader<&File>, + ) -> Result { + // NOTE: Field order matters! Data is laid out in the file exactly in this order. + let hyperparameters = Hyperparameters { + n_vocab: read_i32(reader)?.try_into()?, + n_embd: read_i32(reader)?.try_into()?, + n_mult: read_i32(reader)?.try_into()?, + n_head: read_i32(reader)?.try_into()?, + n_layer: read_i32(reader)?.try_into()?, + n_rot: read_i32(reader)?.try_into()?, + file_type: { + let ftype = read_i32(reader)?; + FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))? + }, + n_ctx: self.n_ctx, + }; + let partial = PartialHyperparameters { + n_vocab: hyperparameters.n_vocab, + }; + self.hyperparameters = hyperparameters; + (self.load_progress_callback)(LoadProgress::HyperparametersLoaded(&self.hyperparameters)); + + Ok(partial) } - fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow { + fn tensor_buffer(&mut self, info: TensorInfo) -> Result<(), LoadError> { let tensor_name = match String::from_utf8(info.name.clone()) { Ok(n) => n, - Err(err) => return ControlFlow::Break(LoadError::InvalidUtf8(err)), + Err(err) => return Err(LoadError::InvalidUtf8(err)), }; self.tensors.insert(tensor_name, info); - ControlFlow::Continue(()) - } -} - -impl Loader { - fn mmap_active(&mut self) -> bool { - self.prefer_mmap && self.container_type.support_mmap() + Ok(()) } } - -/// use this to load params for llama model inside [`LoadHandler::load_hyper_parameters`] -fn load_hyperparameters( - reader: &mut R, - n_ctx: usize, -) -> Result<(Hyperparameters, PartialHyperparameters), ggml_format::LoadError> { - // NOTE: Field order matters! Data is laid out in the file exactly in this order. - let hparams = Hyperparameters { - n_vocab: read_i32(reader)?.try_into()?, - n_embd: read_i32(reader)?.try_into()?, - n_mult: read_i32(reader)?.try_into()?, - n_head: read_i32(reader)?.try_into()?, - n_layer: read_i32(reader)?.try_into()?, - n_rot: read_i32(reader)?.try_into()?, - file_type: { - let ftype = read_i32(reader)?; - FileType::try_from(ftype).map_err(|_| { - ggml_format::LoadError::UserInterrupted(LoadError::UnsupportedFileType(ftype)) - })? - }, - n_ctx, - }; - let partial = PartialHyperparameters { - n_vocab: hparams.n_vocab, - }; - Ok((hparams, partial)) -} From ab504b05eb279f36325333ad8e8f824b3c515957 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 25 Apr 2023 05:17:28 +0200 Subject: [PATCH 12/16] feat: document ggml-format, fix details --- ggml-format/src/lib.rs | 34 +++++++------ ggml-format/src/loader.rs | 95 +++++++++++++++++++++-------------- ggml-format/src/util.rs | 11 ++++ llama-rs/src/loader.rs | 26 +++++----- llama-rs/src/loader2.rs | 24 +++++---- llama-rs/src/loader_common.rs | 4 +- llama-rs/src/quantize.rs | 5 +- 7 files changed, 121 insertions(+), 78 deletions(-) diff --git a/ggml-format/src/lib.rs b/ggml-format/src/lib.rs index f79aa18d..f81fce2d 100644 --- a/ggml-format/src/lib.rs +++ b/ggml-format/src/lib.rs @@ -1,8 +1,13 @@ -//! standalone model loader +#![deny(missing_docs)] +//! A reader and writer for the `ggml` model format. //! -//! Only the hyperparameter is llama-specific. Everything else can be reused for other LLM. -#![allow(clippy::nonminimal_bool)] +//! The reader supports the GGML, GGMF and GGJT container formats, but +//! only single-part models. +//! +//! The writer isn't implemented yet. It will support the GGJT container +//! format only. +/// Utilities for reading and writing. pub mod util; mod loader; @@ -11,25 +16,26 @@ pub use loader::{ load_model_from_reader, LoadError, LoadHandler, PartialHyperparameters, TensorInfo, }; +/// The type of a tensor element. pub type ElementType = ggml::Type; -/// the format of the file containing the model #[derive(Debug, PartialEq, Clone, Copy)] -#[allow(clippy::upper_case_acronyms)] +/// The format of the file containing the model. pub enum ContainerType { - /// legacy format, oldest ggml tensor file format - GGML, - /// also legacy format, newer than GGML, older than GGJT - GGMF, - /// mmap-able format - GGJT, + /// `GGML`: legacy format, oldest ggml tensor file format + Ggml, + /// `GGMF`: also legacy format. Introduces versioning. Newer than GGML, older than GGJT. + Ggmf, + /// `GGJT`: mmap-able format. + Ggjt, } impl ContainerType { + /// Does this container type support mmap? pub fn support_mmap(&self) -> bool { match self { - ContainerType::GGML => false, - ContainerType::GGMF => false, - ContainerType::GGJT => true, + ContainerType::Ggml => false, + ContainerType::Ggmf => false, + ContainerType::Ggjt => true, } } } diff --git a/ggml-format/src/loader.rs b/ggml-format/src/loader.rs index 2cd1d32a..e42279c8 100644 --- a/ggml-format/src/loader.rs +++ b/ggml-format/src/loader.rs @@ -9,57 +9,79 @@ use crate::{ }; #[derive(Debug, thiserror::Error)] +/// Errors that can occur while loading a model. pub enum LoadError { #[error("invalid file magic number: {0}")] + /// The file magic number is invalid. InvalidMagic(u32), - #[error("invalid ggml format: format={0:?} version={1}")] + /// An unsupported format version was found. InvalidFormatVersion(ContainerType, u32), - - #[error("{0}")] + #[error("non-specific I/O error")] + /// A non-specific IO error. Io(#[from] std::io::Error), - - #[error("{0}")] - FailedCast(#[from] std::num::TryFromIntError), - - #[error("implementation returned error: {0}")] - ImplementationError(E), - - #[error("unsupported tensor dtype/f16_: {0}")] - UnsupportedElementType(i32), - - /// sanity check failed + #[error("could not convert bytes to a UTF-8 string")] + /// One of the strings encountered was not valid UTF-8. + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("implementation error")] + /// An error `E` was returned by the implementation of the loader. + ImplementationError(#[source] E), + #[error("unsupported tensor type {ftype} for tensor {tensor_name}")] + /// One of the tensors encountered had an unsupported data type. + UnsupportedElementType { + /// The name of the tensor. + tensor_name: String, + /// The format type that was encountered. + ftype: i32, + }, #[error("invariant broken: {0}")] + /// An invariant was broken. InvariantBroken(String), } #[derive(Debug, Clone)] +/// Information about a tensor that is read. pub struct TensorInfo { - pub name: Vec, + /// The name of the tensor. + pub name: String, + /// The number of dimensions in the tensor. pub n_dims: usize, + /// The dimensions of the tensor. pub dims: [usize; 2], + /// The number of elements in the tensor. pub n_elements: usize, + /// The type of the elements in the tensor. pub element_type: ElementType, /// start of tensor - start of file pub start_offset: u64, } impl TensorInfo { + /// Get the dimensions of the tensor. + pub fn dims(&self) -> &[usize] { + &self.dims[0..self.n_dims] + } + + /// Calculate the size of the tensor's values in bytes. pub fn calc_size(&self) -> usize { let mut size = ggml::type_size(self.element_type); - for &dim in &self.dims[0..self.n_dims] { + for &dim in self.dims() { size *= dim; } size / ggml::blck_size(self.element_type) } } -/// Info in hyperparameter used for later loading tasks. Used in callback. -/// see [`LoadHandler::load_hyper_parameters`] #[derive(Debug, Clone)] +/// Information present within the hyperparameters that is required to continue loading the model. pub struct PartialHyperparameters { + /// The number of vocabulary tokens. pub n_vocab: usize, } +/// A handler for loading a model. pub trait LoadHandler { /// Called when the container type is read. fn container_type(&mut self, container_type: ContainerType) -> Result<(), E>; @@ -72,21 +94,16 @@ pub trait LoadHandler { fn tensor_buffer(&mut self, info: TensorInfo) -> Result<(), E>; } -#[test] -fn can_be_vtable() { - use std::mem::MaybeUninit; - let _a: MaybeUninit>> = MaybeUninit::uninit(); -} - +/// Load a model from a `reader` with the `handler`, which will be called when certain events occur. pub fn load_model_from_reader( reader: &mut R, handler: &mut impl LoadHandler, ) -> Result<(), LoadError> { // Verify magic let container_type: ContainerType = match read_u32(reader)? { - ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, - ggml::FILE_MAGIC_GGJT => ContainerType::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, + ggml::FILE_MAGIC_GGMF => ContainerType::Ggmf, + ggml::FILE_MAGIC_GGJT => ContainerType::Ggjt, + ggml::FILE_MAGIC_UNVERSIONED => ContainerType::Ggml, magic => return Err(LoadError::InvalidMagic(magic)), }; handler @@ -95,13 +112,13 @@ pub fn load_model_from_reader( // Load format version match container_type { - ContainerType::GGMF | ContainerType::GGJT => { + ContainerType::Ggmf | ContainerType::Ggjt => { let _version: u32 = match read_u32(reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => return Err(LoadError::InvalidFormatVersion(container_type, version)), }; } - ContainerType::GGML => {} + ContainerType::Ggml => {} } // Load hyper params @@ -115,8 +132,8 @@ pub fn load_model_from_reader( let len = read_u32(reader)?.try_into()?; let token = read_bytes_with_len(reader, len)?; let token_score = match container_type { - ContainerType::GGMF | ContainerType::GGJT => read_f32(reader)?, - ContainerType::GGML => { + ContainerType::Ggmf | ContainerType::Ggjt => read_f32(reader)?, + ContainerType::Ggml => { // Legacy model, set empty score 0. } @@ -128,8 +145,8 @@ pub fn load_model_from_reader( // Load tensor data match container_type { - ContainerType::GGMF | ContainerType::GGML => load_weights(reader, handler, false), - ContainerType::GGJT => load_weights(reader, handler, true), + ContainerType::Ggmf | ContainerType::Ggml => load_weights(reader, handler, false), + ContainerType::Ggjt => load_weights(reader, handler, true), } } @@ -147,13 +164,11 @@ fn load_weights( let n_dims: usize = read_i32(reader)?.try_into()?; let name_len = read_i32(reader)?; let ftype = read_i32(reader)?; - let ftype = - ggml::Type::try_from(ftype).map_err(|_| LoadError::UnsupportedElementType(ftype))?; let mut n_elements: usize = 1; let mut dims = [1usize, 1]; let ne_len = dims.len(); - if !(n_dims <= ne_len) { + if n_dims > ne_len { return Err(LoadError::InvariantBroken(format!("{n_dims} <= {ne_len}"))); } @@ -165,12 +180,16 @@ fn load_weights( } // load tensor name - let name = read_bytes_with_len(reader, name_len.try_into()?)?; + let name = String::from_utf8(read_bytes_with_len(reader, name_len.try_into()?)?)?; + let ftype = ggml::Type::try_from(ftype).map_err(|_| LoadError::UnsupportedElementType { + tensor_name: name.clone(), + ftype, + })?; // sanity check match ftype { ElementType::Q4_0 | ElementType::Q4_1 => { - if !(dims[0] % 64 == 0) { + if dims[0] % 64 != 0 { return Err(LoadError::InvariantBroken(format!("{dims:?}[0] % 64 == 0"))); } } diff --git a/ggml-format/src/util.rs b/ggml-format/src/util.rs index ec3fb31d..143af53a 100644 --- a/ggml-format/src/util.rs +++ b/ggml-format/src/util.rs @@ -1,24 +1,29 @@ pub use std::fs::File; pub use std::io::{BufRead, BufReader, BufWriter, Read, Seek, SeekFrom, Write}; +/// Read a fixed-size array of bytes from a reader. pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> { let mut bytes = [0u8; N]; reader.read_exact(&mut bytes)?; Ok(bytes) } +/// Read a `i32` from a reader. pub fn read_i32(reader: &mut impl BufRead) -> Result { Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) } +/// Read a `u32` from a reader. pub fn read_u32(reader: &mut impl BufRead) -> Result { Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) } +/// Read a `f32` from a reader. pub fn read_f32(reader: &mut impl BufRead) -> Result { Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) } +/// Read a variable-length array of bytes from a reader. pub fn read_bytes_with_len( reader: &mut impl BufRead, len: usize, @@ -28,18 +33,22 @@ pub fn read_bytes_with_len( Ok(bytes) } +/// Read and write a `i32` from a reader to a writer. pub fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { Ok(i32::from_le_bytes(rw::<4>(reader, writer)?)) } +/// Read and write a `u32` from a reader to a writer. pub fn rw_u32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { Ok(u32::from_le_bytes(rw::<4>(reader, writer)?)) } +/// Read and write a `f32` from a reader to a writer. pub fn rw_f32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { Ok(f32::from_le_bytes(rw::<4>(reader, writer)?)) } +/// Read and write a variable-length array of bytes from a reader to a writer. pub fn rw_bytes_with_len( reader: &mut impl BufRead, writer: &mut impl Write, @@ -51,6 +60,7 @@ pub fn rw_bytes_with_len( Ok(buf) } +/// Read and write a fixed-size array of bytes from a reader to a writer. fn rw( reader: &mut impl BufRead, writer: &mut impl Write, @@ -61,6 +71,7 @@ fn rw( } // NOTE: Implementation from #![feature(buf_read_has_data_left)] +/// Check if there is any data left in the reader. pub fn has_data_left(reader: &mut impl BufRead) -> Result { reader.fill_buf().map(|b| !b.is_empty()) } diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index b0ba2380..4174eb74 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -36,20 +36,22 @@ pub(crate) fn load( let mut reader = BufReader::new(&file); // Verify magic - let model_type: ContainerType = match read_u32(&mut reader)? { - ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, - ggml::FILE_MAGIC_GGJT => ContainerType::GGJT, - ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, + let magic = read_u32(&mut reader)?; + let model_type: ContainerType = match magic { + ggml::FILE_MAGIC_GGMF => ContainerType::Ggmf, + ggml::FILE_MAGIC_GGJT => ContainerType::Ggjt, + ggml::FILE_MAGIC_UNVERSIONED => ContainerType::Ggml, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), + magic, }) } }; // Load format version match model_type { - ContainerType::GGMF | ContainerType::GGJT => { + ContainerType::Ggmf | ContainerType::Ggjt => { let _version: u32 = match read_u32(&mut reader)? { ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, version => { @@ -60,7 +62,7 @@ pub(crate) fn load( } }; } - ContainerType::GGML => {} + ContainerType::Ggml => {} } // ================= @@ -100,8 +102,8 @@ pub(crate) fn load( let token = read_bytes_with_len(&mut reader, len.try_into()?)?; let score = match model_type { - ContainerType::GGMF | ContainerType::GGJT => read_f32(&mut reader)?, - ContainerType::GGML => { + ContainerType::Ggmf | ContainerType::Ggjt => read_f32(&mut reader)?, + ContainerType::Ggml => { // Legacy model, set empty score 0. } @@ -175,7 +177,7 @@ pub(crate) fn load( let mut model = Model::new_loader1(context, hparams, vocabulary, n_ff, wtype, mmap); match model_type { - ContainerType::GGMF | ContainerType::GGML => { + ContainerType::Ggmf | ContainerType::Ggml => { let file_offset = reader.stream_position()?; drop(reader); load_weights_ggmf_or_unversioned( @@ -185,7 +187,7 @@ pub(crate) fn load( model.tensors_mut(), )? } - ContainerType::GGJT => { + ContainerType::Ggjt => { load_weights_ggjt( &mut reader, mmap_ptr, @@ -424,7 +426,7 @@ fn load_tensor_header_ggmf<'a>( let bpe = match bpe { Some(x) => x, None => { - return Err(LoadError::InvalidFtype { + return Err(LoadError::UnsupportedElementType { tensor_name, ftype, path: path.to_owned(), @@ -503,7 +505,7 @@ fn load_weights_ggjt( match tensor_type_size(ftype, ne) { Some(_) => {} None => { - return Err(LoadError::InvalidFtype { + return Err(LoadError::UnsupportedElementType { tensor_name, ftype, path: path.to_owned(), diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index a2a4a623..eb8cd666 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -18,7 +18,7 @@ use crate::{ impl LoadError { pub(crate) fn from_format_error(value: FormatLoadError, path: PathBuf) -> Self { match value { - FormatLoadError::InvalidMagic(_magic) => LoadError::InvalidMagic { path }, + FormatLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { path, magic }, FormatLoadError::InvalidFormatVersion(container_type, version) => { LoadError::InvalidFormatVersion { container_type, @@ -26,10 +26,17 @@ impl LoadError { } } FormatLoadError::Io(err) => LoadError::Io(err), - FormatLoadError::FailedCast(err) => LoadError::InvalidIntegerConversion(err), + FormatLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err), + FormatLoadError::InvalidIntegerConversion(err) => { + LoadError::InvalidIntegerConversion(err) + } FormatLoadError::ImplementationError(err) => err, - FormatLoadError::UnsupportedElementType(ty) => { - LoadError::HyperparametersF16Invalid { ftype: ty } + FormatLoadError::UnsupportedElementType { tensor_name, ftype } => { + LoadError::UnsupportedElementType { + path, + tensor_name, + ftype, + } } FormatLoadError::InvariantBroken(invariant) => { LoadError::InvariantBroken { path, invariant } @@ -204,7 +211,7 @@ impl Loader { prefer_mmap, load_progress_callback, - container_type: ContainerType::GGJT, + container_type: ContainerType::Ggjt, hyperparameters: Hyperparameters::default(), vocabulary: Vocabulary::default(), tensors: HashMap::default(), @@ -259,12 +266,7 @@ impl ggml_format::LoadHandler Result<(), LoadError> { - let tensor_name = match String::from_utf8(info.name.clone()) { - Ok(n) => n, - Err(err) => return Err(LoadError::InvalidUtf8(err)), - }; - - self.tensors.insert(tensor_name, info); + self.tensors.insert(info.name.clone(), info); Ok(()) } } diff --git a/llama-rs/src/loader_common.rs b/llama-rs/src/loader_common.rs index 716fd9fa..a40c45de 100644 --- a/llama-rs/src/loader_common.rs +++ b/llama-rs/src/loader_common.rs @@ -153,6 +153,8 @@ pub enum LoadError { InvalidMagic { /// The path that failed. path: PathBuf, + /// The magic number that was encountered. + magic: u32, }, #[error("invalid file format version {version}")] /// The version of the format is not supported by this version of `llama-rs`. @@ -187,7 +189,7 @@ pub enum LoadError { }, /// The tensor `tensor_name` did not have the expected format type. #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")] - InvalidFtype { + UnsupportedElementType { /// The name of the tensor. tensor_name: String, /// The format type that was encountered. diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs index e7725c96..47d39a53 100644 --- a/llama-rs/src/quantize.rs +++ b/llama-rs/src/quantize.rs @@ -124,6 +124,7 @@ pub fn quantize( if magic != FILE_MAGIC_GGMF { return Err(LoadError::InvalidMagic { path: file_in.to_owned(), + magic, } .into()); } @@ -131,7 +132,7 @@ pub fn quantize( let format_version = rw_u32(&mut finp, &mut fout)?; if format_version != FORMAT_VERSION { return Err(LoadError::InvalidFormatVersion { - container_type: ContainerType::GGMF, + container_type: ContainerType::Ggmf, version: format_version, } .into()); @@ -231,7 +232,7 @@ pub fn quantize( if quantize { if ftype != 0 && ftype != 1 { - return Err(LoadError::InvalidFtype { + return Err(LoadError::UnsupportedElementType { ftype, tensor_name: name, path: file_in.to_owned(), From ca993871f08250d09c5e59084bec09f2f9794588 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 25 Apr 2023 05:25:53 +0200 Subject: [PATCH 13/16] chore: remove some clippy ignores --- ggml/src/lib.rs | 20 ++++++-------------- llama-rs/src/loader.rs | 36 ++++++++++++++++++++++++++++-------- llama-rs/src/model.rs | 10 +++------- 3 files changed, 37 insertions(+), 29 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index bc6f781b..b6d471b1 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -325,14 +325,8 @@ impl Context { } /// Creates a 2D view over `a`. - pub fn op_view_2d( - &self, - a: &Tensor, - ne0: usize, - ne1: usize, - nb1: usize, - offset: usize, - ) -> Tensor { + pub fn op_view_2d(&self, a: &Tensor, ne: (usize, usize), nb1: usize, offset: usize) -> Tensor { + let (ne0, ne1) = ne; let tensor = unsafe { ggml_sys::ggml_view_2d( self.ptr.as_ptr(), @@ -347,17 +341,15 @@ impl Context { } /// Creates a 3d view over `a`. - #[allow(clippy::too_many_arguments)] pub fn op_view_3d( &self, a: &Tensor, - ne0: usize, - ne1: usize, - ne2: usize, - nb1: usize, - nb2: usize, + ne: (usize, usize, usize), + nb: (usize, usize), offset: usize, ) -> Tensor { + let (ne0, ne1, ne2) = ne; + let (nb1, nb2) = nb; let tensor = unsafe { ggml_sys::ggml_view_3d( self.ptr.as_ptr(), diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index 4174eb74..2b72a517 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -252,7 +252,14 @@ fn load_weights_ggmf_or_unversioned( let length = read_i32(&mut part_reader)?; let ftype = read_i32(&mut part_reader)?; - let (nelements, ne, tensor_name, tensor, split_type, bpe) = load_tensor_header_ggmf( + let TensorHeaderGgmf { + nelements, + ne, + tensor_name, + tensor, + split_type, + bpe, + } = load_tensor_header_ggmf( n_dims, &mut part_reader, length, @@ -350,7 +357,14 @@ fn load_weights_ggmf_or_unversioned( Ok(()) } -#[allow(clippy::type_complexity)] +struct TensorHeaderGgmf<'a> { + nelements: usize, + ne: [i64; 2], + tensor_name: String, + tensor: &'a mut ggml::Tensor, + split_type: i32, + bpe: usize, +} fn load_tensor_header_ggmf<'a>( n_dims: usize, reader: &mut impl BufRead, @@ -359,7 +373,7 @@ fn load_tensor_header_ggmf<'a>( path: &Path, n_parts: usize, ftype: i32, -) -> Result<(usize, [i64; 2], String, &'a mut ggml::Tensor, i32, usize), LoadError> { +) -> Result, LoadError> { let mut nelements = 1; let mut ne = [1i64, 1i64]; assert!(n_dims <= ne.len()); @@ -373,13 +387,12 @@ fn load_tensor_header_ggmf<'a>( else { return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); }; - #[allow(clippy::if_same_then_else)] let split_type = if tensor_name.contains("tok_embeddings") { 0 } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { + if tensor_name.contains("attention.wo.weight") + || tensor_name.contains("feed_forward.w2.weight") + { 0 } else { 1 @@ -433,7 +446,14 @@ fn load_tensor_header_ggmf<'a>( }); } }; - Ok((nelements, ne, tensor_name, tensor, split_type, bpe)) + Ok(TensorHeaderGgmf { + nelements, + ne, + tensor_name, + tensor, + split_type, + bpe, + }) } fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index d7cfa91f..730afc14 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -340,8 +340,7 @@ impl Model { let v = ctx0.op_view_2d( &session.memory_v, - n, - n_embd, + (n, n_embd), n_ctx * memv_elsize, (il * n_ctx) * memv_elsize * n_embd + n_past * memv_elsize, ); @@ -388,11 +387,8 @@ impl Model { // split cached V into n_head heads let v = ctx0.op_view_3d( &session.memory_v, - n_past + n, - n_embd / n_head, - n_head, - n_ctx * memv_elsize, - n_ctx * memv_elsize * n_embd / n_head, + (n_past + n, n_embd / n_head, n_head), + (n_ctx * memv_elsize, n_ctx * memv_elsize * n_embd / n_head), il * n_ctx * memv_elsize * n_embd, ); From 196d4f380c5d2f8e26b3a0fcc67062848132137f Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 25 Apr 2023 20:53:00 +0200 Subject: [PATCH 14/16] feat(ggml-format): implement writer --- Cargo.lock | 1 + ggml-format/Cargo.toml | 3 + ggml-format/src/lib.rs | 6 +- ggml-format/src/loader.rs | 24 ++--- ggml-format/src/saver.rs | 119 ++++++++++++++++++++++++ ggml-format/src/tests.rs | 184 ++++++++++++++++++++++++++++++++++++++ ggml-format/src/util.rs | 25 ++++-- llama-rs/src/loader2.rs | 8 +- 8 files changed, 350 insertions(+), 20 deletions(-) create mode 100644 ggml-format/src/saver.rs create mode 100644 ggml-format/src/tests.rs diff --git a/Cargo.lock b/Cargo.lock index 9e3a7a33..58b2c4ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -450,6 +450,7 @@ name = "ggml-format" version = "0.1.0" dependencies = [ "ggml", + "rand", "thiserror", ] diff --git a/ggml-format/Cargo.toml b/ggml-format/Cargo.toml index 99ba7216..91daca22 100644 --- a/ggml-format/Cargo.toml +++ b/ggml-format/Cargo.toml @@ -8,3 +8,6 @@ edition = "2021" [dependencies] ggml = { path = "../ggml" } thiserror = "1.0" + +[dev-dependencies] +rand = "0.8" diff --git a/ggml-format/src/lib.rs b/ggml-format/src/lib.rs index f81fce2d..b26aa0f2 100644 --- a/ggml-format/src/lib.rs +++ b/ggml-format/src/lib.rs @@ -11,10 +11,14 @@ pub mod util; mod loader; +mod saver; +#[cfg(test)] +mod tests; pub use loader::{ - load_model_from_reader, LoadError, LoadHandler, PartialHyperparameters, TensorInfo, + data_size, load_model, LoadError, LoadHandler, PartialHyperparameters, TensorInfo, }; +pub use saver::{save_model, SaveError, SaveHandler, TensorData}; /// The type of a tensor element. pub type ElementType = ggml::Type; diff --git a/ggml-format/src/loader.rs b/ggml-format/src/loader.rs index e42279c8..785d6aea 100644 --- a/ggml-format/src/loader.rs +++ b/ggml-format/src/loader.rs @@ -66,14 +66,15 @@ impl TensorInfo { /// Calculate the size of the tensor's values in bytes. pub fn calc_size(&self) -> usize { - let mut size = ggml::type_size(self.element_type); - for &dim in self.dims() { - size *= dim; - } - size / ggml::blck_size(self.element_type) + data_size(self.element_type, self.dims().iter().product()) } } +/// Returns the size occupied by a tensor's data in bytes given the element type and number of elements. +pub fn data_size(element_type: ElementType, n_elements: usize) -> usize { + (ggml::type_size(element_type) * n_elements) / ggml::blck_size(element_type) +} + #[derive(Debug, Clone)] /// Information present within the hyperparameters that is required to continue loading the model. pub struct PartialHyperparameters { @@ -82,22 +83,25 @@ pub struct PartialHyperparameters { } /// A handler for loading a model. -pub trait LoadHandler { +pub trait LoadHandler { /// Called when the container type is read. fn container_type(&mut self, container_type: ContainerType) -> Result<(), E>; /// Called when a vocabulary token is read. fn vocabulary_token(&mut self, i: usize, token: Vec, score: f32) -> Result<(), E>; /// Called when the hyperparameters need to be read. /// You must read the hyperparameters for your model here. - fn read_hyperparameters(&mut self, reader: &mut R) -> Result; + fn read_hyperparameters( + &mut self, + reader: &mut dyn BufRead, + ) -> Result; /// Called when a new tensor is found. fn tensor_buffer(&mut self, info: TensorInfo) -> Result<(), E>; } /// Load a model from a `reader` with the `handler`, which will be called when certain events occur. -pub fn load_model_from_reader( +pub fn load_model( reader: &mut R, - handler: &mut impl LoadHandler, + handler: &mut impl LoadHandler, ) -> Result<(), LoadError> { // Verify magic let container_type: ContainerType = match read_u32(reader)? { @@ -156,7 +160,7 @@ pub fn load_model_from_reader( /// align to 4 bytes before reading tensor weights fn load_weights( reader: &mut R, - handler: &mut impl LoadHandler, + handler: &mut impl LoadHandler, align: bool, ) -> Result<(), LoadError> { while has_data_left(reader)? { diff --git a/ggml-format/src/saver.rs b/ggml-format/src/saver.rs new file mode 100644 index 00000000..565032a3 --- /dev/null +++ b/ggml-format/src/saver.rs @@ -0,0 +1,119 @@ +use std::{ + error::Error, + io::{Seek, Write}, +}; + +use crate::{util, ElementType}; + +#[derive(Debug, thiserror::Error)] +/// Errors that can occur while writing a model. +pub enum SaveError { + #[error("non-specific I/O error")] + /// A non-specific IO error. + Io(#[from] std::io::Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("implementation error")] + /// An error `E` was returned by the implementation of the loader. + ImplementationError(#[source] E), + #[error("invariant broken: {0}")] + /// An invariant was broken. + InvariantBroken(String), +} + +/// A handler for saving a model. +pub trait SaveHandler { + /// Called when the hyperparameters are to be written. + /// You must write the hyperparameters to the given writer. + fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), E>; + + /// Called when a tensor is to be written. + /// You must return data for the tensor to be saved. + fn tensor_data(&mut self, tensor_name: &str) -> Result; +} + +/// Information about a tensor that is to be saved. +#[derive(Clone, PartialEq, Debug)] +pub struct TensorData { + /// The number of dimensions in the tensor. + pub n_dims: usize, + /// The dimensions of the tensor. + pub dims: [usize; 2], + /// The type of the elements in the tensor. + pub element_type: ElementType, + /// The data to save to disk. + // TODO: This can be done more efficiently by borrowing the data, but + // I wanted to avoid the lifetime parameter for now, especially as + // the naive solution would borrow `TensorData` for the lifetime of the + // handler, which is obviously not ideal if you're trying to transcode + // an existing file tensor-by-tensor. + pub data: Vec, +} + +/// Saves a model to the given writer. +/// +/// Only GGJT is supported. +pub fn save_model( + writer: &mut W, + handler: &mut dyn SaveHandler, + vocabulary: &[(Vec, f32)], + tensor_names: &[String], +) -> Result<(), SaveError> { + // Write header and hyperparameters + util::write_u32(writer, ggml::FILE_MAGIC_GGJT)?; + util::write_u32(writer, ggml::FORMAT_VERSION)?; + handler + .write_hyperparameters(writer) + .map_err(SaveError::ImplementationError)?; + + // Write vocabulary + for (token, score) in vocabulary { + util::write_u32(writer, token.len().try_into()?)?; + writer.write_all(token)?; + util::write_f32(writer, *score)?; + } + + // Write tensors + for name in tensor_names { + let TensorData { + n_dims, + dims, + element_type, + data, + } = handler + .tensor_data(name) + .map_err(SaveError::ImplementationError)?; + + match element_type { + ElementType::Q4_0 | ElementType::Q4_1 => { + if dims[0] % 64 != 0 { + return Err(SaveError::InvariantBroken(format!("{dims:?}[0] % 64 == 0"))); + } + } + _ => {} + } + + // Write tensor header + util::write_i32(writer, n_dims.try_into()?)?; + util::write_i32(writer, name.len().try_into()?)?; + util::write_i32(writer, element_type.into())?; + for &dim in &dims[0..n_dims] { + util::write_i32(writer, dim.try_into()?)?; + } + + // Write tensor name + writer.write_all(name.as_bytes())?; + + // Align to nearest 32 bytes + let offset_curr = writer.stream_position()?; + let offset_aligned = (offset_curr + 31) & !31; + let padding = usize::try_from(offset_aligned - offset_curr)?; + writer.write_all(&vec![0; padding])?; + + // Write tensor data + writer.write_all(&data)?; + } + + Ok(()) +} diff --git a/ggml-format/src/tests.rs b/ggml-format/src/tests.rs new file mode 100644 index 00000000..b78d4710 --- /dev/null +++ b/ggml-format/src/tests.rs @@ -0,0 +1,184 @@ +use std::{ + collections::BTreeMap, + error::Error, + io::{BufRead, Write}, +}; + +use crate::*; +use rand::{distributions::Uniform, prelude::*}; + +#[derive(Debug)] +struct DummyError; +impl std::fmt::Display for DummyError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(&self, f) + } +} +impl Error for DummyError {} + +#[test] +fn can_roundtrip_loader_and_saver() { + let vocabulary = vec![ + ("blazingly".as_bytes().to_vec(), 0.1), + ("fast".as_bytes().to_vec(), 0.2), + ("memory".as_bytes().to_vec(), 0.3), + ("efficient".as_bytes().to_vec(), 0.4), + ]; + + let mut rng = rand::thread_rng(); + let element_type = ggml::Type::F16; + let model = Model { + hyperparameters: Hyperparameters { + some_hyperparameter: random(), + some_other_hyperparameter: random(), + vocabulary_size: vocabulary.len().try_into().unwrap(), + }, + vocabulary, + tensors: (0..10) + .map(|i| { + let n_dims = Uniform::from(1..3).sample(&mut rng); + let dims = (0..n_dims) + .map(|_| Uniform::from(1..10).sample(&mut rng)) + .chain(std::iter::repeat(1).take(2 - n_dims)) + .collect::>(); + + let n_elements = dims.iter().product::(); + let data = (0..data_size(element_type, n_elements)) + .map(|_| random()) + .collect::>(); + + ( + format!("tensor_{}", i), + TensorData { + n_dims, + dims: dims.try_into().unwrap(), + element_type, + data, + }, + ) + }) + .collect(), + }; + + // Save the model. + let mut buffer = Vec::new(); + let mut cursor = std::io::Cursor::new(&mut buffer); + let mut save_handler = MockSaveHandler { model: &model }; + save_model( + &mut cursor, + &mut save_handler, + &model.vocabulary, + &model.tensors.keys().cloned().collect::>(), + ) + .unwrap(); + + // Load the model and confirm that it is the same as the original. + let mut cursor = std::io::Cursor::new(&buffer); + let mut load_handler = MockLoadHandler { + data: &buffer, + loaded_model: Model::default(), + }; + load_model(&mut cursor, &mut load_handler).unwrap(); + assert_eq!(load_handler.loaded_model, model); +} + +#[derive(Default, PartialEq, Debug)] +struct Hyperparameters { + some_hyperparameter: u32, + some_other_hyperparameter: u32, + vocabulary_size: u32, +} +impl Hyperparameters { + fn read(reader: &mut dyn BufRead) -> Result { + Ok(Self { + some_hyperparameter: util::read_u32(reader)?, + some_other_hyperparameter: util::read_u32(reader)? as u32, + vocabulary_size: util::read_u32(reader)?, + }) + } + + fn write(&self, writer: &mut dyn Write) -> Result<(), std::io::Error> { + util::write_u32(writer, self.some_hyperparameter)?; + util::write_u32(writer, self.some_other_hyperparameter as u32)?; + util::write_u32(writer, self.vocabulary_size)?; + Ok(()) + } +} + +#[derive(Default, PartialEq, Debug)] +struct Model { + hyperparameters: Hyperparameters, + vocabulary: Vec<(Vec, f32)>, + tensors: BTreeMap, +} + +struct MockSaveHandler<'a> { + model: &'a Model, +} +impl SaveHandler for MockSaveHandler<'_> { + fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), DummyError> { + self.model.hyperparameters.write(writer).unwrap(); + Ok(()) + } + + fn tensor_data(&mut self, tensor_name: &str) -> Result { + self.model + .tensors + .get(tensor_name) + .cloned() + .ok_or(DummyError) + } +} + +struct MockLoadHandler<'a> { + data: &'a [u8], + loaded_model: Model, +} +impl LoadHandler for MockLoadHandler<'_> { + fn container_type(&mut self, container_type: ContainerType) -> Result<(), DummyError> { + assert_eq!(container_type, ContainerType::Ggjt); + Ok(()) + } + + fn vocabulary_token(&mut self, i: usize, token: Vec, score: f32) -> Result<(), DummyError> { + assert_eq!(i, self.loaded_model.vocabulary.len()); + self.loaded_model.vocabulary.push((token, score)); + Ok(()) + } + + fn read_hyperparameters( + &mut self, + reader: &mut dyn BufRead, + ) -> Result { + self.loaded_model.hyperparameters = Hyperparameters::read(reader).unwrap(); + Ok(PartialHyperparameters { + n_vocab: self + .loaded_model + .hyperparameters + .vocabulary_size + .try_into() + .unwrap(), + }) + } + + fn tensor_buffer(&mut self, info: TensorInfo) -> Result<(), DummyError> { + self.loaded_model.tensors.insert( + info.name, + TensorData { + n_dims: info.n_dims, + dims: info.dims, + element_type: info.element_type, + data: { + let n_bytes = info.n_elements * ggml::type_size(info.element_type); + let mut data = vec![0; n_bytes]; + data.copy_from_slice( + &self.data + [info.start_offset as usize..info.start_offset as usize + n_bytes], + ); + data + }, + }, + ); + Ok(()) + } +} diff --git a/ggml-format/src/util.rs b/ggml-format/src/util.rs index 143af53a..9800117f 100644 --- a/ggml-format/src/util.rs +++ b/ggml-format/src/util.rs @@ -2,30 +2,30 @@ pub use std::fs::File; pub use std::io::{BufRead, BufReader, BufWriter, Read, Seek, SeekFrom, Write}; /// Read a fixed-size array of bytes from a reader. -pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> { +pub fn read_bytes(reader: &mut dyn BufRead) -> Result<[u8; N], std::io::Error> { let mut bytes = [0u8; N]; reader.read_exact(&mut bytes)?; Ok(bytes) } /// Read a `i32` from a reader. -pub fn read_i32(reader: &mut impl BufRead) -> Result { +pub fn read_i32(reader: &mut dyn BufRead) -> Result { Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) } /// Read a `u32` from a reader. -pub fn read_u32(reader: &mut impl BufRead) -> Result { +pub fn read_u32(reader: &mut dyn BufRead) -> Result { Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) } /// Read a `f32` from a reader. -pub fn read_f32(reader: &mut impl BufRead) -> Result { +pub fn read_f32(reader: &mut dyn BufRead) -> Result { Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) } /// Read a variable-length array of bytes from a reader. pub fn read_bytes_with_len( - reader: &mut impl BufRead, + reader: &mut dyn BufRead, len: usize, ) -> Result, std::io::Error> { let mut bytes = vec![0u8; len]; @@ -33,6 +33,21 @@ pub fn read_bytes_with_len( Ok(bytes) } +/// Write a `i32` from a writer. +pub fn write_i32(writer: &mut dyn Write, value: i32) -> Result<(), std::io::Error> { + writer.write_all(&value.to_le_bytes()) +} + +/// Write a `u32` from a writer. +pub fn write_u32(writer: &mut dyn Write, value: u32) -> Result<(), std::io::Error> { + writer.write_all(&value.to_le_bytes()) +} + +/// Write a `f32` from a writer. +pub fn write_f32(writer: &mut dyn Write, value: f32) -> Result<(), std::io::Error> { + writer.write_all(&value.to_le_bytes()) +} + /// Read and write a `i32` from a reader to a writer. pub fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { Ok(i32::from_le_bytes(rw::<4>(reader, writer)?)) diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index eb8cd666..993d4e5c 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -6,7 +6,7 @@ use memmap2::Mmap; use std::{ collections::HashMap, fs::File, - io::{BufReader, Read, Seek, SeekFrom}, + io::{BufRead, BufReader, Read, Seek, SeekFrom}, path::{Path, PathBuf}, }; @@ -75,7 +75,7 @@ pub(crate) fn load( let mut loader = Loader::new(n_context_tokens, prefer_mmap, load_progress_callback); let use_mmap = loader.mmap_active(); - ggml_format::load_model_from_reader(&mut reader, &mut loader) + ggml_format::load_model(&mut reader, &mut loader) .map_err(|err| LoadError::from_format_error(err, path.clone()))?; let Loader { @@ -222,7 +222,7 @@ impl Loader { self.prefer_mmap && self.container_type.support_mmap() } } -impl ggml_format::LoadHandler> for Loader { +impl ggml_format::LoadHandler for Loader { fn container_type(&mut self, container_type: ContainerType) -> Result<(), LoadError> { self.container_type = container_type; Ok(()) @@ -240,7 +240,7 @@ impl ggml_format::LoadHandler, + reader: &mut dyn BufRead, ) -> Result { // NOTE: Field order matters! Data is laid out in the file exactly in this order. let hyperparameters = Hyperparameters { From 6f86e32d8c5c90cd9640f9f201edcdbe37d93492 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 25 Apr 2023 22:36:27 +0200 Subject: [PATCH 15/16] feat(quantize): rewrite to use ggml-format --- ggml-format/src/loader.rs | 13 + ggml-format/src/tests.rs | 26 +- ggml-format/src/util.rs | 37 --- ggml/src/lib.rs | 20 +- llama-cli/src/cli_args.rs | 43 +++- llama-cli/src/main.rs | 42 +++- llama-rs/src/loader2.rs | 24 +- llama-rs/src/quantize.rs | 509 ++++++++++++++++++++------------------ 8 files changed, 369 insertions(+), 345 deletions(-) diff --git a/ggml-format/src/loader.rs b/ggml-format/src/loader.rs index 785d6aea..ffc99c9b 100644 --- a/ggml-format/src/loader.rs +++ b/ggml-format/src/loader.rs @@ -68,6 +68,19 @@ impl TensorInfo { pub fn calc_size(&self) -> usize { data_size(self.element_type, self.dims().iter().product()) } + + /// Reads the tensor's data from the given reader in an owned fashion. + /// + /// The behaviour is undefined if the reader does not correspond to this info. + /// + /// Do not use this if loading with `mmap`. + pub fn read_data(&self, reader: &mut R) -> std::io::Result> { + let n_bytes = self.n_elements * ggml::type_size(self.element_type); + let mut data = vec![0; n_bytes]; + reader.seek(SeekFrom::Start(self.start_offset))?; + reader.read_exact(&mut data)?; + Ok(data) + } } /// Returns the size occupied by a tensor's data in bytes given the element type and number of elements. diff --git a/ggml-format/src/tests.rs b/ggml-format/src/tests.rs index b78d4710..91d925bb 100644 --- a/ggml-format/src/tests.rs +++ b/ggml-format/src/tests.rs @@ -162,23 +162,15 @@ impl LoadHandler for MockLoadHandler<'_> { } fn tensor_buffer(&mut self, info: TensorInfo) -> Result<(), DummyError> { - self.loaded_model.tensors.insert( - info.name, - TensorData { - n_dims: info.n_dims, - dims: info.dims, - element_type: info.element_type, - data: { - let n_bytes = info.n_elements * ggml::type_size(info.element_type); - let mut data = vec![0; n_bytes]; - data.copy_from_slice( - &self.data - [info.start_offset as usize..info.start_offset as usize + n_bytes], - ); - data - }, - }, - ); + let data = TensorData { + n_dims: info.n_dims, + dims: info.dims, + element_type: info.element_type, + data: info + .read_data(&mut std::io::Cursor::new(self.data)) + .unwrap(), + }; + self.loaded_model.tensors.insert(info.name, data); Ok(()) } } diff --git a/ggml-format/src/util.rs b/ggml-format/src/util.rs index 9800117f..ac215feb 100644 --- a/ggml-format/src/util.rs +++ b/ggml-format/src/util.rs @@ -48,43 +48,6 @@ pub fn write_f32(writer: &mut dyn Write, value: f32) -> Result<(), std::io::Erro writer.write_all(&value.to_le_bytes()) } -/// Read and write a `i32` from a reader to a writer. -pub fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { - Ok(i32::from_le_bytes(rw::<4>(reader, writer)?)) -} - -/// Read and write a `u32` from a reader to a writer. -pub fn rw_u32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { - Ok(u32::from_le_bytes(rw::<4>(reader, writer)?)) -} - -/// Read and write a `f32` from a reader to a writer. -pub fn rw_f32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result { - Ok(f32::from_le_bytes(rw::<4>(reader, writer)?)) -} - -/// Read and write a variable-length array of bytes from a reader to a writer. -pub fn rw_bytes_with_len( - reader: &mut impl BufRead, - writer: &mut impl Write, - len: usize, -) -> Result, std::io::Error> { - let mut buf = vec![0; len]; - reader.read_exact(&mut buf)?; - writer.write_all(&buf)?; - Ok(buf) -} - -/// Read and write a fixed-size array of bytes from a reader to a writer. -fn rw( - reader: &mut impl BufRead, - writer: &mut impl Write, -) -> Result<[u8; N], std::io::Error> { - let bytes: [u8; N] = read_bytes(reader)?; - writer.write_all(&bytes)?; - Ok(bytes) -} - // NOTE: Implementation from #![feature(buf_read_has_data_left)] /// Check if there is any data left in the reader. pub fn has_data_left(reader: &mut impl BufRead) -> Result { diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index b6d471b1..06bdf64f 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -697,17 +697,17 @@ fn i64_to_usize(val: i64) -> usize { /// You must ensure the arrays passed in are of the correct size. pub unsafe fn quantize_q4_0( src: &[f32], - dst: &mut [f32], - n: i32, - k: i32, + dst: &mut [u8], + n: usize, + k: usize, hist: &mut [i64], ) -> usize { unsafe { ggml_sys::ggml_quantize_q4_0( src.as_ptr(), dst.as_mut_ptr() as *mut c_void, - n, - k, + n.try_into().unwrap(), + k.try_into().unwrap(), hist.as_mut_ptr(), ) } @@ -720,17 +720,17 @@ pub unsafe fn quantize_q4_0( /// You must ensure the arrays passed in are of the correct size. pub unsafe fn quantize_q4_1( src: &[f32], - dst: &mut [f32], - n: i32, - k: i32, + dst: &mut [u8], + n: usize, + k: usize, hist: &mut [i64], ) -> usize { unsafe { ggml_sys::ggml_quantize_q4_1( src.as_ptr(), dst.as_mut_ptr() as *mut c_void, - n, - k, + n.try_into().unwrap(), + k.try_into().unwrap(), hist.as_mut_ptr(), ) } diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index 826acfaf..b14a3e73 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -379,18 +379,6 @@ pub struct Convert { #[arg(long, short = 't', value_enum, default_value_t = FileType::Q4_0)] pub file_type: FileType, } - -#[derive(Parser, Debug)] -pub struct Quantize { - /// The path to the model to quantize - #[arg()] - pub source: PathBuf, - - /// The path to save the quantized model to - #[arg()] - pub destination: PathBuf, -} - #[derive(Parser, Debug, ValueEnum, Clone, Copy)] pub enum FileType { /// Quantized 4-bit (type 0). @@ -412,3 +400,34 @@ impl From for llama_rs::FileType { } } } + +#[derive(Parser, Debug)] +pub struct Quantize { + /// The path to the model to quantize + #[arg()] + pub source: PathBuf, + + /// The path to save the quantized model to + #[arg()] + pub destination: PathBuf, + + /// The format to convert to + pub target: QuantizationTarget, +} + +#[derive(Parser, Debug, ValueEnum, Clone, Copy)] +#[clap(rename_all = "snake_case")] +pub enum QuantizationTarget { + /// Quantized 4-bit (type 0). + Q4_0, + /// Quantized 4-bit (type 1). + Q4_1, +} +impl From for llama_rs::ElementType { + fn from(t: QuantizationTarget) -> Self { + match t { + QuantizationTarget::Q4_0 => llama_rs::ElementType::Q4_0, + QuantizationTarget::Q4_1 => llama_rs::ElementType::Q4_1, + } + } +} diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 3ea48296..cc142875 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -2,7 +2,7 @@ use std::{convert::Infallible, io::Write}; use clap::Parser; use cli_args::Args; -use color_eyre::eyre::Result; +use color_eyre::eyre::{Context, Result}; use llama_rs::{convert::convert_pth_to_ggml, InferenceError}; use rustyline::error::ReadlineError; @@ -23,7 +23,7 @@ fn main() -> Result<()> { Args::Repl(args) => interactive(&args, false)?, Args::ChatExperimental(args) => interactive(&args, true)?, Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.file_type.into()), - Args::Quantize(args) => quantize(&args), + Args::Quantize(args) => quantize(&args)?, } Ok(()) @@ -185,16 +185,42 @@ fn interactive( Ok(()) } -fn quantize(args: &cli_args::Quantize) { - llama_rs::quantize::quantize( +fn quantize(args: &cli_args::Quantize) -> Result<()> { + use llama_rs::quantize::{quantize, QuantizeProgress::*}; + quantize( &args.source, &args.destination, - llama_rs::ElementType::Q4_0, - |p| { - println!("{p:?}"); + args.target.into(), + |progress| match progress { + HyperparametersLoaded(_) => log::info!("Loaded hyperparameters"), + TensorLoading { + name, + dims, + element_type, + n_elements, + } => log::info!( + "Loading tensor `{name}` ({n_elements} ({dims:?}) {element_type} elements)" + ), + TensorQuantizing { name } => log::info!("Quantizing tensor `{name}`"), + TensorQuantized { + name, + original_size, + reduced_size, + history, + } => log::info!( + "Quantized tensor `{name}` from {original_size} to {reduced_size} bytes ({history:?})" + ), + TensorSkipped { name, size } => log::info!("Skipped tensor `{name}` ({size} bytes)"), + Finished { + original_size, + reduced_size, + history, + } => log::info!( + "Finished quantization from {original_size} to {reduced_size} bytes ({history:?})" + ), }, ) - .unwrap(); + .wrap_err("failed to quantize model") } fn load_prompt_file_with_prompt( diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs index 993d4e5c..aec7377c 100644 --- a/llama-rs/src/loader2.rs +++ b/llama-rs/src/loader2.rs @@ -72,8 +72,7 @@ pub(crate) fn load( total_parts: 1, }); - let mut loader = Loader::new(n_context_tokens, prefer_mmap, load_progress_callback); - let use_mmap = loader.mmap_active(); + let mut loader = Loader::new(n_context_tokens, load_progress_callback); ggml_format::load_model(&mut reader, &mut loader) .map_err(|err| LoadError::from_format_error(err, path.clone()))?; @@ -83,12 +82,15 @@ pub(crate) fn load( vocabulary, tensors, mut load_progress_callback, + container_type, .. } = loader; let Hyperparameters { n_embd, n_mult, .. } = hyperparameters; let n_ff = ((2 * (4 * n_embd) / 3 + n_mult - 1) / n_mult) * n_mult; + let use_mmap = prefer_mmap && container_type.support_mmap(); + let ctx_size = tensors .values() .map(|ti| { @@ -192,23 +194,21 @@ pub(crate) fn load( Ok(model) } -struct Loader { +pub(crate) struct Loader { // Input n_ctx: usize, - prefer_mmap: bool, load_progress_callback: F, // Output - container_type: ContainerType, - hyperparameters: Hyperparameters, - vocabulary: Vocabulary, - tensors: HashMap, + pub(crate) container_type: ContainerType, + pub(crate) hyperparameters: Hyperparameters, + pub(crate) vocabulary: Vocabulary, + pub(crate) tensors: HashMap, } impl Loader { - fn new(n_ctx: usize, prefer_mmap: bool, load_progress_callback: F) -> Self { + pub(crate) fn new(n_ctx: usize, load_progress_callback: F) -> Self { Self { n_ctx, - prefer_mmap, load_progress_callback, container_type: ContainerType::Ggjt, @@ -217,10 +217,6 @@ impl Loader { tensors: HashMap::default(), } } - - fn mmap_active(&mut self) -> bool { - self.prefer_mmap && self.container_type.support_mmap() - } } impl ggml_format::LoadHandler for Loader { fn container_type(&mut self, container_type: ContainerType) -> Result<(), LoadError> { diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs index 47d39a53..541313ac 100644 --- a/llama-rs/src/quantize.rs +++ b/llama-rs/src/quantize.rs @@ -1,23 +1,17 @@ //! Implements quantization of weights. -use crate::{loader::read_string, FileType, Hyperparameters, LoadError, Vocabulary}; -use ggml::{ - quantize_q4_0, quantize_q4_1, Type, FILE_MAGIC_GGMF, FILE_MAGIC_UNVERSIONED, FORMAT_VERSION, -}; -use ggml_format::{ - util::{read_i32, rw_bytes_with_len, rw_f32, rw_i32, rw_u32}, - ContainerType, -}; +use crate::{loader2::Loader, Hyperparameters, LoadError, LoadProgress}; +use ggml_format::{util::write_i32, SaveError, SaveHandler, TensorData, TensorInfo}; use half::f16; use std::{ + collections::HashMap, fs::File, - io::{BufReader, BufWriter, Read, Write}, + io::{BufReader, BufWriter, Write}, path::{Path, PathBuf}, + sync::Arc, }; use thiserror::Error; -const FTYPE_STR: [&str; 4] = ["f32", "f16", "q4_0", "q4_1"]; - #[derive(Clone, Debug)] /// Progress of quantization. @@ -29,29 +23,36 @@ pub enum QuantizeProgress<'a> { /// Name of the tensor. name: &'a str, /// Size of the tensor. - size: [i32; 2], + dims: [usize; 2], /// Type of the tensor. - ftype: &'a str, + element_type: ggml::Type, /// Number of elements in the tensor. - elements: i32, + n_elements: usize, }, /// A tensor is being quantized. - Quantizing, + TensorQuantizing { + /// Name of the tensor. + name: &'a str, + }, /// A tensor has been quantized. - Quantized { + TensorQuantized { + /// Name of the tensor. + name: &'a str, /// The original size of the tensor. - original_size: f32, + original_size: usize, /// The reduced size of the tensor. - reduced_size: f32, + reduced_size: usize, /// The history of the quantization. history: Vec, }, /// A tensor has been skipped. - Skipped { - /// The original size of the tensor. - size: f32, + TensorSkipped { + /// Name of the tensor. + name: &'a str, + /// The original size (in bytes) of the tensor data. + size: usize, }, - /// A model is being quantized. + /// A model has been quantized. Finished { /// The original size of the model. original_size: f32, @@ -70,7 +71,7 @@ pub enum QuantizeError { Load(#[from] LoadError), #[error("non-specific I/O error")] /// A non-specific IO error. - IO(#[from] std::io::Error), + Io(#[from] std::io::Error), #[error("could not convert bytes to a UTF-8 string")] /// One of the strings encountered was not valid UTF-8. InvalidUtf8(#[from] std::string::FromUtf8Error), @@ -85,255 +86,269 @@ pub enum QuantizeError { /// The path that failed. path: PathBuf, }, + /// An invariant was broken. + /// + /// This error is not relevant unless `loader2` is being used. + #[error("invariant broken: {invariant} in {path:?}")] + InvariantBroken { + /// The path that failed. + path: PathBuf, + /// The invariant that was broken. + invariant: String, + }, + /// Attempted to quantize to an invalid target. + #[error("invalid quantization target {element_type:?}")] + InvalidQuantizationTarget { + /// The quantization target. + element_type: ggml::Type, + }, + /// The quantization process encountered an unsupported element type. + #[error("unsupported element type {element_type:?}")] + UnsupportedElementType { + /// The element type. + element_type: ggml::Type, + }, +} +impl QuantizeError { + pub(crate) fn from_format_error(value: SaveError, path: PathBuf) -> Self { + match value { + SaveError::Io(io) => QuantizeError::Io(io), + SaveError::InvalidIntegerConversion(e) => QuantizeError::InvalidIntegerConversion(e), + SaveError::ImplementationError(e) => e, + SaveError::InvariantBroken(invariant) => { + QuantizeError::InvariantBroken { path, invariant } + } + } + } } /// Quantizes a model. pub fn quantize( - file_name_in: impl AsRef, - file_name_out: impl AsRef, - ty: crate::ElementType, + path_in: impl AsRef, + path_out: impl AsRef, + desired_type: ggml::Type, progress_callback: impl Fn(QuantizeProgress), ) -> Result<(), QuantizeError> { - let itype: i32 = match ty { - Type::Q4_0 => 2, - Type::Q4_1 => 3, - _ => todo!("Unsupported quantization format. This should be an error."), - }; - - let file_in = file_name_in.as_ref(); - let mut finp = BufReader::new(File::open(file_in).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: file_in.to_owned(), - })?); - - let file_out = file_name_out.as_ref(); - let mut fout = - BufWriter::new( - File::create(file_out).map_err(|e| QuantizeError::CreateFileFailed { - source: e, - path: file_out.to_owned(), - })?, - ); + // Sanity check + if !matches!(desired_type, ggml::Type::Q4_0 | ggml::Type::Q4_1) { + return Err(QuantizeError::InvalidQuantizationTarget { + element_type: desired_type, + }); + } - // Verify magic - { - let magic = rw_u32(&mut finp, &mut fout)?; - if magic == FILE_MAGIC_UNVERSIONED { - todo!("Unversioned files are not supported yet") - } - if magic != FILE_MAGIC_GGMF { - return Err(LoadError::InvalidMagic { - path: file_in.to_owned(), - magic, - } - .into()); - } + // Load the model + let progress_callback = Arc::new(progress_callback); - let format_version = rw_u32(&mut finp, &mut fout)?; - if format_version != FORMAT_VERSION { - return Err(LoadError::InvalidFormatVersion { - container_type: ContainerType::Ggmf, - version: format_version, + let path_in = path_in.as_ref(); + let mut file_in = File::open(path_in).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: path_in.to_owned(), + })?; + let mut reader = BufReader::new(&file_in); + let mut loader = Loader::new(0, { + let progress_callback = progress_callback.clone(); + move |p| { + if let LoadProgress::HyperparametersLoaded(h) = p { + progress_callback(QuantizeProgress::HyperparametersLoaded(h)) } - .into()); } - } + }); + ggml_format::load_model(&mut reader, &mut loader) + .map_err(|err| LoadError::from_format_error(err, path_in.to_owned()))?; + + // Save the quantized model, quantizing as we go + let Loader { + hyperparameters, + vocabulary, + tensors, + .. + } = loader; + + let vocabulary = vocabulary + .id_to_token + .iter() + .cloned() + .zip(vocabulary.id_to_token_score) + .collect::>(); + + let path_out = path_out.as_ref(); + let mut writer = BufWriter::new(File::create(path_out)?); + let mut saver = QuantizeSaver::new( + desired_type, + &hyperparameters, + &tensors, + &mut file_in, + |p| progress_callback(p), + ); + ggml_format::save_model( + &mut writer, + &mut saver, + &vocabulary, + &tensors.keys().cloned().collect::>(), + ) + .map_err(|err| QuantizeError::from_format_error(err, path_out.to_owned()))?; + + // Final report + let sum_all: i64 = saver.history_all.iter().sum(); + progress_callback(QuantizeProgress::Finished { + original_size: saver.total_size_original as f32 / 1024.0 / 1024.0, + reduced_size: saver.total_size_new as f32 / 1024.0 / 1024.0, + history: saver + .history_all + .iter() + .map(|hist| *hist as f32 / sum_all as f32) + .collect(), + }); - let mut hparams = Hyperparameters::default(); + Ok(()) +} - // Load parameters - { - hparams.n_vocab = rw_i32(&mut finp, &mut fout)?.try_into()?; - hparams.n_embd = rw_i32(&mut finp, &mut fout)?.try_into()?; - hparams.n_mult = rw_i32(&mut finp, &mut fout)?.try_into()?; - hparams.n_head = rw_i32(&mut finp, &mut fout)?.try_into()?; - hparams.n_layer = rw_i32(&mut finp, &mut fout)?.try_into()?; - hparams.n_rot = rw_i32(&mut finp, &mut fout)?.try_into()?; - let ftype = rw_i32(&mut finp, &mut fout)?; - hparams.file_type = - FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?; - fout.write_all(&ftype.to_le_bytes())?; - fout.write_all(&itype.to_le_bytes())?; +struct QuantizeSaver<'a, F: Fn(QuantizeProgress)> { + // Input + quantization_type: ggml::Type, + hyperparameters: &'a Hyperparameters, + tensors: &'a HashMap, + source_file: &'a mut File, + progress_callback: F, + + // Output + total_size_original: usize, + total_size_new: usize, + history_all: Vec, +} +impl<'a, F: Fn(QuantizeProgress)> QuantizeSaver<'a, F> { + fn new( + quantization_type: ggml::Type, + hyperparameters: &'a Hyperparameters, + tensors: &'a HashMap, + source_file: &'a mut File, + progress_callback: F, + ) -> Self { + Self { + quantization_type, + hyperparameters, + tensors, + source_file, + progress_callback, + + total_size_original: 0, + total_size_new: 0, + history_all: vec![0; 16], + } } - - progress_callback(QuantizeProgress::HyperparametersLoaded(&hparams)); - - // load vocab - let mut vocab = Vocabulary { - id_to_token: vec![], - id_to_token_score: vec![], - token_to_id: Default::default(), - max_token_length: 0, - }; - - for i in 0..hparams.n_vocab { - let len = rw_u32(&mut finp, &mut fout)?.try_into()?; - let word = rw_bytes_with_len(&mut finp, &mut fout, len)?; - let score = rw_f32(&mut finp, &mut fout)?; - - vocab.token_to_id.insert(word.clone(), i.try_into()?); - vocab.id_to_token.push(word); - vocab.id_to_token_score.push(score); +} +impl SaveHandler for QuantizeSaver<'_, F> { + fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), QuantizeError> { + let h = self.hyperparameters; + write_i32(writer, h.n_vocab.try_into()?)?; + write_i32(writer, h.n_embd.try_into()?)?; + write_i32(writer, h.n_mult.try_into()?)?; + write_i32(writer, h.n_head.try_into()?)?; + write_i32(writer, h.n_layer.try_into()?)?; + write_i32(writer, h.n_rot.try_into()?)?; + write_i32(writer, h.file_type.into())?; + Ok(()) } - // Load weights - { - let mut total_size_org: usize = 0; - let mut total_size_new: usize = 0; - - let mut work: Vec = vec![]; - - let mut data_u8: Vec = vec![]; - let mut data_f16: Vec = vec![]; - let mut data_f32: Vec = vec![]; - - let mut hist_all: Vec = vec![0; 16]; - - loop { - let n_dims: i32; - if let Ok(r) = read_i32(&mut finp) { - n_dims = r; - } else { - break; - } - - let length: usize; - if let Ok(r) = read_i32(&mut finp) { - length = r as usize; - } else { - break; - } - - let mut ftype: i32; - if let Ok(r) = read_i32(&mut finp) { - ftype = r; - } else { - break; - } + fn tensor_data(&mut self, tensor_name: &str) -> Result { + let tensor = self.tensors.get(tensor_name).expect( + "tensor not found; should be impossible due to handler being populated from loader", + ); - let mut nelements = 1i32; - let mut ne = [1i32, 1i32]; - for i in 0..n_dims { - ne[i as usize] = read_i32(&mut finp)?; - nelements *= ne[i as usize]; - } + (self.progress_callback)(QuantizeProgress::TensorLoading { + name: tensor_name, + dims: tensor.dims, + n_elements: tensor.n_elements, + element_type: tensor.element_type, + }); - let name = read_string(&mut finp, length)?; + // Quantize only 2D tensors + let quantize = tensor_name.contains("weight") && tensor.n_dims == 2; + let raw_data = tensor.read_data(&mut BufReader::new(&mut self.source_file))?; - progress_callback(QuantizeProgress::TensorLoading { - name: &name, - size: ne, - elements: nelements, - ftype: FTYPE_STR[ftype as usize], + if quantize && !matches!(tensor.element_type, ggml::Type::F32 | ggml::Type::F16) { + return Err(QuantizeError::UnsupportedElementType { + element_type: tensor.element_type, }); + } - // Quantize only 2D tensors - let quantize = name.contains("weight") && n_dims == 2; - - if quantize { - if ftype != 0 && ftype != 1 { - return Err(LoadError::UnsupportedElementType { - ftype, - tensor_name: name, - path: file_in.to_owned(), - } - .into()); - } - - data_f32.resize(nelements as usize, 0.0); - if ftype == 1 { - data_f16.resize(nelements as usize, 0); - - let mut buffer = vec![0u8; (nelements * 2) as usize]; - finp.read_exact(&mut buffer)?; - // Compute buffer - for (index, chunk) in buffer.chunks(2).enumerate() { - let i = u16::from_le_bytes([chunk[0], chunk[1]]); - data_f16[index] = i; - - //data_f32[index] = ggml_fp16_to_fp32(i); - data_f32[index] = f16::from_bits(i).to_f32(); - } - } else { - let mut buffer = vec![0u8; (nelements * 4) as usize]; - finp.read_exact(&mut buffer)?; - - for (index, chunk) in buffer.chunks(4).enumerate() { - data_f32[index] = - f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); - } - } - - ftype = itype; - } else { - // Determines the total bytes were dealing with - let bpe = (nelements * if ftype == 0 { 4 } else { 2 }) as usize; - - data_u8.resize(bpe, 0); - finp.read_exact(&mut data_u8)?; - } - - // Write data - fout.write_all(&n_dims.to_le_bytes())?; - fout.write_all(&(length as i32).to_le_bytes())?; - fout.write_all(&ftype.to_le_bytes())?; - - for i in 0..n_dims { - fout.write_all(&ne[i as usize].to_le_bytes())?; + self.total_size_original += raw_data.len(); + + let (element_type, data) = if quantize { + (self.progress_callback)(QuantizeProgress::TensorQuantizing { name: tensor_name }); + + let data_f32: Vec = match tensor.element_type { + ggml::Type::F32 => raw_data + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap())) + .collect(), + ggml::Type::F16 => raw_data + .chunks_exact(2) + .map(|chunk| { + f16::from_bits(u16::from_le_bytes(chunk.try_into().unwrap())).to_f32() + }) + .collect(), + _ => unreachable!(), + }; + + let mut history_current = vec![0; 16]; + + // A conservative multiplier of 4 is used here. + let mut work = vec![0u8; tensor.n_elements * 4]; + let curr_size = match self.quantization_type { + ggml::Type::Q4_0 => unsafe { + ggml::quantize_q4_0( + &data_f32, + &mut work, + tensor.n_elements, + tensor.dims[0], + &mut history_current, + ) + }, + ggml::Type::Q4_1 => unsafe { + ggml::quantize_q4_1( + &data_f32, + &mut work, + tensor.n_elements, + tensor.dims[0], + &mut history_current, + ) + }, + _ => unreachable!(), + }; + + let mut history_new = vec![]; + for (i, val) in history_current.iter().enumerate() { + self.history_all[i] += val; + history_new.push(*val as f32 / tensor.n_elements as f32); } - fout.write_all(name.as_bytes())?; - - if quantize { - progress_callback(QuantizeProgress::Quantizing); - work.resize(nelements as usize, 0.0); - - let mut hist_cur = vec![0; 16]; - let curr_size = if matches!(ty, crate::ElementType::Q4_0) { - unsafe { quantize_q4_0(&data_f32, &mut work, nelements, ne[0], &mut hist_cur) } - } else { - unsafe { quantize_q4_1(&data_f32, &mut work, nelements, ne[0], &mut hist_cur) } - }; + let new_data = &work[0..curr_size]; - // We divide curr size by 4 since size refers to bytes - for i in work.iter().take(curr_size / 4) { - fout.write_all(&i.to_le_bytes())?; - } - - total_size_new += curr_size; + (self.progress_callback)(QuantizeProgress::TensorQuantized { + name: tensor_name, + original_size: raw_data.len(), + reduced_size: new_data.len(), + history: history_new, + }); - let mut new_hist = vec![]; - for (i, val) in hist_cur.iter().enumerate() { - hist_all[i] += val; - new_hist.push(*val as f32 / nelements as f32); - } + self.total_size_new += new_data.len(); - progress_callback(QuantizeProgress::Quantized { - original_size: nelements as f32 * 4.0 / 1024.0 / 1024.0, - reduced_size: curr_size as f32 / 1024.0 / 1024.0, - history: new_hist, - }); - } else { - fout.write_all(&data_u8)?; - progress_callback(QuantizeProgress::Skipped { - size: data_u8.len() as f32 / 1024.0 / 1024.0, - }); - total_size_new += data_u8.len(); - } - - total_size_org += (nelements * 4) as usize; - } - - let sum_all: i64 = hist_all.iter().sum(); - progress_callback(QuantizeProgress::Finished { - original_size: total_size_org as f32 / 1024.0 / 1024.0, - reduced_size: total_size_new as f32 / 1024.0 / 1024.0, - history: hist_all - .iter() - .map(|hist| *hist as f32 / sum_all as f32) - .collect(), + (self.quantization_type, new_data.to_owned()) + } else { + (self.progress_callback)(QuantizeProgress::TensorSkipped { + name: tensor_name, + size: raw_data.len(), + }); + self.total_size_new += raw_data.len(); + (tensor.element_type, raw_data) + }; + + Ok(TensorData { + n_dims: tensor.n_dims, + dims: tensor.dims, + element_type, + data, }) } - - Ok(()) } From d968bfa3892a0cd9e08ce10d99a7f436aa1063b2 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 25 Apr 2023 23:04:37 +0200 Subject: [PATCH 16/16] feat(ggml): make quantizatin safe --- ggml/src/lib.rs | 85 ++++++++++++++++++++-------------------- llama-rs/src/quantize.rs | 37 +++++------------ 2 files changed, 52 insertions(+), 70 deletions(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 06bdf64f..6d8905f8 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -9,7 +9,7 @@ //! All [Tensor]s are nodes in this computational graph, and values cannot be retrieved until computation is completed. use std::{ - ffi::c_void, + os::raw::{c_int, c_void}, ptr::NonNull, sync::{Arc, Weak}, }; @@ -272,7 +272,7 @@ impl Context { pub unsafe fn op_map_unary( &self, a: &Tensor, - fun: unsafe extern "C" fn(cnt: ::std::os::raw::c_int, dst: *mut f32, src: *const f32), + fun: unsafe extern "C" fn(cnt: c_int, dst: *mut f32, src: *const f32), ) -> Tensor { let tensor = unsafe { ggml_sys::ggml_map_unary_f32(self.ptr.as_ptr(), a.ptr.as_ptr(), Some(fun)) }; @@ -298,12 +298,7 @@ impl Context { &self, a: &Tensor, b: &Tensor, - fun: unsafe extern "C" fn( - cnt: ::std::os::raw::c_int, - dst: *mut f32, - src0: *const f32, - src1: *const f32, - ), + fun: unsafe extern "C" fn(cnt: c_int, dst: *mut f32, src0: *const f32, src1: *const f32), ) -> Tensor { let tensor = unsafe { ggml_sys::ggml_map_binary_f32( @@ -690,48 +685,52 @@ fn i64_to_usize(val: i64) -> usize { usize::try_from(val).unwrap() } +/// Contains the result of a quantization operation. +pub struct QuantizationResult { + /// The quantized output. + pub output: Vec, + /// The quantization history. + pub history: Vec, +} + /// Quantizes `src` into `dst` using `q4_0` quantization. /// -/// # Safety -/// -/// You must ensure the arrays passed in are of the correct size. -pub unsafe fn quantize_q4_0( - src: &[f32], - dst: &mut [u8], - n: usize, - k: usize, - hist: &mut [i64], -) -> usize { - unsafe { - ggml_sys::ggml_quantize_q4_0( - src.as_ptr(), - dst.as_mut_ptr() as *mut c_void, - n.try_into().unwrap(), - k.try_into().unwrap(), - hist.as_mut_ptr(), - ) - } +/// You must ensure that `src.len() == n_elements`, and `n_elements_0` +/// is the first dimension of `src`. +pub fn quantize_q4_0(src: &[f32], n_elements: usize, n_elements_0: usize) -> QuantizationResult { + quantize_impl(src, n_elements, n_elements_0, ggml_sys::ggml_quantize_q4_0) } /// Quantizes `src` into `dst` using `q4_1` quantization. /// -/// # Safety -/// -/// You must ensure the arrays passed in are of the correct size. -pub unsafe fn quantize_q4_1( +/// You must ensure that `src.len() == n_elements`, and `n_elements_0` +/// is the first dimension of `src`. +pub fn quantize_q4_1(src: &[f32], n_elements: usize, n_elements_0: usize) -> QuantizationResult { + quantize_impl(src, n_elements, n_elements_0, ggml_sys::ggml_quantize_q4_1) +} + +fn quantize_impl( src: &[f32], - dst: &mut [u8], - n: usize, - k: usize, - hist: &mut [i64], -) -> usize { - unsafe { - ggml_sys::ggml_quantize_q4_1( + n_elements: usize, + n_elements_0: usize, + quantizer: unsafe extern "C" fn(*const f32, *mut c_void, c_int, c_int, *mut i64) -> usize, +) -> QuantizationResult { + assert_eq!(src.len(), n_elements); + assert_eq!(n_elements % n_elements_0, 0); + + // A conservative multiplier of 4 is used here. + let mut output = vec![0u8; n_elements * 4]; + let mut history = vec![0i64; 16]; + let output_size = unsafe { + quantizer( src.as_ptr(), - dst.as_mut_ptr() as *mut c_void, - n.try_into().unwrap(), - k.try_into().unwrap(), - hist.as_mut_ptr(), + output.as_mut_ptr() as *mut c_void, + n_elements.try_into().unwrap(), + n_elements_0.try_into().unwrap(), + history.as_mut_ptr(), ) - } + }; + + output.resize(output_size, 0u8); + QuantizationResult { output, history } } diff --git a/llama-rs/src/quantize.rs b/llama-rs/src/quantize.rs index 541313ac..dd7ec58b 100644 --- a/llama-rs/src/quantize.rs +++ b/llama-rs/src/quantize.rs @@ -291,40 +291,23 @@ impl SaveHandler for QuantizeSaver<'_, F _ => unreachable!(), }; - let mut history_current = vec![0; 16]; - - // A conservative multiplier of 4 is used here. - let mut work = vec![0u8; tensor.n_elements * 4]; - let curr_size = match self.quantization_type { - ggml::Type::Q4_0 => unsafe { - ggml::quantize_q4_0( - &data_f32, - &mut work, - tensor.n_elements, - tensor.dims[0], - &mut history_current, - ) - }, - ggml::Type::Q4_1 => unsafe { - ggml::quantize_q4_1( - &data_f32, - &mut work, - tensor.n_elements, - tensor.dims[0], - &mut history_current, - ) - }, + let result = match self.quantization_type { + ggml::Type::Q4_0 => { + ggml::quantize_q4_0(&data_f32, tensor.n_elements, tensor.dims[0]) + } + ggml::Type::Q4_1 => { + ggml::quantize_q4_1(&data_f32, tensor.n_elements, tensor.dims[0]) + } _ => unreachable!(), }; + let new_data = result.output; let mut history_new = vec![]; - for (i, val) in history_current.iter().enumerate() { + for (i, val) in result.history.iter().enumerate() { self.history_all[i] += val; history_new.push(*val as f32 / tensor.n_elements as f32); } - let new_data = &work[0..curr_size]; - (self.progress_callback)(QuantizeProgress::TensorQuantized { name: tensor_name, original_size: raw_data.len(), @@ -334,7 +317,7 @@ impl SaveHandler for QuantizeSaver<'_, F self.total_size_new += new_data.len(); - (self.quantization_type, new_data.to_owned()) + (self.quantization_type, new_data) } else { (self.progress_callback)(QuantizeProgress::TensorSkipped { name: tensor_name,