From d601a12b30a6f9cbf21202b3486d40105ac16cf5 Mon Sep 17 00:00:00 2001 From: Philpax Date: Fri, 7 Apr 2023 20:34:30 +0200 Subject: [PATCH 1/3] refactor(llama): remove bincode --- Cargo.lock | 2 +- llama-cli/Cargo.toml | 1 + llama-cli/src/main.rs | 51 +++------------------------- llama-cli/src/snapshot.rs | 71 ++++++++++++++++++++++++++++++--------- llama-rs/Cargo.toml | 8 ++--- llama-rs/src/lib.rs | 17 ---------- 6 files changed, 67 insertions(+), 83 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ac1053c..f46addca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -460,6 +460,7 @@ checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" name = "llama-cli" version = "0.1.0" dependencies = [ + "bincode", "clap", "env_logger", "llama-rs", @@ -476,7 +477,6 @@ dependencies = [ name = "llama-rs" version = "0.1.0" dependencies = [ - "bincode", "bytemuck", "ggml", "partial_sort", diff --git a/llama-cli/Cargo.toml b/llama-cli/Cargo.toml index 28dc4f7d..2eff43b7 100644 --- a/llama-cli/Cargo.toml +++ b/llama-cli/Cargo.toml @@ -10,6 +10,7 @@ llama-rs = { path = "../llama-rs", features = ["convert"] } rand = { workspace = true } +bincode = "1.3.3" clap = { version = "4.1.8", features = ["derive"] } env_logger = "0.10.0" log = "0.4" diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 631cf733..63c8b0fe 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -1,11 +1,8 @@ -use std::{convert::Infallible, io::Write, path::Path}; +use std::{convert::Infallible, io::Write}; use clap::Parser; use cli_args::Args; -use llama_rs::{ - convert::convert_pth_to_ggml, InferenceError, InferenceSession, InferenceSessionParameters, - Model, -}; +use llama_rs::{convert::convert_pth_to_ggml, InferenceError}; use rustyline::error::ReadlineError; mod cli_args; @@ -31,7 +28,7 @@ fn infer(args: &cli_args::Infer) { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); let inference_session_params = args.generate.inference_session_parameters(); let (model, vocabulary) = args.model_load.load(); - let (mut session, session_loaded) = load_session_from_disk( + let (mut session, session_loaded) = snapshot::read_or_create_session( &model, args.persist_session.as_deref(), args.generate.load_session.as_deref(), @@ -70,18 +67,7 @@ fn infer(args: &cli_args::Infer) { if let Some(session_path) = args.save_session.as_ref().or(args.persist_session.as_ref()) { // Write the memory to the cache file - // SAFETY: no other model functions used inside the block - unsafe { - match snapshot::write_to_disk(&session.get_snapshot(), session_path) { - Ok(_) => { - log::info!("Successfully wrote session to {session_path:?}"); - } - Err(err) => { - log::error!("Could not write session at {session_path:?}: {err}"); - std::process::exit(1); - } - } - } + snapshot::write_session(session, session_path); } } @@ -121,7 +107,7 @@ fn interactive( let prompt_file = args.prompt_file.contents(); let inference_session_params = args.generate.inference_session_parameters(); let (model, vocabulary) = args.model_load.load(); - let (mut session, session_loaded) = load_session_from_disk( + let (mut session, session_loaded) = snapshot::read_or_create_session( &model, None, args.generate.load_session.as_deref(), @@ -209,33 +195,6 @@ fn load_prompt_file_with_prompt( } } -pub fn load_session_from_disk( - model: &Model, - persist_session: Option<&Path>, - load_session: Option<&Path>, - inference_session_params: InferenceSessionParameters, -) -> (InferenceSession, bool) { - fn load_snapshot_from_disk(model: &Model, path: &Path) -> InferenceSession { - let snapshot = snapshot::load_from_disk(path); - match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) { - Ok(session) => { - log::info!("Loaded inference session from {path:?}"); - session - } - Err(err) => { - eprintln!("Could not load inference session. Error: {err}"); - std::process::exit(1); - } - } - } - - match (persist_session, load_session) { - (Some(path), _) if path.exists() => (load_snapshot_from_disk(model, path), true), - (_, Some(path)) => (load_snapshot_from_disk(model, path), true), - _ => (model.start_session(inference_session_params), false), - } -} - fn process_prompt(raw_prompt: &str, prompt: &str) -> String { raw_prompt.replace("{{PROMPT}}", prompt) } diff --git a/llama-cli/src/snapshot.rs b/llama-cli/src/snapshot.rs index 0bbfec42..5107aebc 100644 --- a/llama-cli/src/snapshot.rs +++ b/llama-cli/src/snapshot.rs @@ -1,27 +1,68 @@ -use llama_rs::{InferenceSnapshot, InferenceSnapshotRef, SnapshotError}; +use llama_rs::{InferenceSession, InferenceSessionParameters, Model}; use std::{ + error::Error, fs::File, io::{BufReader, BufWriter}, path::Path, }; -use zstd::zstd_safe::CompressionLevel; +use zstd::{ + stream::{read::Decoder, write::Encoder}, + zstd_safe::CompressionLevel, +}; const SNAPSHOT_COMPRESSION_LEVEL: CompressionLevel = 1; -pub fn load_from_disk(path: impl AsRef) -> Result { - let mut reader = zstd::stream::read::Decoder::new(BufReader::new(File::open(path.as_ref())?))?; - InferenceSnapshot::read(&mut reader) +pub fn read_or_create_session( + model: &Model, + persist_session: Option<&Path>, + load_session: Option<&Path>, + inference_session_params: InferenceSessionParameters, +) -> (InferenceSession, bool) { + fn load(model: &Model, path: &Path) -> InferenceSession { + let file = unwrap_or_exit(File::open(path), || format!("Could not open file {path:?}")); + let decoder = unwrap_or_exit(Decoder::new(BufReader::new(file)), || { + format!("Could not create decoder for {path:?}") + }); + let snapshot = unwrap_or_exit(bincode::deserialize_from(decoder), || { + format!("Could not deserialize inference session from {path:?}") + }); + let session = unwrap_or_exit(model.session_from_snapshot(snapshot), || { + format!("Could not convert snapshot from {path:?} to session") + }); + log::info!("Loaded inference session from {path:?}"); + session + } + + match (persist_session, load_session) { + (Some(path), _) if path.exists() => (load(model, path), true), + (_, Some(path)) => (load(model, path), true), + _ => (model.start_session(inference_session_params), false), + } } -pub fn write_to_disk( - snap: &InferenceSnapshotRef<'_>, - path: impl AsRef, -) -> Result<(), SnapshotError> { - let mut writer = zstd::stream::write::Encoder::new( - BufWriter::new(File::create(path.as_ref())?), - SNAPSHOT_COMPRESSION_LEVEL, - )? - .auto_finish(); +pub fn write_session(mut session: llama_rs::InferenceSession, path: &Path) { + // SAFETY: the session is consumed here, so nothing else can access it. + let snapshot = unsafe { session.get_snapshot() }; + let file = unwrap_or_exit(File::create(path), || { + format!("Could not create file {path:?}") + }); + let encoder = unwrap_or_exit( + Encoder::new(BufWriter::new(file), SNAPSHOT_COMPRESSION_LEVEL), + || format!("Could not create encoder for {path:?}"), + ); + unwrap_or_exit( + bincode::serialize_into(encoder.auto_finish(), &snapshot), + || format!("Could not serialize inference session to {path:?}"), + ); + log::info!("Successfully wrote session to {path:?}"); +} - snap.write(&mut writer) +fn unwrap_or_exit(result: Result, error_message: impl Fn() -> String) -> T { + match result { + Ok(t) => t, + Err(err) => { + log::error!("{}. Error: {err}", error_message()); + std::process::exit(1); + } + } } diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 0e48cd58..076dd7bc 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -9,16 +9,16 @@ rust-version = "1.65" [dependencies] ggml = { path = "../ggml" } +rand = { workspace = true } + bytemuck = "1.13.1" partial_sort = "0.2.0" thiserror = "1.0" -rand = { workspace = true } -serde = { version = "1.0.156", features = ["derive"] } +serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" -bincode = "1.3.3" # Used for the `convert` feature -serde_json = { version = "1.0.94", optional = true } +serde_json = { version = "1.0", optional = true } protobuf = { version = "= 2.14.0", optional = true } rust_tokenizers = { version = "3.1.2", optional = true } diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 14553379..8be6759a 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -515,9 +515,6 @@ pub enum SnapshotError { /// Arbitrary I/O error. #[error("I/O error while reading or writing snapshot")] IO(#[from] std::io::Error), - /// Error during the serialization process. - #[error("error during snapshot serialization")] - Serialization(#[from] bincode::Error), /// Mismatch between the snapshotted memory and the in-memory memory. #[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")] MemorySizeMismatch { @@ -1665,20 +1662,6 @@ impl InferenceSession { } } -impl<'a> InferenceSnapshotRef<'a> { - /// Write this snapshot to the given writer. - pub fn write(&self, writer: &mut impl std::io::Write) -> Result<(), SnapshotError> { - Ok(bincode::serialize_into(writer, &self)?) - } -} - -impl InferenceSnapshot { - /// Read a snapshot from the given reader. - pub fn read(reader: &mut impl std::io::Read) -> Result { - Ok(bincode::deserialize_from(reader)?) - } -} - impl Vocabulary { // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece /// Tokenize a `text` with this vocabulary. From b64db4fa626a4b0100ec6cfbf1bf1439026a1eda Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 13 Apr 2023 00:56:38 +0200 Subject: [PATCH 2/3] docs(lib): improve snapshot docs, etc --- llama-rs/src/lib.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 8be6759a..acdaf75a 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -133,8 +133,13 @@ impl Clone for InferenceSession { } #[derive(serde::Serialize, Clone, PartialEq)] -/// A serializable snapshot of the inference process. Can be saved to disk. -// Keep in sync with [InferenceSession] and [InferenceSnapshot] +/// A serializable snapshot of the inference process. +/// Can be created by calling [InferenceSession::get_snapshot]. +/// +/// If serializing, ensure that your serializer is binary-efficient. +/// This type contains a large array of bytes; traditional textual serializers +/// are likely to serialize this as an array of numbers at extreme cost. +// Keep in sync with [InferenceSession] and [InferenceSnapshot]. pub struct InferenceSnapshotRef<'a> { /// How many tokens have been stored in the memory so far. pub npast: usize, @@ -153,9 +158,9 @@ pub struct InferenceSnapshotRef<'a> { } /// A serializable snapshot of the inference process. Can be restored by calling -/// `Model::restore_from_snapshot`. +/// [Model::session_from_snapshot]. #[derive(serde::Deserialize, Clone, PartialEq)] -// Keep in sync with [InferenceSession] and [InferenceSnapshotRef] +// Keep in sync with [InferenceSession] and [InferenceSnapshotRef]. pub struct InferenceSnapshot { /// How many tokens have been stored in the memory so far. pub npast: usize, @@ -548,10 +553,10 @@ pub enum InferenceError { #[derive(Default, Debug, Clone)] pub struct EvaluateOutputRequest { /// Returns all the logits for the provided batch of tokens. - /// Output shape is n_batch * n_vocab + /// Output shape is `n_batch * n_vocab`. pub all_logits: Option>, /// Returns the embeddings for the provided batch of tokens - /// Output shape is n_batch * n_embd + /// Output shape is `n_batch * n_embd`. pub embeddings: Option>, } @@ -1384,7 +1389,7 @@ impl Model { session.n_past += input_tokens.len(); } - /// Hydrates a previously obtained InferenceSnapshot for this model + /// Hydrates a previously obtained InferenceSnapshot for this model. pub fn session_from_snapshot( &self, snapshot: InferenceSnapshot, From 0b4ab400f96ef175f6572d0f375d9b667b0d5298 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 13 Apr 2023 01:07:45 +0200 Subject: [PATCH 3/3] feat(lib): implement InferenceSnapshotRef::to_owned --- llama-rs/src/lib.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index acdaf75a..2e96a522 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -156,6 +156,21 @@ pub struct InferenceSnapshotRef<'a> { #[serde(with = "serde_bytes")] pub memory_v: &'a [u8], } +impl InferenceSnapshotRef<'_> { + /// Creates an owned [InferenceSnapshot] from this [InferenceSnapshotRef]. + /// + /// The [ToOwned] trait is not used due to its blanket implementation for all [Clone] types. + pub fn to_owned(&self) -> InferenceSnapshot { + InferenceSnapshot { + npast: self.npast, + session_params: self.session_params, + tokens: self.tokens.clone(), + last_logits: self.logits.clone(), + memory_k: self.memory_k.to_vec(), + memory_v: self.memory_v.to_vec(), + } + } +} /// A serializable snapshot of the inference process. Can be restored by calling /// [Model::session_from_snapshot].