Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions llama-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ llama-rs = { path = "../llama-rs", features = ["convert"] }

rand = { workspace = true }

bincode = "1.3.3"
clap = { version = "4.1.8", features = ["derive"] }
env_logger = "0.10.0"
log = "0.4"
Expand Down
51 changes: 5 additions & 46 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use std::{convert::Infallible, io::Write, path::Path};
use std::{convert::Infallible, io::Write};

use clap::Parser;
use cli_args::Args;
use llama_rs::{
convert::convert_pth_to_ggml, InferenceError, InferenceSession, InferenceSessionParameters,
Model,
};
use llama_rs::{convert::convert_pth_to_ggml, InferenceError};
use rustyline::error::ReadlineError;

mod cli_args;
Expand All @@ -31,7 +28,7 @@ fn infer(args: &cli_args::Infer) {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref());
let inference_session_params = args.generate.inference_session_parameters();
let (model, vocabulary) = args.model_load.load();
let (mut session, session_loaded) = load_session_from_disk(
let (mut session, session_loaded) = snapshot::read_or_create_session(
&model,
args.persist_session.as_deref(),
args.generate.load_session.as_deref(),
Expand Down Expand Up @@ -70,18 +67,7 @@ fn infer(args: &cli_args::Infer) {

if let Some(session_path) = args.save_session.as_ref().or(args.persist_session.as_ref()) {
// Write the memory to the cache file
// SAFETY: no other model functions used inside the block
unsafe {
match snapshot::write_to_disk(&session.get_snapshot(), session_path) {
Ok(_) => {
log::info!("Successfully wrote session to {session_path:?}");
}
Err(err) => {
log::error!("Could not write session at {session_path:?}: {err}");
std::process::exit(1);
}
}
}
snapshot::write_session(session, session_path);
}
}

Expand Down Expand Up @@ -121,7 +107,7 @@ fn interactive(
let prompt_file = args.prompt_file.contents();
let inference_session_params = args.generate.inference_session_parameters();
let (model, vocabulary) = args.model_load.load();
let (mut session, session_loaded) = load_session_from_disk(
let (mut session, session_loaded) = snapshot::read_or_create_session(
&model,
None,
args.generate.load_session.as_deref(),
Expand Down Expand Up @@ -209,33 +195,6 @@ fn load_prompt_file_with_prompt(
}
}

pub fn load_session_from_disk(
model: &Model,
persist_session: Option<&Path>,
load_session: Option<&Path>,
inference_session_params: InferenceSessionParameters,
) -> (InferenceSession, bool) {
fn load_snapshot_from_disk(model: &Model, path: &Path) -> InferenceSession {
let snapshot = snapshot::load_from_disk(path);
match snapshot.and_then(|snapshot| model.session_from_snapshot(snapshot)) {
Ok(session) => {
log::info!("Loaded inference session from {path:?}");
session
}
Err(err) => {
eprintln!("Could not load inference session. Error: {err}");
std::process::exit(1);
}
}
}

match (persist_session, load_session) {
(Some(path), _) if path.exists() => (load_snapshot_from_disk(model, path), true),
(_, Some(path)) => (load_snapshot_from_disk(model, path), true),
_ => (model.start_session(inference_session_params), false),
}
}

fn process_prompt(raw_prompt: &str, prompt: &str) -> String {
raw_prompt.replace("{{PROMPT}}", prompt)
}
71 changes: 56 additions & 15 deletions llama-cli/src/snapshot.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,68 @@
use llama_rs::{InferenceSnapshot, InferenceSnapshotRef, SnapshotError};
use llama_rs::{InferenceSession, InferenceSessionParameters, Model};
use std::{
error::Error,
fs::File,
io::{BufReader, BufWriter},
path::Path,
};
use zstd::zstd_safe::CompressionLevel;
use zstd::{
stream::{read::Decoder, write::Encoder},
zstd_safe::CompressionLevel,
};

const SNAPSHOT_COMPRESSION_LEVEL: CompressionLevel = 1;

pub fn load_from_disk(path: impl AsRef<Path>) -> Result<InferenceSnapshot, SnapshotError> {
let mut reader = zstd::stream::read::Decoder::new(BufReader::new(File::open(path.as_ref())?))?;
InferenceSnapshot::read(&mut reader)
pub fn read_or_create_session(
model: &Model,
persist_session: Option<&Path>,
load_session: Option<&Path>,
inference_session_params: InferenceSessionParameters,
) -> (InferenceSession, bool) {
fn load(model: &Model, path: &Path) -> InferenceSession {
let file = unwrap_or_exit(File::open(path), || format!("Could not open file {path:?}"));
let decoder = unwrap_or_exit(Decoder::new(BufReader::new(file)), || {
format!("Could not create decoder for {path:?}")
});
let snapshot = unwrap_or_exit(bincode::deserialize_from(decoder), || {
format!("Could not deserialize inference session from {path:?}")
});
let session = unwrap_or_exit(model.session_from_snapshot(snapshot), || {
format!("Could not convert snapshot from {path:?} to session")
});
log::info!("Loaded inference session from {path:?}");
session
}

match (persist_session, load_session) {
(Some(path), _) if path.exists() => (load(model, path), true),
(_, Some(path)) => (load(model, path), true),
_ => (model.start_session(inference_session_params), false),
}
}

pub fn write_to_disk(
snap: &InferenceSnapshotRef<'_>,
path: impl AsRef<Path>,
) -> Result<(), SnapshotError> {
let mut writer = zstd::stream::write::Encoder::new(
BufWriter::new(File::create(path.as_ref())?),
SNAPSHOT_COMPRESSION_LEVEL,
)?
.auto_finish();
pub fn write_session(mut session: llama_rs::InferenceSession, path: &Path) {
// SAFETY: the session is consumed here, so nothing else can access it.
let snapshot = unsafe { session.get_snapshot() };
let file = unwrap_or_exit(File::create(path), || {
format!("Could not create file {path:?}")
});
let encoder = unwrap_or_exit(
Encoder::new(BufWriter::new(file), SNAPSHOT_COMPRESSION_LEVEL),
|| format!("Could not create encoder for {path:?}"),
);
unwrap_or_exit(
bincode::serialize_into(encoder.auto_finish(), &snapshot),
|| format!("Could not serialize inference session to {path:?}"),
);
log::info!("Successfully wrote session to {path:?}");
}

snap.write(&mut writer)
fn unwrap_or_exit<T, E: Error>(result: Result<T, E>, error_message: impl Fn() -> String) -> T {
match result {
Ok(t) => t,
Err(err) => {
log::error!("{}. Error: {err}", error_message());
std::process::exit(1);
}
}
}
8 changes: 4 additions & 4 deletions llama-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ rust-version = "1.65"
[dependencies]
ggml = { path = "../ggml" }

rand = { workspace = true }

bytemuck = "1.13.1"
partial_sort = "0.2.0"
thiserror = "1.0"
rand = { workspace = true }
serde = { version = "1.0.156", features = ["derive"] }
serde = { version = "1.0", features = ["derive"] }
serde_bytes = "0.11"
bincode = "1.3.3"

# Used for the `convert` feature
serde_json = { version = "1.0.94", optional = true }
serde_json = { version = "1.0", optional = true }
protobuf = { version = "= 2.14.0", optional = true }
rust_tokenizers = { version = "3.1.2", optional = true }

Expand Down
51 changes: 27 additions & 24 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,13 @@ impl Clone for InferenceSession {
}

#[derive(serde::Serialize, Clone, PartialEq)]
/// A serializable snapshot of the inference process. Can be saved to disk.
// Keep in sync with [InferenceSession] and [InferenceSnapshot]
/// A serializable snapshot of the inference process.
/// Can be created by calling [InferenceSession::get_snapshot].
///
/// If serializing, ensure that your serializer is binary-efficient.
/// This type contains a large array of bytes; traditional textual serializers
/// are likely to serialize this as an array of numbers at extreme cost.
// Keep in sync with [InferenceSession] and [InferenceSnapshot].
pub struct InferenceSnapshotRef<'a> {
/// How many tokens have been stored in the memory so far.
pub npast: usize,
Expand All @@ -151,11 +156,26 @@ pub struct InferenceSnapshotRef<'a> {
#[serde(with = "serde_bytes")]
pub memory_v: &'a [u8],
}
impl InferenceSnapshotRef<'_> {
/// Creates an owned [InferenceSnapshot] from this [InferenceSnapshotRef].
///
/// The [ToOwned] trait is not used due to its blanket implementation for all [Clone] types.
pub fn to_owned(&self) -> InferenceSnapshot {
InferenceSnapshot {
npast: self.npast,
session_params: self.session_params,
tokens: self.tokens.clone(),
last_logits: self.logits.clone(),
memory_k: self.memory_k.to_vec(),
memory_v: self.memory_v.to_vec(),
}
}
}

/// A serializable snapshot of the inference process. Can be restored by calling
/// `Model::restore_from_snapshot`.
/// [Model::session_from_snapshot].
#[derive(serde::Deserialize, Clone, PartialEq)]
// Keep in sync with [InferenceSession] and [InferenceSnapshotRef]
// Keep in sync with [InferenceSession] and [InferenceSnapshotRef].
pub struct InferenceSnapshot {
/// How many tokens have been stored in the memory so far.
pub npast: usize,
Expand Down Expand Up @@ -515,9 +535,6 @@ pub enum SnapshotError {
/// Arbitrary I/O error.
#[error("I/O error while reading or writing snapshot")]
IO(#[from] std::io::Error),
/// Error during the serialization process.
#[error("error during snapshot serialization")]
Serialization(#[from] bincode::Error),
/// Mismatch between the snapshotted memory and the in-memory memory.
#[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")]
MemorySizeMismatch {
Expand Down Expand Up @@ -551,10 +568,10 @@ pub enum InferenceError {
#[derive(Default, Debug, Clone)]
pub struct EvaluateOutputRequest {
/// Returns all the logits for the provided batch of tokens.
/// Output shape is n_batch * n_vocab
/// Output shape is `n_batch * n_vocab`.
pub all_logits: Option<Vec<f32>>,
/// Returns the embeddings for the provided batch of tokens
/// Output shape is n_batch * n_embd
/// Output shape is `n_batch * n_embd`.
pub embeddings: Option<Vec<f32>>,
}

Expand Down Expand Up @@ -1387,7 +1404,7 @@ impl Model {
session.n_past += input_tokens.len();
}

/// Hydrates a previously obtained InferenceSnapshot for this model
/// Hydrates a previously obtained InferenceSnapshot for this model.
pub fn session_from_snapshot(
&self,
snapshot: InferenceSnapshot,
Expand Down Expand Up @@ -1665,20 +1682,6 @@ impl InferenceSession {
}
}

impl<'a> InferenceSnapshotRef<'a> {
/// Write this snapshot to the given writer.
pub fn write(&self, writer: &mut impl std::io::Write) -> Result<(), SnapshotError> {
Ok(bincode::serialize_into(writer, &self)?)
}
}

impl InferenceSnapshot {
/// Read a snapshot from the given reader.
pub fn read(reader: &mut impl std::io::Read) -> Result<Self, SnapshotError> {
Ok(bincode::deserialize_from(reader)?)
}
}

impl Vocabulary {
// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
/// Tokenize a `text` with this vocabulary.
Expand Down